2020-01-31 07:22:37 +03:00
package schema
import (
2020-05-05 16:28:38 +03:00
"context"
2020-02-02 09:40:44 +03:00
"errors"
2020-02-01 07:46:52 +03:00
"fmt"
2020-01-31 07:22:37 +03:00
"go/ast"
"reflect"
2023-04-11 08:10:38 +03:00
"strings"
2020-01-31 07:22:37 +03:00
"sync"
2020-02-01 15:18:25 +03:00
2020-06-02 04:16:07 +03:00
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
2020-01-31 07:22:37 +03:00
)
2023-10-10 09:50:29 +03:00
type callbackType string
const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)
2020-02-02 09:40:44 +03:00
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors . New ( "unsupported data type" )
2020-01-31 07:22:37 +03:00
type Schema struct {
2020-02-23 16:22:35 +03:00
Name string
ModelType reflect . Type
Table string
PrioritizedPrimaryField * Field
DBNames [ ] string
PrimaryFields [ ] * Field
2020-05-25 06:11:09 +03:00
PrimaryFieldDBNames [ ] string
2020-02-23 16:22:35 +03:00
Fields [ ] * Field
FieldsByName map [ string ] * Field
2023-04-11 08:10:38 +03:00
FieldsByBindName map [ string ] * Field // embedded fields is 'Embed.Field'
2020-02-23 16:22:35 +03:00
FieldsByDBName map [ string ] * Field
2020-06-05 16:23:20 +03:00
FieldsWithDefaultDBValue [ ] * Field // fields with default value assigned by database
2020-02-23 16:22:35 +03:00
Relationships Relationships
2020-05-19 16:50:06 +03:00
CreateClauses [ ] clause . Interface
QueryClauses [ ] clause . Interface
UpdateClauses [ ] clause . Interface
DeleteClauses [ ] clause . Interface
2020-02-23 16:22:35 +03:00
BeforeCreate , AfterCreate bool
BeforeUpdate , AfterUpdate bool
BeforeDelete , AfterDelete bool
BeforeSave , AfterSave bool
AfterFind bool
err error
2020-11-27 09:32:20 +03:00
initialized chan struct { }
2020-02-23 16:22:35 +03:00
namer Namer
cacheStore * sync . Map
2020-02-01 07:46:52 +03:00
}
func ( schema Schema ) String ( ) string {
2020-02-01 19:03:56 +03:00
if schema . ModelType . Name ( ) == "" {
2021-06-10 05:21:28 +03:00
return fmt . Sprintf ( "%s(%s)" , schema . Name , schema . Table )
2020-02-01 19:03:56 +03:00
}
2021-06-10 05:21:28 +03:00
return fmt . Sprintf ( "%s.%s" , schema . ModelType . PkgPath ( ) , schema . ModelType . Name ( ) )
2020-02-01 07:46:52 +03:00
}
2020-05-18 08:07:11 +03:00
func ( schema Schema ) MakeSlice ( ) reflect . Value {
2020-11-10 13:38:24 +03:00
slice := reflect . MakeSlice ( reflect . SliceOf ( reflect . PtrTo ( schema . ModelType ) ) , 0 , 20 )
2020-05-18 08:07:11 +03:00
results := reflect . New ( slice . Type ( ) )
results . Elem ( ) . Set ( slice )
return results
}
2020-02-01 07:46:52 +03:00
func ( schema Schema ) LookUpField ( name string ) * Field {
if field , ok := schema . FieldsByDBName [ name ] ; ok {
return field
}
if field , ok := schema . FieldsByName [ name ] ; ok {
return field
}
return nil
2020-01-31 07:22:37 +03:00
}
2023-04-11 08:10:38 +03:00
// LookUpFieldByBindName looks for the closest field in the embedded struct.
//
// type Struct struct {
// Embedded struct {
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
// }
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
// }
func ( schema Schema ) LookUpFieldByBindName ( bindNames [ ] string , name string ) * Field {
if len ( bindNames ) == 0 {
return nil
}
for i := len ( bindNames ) - 1 ; i >= 0 ; i -- {
find := strings . Join ( bindNames [ : i ] , "." ) + "." + name
if field , ok := schema . FieldsByBindName [ find ] ; ok {
return field
}
}
return nil
}
2020-06-02 10:48:19 +03:00
type Tabler interface {
TableName ( ) string
}
2022-10-07 16:18:37 +03:00
type TablerWithNamer interface {
TableName ( Namer ) string
}
2021-05-10 04:51:50 +03:00
// Parse get data type from dialector
2020-02-24 03:51:35 +03:00
func Parse ( dest interface { } , cacheStore * sync . Map , namer Namer ) ( * Schema , error ) {
2021-10-25 06:26:44 +03:00
return ParseWithSpecialTableName ( dest , cacheStore , namer , "" )
2021-10-12 20:59:28 +03:00
}
2021-10-25 06:26:44 +03:00
// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName ( dest interface { } , cacheStore * sync . Map , namer Namer , specialTableName string ) ( * Schema , error ) {
2020-08-03 05:30:25 +03:00
if dest == nil {
return nil , fmt . Errorf ( "%w: %+v" , ErrUnsupportedDataType , dest )
}
2021-10-12 16:19:08 +03:00
value := reflect . ValueOf ( dest )
if value . Kind ( ) == reflect . Ptr && value . IsNil ( ) {
value = reflect . New ( value . Type ( ) . Elem ( ) )
}
modelType := reflect . Indirect ( value ) . Type ( )
2021-09-17 09:04:19 +03:00
if modelType . Kind ( ) == reflect . Interface {
modelType = reflect . Indirect ( reflect . ValueOf ( dest ) ) . Elem ( ) . Type ( )
}
2020-06-23 09:38:36 +03:00
for modelType . Kind ( ) == reflect . Slice || modelType . Kind ( ) == reflect . Array || modelType . Kind ( ) == reflect . Ptr {
2020-01-31 07:22:37 +03:00
modelType = modelType . Elem ( )
}
if modelType . Kind ( ) != reflect . Struct {
2020-02-01 07:46:52 +03:00
if modelType . PkgPath ( ) == "" {
2020-02-24 03:51:35 +03:00
return nil , fmt . Errorf ( "%w: %+v" , ErrUnsupportedDataType , dest )
2020-02-01 07:46:52 +03:00
}
2021-06-10 05:21:28 +03:00
return nil , fmt . Errorf ( "%w: %s.%s" , ErrUnsupportedDataType , modelType . PkgPath ( ) , modelType . Name ( ) )
2020-01-31 07:22:37 +03:00
}
2021-10-25 06:26:44 +03:00
// Cache the Schema for performance,
// Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey interface { }
if specialTableName != "" {
schemaCacheKey = fmt . Sprintf ( "%p-%s" , modelType , specialTableName )
} else {
schemaCacheKey = modelType
}
2022-07-26 15:01:20 +03:00
// Load exist schema cache, return if exists
2021-10-25 06:26:44 +03:00
if v , ok := cacheStore . Load ( schemaCacheKey ) ; ok {
2020-11-27 09:32:20 +03:00
s := v . ( * Schema )
2021-05-10 04:51:50 +03:00
// Wait for the initialization of other goroutines to complete
2020-11-27 09:32:20 +03:00
<- s . initialized
2020-12-04 06:28:38 +03:00
return s , s . err
2020-01-31 07:22:37 +03:00
}
2020-06-02 10:48:19 +03:00
modelValue := reflect . New ( modelType )
tableName := namer . TableName ( modelType . Name ( ) )
if tabler , ok := modelValue . Interface ( ) . ( Tabler ) ; ok {
tableName = tabler . TableName ( )
}
2022-10-07 16:18:37 +03:00
if tabler , ok := modelValue . Interface ( ) . ( TablerWithNamer ) ; ok {
tableName = tabler . TableName ( namer )
}
2020-09-24 06:32:38 +03:00
if en , ok := namer . ( embeddedNamer ) ; ok {
tableName = en . Table
}
2021-10-25 06:26:44 +03:00
if specialTableName != "" && specialTableName != tableName {
tableName = specialTableName
}
2020-06-02 10:48:19 +03:00
2020-01-31 07:22:37 +03:00
schema := & Schema {
2023-04-11 08:10:38 +03:00
Name : modelType . Name ( ) ,
ModelType : modelType ,
Table : tableName ,
FieldsByName : map [ string ] * Field { } ,
FieldsByBindName : map [ string ] * Field { } ,
FieldsByDBName : map [ string ] * Field { } ,
Relationships : Relationships { Relations : map [ string ] * Relationship { } } ,
cacheStore : cacheStore ,
namer : namer ,
initialized : make ( chan struct { } ) ,
2020-01-31 07:22:37 +03:00
}
2021-05-10 04:51:50 +03:00
// When the schema initialization is completed, the channel will be closed
defer close ( schema . initialized )
2022-07-26 15:01:20 +03:00
// Load exist schema cache, return if exists
2021-10-25 06:26:44 +03:00
if v , ok := cacheStore . Load ( schemaCacheKey ) ; ok {
2021-05-10 04:51:50 +03:00
s := v . ( * Schema )
// Wait for the initialization of other goroutines to complete
<- s . initialized
return s , s . err
}
2020-01-31 07:22:37 +03:00
2020-02-01 07:46:52 +03:00
for i := 0 ; i < modelType . NumField ( ) ; i ++ {
if fieldStruct := modelType . Field ( i ) ; ast . IsExported ( fieldStruct . Name ) {
2020-02-01 16:48:06 +03:00
if field := schema . ParseField ( fieldStruct ) ; field . EmbeddedSchema != nil {
2020-02-01 07:46:52 +03:00
schema . Fields = append ( schema . Fields , field . EmbeddedSchema . Fields ... )
2020-02-01 16:48:06 +03:00
} else {
schema . Fields = append ( schema . Fields , field )
2020-01-31 09:31:15 +03:00
}
}
2020-01-31 07:22:37 +03:00
}
for _ , field := range schema . Fields {
2020-02-02 09:40:44 +03:00
if field . DBName == "" && field . DataType != "" {
2020-02-01 19:03:56 +03:00
field . DBName = namer . ColumnName ( schema . Table , field . Name )
2020-01-31 09:31:15 +03:00
}
2023-04-11 08:10:38 +03:00
bindName := field . BindName ( )
2020-01-31 07:22:37 +03:00
if field . DBName != "" {
2020-01-31 09:31:15 +03:00
// nonexistence or shortest path or first appear prioritized if has permission
2020-09-09 11:32:29 +03:00
if v , ok := schema . FieldsByDBName [ field . DBName ] ; ! ok || ( ( field . Creatable || field . Updatable || field . Readable ) && len ( field . BindNames ) < len ( v . BindNames ) ) {
2020-02-18 17:56:37 +03:00
if _ , ok := schema . FieldsByDBName [ field . DBName ] ; ! ok {
schema . DBNames = append ( schema . DBNames , field . DBName )
}
2020-01-31 07:22:37 +03:00
schema . FieldsByDBName [ field . DBName ] = field
schema . FieldsByName [ field . Name ] = field
2023-04-11 08:10:38 +03:00
schema . FieldsByBindName [ bindName ] = field
2020-02-01 16:48:06 +03:00
if v != nil && v . PrimaryKey {
for idx , f := range schema . PrimaryFields {
if f == v {
schema . PrimaryFields = append ( schema . PrimaryFields [ 0 : idx ] , schema . PrimaryFields [ idx + 1 : ] ... )
}
}
}
if field . PrimaryKey {
schema . PrimaryFields = append ( schema . PrimaryFields , field )
}
2020-01-31 07:22:37 +03:00
}
}
2020-12-06 06:06:52 +03:00
if of , ok := schema . FieldsByName [ field . Name ] ; ! ok || of . TagSettings [ "-" ] == "-" {
2020-01-31 07:22:37 +03:00
schema . FieldsByName [ field . Name ] = field
}
2023-04-11 08:10:38 +03:00
if of , ok := schema . FieldsByBindName [ bindName ] ; ! ok || of . TagSettings [ "-" ] == "-" {
schema . FieldsByBindName [ bindName ] = field
}
2020-02-15 11:04:21 +03:00
field . setupValuerAndSetter ( )
2020-01-31 07:22:37 +03:00
}
2020-08-30 15:57:58 +03:00
prioritizedPrimaryField := schema . LookUpField ( "id" )
if prioritizedPrimaryField == nil {
prioritizedPrimaryField = schema . LookUpField ( "ID" )
}
if prioritizedPrimaryField != nil {
if prioritizedPrimaryField . PrimaryKey {
schema . PrioritizedPrimaryField = prioritizedPrimaryField
2020-02-01 16:48:06 +03:00
} else if len ( schema . PrimaryFields ) == 0 {
2020-08-30 15:57:58 +03:00
prioritizedPrimaryField . PrimaryKey = true
schema . PrioritizedPrimaryField = prioritizedPrimaryField
schema . PrimaryFields = append ( schema . PrimaryFields , prioritizedPrimaryField )
2020-01-31 07:22:37 +03:00
}
2020-02-01 16:48:06 +03:00
}
2020-01-31 07:22:37 +03:00
2023-03-10 11:50:03 +03:00
if schema . PrioritizedPrimaryField == nil {
if len ( schema . PrimaryFields ) == 1 {
schema . PrioritizedPrimaryField = schema . PrimaryFields [ 0 ]
} else if len ( schema . PrimaryFields ) > 1 {
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
for _ , field := range schema . PrimaryFields {
if field . AutoIncrement {
schema . PrioritizedPrimaryField = field
break
}
}
}
2020-06-05 16:23:20 +03:00
}
2020-05-25 06:11:09 +03:00
for _ , field := range schema . PrimaryFields {
schema . PrimaryFieldDBNames = append ( schema . PrimaryFieldDBNames , field . DBName )
}
2021-11-08 15:20:55 +03:00
for _ , field := range schema . Fields {
2022-12-01 15:26:59 +03:00
if field . DataType != "" && field . HasDefaultValue && field . DefaultValueInterface == nil {
2020-06-05 16:23:20 +03:00
schema . FieldsWithDefaultDBValue = append ( schema . FieldsWithDefaultDBValue , field )
2020-02-20 05:13:26 +03:00
}
}
2020-06-05 16:23:20 +03:00
if field := schema . PrioritizedPrimaryField ; field != nil {
2020-07-20 13:59:28 +03:00
switch field . GORMDataType {
2020-02-20 05:13:26 +03:00
case Int , Uint :
2020-07-08 13:50:49 +03:00
if _ , ok := field . TagSettings [ "AUTOINCREMENT" ] ; ! ok {
2020-07-22 14:03:19 +03:00
if ! field . HasDefaultValue || field . DefaultValueInterface != nil {
schema . FieldsWithDefaultDBValue = append ( schema . FieldsWithDefaultDBValue , field )
}
2020-07-08 13:50:49 +03:00
field . HasDefaultValue = true
field . AutoIncrement = true
}
2020-02-20 05:13:26 +03:00
}
}
2023-10-10 09:50:29 +03:00
callbackTypes := [ ] callbackType {
callbackTypeBeforeCreate , callbackTypeAfterCreate ,
callbackTypeBeforeUpdate , callbackTypeAfterUpdate ,
callbackTypeBeforeSave , callbackTypeAfterSave ,
callbackTypeBeforeDelete , callbackTypeAfterDelete ,
callbackTypeAfterFind ,
}
for _ , cbName := range callbackTypes {
if methodValue := callBackToMethodValue ( modelValue , cbName ) ; methodValue . IsValid ( ) {
2020-02-23 16:22:35 +03:00
switch methodValue . Type ( ) . String ( ) {
2020-05-31 18:55:56 +03:00
case "func(*gorm.DB) error" : // TODO hack
2023-10-10 09:50:29 +03:00
reflect . Indirect ( reflect . ValueOf ( schema ) ) . FieldByName ( string ( cbName ) ) . SetBool ( true )
2020-02-23 16:22:35 +03:00
default :
2023-10-10 09:50:29 +03:00
logger . Default . Warn ( context . Background ( ) , "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html" , schema , cbName , cbName )
2020-02-23 16:22:35 +03:00
}
}
}
2021-10-25 06:26:44 +03:00
// Cache the schema
if v , loaded := cacheStore . LoadOrStore ( schemaCacheKey , schema ) ; loaded {
s := v . ( * Schema )
// Wait for the initialization of other goroutines to complete
<- s . initialized
return s , s . err
2021-09-07 10:30:14 +03:00
}
defer func ( ) {
if schema . err != nil {
logger . Default . Error ( context . Background ( ) , schema . err . Error ( ) )
cacheStore . Delete ( modelType )
}
} ( )
2020-12-04 06:28:38 +03:00
if _ , embedded := schema . cacheStore . Load ( embeddedCacheKey ) ; ! embedded {
for _ , field := range schema . Fields {
if field . DataType == "" && ( field . Creatable || field . Updatable || field . Readable ) {
if schema . parseRelation ( field ) ; schema . err != nil {
return schema , schema . err
2021-02-07 09:24:11 +03:00
} else {
schema . FieldsByName [ field . Name ] = field
2023-04-11 08:10:38 +03:00
schema . FieldsByBindName [ field . BindName ( ) ] = field
2020-08-18 06:21:40 +03:00
}
2020-12-04 06:28:38 +03:00
}
2020-08-17 11:31:09 +03:00
2020-12-04 06:28:38 +03:00
fieldValue := reflect . New ( field . IndirectFieldType )
2021-07-14 16:45:23 +03:00
fieldInterface := fieldValue . Interface ( )
if fc , ok := fieldInterface . ( CreateClausesInterface ) ; ok {
2020-12-04 06:28:38 +03:00
field . Schema . CreateClauses = append ( field . Schema . CreateClauses , fc . CreateClauses ( field ) ... )
}
2020-08-17 11:31:09 +03:00
2021-07-14 16:45:23 +03:00
if fc , ok := fieldInterface . ( QueryClausesInterface ) ; ok {
2020-12-04 06:28:38 +03:00
field . Schema . QueryClauses = append ( field . Schema . QueryClauses , fc . QueryClauses ( field ) ... )
}
2020-08-17 11:31:09 +03:00
2021-07-14 16:45:23 +03:00
if fc , ok := fieldInterface . ( UpdateClausesInterface ) ; ok {
2020-12-04 06:28:38 +03:00
field . Schema . UpdateClauses = append ( field . Schema . UpdateClauses , fc . UpdateClauses ( field ) ... )
}
2021-07-14 16:45:23 +03:00
if fc , ok := fieldInterface . ( DeleteClausesInterface ) ; ok {
2020-12-04 06:28:38 +03:00
field . Schema . DeleteClauses = append ( field . Schema . DeleteClauses , fc . DeleteClauses ( field ) ... )
2020-08-17 11:31:09 +03:00
}
2020-02-01 07:46:52 +03:00
}
2020-01-31 07:22:37 +03:00
}
2020-02-24 03:51:35 +03:00
return schema , schema . err
2020-01-31 07:22:37 +03:00
}
2020-11-27 09:32:20 +03:00
2023-10-10 09:50:29 +03:00
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue ( modelType reflect . Value , cbType callbackType ) reflect . Value {
switch cbType {
case callbackTypeBeforeCreate :
return modelType . MethodByName ( string ( callbackTypeBeforeCreate ) )
case callbackTypeAfterCreate :
return modelType . MethodByName ( string ( callbackTypeAfterCreate ) )
case callbackTypeBeforeUpdate :
return modelType . MethodByName ( string ( callbackTypeBeforeUpdate ) )
case callbackTypeAfterUpdate :
return modelType . MethodByName ( string ( callbackTypeAfterUpdate ) )
case callbackTypeBeforeSave :
return modelType . MethodByName ( string ( callbackTypeBeforeSave ) )
case callbackTypeAfterSave :
return modelType . MethodByName ( string ( callbackTypeAfterSave ) )
case callbackTypeBeforeDelete :
return modelType . MethodByName ( string ( callbackTypeBeforeDelete ) )
case callbackTypeAfterDelete :
return modelType . MethodByName ( string ( callbackTypeAfterDelete ) )
case callbackTypeAfterFind :
return modelType . MethodByName ( string ( callbackTypeAfterFind ) )
default :
return reflect . ValueOf ( nil )
}
}
2020-11-27 09:32:20 +03:00
func getOrParse ( dest interface { } , cacheStore * sync . Map , namer Namer ) ( * Schema , error ) {
modelType := reflect . ValueOf ( dest ) . Type ( )
for modelType . Kind ( ) == reflect . Slice || modelType . Kind ( ) == reflect . Array || modelType . Kind ( ) == reflect . Ptr {
modelType = modelType . Elem ( )
}
if modelType . Kind ( ) != reflect . Struct {
if modelType . PkgPath ( ) == "" {
return nil , fmt . Errorf ( "%w: %+v" , ErrUnsupportedDataType , dest )
}
2021-06-10 05:21:28 +03:00
return nil , fmt . Errorf ( "%w: %s.%s" , ErrUnsupportedDataType , modelType . PkgPath ( ) , modelType . Name ( ) )
2020-11-27 09:32:20 +03:00
}
if v , ok := cacheStore . Load ( modelType ) ; ok {
return v . ( * Schema ) , nil
}
return Parse ( dest , cacheStore , namer )
}