Refactor nested preloading associations, close #3970

This commit is contained in:
Jinzhu 2021-01-26 16:33:19 +08:00
parent 08678106a4
commit 7f198ead0e
2 changed files with 20 additions and 32 deletions

View File

@ -9,10 +9,9 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
var ( var (
reflectValue = db.Statement.ReflectValue reflectValue = db.Statement.ReflectValue
rel = rels[len(rels)-1]
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
relForeignKeys []string relForeignKeys []string
relForeignFields []*schema.Field relForeignFields []*schema.Field
@ -27,10 +26,6 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
return true return true
}) })
if len(rels) > 1 {
reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1])
}
if rel.JoinTable != nil { if rel.JoinTable != nil {
var joinForeignFields, joinRelForeignFields []*schema.Field var joinForeignFields, joinRelForeignFields []*schema.Field
var joinForeignKeys []string var joinForeignKeys []string
@ -97,6 +92,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
} }
// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
}
reflectResults := rel.FieldSchema.MakeSlice().Elem() reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)

View File

@ -8,7 +8,6 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/schema"
) )
func Query(db *gorm.DB) { func Query(db *gorm.DB) {
@ -168,48 +167,37 @@ func BuildQuerySQL(db *gorm.DB) {
func Preload(db *gorm.DB) { func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 { if db.Error == nil && len(db.Statement.Preloads) > 0 {
preloadMap := map[string][]string{} preloadMap := map[string]map[string][]interface{}{}
for name := range db.Statement.Preloads { for name := range db.Statement.Preloads {
if name == clause.Associations { if name == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations { for _, rel := range db.Statement.Schema.Relationships.Relations {
if rel.Schema == db.Statement.Schema { if rel.Schema == db.Statement.Schema {
preloadMap[rel.Name] = []string{rel.Name} preloadMap[rel.Name] = nil
} }
} }
} else { } else {
preloadFields := strings.Split(name, ".") preloadFields := strings.Split(name, ".")
for idx := range preloadFields { if _, ok := preloadMap[preloadFields[0]]; !ok {
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] preloadMap[preloadFields[0]] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
} }
} }
} }
preloadNames := make([]string, len(preloadMap)) preloadNames := make([]string, 0, len(preloadMap))
idx := 0
for key := range preloadMap { for key := range preloadMap {
preloadNames[idx] = key preloadNames = append(preloadNames, key)
idx++
} }
sort.Strings(preloadNames) sort.Strings(preloadNames)
for _, name := range preloadNames { for _, name := range preloadNames {
var ( if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
curSchema = db.Statement.Schema preload(db, rel, db.Statement.Preloads[name], preloadMap[name])
preloadFields = preloadMap[name]
rels = make([]*schema.Relationship, len(preloadFields))
)
for idx, preloadField := range preloadFields {
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
rels[idx] = rel
curSchema = rel.FieldSchema
} else { } else {
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}
}
if db.Error == nil {
preload(db, rels, db.Statement.Preloads[name])
} }
} }
} }