2020-01-31 07:22:37 +03:00
package schema
import (
2020-05-05 16:28:38 +03:00
"context"
2020-02-02 09:40:44 +03:00
"errors"
2020-02-01 07:46:52 +03:00
"fmt"
2020-01-31 07:22:37 +03:00
"go/ast"
"reflect"
"sync"
2020-02-01 15:18:25 +03:00
2020-06-02 04:16:07 +03:00
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
2020-01-31 07:22:37 +03:00
)
2020-02-02 09:40:44 +03:00
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors . New ( "unsupported data type" )
2020-01-31 07:22:37 +03:00
type Schema struct {
2020-02-23 16:22:35 +03:00
Name string
ModelType reflect . Type
Table string
PrioritizedPrimaryField * Field
DBNames [ ] string
PrimaryFields [ ] * Field
2020-05-25 06:11:09 +03:00
PrimaryFieldDBNames [ ] string
2020-02-23 16:22:35 +03:00
Fields [ ] * Field
FieldsByName map [ string ] * Field
FieldsByDBName map [ string ] * Field
2020-06-05 16:23:20 +03:00
FieldsWithDefaultDBValue [ ] * Field // fields with default value assigned by database
2020-02-23 16:22:35 +03:00
Relationships Relationships
2020-05-19 16:50:06 +03:00
CreateClauses [ ] clause . Interface
QueryClauses [ ] clause . Interface
UpdateClauses [ ] clause . Interface
DeleteClauses [ ] clause . Interface
2020-02-23 16:22:35 +03:00
BeforeCreate , AfterCreate bool
BeforeUpdate , AfterUpdate bool
BeforeDelete , AfterDelete bool
BeforeSave , AfterSave bool
AfterFind bool
err error
2020-11-27 09:32:20 +03:00
initialized chan struct { }
2020-02-23 16:22:35 +03:00
namer Namer
cacheStore * sync . Map
2020-02-01 07:46:52 +03:00
}
func ( schema Schema ) String ( ) string {
2020-02-01 19:03:56 +03:00
if schema . ModelType . Name ( ) == "" {
2021-06-10 05:21:28 +03:00
return fmt . Sprintf ( "%s(%s)" , schema . Name , schema . Table )
2020-02-01 19:03:56 +03:00
}
2021-06-10 05:21:28 +03:00
return fmt . Sprintf ( "%s.%s" , schema . ModelType . PkgPath ( ) , schema . ModelType . Name ( ) )
2020-02-01 07:46:52 +03:00
}
2020-05-18 08:07:11 +03:00
func ( schema Schema ) MakeSlice ( ) reflect . Value {
2020-11-10 13:38:24 +03:00
slice := reflect . MakeSlice ( reflect . SliceOf ( reflect . PtrTo ( schema . ModelType ) ) , 0 , 20 )
2020-05-18 08:07:11 +03:00
results := reflect . New ( slice . Type ( ) )
results . Elem ( ) . Set ( slice )
return results
}
2020-02-01 07:46:52 +03:00
func ( schema Schema ) LookUpField ( name string ) * Field {
if field , ok := schema . FieldsByDBName [ name ] ; ok {
return field
}
if field , ok := schema . FieldsByName [ name ] ; ok {
return field
}
return nil
2020-01-31 07:22:37 +03:00
}
2020-06-02 10:48:19 +03:00
type Tabler interface {
TableName ( ) string
}
2021-05-10 04:51:50 +03:00
// Parse get data type from dialector
2020-02-24 03:51:35 +03:00
func Parse ( dest interface { } , cacheStore * sync . Map , namer Namer ) ( * Schema , error ) {
2020-08-03 05:30:25 +03:00
if dest == nil {
return nil , fmt . Errorf ( "%w: %+v" , ErrUnsupportedDataType , dest )
}
2020-02-24 03:51:35 +03:00
modelType := reflect . ValueOf ( dest ) . Type ( )
2020-06-23 09:38:36 +03:00
for modelType . Kind ( ) == reflect . Slice || modelType . Kind ( ) == reflect . Array || modelType . Kind ( ) == reflect . Ptr {
2020-01-31 07:22:37 +03:00
modelType = modelType . Elem ( )
}
if modelType . Kind ( ) != reflect . Struct {
2020-02-01 07:46:52 +03:00
if modelType . PkgPath ( ) == "" {
2020-02-24 03:51:35 +03:00
return nil , fmt . Errorf ( "%w: %+v" , ErrUnsupportedDataType , dest )
2020-02-01 07:46:52 +03:00
}
2021-06-10 05:21:28 +03:00
return nil , fmt . Errorf ( "%w: %s.%s" , ErrUnsupportedDataType , modelType . PkgPath ( ) , modelType . Name ( ) )
2020-01-31 07:22:37 +03:00
}
if v , ok := cacheStore . Load ( modelType ) ; ok {
2020-11-27 09:32:20 +03:00
s := v . ( * Schema )
2021-05-10 04:51:50 +03:00
// Wait for the initialization of other goroutines to complete
2020-11-27 09:32:20 +03:00
<- s . initialized
2020-12-04 06:28:38 +03:00
return s , s . err
2020-01-31 07:22:37 +03:00
}
2020-06-02 10:48:19 +03:00
modelValue := reflect . New ( modelType )
tableName := namer . TableName ( modelType . Name ( ) )
if tabler , ok := modelValue . Interface ( ) . ( Tabler ) ; ok {
tableName = tabler . TableName ( )
}
2020-09-24 06:32:38 +03:00
if en , ok := namer . ( embeddedNamer ) ; ok {
tableName = en . Table
}
2020-06-02 10:48:19 +03:00
2020-01-31 07:22:37 +03:00
schema := & Schema {
2020-02-01 07:46:52 +03:00
Name : modelType . Name ( ) ,
2020-01-31 07:22:37 +03:00
ModelType : modelType ,
2020-06-02 10:48:19 +03:00
Table : tableName ,
2020-01-31 07:22:37 +03:00
FieldsByName : map [ string ] * Field { } ,
FieldsByDBName : map [ string ] * Field { } ,
2020-02-01 16:48:06 +03:00
Relationships : Relationships { Relations : map [ string ] * Relationship { } } ,
2020-02-01 07:46:52 +03:00
cacheStore : cacheStore ,
2020-02-01 15:18:25 +03:00
namer : namer ,
2020-11-27 09:32:20 +03:00
initialized : make ( chan struct { } ) ,
2020-01-31 07:22:37 +03:00
}
2021-05-10 04:51:50 +03:00
// When the schema initialization is completed, the channel will be closed
defer close ( schema . initialized )
if v , loaded := cacheStore . LoadOrStore ( modelType , schema ) ; loaded {
s := v . ( * Schema )
// Wait for the initialization of other goroutines to complete
<- s . initialized
return s , s . err
}
2020-01-31 07:22:37 +03:00
2020-02-01 07:46:52 +03:00
defer func ( ) {
if schema . err != nil {
2020-05-05 16:28:38 +03:00
logger . Default . Error ( context . Background ( ) , schema . err . Error ( ) )
2020-02-01 07:46:52 +03:00
cacheStore . Delete ( modelType )
2020-01-31 07:22:37 +03:00
}
2020-02-01 07:46:52 +03:00
} ( )
2020-01-31 07:22:37 +03:00
2020-02-01 07:46:52 +03:00
for i := 0 ; i < modelType . NumField ( ) ; i ++ {
if fieldStruct := modelType . Field ( i ) ; ast . IsExported ( fieldStruct . Name ) {
2020-02-01 16:48:06 +03:00
if field := schema . ParseField ( fieldStruct ) ; field . EmbeddedSchema != nil {
2020-02-01 07:46:52 +03:00
schema . Fields = append ( schema . Fields , field . EmbeddedSchema . Fields ... )
2020-02-01 16:48:06 +03:00
} else {
schema . Fields = append ( schema . Fields , field )
2020-01-31 09:31:15 +03:00
}
}
2020-01-31 07:22:37 +03:00
}
for _ , field := range schema . Fields {
2020-02-02 09:40:44 +03:00
if field . DBName == "" && field . DataType != "" {
2020-02-01 19:03:56 +03:00
field . DBName = namer . ColumnName ( schema . Table , field . Name )
2020-01-31 09:31:15 +03:00
}
2020-01-31 07:22:37 +03:00
if field . DBName != "" {
2020-01-31 09:31:15 +03:00
// nonexistence or shortest path or first appear prioritized if has permission
2020-09-09 11:32:29 +03:00
if v , ok := schema . FieldsByDBName [ field . DBName ] ; ! ok || ( ( field . Creatable || field . Updatable || field . Readable ) && len ( field . BindNames ) < len ( v . BindNames ) ) {
2020-02-18 17:56:37 +03:00
if _ , ok := schema . FieldsByDBName [ field . DBName ] ; ! ok {
schema . DBNames = append ( schema . DBNames , field . DBName )
}
2020-01-31 07:22:37 +03:00
schema . FieldsByDBName [ field . DBName ] = field
schema . FieldsByName [ field . Name ] = field
2020-02-01 16:48:06 +03:00
if v != nil && v . PrimaryKey {
for idx , f := range schema . PrimaryFields {
if f == v {
schema . PrimaryFields = append ( schema . PrimaryFields [ 0 : idx ] , schema . PrimaryFields [ idx + 1 : ] ... )
}
}
}
if field . PrimaryKey {
schema . PrimaryFields = append ( schema . PrimaryFields , field )
}
2020-01-31 07:22:37 +03:00
}
}
2020-12-06 06:06:52 +03:00
if of , ok := schema . FieldsByName [ field . Name ] ; ! ok || of . TagSettings [ "-" ] == "-" {
2020-01-31 07:22:37 +03:00
schema . FieldsByName [ field . Name ] = field
}
2020-02-15 11:04:21 +03:00
field . setupValuerAndSetter ( )
2020-01-31 07:22:37 +03:00
}
2020-08-30 15:57:58 +03:00
prioritizedPrimaryField := schema . LookUpField ( "id" )
if prioritizedPrimaryField == nil {
prioritizedPrimaryField = schema . LookUpField ( "ID" )
}
if prioritizedPrimaryField != nil {
if prioritizedPrimaryField . PrimaryKey {
schema . PrioritizedPrimaryField = prioritizedPrimaryField
2020-02-01 16:48:06 +03:00
} else if len ( schema . PrimaryFields ) == 0 {
2020-08-30 15:57:58 +03:00
prioritizedPrimaryField . PrimaryKey = true
schema . PrioritizedPrimaryField = prioritizedPrimaryField
schema . PrimaryFields = append ( schema . PrimaryFields , prioritizedPrimaryField )
2020-01-31 07:22:37 +03:00
}
2020-02-01 16:48:06 +03:00
}
2020-01-31 07:22:37 +03:00
2020-06-05 16:23:20 +03:00
if schema . PrioritizedPrimaryField == nil && len ( schema . PrimaryFields ) == 1 {
schema . PrioritizedPrimaryField = schema . PrimaryFields [ 0 ]
}
2020-05-25 06:11:09 +03:00
for _ , field := range schema . PrimaryFields {
schema . PrimaryFieldDBNames = append ( schema . PrimaryFieldDBNames , field . DBName )
}
2020-06-05 16:23:20 +03:00
for _ , field := range schema . FieldsByDBName {
2020-02-20 05:13:26 +03:00
if field . HasDefaultValue && field . DefaultValueInterface == nil {
2020-06-05 16:23:20 +03:00
schema . FieldsWithDefaultDBValue = append ( schema . FieldsWithDefaultDBValue , field )
2020-02-20 05:13:26 +03:00
}
}
2020-06-05 16:23:20 +03:00
if field := schema . PrioritizedPrimaryField ; field != nil {
2020-07-20 13:59:28 +03:00
switch field . GORMDataType {
2020-02-20 05:13:26 +03:00
case Int , Uint :
2020-07-08 13:50:49 +03:00
if _ , ok := field . TagSettings [ "AUTOINCREMENT" ] ; ! ok {
2020-07-22 14:03:19 +03:00
if ! field . HasDefaultValue || field . DefaultValueInterface != nil {
schema . FieldsWithDefaultDBValue = append ( schema . FieldsWithDefaultDBValue , field )
}
2020-07-08 13:50:49 +03:00
field . HasDefaultValue = true
field . AutoIncrement = true
}
2020-02-20 05:13:26 +03:00
}
}
2020-02-23 16:22:35 +03:00
callbacks := [ ] string { "BeforeCreate" , "AfterCreate" , "BeforeUpdate" , "AfterUpdate" , "BeforeSave" , "AfterSave" , "BeforeDelete" , "AfterDelete" , "AfterFind" }
for _ , name := range callbacks {
2020-06-02 10:48:19 +03:00
if methodValue := modelValue . MethodByName ( name ) ; methodValue . IsValid ( ) {
2020-02-23 16:22:35 +03:00
switch methodValue . Type ( ) . String ( ) {
2020-05-31 18:55:56 +03:00
case "func(*gorm.DB) error" : // TODO hack
2020-02-23 16:22:35 +03:00
reflect . Indirect ( reflect . ValueOf ( schema ) ) . FieldByName ( name ) . SetBool ( true )
default :
2021-07-13 11:38:44 +03:00
logger . Default . Warn ( context . Background ( ) , "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html" , schema , name , name )
2020-02-23 16:22:35 +03:00
}
}
}
2020-12-04 06:28:38 +03:00
if _ , embedded := schema . cacheStore . Load ( embeddedCacheKey ) ; ! embedded {
for _ , field := range schema . Fields {
if field . DataType == "" && ( field . Creatable || field . Updatable || field . Readable ) {
if schema . parseRelation ( field ) ; schema . err != nil {
return schema , schema . err
2021-02-07 09:24:11 +03:00
} else {
schema . FieldsByName [ field . Name ] = field
2020-08-18 06:21:40 +03:00
}
2020-12-04 06:28:38 +03:00
}
2020-08-17 11:31:09 +03:00
2020-12-04 06:28:38 +03:00
fieldValue := reflect . New ( field . IndirectFieldType )
2021-07-14 16:45:23 +03:00
fieldInterface := fieldValue . Interface ( )
if fc , ok := fieldInterface . ( CreateClausesInterface ) ; ok {
2020-12-04 06:28:38 +03:00
field . Schema . CreateClauses = append ( field . Schema . CreateClauses , fc . CreateClauses ( field ) ... )
}
2020-08-17 11:31:09 +03:00
2021-07-14 16:45:23 +03:00
if fc , ok := fieldInterface . ( QueryClausesInterface ) ; ok {
2020-12-04 06:28:38 +03:00
field . Schema . QueryClauses = append ( field . Schema . QueryClauses , fc . QueryClauses ( field ) ... )
}
2020-08-17 11:31:09 +03:00
2021-07-14 16:45:23 +03:00
if fc , ok := fieldInterface . ( UpdateClausesInterface ) ; ok {
2020-12-04 06:28:38 +03:00
field . Schema . UpdateClauses = append ( field . Schema . UpdateClauses , fc . UpdateClauses ( field ) ... )
}
2021-07-14 16:45:23 +03:00
if fc , ok := fieldInterface . ( DeleteClausesInterface ) ; ok {
2020-12-04 06:28:38 +03:00
field . Schema . DeleteClauses = append ( field . Schema . DeleteClauses , fc . DeleteClauses ( field ) ... )
2020-08-17 11:31:09 +03:00
}
2020-02-01 07:46:52 +03:00
}
2020-01-31 07:22:37 +03:00
}
2020-02-24 03:51:35 +03:00
return schema , schema . err
2020-01-31 07:22:37 +03:00
}
2020-11-27 09:32:20 +03:00
func getOrParse ( dest interface { } , cacheStore * sync . Map , namer Namer ) ( * Schema , error ) {
modelType := reflect . ValueOf ( dest ) . Type ( )
for modelType . Kind ( ) == reflect . Slice || modelType . Kind ( ) == reflect . Array || modelType . Kind ( ) == reflect . Ptr {
modelType = modelType . Elem ( )
}
if modelType . Kind ( ) != reflect . Struct {
if modelType . PkgPath ( ) == "" {
return nil , fmt . Errorf ( "%w: %+v" , ErrUnsupportedDataType , dest )
}
2021-06-10 05:21:28 +03:00
return nil , fmt . Errorf ( "%w: %s.%s" , ErrUnsupportedDataType , modelType . PkgPath ( ) , modelType . Name ( ) )
2020-11-27 09:32:20 +03:00
}
if v , ok := cacheStore . Load ( modelType ) ; ok {
return v . ( * Schema ) , nil
}
return Parse ( dest , cacheStore , namer )
}