mirror of https://github.com/go-gorm/gorm.git
feat: support embedded preload (#6137)
* feat: support embedded preload * fix lint and test * fix test...
This commit is contained in:
parent
4b0da0e97a
commit
828e22b17f
|
@ -3,6 +3,7 @@ package callbacks
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
@ -10,6 +11,98 @@ import (
|
|||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// parsePreloadMap extracts nested preloads. e.g.
|
||||
//
|
||||
// // schema has a "k0" relation and a "k7.k8" embedded relation
|
||||
// parsePreloadMap(schema, map[string][]interface{}{
|
||||
// clause.Associations: {"arg1"},
|
||||
// "k1": {"arg2"},
|
||||
// "k2.k3": {"arg3"},
|
||||
// "k4.k5.k6": {"arg4"},
|
||||
// })
|
||||
// // preloadMap is
|
||||
// map[string]map[string][]interface{}{
|
||||
// "k0": {},
|
||||
// "k7": {
|
||||
// "k8": {},
|
||||
// },
|
||||
// "k1": {},
|
||||
// "k2": {
|
||||
// "k3": {"arg3"},
|
||||
// },
|
||||
// "k4": {
|
||||
// "k5.k6": {"arg4"},
|
||||
// },
|
||||
// }
|
||||
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
setPreloadMap := func(name, value string, args []interface{}) {
|
||||
if _, ok := preloadMap[name]; !ok {
|
||||
preloadMap[name] = map[string][]interface{}{}
|
||||
}
|
||||
if value != "" {
|
||||
preloadMap[name][value] = args
|
||||
}
|
||||
}
|
||||
|
||||
for name, args := range preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
|
||||
if preloadFields[0] == clause.Associations {
|
||||
for _, relation := range s.Relationships.Relations {
|
||||
if relation.Schema == s {
|
||||
setPreloadMap(relation.Name, value, args)
|
||||
}
|
||||
}
|
||||
|
||||
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
|
||||
for _, value := range embeddedValues(embeddedRelations) {
|
||||
setPreloadMap(embedded, value, args)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
setPreloadMap(preloadFields[0], value, args)
|
||||
}
|
||||
}
|
||||
return preloadMap
|
||||
}
|
||||
|
||||
func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
if embeddedRelations == nil {
|
||||
return nil
|
||||
}
|
||||
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||
for _, relation := range embeddedRelations.Relations {
|
||||
// skip first struct name
|
||||
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
|
||||
}
|
||||
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||
names = append(names, embeddedValues(relations)...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
|
||||
if relationships == nil {
|
||||
return nil
|
||||
}
|
||||
preloadMap := parsePreloadMap(s, preloads)
|
||||
for name := range preloadMap {
|
||||
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
|
||||
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if rel := relationships.Relations[name]; rel != nil {
|
||||
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||
var (
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
|
|
|
@ -267,32 +267,7 @@ func Preload(db *gorm.DB) {
|
|||
return
|
||||
}
|
||||
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
for name := range db.Statement.Preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
if preloadFields[0] == clause.Associations {
|
||||
for _, rel := range db.Statement.Schema.Relationships.Relations {
|
||||
if rel.Schema == db.Statement.Schema {
|
||||
if _, ok := preloadMap[rel.Name]; !ok {
|
||||
preloadMap[rel.Name] = map[string][]interface{}{}
|
||||
}
|
||||
|
||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
||||
preloadMap[rel.Name][value] = db.Statement.Preloads[name]
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, ok := preloadMap[preloadFields[0]]; !ok {
|
||||
preloadMap[preloadFields[0]] = map[string][]interface{}{}
|
||||
}
|
||||
|
||||
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
||||
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
|
@ -312,7 +287,9 @@ func Preload(db *gorm.DB) {
|
|||
preloadDB.Statement.Unscoped = db.Statement.Unscoped
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||
if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
|
||||
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
|
||||
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||
|
|
|
@ -89,6 +89,10 @@ type Field struct {
|
|||
NewValuePool FieldNewValuePool
|
||||
}
|
||||
|
||||
func (field *Field) BindName() string {
|
||||
return strings.Join(field.BindNames, ".")
|
||||
}
|
||||
|
||||
// ParseField parses reflect.StructField to Field
|
||||
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
var (
|
||||
|
|
|
@ -27,6 +27,8 @@ type Relationships struct {
|
|||
HasMany []*Relationship
|
||||
Many2Many []*Relationship
|
||||
Relations map[string]*Relationship
|
||||
|
||||
EmbeddedRelations map[string]*Relationships
|
||||
}
|
||||
|
||||
type Relationship struct {
|
||||
|
@ -106,7 +108,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||
}
|
||||
|
||||
if schema.err == nil {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
schema.setRelation(relation)
|
||||
switch relation.Type {
|
||||
case HasOne:
|
||||
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
||||
|
@ -122,6 +124,39 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
|||
return relation
|
||||
}
|
||||
|
||||
func (schema *Schema) setRelation(relation *Relationship) {
|
||||
// set non-embedded relation
|
||||
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
|
||||
if len(rel.Field.BindNames) > 1 {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
} else {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
|
||||
// set embedded relation
|
||||
if len(relation.Field.BindNames) <= 1 {
|
||||
return
|
||||
}
|
||||
relationships := &schema.Relationships
|
||||
for i, name := range relation.Field.BindNames {
|
||||
if i < len(relation.Field.BindNames)-1 {
|
||||
if relationships.EmbeddedRelations == nil {
|
||||
relationships.EmbeddedRelations = map[string]*Relationships{}
|
||||
}
|
||||
if r := relationships.EmbeddedRelations[name]; r == nil {
|
||||
relationships.EmbeddedRelations[name] = &Relationships{}
|
||||
}
|
||||
relationships = relationships.EmbeddedRelations[name]
|
||||
} else {
|
||||
if relationships.Relations == nil {
|
||||
relationships.Relations = map[string]*Relationship{}
|
||||
}
|
||||
relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
||||
//
|
||||
// type User struct {
|
||||
|
@ -166,6 +201,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
|||
}
|
||||
}
|
||||
|
||||
if primaryKeyField == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
|
||||
return
|
||||
}
|
||||
|
||||
// use same data type for foreign keys
|
||||
if copyableDataType(primaryKeyField.DataType) {
|
||||
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
||||
|
@ -443,6 +483,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||
primaryFields = primarySchema.PrimaryFields
|
||||
}
|
||||
|
||||
primaryFieldLoop:
|
||||
for _, primaryField := range primaryFields {
|
||||
lookUpName := primarySchemaName + primaryField.Name
|
||||
if gl == guessBelongs {
|
||||
|
@ -454,11 +495,18 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
|||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
}
|
||||
|
||||
for _, name := range lookUpNames {
|
||||
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
continue primaryFieldLoop
|
||||
}
|
||||
}
|
||||
for _, name := range lookUpNames {
|
||||
if f := foreignSchema.LookUpField(name); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
break
|
||||
continue primaryFieldLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -518,6 +518,132 @@ func TestEmbeddedRelation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedHas(t *testing.T) {
|
||||
type Toy struct {
|
||||
ID int
|
||||
Name string
|
||||
OwnerID int
|
||||
OwnerType string
|
||||
}
|
||||
type User struct {
|
||||
ID int
|
||||
Cat struct {
|
||||
Name string
|
||||
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
} `gorm:"embedded;embeddedPrefix:cat_"`
|
||||
Dog struct {
|
||||
ID int
|
||||
Name string
|
||||
UserID int
|
||||
Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"Cat": {
|
||||
Relations: map[string]Relation{
|
||||
"Toy": {
|
||||
Name: "Toy",
|
||||
Type: schema.HasOne,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
"Toys": {
|
||||
Name: "Toys",
|
||||
Type: schema.HasMany,
|
||||
Schema: "User",
|
||||
FieldSchema: "Toy",
|
||||
Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"},
|
||||
References: []Reference{
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
{ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmbeddedBelongsTo(t *testing.T) {
|
||||
type Country struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
Name string
|
||||
}
|
||||
type Address struct {
|
||||
CountryID int
|
||||
Country Country
|
||||
}
|
||||
type NestedAddress struct {
|
||||
Address
|
||||
}
|
||||
type Org struct {
|
||||
ID int
|
||||
PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||
VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||
AddressID int
|
||||
Address struct {
|
||||
ID int
|
||||
Address
|
||||
}
|
||||
NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||
}
|
||||
|
||||
s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{})
|
||||
if err != nil {
|
||||
t.Errorf("Failed to parse schema, got error %v", err)
|
||||
}
|
||||
|
||||
checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{
|
||||
"PostalAddress": {
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"VisitingAddress": {
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"NestedAddress": {
|
||||
EmbeddedRelations: map[string]EmbeddedRelations{
|
||||
"Address": {
|
||||
Relations: map[string]Relation{
|
||||
"Country": {
|
||||
Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country",
|
||||
References: []Reference{
|
||||
{PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestVariableRelation(t *testing.T) {
|
||||
var result struct {
|
||||
User
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
|
@ -25,6 +26,7 @@ type Schema struct {
|
|||
PrimaryFieldDBNames []string
|
||||
Fields []*Field
|
||||
FieldsByName map[string]*Field
|
||||
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
|
||||
FieldsByDBName map[string]*Field
|
||||
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
||||
Relationships Relationships
|
||||
|
@ -67,6 +69,27 @@ func (schema Schema) LookUpField(name string) *Field {
|
|||
return nil
|
||||
}
|
||||
|
||||
// LookUpFieldByBindName looks for the closest field in the embedded struct.
|
||||
//
|
||||
// type Struct struct {
|
||||
// Embedded struct {
|
||||
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
|
||||
// }
|
||||
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
|
||||
// }
|
||||
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
|
||||
if len(bindNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := len(bindNames) - 1; i >= 0; i-- {
|
||||
find := strings.Join(bindNames[:i], ".") + "." + name
|
||||
if field, ok := schema.FieldsByBindName[find]; ok {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
@ -140,15 +163,16 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||
}
|
||||
|
||||
schema := &Schema{
|
||||
Name: modelType.Name(),
|
||||
ModelType: modelType,
|
||||
Table: tableName,
|
||||
FieldsByName: map[string]*Field{},
|
||||
FieldsByDBName: map[string]*Field{},
|
||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||
cacheStore: cacheStore,
|
||||
namer: namer,
|
||||
initialized: make(chan struct{}),
|
||||
Name: modelType.Name(),
|
||||
ModelType: modelType,
|
||||
Table: tableName,
|
||||
FieldsByName: map[string]*Field{},
|
||||
FieldsByBindName: map[string]*Field{},
|
||||
FieldsByDBName: map[string]*Field{},
|
||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||
cacheStore: cacheStore,
|
||||
namer: namer,
|
||||
initialized: make(chan struct{}),
|
||||
}
|
||||
// When the schema initialization is completed, the channel will be closed
|
||||
defer close(schema.initialized)
|
||||
|
@ -176,6 +200,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
||||
}
|
||||
|
||||
bindName := field.BindName()
|
||||
if field.DBName != "" {
|
||||
// nonexistence or shortest path or first appear prioritized if has permission
|
||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
||||
|
@ -184,6 +209,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||
}
|
||||
schema.FieldsByDBName[field.DBName] = field
|
||||
schema.FieldsByName[field.Name] = field
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
|
||||
if v != nil && v.PrimaryKey {
|
||||
for idx, f := range schema.PrimaryFields {
|
||||
|
@ -202,6 +228,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
||||
schema.FieldsByName[field.Name] = field
|
||||
}
|
||||
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
}
|
||||
|
||||
field.setupValuerAndSetter()
|
||||
}
|
||||
|
@ -293,6 +322,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||
return schema, schema.err
|
||||
} else {
|
||||
schema.FieldsByName[field.Name] = field
|
||||
schema.FieldsByBindName[field.BindName()] = field
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -201,6 +201,37 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
|
|||
})
|
||||
}
|
||||
|
||||
type EmbeddedRelations struct {
|
||||
Relations map[string]Relation
|
||||
EmbeddedRelations map[string]EmbeddedRelations
|
||||
}
|
||||
|
||||
func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) {
|
||||
for name, relations := range actual {
|
||||
rs := expected[name]
|
||||
t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) {
|
||||
if len(relations.Relations) != len(rs.Relations) {
|
||||
t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations))
|
||||
}
|
||||
if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) {
|
||||
t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations))
|
||||
}
|
||||
for n, rel := range relations.Relations {
|
||||
if r, ok := rs.Relations[n]; !ok {
|
||||
t.Errorf("failed to find relation by name %s", n)
|
||||
} else {
|
||||
checkSchemaRelation(t, &schema.Schema{
|
||||
Relationships: schema.Relationships{
|
||||
Relations: map[string]*schema.Relationship{n: rel},
|
||||
},
|
||||
}, r)
|
||||
}
|
||||
}
|
||||
checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
|
||||
for k, v := range values {
|
||||
t.Run("CheckField/"+k, func(t *testing.T) {
|
||||
|
|
|
@ -306,3 +306,141 @@ func TestNestedPreloadWithUnscoped(t *testing.T) {
|
|||
DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID)
|
||||
CheckUserUnscoped(t, *user6, user)
|
||||
}
|
||||
|
||||
func TestEmbedPreload(t *testing.T) {
|
||||
type Country struct {
|
||||
ID int `gorm:"primaryKey"`
|
||||
Name string
|
||||
}
|
||||
type EmbeddedAddress struct {
|
||||
ID int
|
||||
Name string
|
||||
CountryID *int
|
||||
Country *Country
|
||||
}
|
||||
type NestedAddress struct {
|
||||
EmbeddedAddress
|
||||
}
|
||||
type Org struct {
|
||||
ID int
|
||||
PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"`
|
||||
VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"`
|
||||
AddressID *int
|
||||
Address *EmbeddedAddress
|
||||
NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"`
|
||||
}
|
||||
|
||||
DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{})
|
||||
DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{})
|
||||
|
||||
org := Org{
|
||||
PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}},
|
||||
VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}},
|
||||
Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}},
|
||||
NestedAddress: NestedAddress{
|
||||
EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&org).Error; err != nil {
|
||||
t.Errorf("failed to create org, got err: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
preloads map[string][]interface{}
|
||||
expect Org
|
||||
}{
|
||||
{
|
||||
name: "address country",
|
||||
preloads: map[string][]interface{}{"Address.Country": {}},
|
||||
expect: Org{
|
||||
ID: org.ID,
|
||||
PostalAddress: EmbeddedAddress{
|
||||
ID: org.PostalAddress.ID,
|
||||
Name: org.PostalAddress.Name,
|
||||
CountryID: org.PostalAddress.CountryID,
|
||||
Country: nil,
|
||||
},
|
||||
VisitingAddress: EmbeddedAddress{
|
||||
ID: org.VisitingAddress.ID,
|
||||
Name: org.VisitingAddress.Name,
|
||||
CountryID: org.VisitingAddress.CountryID,
|
||||
Country: nil,
|
||||
},
|
||||
AddressID: org.AddressID,
|
||||
Address: org.Address,
|
||||
NestedAddress: NestedAddress{EmbeddedAddress{
|
||||
ID: org.NestedAddress.ID,
|
||||
Name: org.NestedAddress.Name,
|
||||
CountryID: org.NestedAddress.CountryID,
|
||||
Country: nil,
|
||||
}},
|
||||
},
|
||||
}, {
|
||||
name: "postal address country",
|
||||
preloads: map[string][]interface{}{"PostalAddress.Country": {}},
|
||||
expect: Org{
|
||||
ID: org.ID,
|
||||
PostalAddress: org.PostalAddress,
|
||||
VisitingAddress: EmbeddedAddress{
|
||||
ID: org.VisitingAddress.ID,
|
||||
Name: org.VisitingAddress.Name,
|
||||
CountryID: org.VisitingAddress.CountryID,
|
||||
Country: nil,
|
||||
},
|
||||
AddressID: org.AddressID,
|
||||
Address: nil,
|
||||
NestedAddress: NestedAddress{EmbeddedAddress{
|
||||
ID: org.NestedAddress.ID,
|
||||
Name: org.NestedAddress.Name,
|
||||
CountryID: org.NestedAddress.CountryID,
|
||||
Country: nil,
|
||||
}},
|
||||
},
|
||||
}, {
|
||||
name: "nested address country",
|
||||
preloads: map[string][]interface{}{"NestedAddress.EmbeddedAddress.Country": {}},
|
||||
expect: Org{
|
||||
ID: org.ID,
|
||||
PostalAddress: EmbeddedAddress{
|
||||
ID: org.PostalAddress.ID,
|
||||
Name: org.PostalAddress.Name,
|
||||
CountryID: org.PostalAddress.CountryID,
|
||||
Country: nil,
|
||||
},
|
||||
VisitingAddress: EmbeddedAddress{
|
||||
ID: org.VisitingAddress.ID,
|
||||
Name: org.VisitingAddress.Name,
|
||||
CountryID: org.VisitingAddress.CountryID,
|
||||
Country: nil,
|
||||
},
|
||||
AddressID: org.AddressID,
|
||||
Address: nil,
|
||||
NestedAddress: org.NestedAddress,
|
||||
},
|
||||
}, {
|
||||
name: "associations",
|
||||
preloads: map[string][]interface{}{
|
||||
clause.Associations: {},
|
||||
// clause.Associations won’t preload nested associations
|
||||
"Address.Country": {},
|
||||
},
|
||||
expect: org,
|
||||
},
|
||||
}
|
||||
|
||||
DB = DB.Debug()
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
actual := Org{}
|
||||
tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{})
|
||||
for name, args := range test.preloads {
|
||||
tx = tx.Preload(name, args...)
|
||||
}
|
||||
if err := tx.Find(&actual).Error; err != nil {
|
||||
t.Errorf("failed to find org, got err: %v", err)
|
||||
}
|
||||
AssertEqual(t, actual, test.expect)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue