gorm/schema/schema.go

332 lines
10 KiB
Go
Raw Normal View History

2020-01-31 07:22:37 +03:00
package schema
import (
2020-05-05 16:28:38 +03:00
"context"
2020-02-02 09:40:44 +03:00
"errors"
"fmt"
2020-01-31 07:22:37 +03:00
"go/ast"
"reflect"
"sync"
2020-06-02 04:16:07 +03:00
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
2020-01-31 07:22:37 +03:00
)
2020-02-02 09:40:44 +03:00
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type")
2020-01-31 07:22:37 +03:00
type Schema struct {
2020-02-23 16:22:35 +03:00
Name string
ModelType reflect.Type
Table string
PrioritizedPrimaryField *Field
DBNames []string
PrimaryFields []*Field
2020-05-25 06:11:09 +03:00
PrimaryFieldDBNames []string
2020-02-23 16:22:35 +03:00
Fields []*Field
FieldsByName map[string]*Field
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
2020-02-23 16:22:35 +03:00
Relationships Relationships
2020-05-19 16:50:06 +03:00
CreateClauses []clause.Interface
QueryClauses []clause.Interface
UpdateClauses []clause.Interface
DeleteClauses []clause.Interface
2020-02-23 16:22:35 +03:00
BeforeCreate, AfterCreate bool
BeforeUpdate, AfterUpdate bool
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
err error
initialized chan struct{}
2020-02-23 16:22:35 +03:00
namer Namer
cacheStore *sync.Map
}
func (schema Schema) String() string {
if schema.ModelType.Name() == "" {
return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
}
return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
}
2020-05-18 08:07:11 +03:00
func (schema Schema) MakeSlice() reflect.Value {
2020-11-10 13:38:24 +03:00
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
2020-05-18 08:07:11 +03:00
results := reflect.New(slice.Type())
results.Elem().Set(slice)
return results
}
func (schema Schema) LookUpField(name string) *Field {
if field, ok := schema.FieldsByDBName[name]; ok {
return field
}
if field, ok := schema.FieldsByName[name]; ok {
return field
}
return nil
2020-01-31 07:22:37 +03:00
}
type Tabler interface {
TableName() string
}
2021-05-10 04:51:50 +03:00
// Parse get data type from dialector
2020-02-24 03:51:35 +03:00
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
}
// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
value := reflect.ValueOf(dest)
if value.Kind() == reflect.Ptr && value.IsNil() {
value = reflect.New(value.Type().Elem())
}
modelType := reflect.Indirect(value).Type()
2021-09-17 09:04:19 +03:00
if modelType.Kind() == reflect.Interface {
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
}
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
2020-01-31 07:22:37 +03:00
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
2020-02-24 03:51:35 +03:00
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
2020-01-31 07:22:37 +03:00
}
// Cache the Schema for performance,
// Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey interface{}
if specialTableName != "" {
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
} else {
schemaCacheKey = modelType
}
2022-07-26 15:01:20 +03:00
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
2021-05-10 04:51:50 +03:00
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
2020-01-31 07:22:37 +03:00
}
modelValue := reflect.New(modelType)
tableName := namer.TableName(modelType.Name())
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
if specialTableName != "" && specialTableName != tableName {
tableName = specialTableName
}
2020-01-31 07:22:37 +03:00
schema := &Schema{
Name: modelType.Name(),
2020-01-31 07:22:37 +03:00
ModelType: modelType,
Table: tableName,
2020-01-31 07:22:37 +03:00
FieldsByName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
2020-02-01 16:48:06 +03:00
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
2020-01-31 07:22:37 +03:00
}
2021-05-10 04:51:50 +03:00
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
2022-07-26 15:01:20 +03:00
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
2021-05-10 04:51:50 +03:00
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
2020-01-31 07:22:37 +03:00
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
2020-02-01 16:48:06 +03:00
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
2020-02-01 16:48:06 +03:00
} else {
schema.Fields = append(schema.Fields, field)
2020-01-31 09:31:15 +03:00
}
}
2020-01-31 07:22:37 +03:00
}
for _, field := range schema.Fields {
2020-02-02 09:40:44 +03:00
if field.DBName == "" && field.DataType != "" {
field.DBName = namer.ColumnName(schema.Table, field.Name)
2020-01-31 09:31:15 +03:00
}
2020-01-31 07:22:37 +03:00
if field.DBName != "" {
2020-01-31 09:31:15 +03:00
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
2020-02-18 17:56:37 +03:00
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName)
}
2020-01-31 07:22:37 +03:00
schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field
2020-02-01 16:48:06 +03:00
if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields {
if f == v {
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
}
}
}
if field.PrimaryKey {
schema.PrimaryFields = append(schema.PrimaryFields, field)
}
2020-01-31 07:22:37 +03:00
}
}
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
2020-01-31 07:22:37 +03:00
schema.FieldsByName[field.Name] = field
}
2020-02-15 11:04:21 +03:00
field.setupValuerAndSetter()
2020-01-31 07:22:37 +03:00
}
prioritizedPrimaryField := schema.LookUpField("id")
if prioritizedPrimaryField == nil {
prioritizedPrimaryField = schema.LookUpField("ID")
}
if prioritizedPrimaryField != nil {
if prioritizedPrimaryField.PrimaryKey {
schema.PrioritizedPrimaryField = prioritizedPrimaryField
2020-02-01 16:48:06 +03:00
} else if len(schema.PrimaryFields) == 0 {
prioritizedPrimaryField.PrimaryKey = true
schema.PrioritizedPrimaryField = prioritizedPrimaryField
schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
2020-01-31 07:22:37 +03:00
}
2020-02-01 16:48:06 +03:00
}
2020-01-31 07:22:37 +03:00
if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
}
2020-05-25 06:11:09 +03:00
for _, field := range schema.PrimaryFields {
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
}
for _, field := range schema.Fields {
2020-02-20 05:13:26 +03:00
if field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
2020-02-20 05:13:26 +03:00
}
}
if field := schema.PrioritizedPrimaryField; field != nil {
2020-07-20 13:59:28 +03:00
switch field.GORMDataType {
2020-02-20 05:13:26 +03:00
case Int, Uint:
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
field.HasDefaultValue = true
field.AutoIncrement = true
}
case String:
if _, ok := field.TagSettings["PRIMARYKEY"]; !ok {
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
field.HasDefaultValue = true
}
2020-02-20 05:13:26 +03:00
}
}
2020-02-23 16:22:35 +03:00
callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
for _, name := range callbacks {
if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
2020-02-23 16:22:35 +03:00
switch methodValue.Type().String() {
2020-05-31 18:55:56 +03:00
case "func(*gorm.DB) error": // TODO hack
2020-02-23 16:22:35 +03:00
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
2020-02-23 16:22:35 +03:00
}
}
}
// Cache the schema
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
} else {
schema.FieldsByName[field.Name] = field
}
}
fieldValue := reflect.New(field.IndirectFieldType)
fieldInterface := fieldValue.Interface()
if fc, ok := fieldInterface.(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
}
if fc, ok := fieldInterface.(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
}
if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
}
if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
}
}
2020-01-31 07:22:37 +03:00
}
2020-02-24 03:51:35 +03:00
return schema, schema.err
2020-01-31 07:22:37 +03:00
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), nil
}
return Parse(dest, cacheStore, namer)
}