forked from mirror/gorm
Fix nested embedded struct, close #3278
This commit is contained in:
parent
9fcc337bd1
commit
dc48e04896
|
@ -301,14 +301,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||
field.Updatable = false
|
||||
field.Readable = false
|
||||
|
||||
cacheStore := schema.cacheStore
|
||||
if _, embedded := schema.cacheStore.Load("embedded_cache_store"); !embedded {
|
||||
cacheStore = &sync.Map{}
|
||||
cacheStore.Store("embedded_cache_store", true)
|
||||
}
|
||||
cacheStore := &sync.Map{}
|
||||
cacheStore.Store(embeddedCacheKey, true)
|
||||
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
|
||||
for _, ef := range field.EmbeddedSchema.Fields {
|
||||
ef.Schema = schema
|
||||
ef.OwnerSchema = field.EmbeddedSchema
|
||||
|
|
|
@ -41,7 +41,7 @@ type AdvancedDataTypeUser struct {
|
|||
}
|
||||
|
||||
type BaseModel struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
ID uint
|
||||
CreatedAt time.Time
|
||||
CreatedBy *int
|
||||
Created *VersionUser `gorm:"foreignKey:CreatedBy"`
|
||||
|
@ -52,7 +52,6 @@ type BaseModel struct {
|
|||
type VersionModel struct {
|
||||
BaseModel
|
||||
Version int
|
||||
CompanyID int
|
||||
}
|
||||
|
||||
type VersionUser struct {
|
||||
|
|
|
@ -212,7 +212,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||
}
|
||||
|
||||
if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded {
|
||||
// parse relations for unidentified fields
|
||||
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType == "" && field.Creatable {
|
||||
if schema.parseRelation(field); schema.err != nil {
|
||||
|
@ -238,6 +238,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return schema, schema.err
|
||||
}
|
||||
|
|
|
@ -162,7 +162,23 @@ func TestCustomizeTableName(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNestedModel(t *testing.T) {
|
||||
if _, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}); err != nil {
|
||||
versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse nested user, got error %v", err)
|
||||
}
|
||||
|
||||
fields := []schema.Field{
|
||||
{Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true},
|
||||
{Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Int, Size: 64},
|
||||
{Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
checkSchemaField(t, versionUser, &f, func(f *schema.Field) {
|
||||
f.Creatable = true
|
||||
f.Updatable = true
|
||||
f.Readable = true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
var embeddedCacheKey = "embedded_cache_store"
|
||||
|
||||
func ParseTagSetting(str string, sep string) map[string]string {
|
||||
settings := map[string]string{}
|
||||
names := strings.Split(str, sep)
|
||||
|
|
|
@ -160,9 +160,9 @@ func TestEmbeddedRelations(t *testing.T) {
|
|||
Advanced bool
|
||||
}
|
||||
|
||||
DB.Debug().Migrator().DropTable(&AdvancedUser{})
|
||||
DB.Migrator().DropTable(&AdvancedUser{})
|
||||
|
||||
if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil {
|
||||
if err := DB.AutoMigrate(&AdvancedUser{}); err != nil {
|
||||
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -76,7 +76,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) {
|
|||
}
|
||||
} else {
|
||||
name := reflect.ValueOf(got).Type().Elem().Name()
|
||||
t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len())
|
||||
t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue