diff --git a/migrator/migrator.go b/migrator/migrator.go index 3e5d86d3..d50159dd 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -120,6 +120,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } return nil }); err != nil { + fmt.Println(err) return err } } diff --git a/schema/field.go b/schema/field.go index 4eb95b98..1ca4cb6d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -62,6 +62,7 @@ type Field struct { TagSettings map[string]string Schema *Schema EmbeddedSchema *Schema + OwnerSchema *Schema ReflectValueOf func(reflect.Value) reflect.Value ValueOf func(reflect.Value) (value interface{}, zero bool) Set func(reflect.Value, interface{}) error @@ -321,6 +322,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema + ef.OwnerSchema = field.EmbeddedSchema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) // index is negative means is pointer if field.FieldType.Kind() == reflect.Struct { diff --git a/schema/relationship.go b/schema/relationship.go index b7ab4f66..93080105 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -5,6 +5,7 @@ import ( "reflect" "regexp" "strings" + "sync" "github.com/jinzhu/inflection" "gorm.io/gorm/clause" @@ -66,9 +67,16 @@ func (schema *Schema) parseRelation(field *Field) { } ) - if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { - schema.err = err - return + if field.OwnerSchema != nil { + if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil { + schema.err = err + return + } + } else { + if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { + schema.err = err + return + } } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { @@ -78,7 +86,7 @@ func (schema *Schema) parseRelation(field *Field) { } else { switch field.IndirectFieldType.Kind() { case reflect.Struct, reflect.Slice: - schema.guessRelation(relation, field, true) + schema.guessRelation(relation, field, guessHas) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) } @@ -316,21 +324,50 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } -func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { +type guessLevel int + +const ( + guessHas guessLevel = iota + guessEmbeddedHas + guessBelongs + guessEmbeddedBelongs +) + +func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { var ( primaryFields, foreignFields []*Field primarySchema, foreignSchema = schema, relation.FieldSchema ) - if !guessHas { - primarySchema, foreignSchema = relation.FieldSchema, schema + reguessOrErr := func(err string, args ...interface{}) { + switch gl { + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + case guessEmbeddedHas: + schema.guessRelation(relation, field, guessBelongs) + case guessBelongs: + schema.guessRelation(relation, field, guessEmbeddedBelongs) + default: + schema.err = fmt.Errorf(err, args...) + } } - reguessOrErr := func(err string, args ...interface{}) { - if guessHas { - schema.guessRelation(relation, field, false) + switch gl { + case guessEmbeddedHas: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } else { - schema.err = fmt.Errorf(err, args...) + reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + return + } + case guessBelongs: + primarySchema, foreignSchema = relation.FieldSchema, schema + case guessEmbeddedBelongs: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema + } else { + reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + return } } @@ -345,8 +382,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH } } else { for _, primaryField := range primarySchema.PrimaryFields { - lookUpName := schema.Name + primaryField.Name - if !guessHas { + lookUpName := primarySchema.Name + primaryField.Name + if gl == guessBelongs { lookUpName = field.Name + primaryField.Name } @@ -358,7 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH } if len(foreignFields) == 0 { - reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas) + reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) return } else if len(relation.primaryKeys) > 0 { for idx, primaryKey := range relation.primaryKeys { @@ -394,11 +431,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], ForeignKey: foreignField, - OwnPrimaryKey: schema == primarySchema && guessHas, + OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas), }) } - if guessHas { + if gl == guessHas || gl == guessEmbeddedHas { relation.Type = "has" } else { relation.Type = BelongsTo diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 84f56165..02765b8c 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -101,8 +101,12 @@ func TestCallbacks(t *testing.T) { results: []string{"c5", "c1", "c2", "c3", "c4"}, }, { - callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}}, - results: []string{"c5", "c1", "c2", "c4", "c3"}, + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c3", "c5", "c1", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c3", "c4"}, }, } diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 7f40a0a4..fb0d6f23 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -7,6 +7,7 @@ import ( "testing" "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) func TestEmbeddedStruct(t *testing.T) { @@ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) { t.Errorf("Failed to create got error %v", err) } } + +func TestEmbeddedRelations(t *testing.T) { + type AdvancedUser struct { + User `gorm:"embedded"` + Advanced bool + } + + DB.Debug().Migrator().DropTable(&AdvancedUser{}) + + if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil { + t.Errorf("Failed to auto migrate advanced user, got error %v", err) + } +}