diff --git a/callbacks/preload.go b/callbacks/preload.go index ea2570ba..15669c84 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -3,6 +3,7 @@ package callbacks import ( "fmt" "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -10,6 +11,98 @@ import ( "gorm.io/gorm/utils" ) +// parsePreloadMap extracts nested preloads. e.g. +// +// // schema has a "k0" relation and a "k7.k8" embedded relation +// parsePreloadMap(schema, map[string][]interface{}{ +// clause.Associations: {"arg1"}, +// "k1": {"arg2"}, +// "k2.k3": {"arg3"}, +// "k4.k5.k6": {"arg4"}, +// }) +// // preloadMap is +// map[string]map[string][]interface{}{ +// "k0": {}, +// "k7": { +// "k8": {}, +// }, +// "k1": {}, +// "k2": { +// "k3": {"arg3"}, +// }, +// "k4": { +// "k5.k6": {"arg4"}, +// }, +// } +func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { + preloadMap := map[string]map[string][]interface{}{} + setPreloadMap := func(name, value string, args []interface{}) { + if _, ok := preloadMap[name]; !ok { + preloadMap[name] = map[string][]interface{}{} + } + if value != "" { + preloadMap[name][value] = args + } + } + + for name, args := range preloads { + preloadFields := strings.Split(name, ".") + value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".") + if preloadFields[0] == clause.Associations { + for _, relation := range s.Relationships.Relations { + if relation.Schema == s { + setPreloadMap(relation.Name, value, args) + } + } + + for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations { + for _, value := range embeddedValues(embeddedRelations) { + setPreloadMap(embedded, value, args) + } + } + } else { + setPreloadMap(preloadFields[0], value, args) + } + } + return preloadMap +} + +func embeddedValues(embeddedRelations *schema.Relationships) []string { + if embeddedRelations == nil { + return nil + } + names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) + for _, relation := range embeddedRelations.Relations { + // skip first struct name + names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) + } + for _, relations := range embeddedRelations.EmbeddedRelations { + names = append(names, embeddedValues(relations)...) + } + return names +} + +func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error { + if relationships == nil { + return nil + } + preloadMap := parsePreloadMap(s, preloads) + for name := range preloadMap { + if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil { + if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil { + return err + } + } else if rel := relationships.Relations[name]; rel != nil { + if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil { + return err + } + } else { + return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name) + } + } + return nil +} + func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( reflectValue = tx.Statement.ReflectValue diff --git a/callbacks/query.go b/callbacks/query.go index c87f17bc..95db1f0a 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -267,32 +267,7 @@ func Preload(db *gorm.DB) { return } - preloadMap := map[string]map[string][]interface{}{} - for name := range db.Statement.Preloads { - preloadFields := strings.Split(name, ".") - if preloadFields[0] == clause.Associations { - for _, rel := range db.Statement.Schema.Relationships.Relations { - if rel.Schema == db.Statement.Schema { - if _, ok := preloadMap[rel.Name]; !ok { - preloadMap[rel.Name] = map[string][]interface{}{} - } - - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[rel.Name][value] = db.Statement.Preloads[name] - } - } - } - } else { - if _, ok := preloadMap[preloadFields[0]]; !ok { - preloadMap[preloadFields[0]] = map[string][]interface{}{} - } - - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] - } - } - } - + preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads) preloadNames := make([]string, 0, len(preloadMap)) for key := range preloadMap { preloadNames = append(preloadNames, key) @@ -312,7 +287,9 @@ func Preload(db *gorm.DB) { preloadDB.Statement.Unscoped = db.Statement.Unscoped for _, name := range preloadNames { - if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { + if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil { + db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations])) + } else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) diff --git a/schema/field.go b/schema/field.go index 15edab93..b5103d53 100644 --- a/schema/field.go +++ b/schema/field.go @@ -89,6 +89,10 @@ type Field struct { NewValuePool FieldNewValuePool } +func (field *Field) BindName() string { + return strings.Join(field.BindNames, ".") +} + // ParseField parses reflect.StructField to Field func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var ( diff --git a/schema/relationship.go b/schema/relationship.go index b33b94a7..e03dcc52 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -27,6 +27,8 @@ type Relationships struct { HasMany []*Relationship Many2Many []*Relationship Relations map[string]*Relationship + + EmbeddedRelations map[string]*Relationships } type Relationship struct { @@ -106,7 +108,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } if schema.err == nil { - schema.Relationships.Relations[relation.Name] = relation + schema.setRelation(relation) switch relation.Type { case HasOne: schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) @@ -122,6 +124,39 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return relation } +func (schema *Schema) setRelation(relation *Relationship) { + // set non-embedded relation + if rel := schema.Relationships.Relations[relation.Name]; rel != nil { + if len(rel.Field.BindNames) > 1 { + schema.Relationships.Relations[relation.Name] = relation + } + } else { + schema.Relationships.Relations[relation.Name] = relation + } + + // set embedded relation + if len(relation.Field.BindNames) <= 1 { + return + } + relationships := &schema.Relationships + for i, name := range relation.Field.BindNames { + if i < len(relation.Field.BindNames)-1 { + if relationships.EmbeddedRelations == nil { + relationships.EmbeddedRelations = map[string]*Relationships{} + } + if r := relationships.EmbeddedRelations[name]; r == nil { + relationships.EmbeddedRelations[name] = &Relationships{} + } + relationships = relationships.EmbeddedRelations[name] + } else { + if relationships.Relations == nil { + relationships.Relations = map[string]*Relationship{} + } + relationships.Relations[relation.Name] = relation + } + } +} + // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` // // type User struct { @@ -166,6 +201,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi } } + if primaryKeyField == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) + return + } + // use same data type for foreign keys if copyableDataType(primaryKeyField.DataType) { relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType @@ -443,6 +483,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu primaryFields = primarySchema.PrimaryFields } + primaryFieldLoop: for _, primaryField := range primaryFields { lookUpName := primarySchemaName + primaryField.Name if gl == guessBelongs { @@ -454,11 +495,18 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } + for _, name := range lookUpNames { + if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + continue primaryFieldLoop + } + } for _, name := range lookUpNames { if f := foreignSchema.LookUpField(name); f != nil { foreignFields = append(foreignFields, f) primaryFields = append(primaryFields, primaryField) - break + continue primaryFieldLoop } } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 85c45589..732f6f75 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -518,6 +518,132 @@ func TestEmbeddedRelation(t *testing.T) { } } +func TestEmbeddedHas(t *testing.T) { + type Toy struct { + ID int + Name string + OwnerID int + OwnerType string + } + type User struct { + ID int + Cat struct { + Name string + Toy Toy `gorm:"polymorphic:Owner;"` + Toys []Toy `gorm:"polymorphic:Owner;"` + } `gorm:"embedded;embeddedPrefix:cat_"` + Dog struct { + ID int + Name string + UserID int + Toy Toy `gorm:"polymorphic:Owner;"` + Toys []Toy `gorm:"polymorphic:Owner;"` + } + Toys []Toy `gorm:"polymorphic:Owner;"` + } + + s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "Cat": { + Relations: map[string]Relation{ + "Toy": { + Name: "Toy", + Type: schema.HasOne, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + "Toys": { + Name: "Toys", + Type: schema.HasMany, + Schema: "User", + FieldSchema: "Toy", + Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, + References: []Reference{ + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, + }, + }, + }, + }, + }) +} + +func TestEmbeddedBelongsTo(t *testing.T) { + type Country struct { + ID int `gorm:"primaryKey"` + Name string + } + type Address struct { + CountryID int + Country Country + } + type NestedAddress struct { + Address + } + type Org struct { + ID int + PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"` + VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"` + AddressID int + Address struct { + ID int + Address + } + NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` + } + + s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Errorf("Failed to parse schema, got error %v", err) + } + + checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ + "PostalAddress": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + "VisitingAddress": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + "NestedAddress": { + EmbeddedRelations: map[string]EmbeddedRelations{ + "Address": { + Relations: map[string]Relation{ + "Country": { + Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", + References: []Reference{ + {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, + }, + }, + }, + }, + }, + }, + }) +} + func TestVariableRelation(t *testing.T) { var result struct { User diff --git a/schema/schema.go b/schema/schema.go index 17bdb25e..e13a5ed1 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -6,6 +6,7 @@ import ( "fmt" "go/ast" "reflect" + "strings" "sync" "gorm.io/gorm/clause" @@ -25,6 +26,7 @@ type Schema struct { PrimaryFieldDBNames []string Fields []*Field FieldsByName map[string]*Field + FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field' FieldsByDBName map[string]*Field FieldsWithDefaultDBValue []*Field // fields with default value assigned by database Relationships Relationships @@ -67,6 +69,27 @@ func (schema Schema) LookUpField(name string) *Field { return nil } +// LookUpFieldByBindName looks for the closest field in the embedded struct. +// +// type Struct struct { +// Embedded struct { +// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID") +// } +// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") +// } +func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { + if len(bindNames) == 0 { + return nil + } + for i := len(bindNames) - 1; i >= 0; i-- { + find := strings.Join(bindNames[:i], ".") + "." + name + if field, ok := schema.FieldsByBindName[find]; ok { + return field + } + } + return nil +} + type Tabler interface { TableName() string } @@ -140,15 +163,16 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } schema := &Schema{ - Name: modelType.Name(), - ModelType: modelType, - Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, - Relationships: Relationships{Relations: map[string]*Relationship{}}, - cacheStore: cacheStore, - namer: namer, - initialized: make(chan struct{}), + Name: modelType.Name(), + ModelType: modelType, + Table: tableName, + FieldsByName: map[string]*Field{}, + FieldsByBindName: map[string]*Field{}, + FieldsByDBName: map[string]*Field{}, + Relationships: Relationships{Relations: map[string]*Relationship{}}, + cacheStore: cacheStore, + namer: namer, + initialized: make(chan struct{}), } // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) @@ -176,6 +200,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam field.DBName = namer.ColumnName(schema.Table, field.Name) } + bindName := field.BindName() if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { @@ -184,6 +209,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } schema.FieldsByDBName[field.DBName] = field schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { for idx, f := range schema.PrimaryFields { @@ -202,6 +228,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { schema.FieldsByName[field.Name] = field } + if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" { + schema.FieldsByBindName[bindName] = field + } field.setupValuerAndSetter() } @@ -293,6 +322,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return schema, schema.err } else { schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field } } diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 9abaecba..605aa03a 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { }) } +type EmbeddedRelations struct { + Relations map[string]Relation + EmbeddedRelations map[string]EmbeddedRelations +} + +func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) { + for name, relations := range actual { + rs := expected[name] + t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) { + if len(relations.Relations) != len(rs.Relations) { + t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations)) + } + if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) { + t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations)) + } + for n, rel := range relations.Relations { + if r, ok := rs.Relations[n]; !ok { + t.Errorf("failed to find relation by name %s", n) + } else { + checkSchemaRelation(t, &schema.Schema{ + Relationships: schema.Relationships{ + Relations: map[string]*schema.Relationship{n: rel}, + }, + }, r) + } + } + checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations) + }) + } +} + func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { diff --git a/tests/preload_test.go b/tests/preload_test.go index e7223b3e..7304e350 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -306,3 +306,141 @@ func TestNestedPreloadWithUnscoped(t *testing.T) { DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID) CheckUserUnscoped(t, *user6, user) } + +func TestEmbedPreload(t *testing.T) { + type Country struct { + ID int `gorm:"primaryKey"` + Name string + } + type EmbeddedAddress struct { + ID int + Name string + CountryID *int + Country *Country + } + type NestedAddress struct { + EmbeddedAddress + } + type Org struct { + ID int + PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"` + VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"` + AddressID *int + Address *EmbeddedAddress + NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` + } + + DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{}) + DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{}) + + org := Org{ + PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}}, + VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}}, + Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}}, + NestedAddress: NestedAddress{ + EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}}, + }, + } + if err := DB.Create(&org).Error; err != nil { + t.Errorf("failed to create org, got err: %v", err) + } + + tests := []struct { + name string + preloads map[string][]interface{} + expect Org + }{ + { + name: "address country", + preloads: map[string][]interface{}{"Address.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: EmbeddedAddress{ + ID: org.PostalAddress.ID, + Name: org.PostalAddress.Name, + CountryID: org.PostalAddress.CountryID, + Country: nil, + }, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: org.Address, + NestedAddress: NestedAddress{EmbeddedAddress{ + ID: org.NestedAddress.ID, + Name: org.NestedAddress.Name, + CountryID: org.NestedAddress.CountryID, + Country: nil, + }}, + }, + }, { + name: "postal address country", + preloads: map[string][]interface{}{"PostalAddress.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: org.PostalAddress, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: nil, + NestedAddress: NestedAddress{EmbeddedAddress{ + ID: org.NestedAddress.ID, + Name: org.NestedAddress.Name, + CountryID: org.NestedAddress.CountryID, + Country: nil, + }}, + }, + }, { + name: "nested address country", + preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}}, + expect: Org{ + ID: org.ID, + PostalAddress: EmbeddedAddress{ + ID: org.PostalAddress.ID, + Name: org.PostalAddress.Name, + CountryID: org.PostalAddress.CountryID, + Country: nil, + }, + VisitingAddress: EmbeddedAddress{ + ID: org.VisitingAddress.ID, + Name: org.VisitingAddress.Name, + CountryID: org.VisitingAddress.CountryID, + Country: nil, + }, + AddressID: org.AddressID, + Address: nil, + NestedAddress: org.NestedAddress, + }, + }, { + name: "associations", + preloads: map[string][]interface{}{ + clause.Associations: {}, + // clause.Associations won’t preload nested associations + "Address.Country": {}, + }, + expect: org, + }, + } + + DB = DB.Debug() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual := Org{} + tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{}) + for name, args := range test.preloads { + tx = tx.Preload(name, args...) + } + if err := tx.Find(&actual).Error; err != nil { + t.Errorf("failed to find org, got err: %v", err) + } + AssertEqual(t, actual, test.expect) + }) + } +}