diff --git a/callbacks/preload.go b/callbacks/preload.go index cf7a0d2b..09f151c7 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -75,7 +75,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string { names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) for _, relation := range embeddedRelations.Relations { // skip first struct name - names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) + names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], ".")) } for _, relations := range embeddedRelations.EmbeddedRelations { names = append(names, embeddedValues(relations)...) diff --git a/schema/field.go b/schema/field.go index ca2e1148..a16c98ab 100644 --- a/schema/field.go +++ b/schema/field.go @@ -56,6 +56,7 @@ type Field struct { Name string DBName string BindNames []string + EmbeddedBindNames []string DataType DataType GORMDataType DataType PrimaryKey bool @@ -112,6 +113,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { Name: fieldStruct.Name, DBName: tagSetting["COLUMN"], BindNames: []string{fieldStruct.Name}, + EmbeddedBindNames: []string{fieldStruct.Name}, FieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type, StructField: fieldStruct, @@ -403,6 +405,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.Schema = schema ef.OwnerSchema = field.EmbeddedSchema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous { + ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...) + } // index is negative means is pointer if field.FieldType.Kind() == reflect.Struct { ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) diff --git a/schema/relationship.go b/schema/relationship.go index 2e94fc2c..c11918a5 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -150,12 +150,12 @@ func (schema *Schema) setRelation(relation *Relationship) { } // set embedded relation - if len(relation.Field.BindNames) <= 1 { + if len(relation.Field.EmbeddedBindNames) <= 1 { return } relationships := &schema.Relationships - for i, name := range relation.Field.BindNames { - if i < len(relation.Field.BindNames)-1 { + for i, name := range relation.Field.EmbeddedBindNames { + if i < len(relation.Field.EmbeddedBindNames)-1 { if relationships.EmbeddedRelations == nil { relationships.EmbeddedRelations = map[string]*Relationships{} } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 23d79bbb..f1acf2d9 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -121,6 +121,29 @@ func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { }) } +func TestBelongsToWithMixin(t *testing.T) { + type Profile struct { + gorm.Model + Refer string + Name string + } + + type ProfileMixin struct { + Profile Profile `gorm:"References:Refer"` + ProfileRefer int + } + + type User struct { + gorm.Model + ProfileMixin + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", + References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}}, + }) +} + func TestHasOneOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model @@ -776,6 +799,10 @@ func TestEmbeddedBelongsTo(t *testing.T) { type NestedAddress struct { Address } + type CountryMixin struct { + CountryID int + Country Country + } type Org struct { ID int PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"` @@ -786,6 +813,7 @@ func TestEmbeddedBelongsTo(t *testing.T) { Address } NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` + CountryMixin } s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{}) @@ -815,15 +843,11 @@ func TestEmbeddedBelongsTo(t *testing.T) { }, }, "NestedAddress": { - EmbeddedRelations: map[string]EmbeddedRelations{ - "Address": { - Relations: map[string]Relation{ - "Country": { - Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", - References: []Reference{ - {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, - }, - }, + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, }, }, }, diff --git a/tests/preload_test.go b/tests/preload_test.go index 14f94139..5c87534f 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -466,7 +466,7 @@ func TestEmbedPreload(t *testing.T) { }, }, { name: "nested address country", - preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}}, + preloads: map[string][]interface{}{"NestedAddress.Country": {}}, expect: Org{ ID: org.ID, PostalAddress: EmbeddedAddress{