Refactor nested preloading associations, close #3970

This commit is contained in:
Jinzhu 2021-01-26 16:33:19 +08:00
parent 08678106a4
commit 7f198ead0e
2 changed files with 20 additions and 32 deletions

View File

@ -9,10 +9,9 @@ import (
"gorm.io/gorm/utils"
)
func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
var (
reflectValue = db.Statement.ReflectValue
rel = rels[len(rels)-1]
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
relForeignKeys []string
relForeignFields []*schema.Field
@ -27,10 +26,6 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
return true
})
if len(rels) > 1 {
reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1])
}
if rel.JoinTable != nil {
var joinForeignFields, joinRelForeignFields []*schema.Field
var joinForeignKeys []string
@ -97,6 +92,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
}
}
// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
}
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)

View File

@ -8,7 +8,6 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
func Query(db *gorm.DB) {
@ -168,48 +167,37 @@ func BuildQuerySQL(db *gorm.DB) {
func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 {
preloadMap := map[string][]string{}
preloadMap := map[string]map[string][]interface{}{}
for name := range db.Statement.Preloads {
if name == clause.Associations {
for _, rel := range db.Statement.Schema.Relationships.Relations {
if rel.Schema == db.Statement.Schema {
preloadMap[rel.Name] = []string{rel.Name}
preloadMap[rel.Name] = nil
}
}
} else {
preloadFields := strings.Split(name, ".")
for idx := range preloadFields {
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
if _, ok := preloadMap[preloadFields[0]]; !ok {
preloadMap[preloadFields[0]] = map[string][]interface{}{}
}
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
}
}
}
preloadNames := make([]string, len(preloadMap))
idx := 0
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames[idx] = key
idx++
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
for _, name := range preloadNames {
var (
curSchema = db.Statement.Schema
preloadFields = preloadMap[name]
rels = make([]*schema.Relationship, len(preloadFields))
)
for idx, preloadField := range preloadFields {
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
rels[idx] = rel
curSchema = rel.FieldSchema
} else {
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
}
}
if db.Error == nil {
preload(db, rels, db.Statement.Preloads[name])
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
preload(db, rel, db.Statement.Preloads[name], preloadMap[name])
} else {
db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}
}
}