mirror of https://github.com/go-gorm/gorm.git
Refactor nested preloading associations, close #3970
This commit is contained in:
parent
08678106a4
commit
7f198ead0e
|
@ -9,10 +9,9 @@ import (
|
||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
|
||||||
var (
|
var (
|
||||||
reflectValue = db.Statement.ReflectValue
|
reflectValue = db.Statement.ReflectValue
|
||||||
rel = rels[len(rels)-1]
|
|
||||||
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
|
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
|
||||||
relForeignKeys []string
|
relForeignKeys []string
|
||||||
relForeignFields []*schema.Field
|
relForeignFields []*schema.Field
|
||||||
|
@ -27,10 +26,6 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
if len(rels) > 1 {
|
|
||||||
reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1])
|
|
||||||
}
|
|
||||||
|
|
||||||
if rel.JoinTable != nil {
|
if rel.JoinTable != nil {
|
||||||
var joinForeignFields, joinRelForeignFields []*schema.Field
|
var joinForeignFields, joinRelForeignFields []*schema.Field
|
||||||
var joinForeignKeys []string
|
var joinForeignKeys []string
|
||||||
|
@ -97,6 +92,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nested preload
|
||||||
|
for p, pvs := range preloads {
|
||||||
|
tx = tx.Preload(p, pvs...)
|
||||||
|
}
|
||||||
|
|
||||||
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
||||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
"gorm.io/gorm/schema"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Query(db *gorm.DB) {
|
func Query(db *gorm.DB) {
|
||||||
|
@ -168,48 +167,37 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||||
|
|
||||||
func Preload(db *gorm.DB) {
|
func Preload(db *gorm.DB) {
|
||||||
if db.Error == nil && len(db.Statement.Preloads) > 0 {
|
if db.Error == nil && len(db.Statement.Preloads) > 0 {
|
||||||
preloadMap := map[string][]string{}
|
preloadMap := map[string]map[string][]interface{}{}
|
||||||
for name := range db.Statement.Preloads {
|
for name := range db.Statement.Preloads {
|
||||||
if name == clause.Associations {
|
if name == clause.Associations {
|
||||||
for _, rel := range db.Statement.Schema.Relationships.Relations {
|
for _, rel := range db.Statement.Schema.Relationships.Relations {
|
||||||
if rel.Schema == db.Statement.Schema {
|
if rel.Schema == db.Statement.Schema {
|
||||||
preloadMap[rel.Name] = []string{rel.Name}
|
preloadMap[rel.Name] = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
preloadFields := strings.Split(name, ".")
|
preloadFields := strings.Split(name, ".")
|
||||||
for idx := range preloadFields {
|
if _, ok := preloadMap[preloadFields[0]]; !ok {
|
||||||
preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1]
|
preloadMap[preloadFields[0]] = map[string][]interface{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
|
||||||
|
preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
preloadNames := make([]string, len(preloadMap))
|
preloadNames := make([]string, 0, len(preloadMap))
|
||||||
idx := 0
|
|
||||||
for key := range preloadMap {
|
for key := range preloadMap {
|
||||||
preloadNames[idx] = key
|
preloadNames = append(preloadNames, key)
|
||||||
idx++
|
|
||||||
}
|
}
|
||||||
sort.Strings(preloadNames)
|
sort.Strings(preloadNames)
|
||||||
|
|
||||||
for _, name := range preloadNames {
|
for _, name := range preloadNames {
|
||||||
var (
|
if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||||
curSchema = db.Statement.Schema
|
preload(db, rel, db.Statement.Preloads[name], preloadMap[name])
|
||||||
preloadFields = preloadMap[name]
|
|
||||||
rels = make([]*schema.Relationship, len(preloadFields))
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx, preloadField := range preloadFields {
|
|
||||||
if rel := curSchema.Relationships.Relations[preloadField]; rel != nil {
|
|
||||||
rels[idx] = rel
|
|
||||||
curSchema = rel.FieldSchema
|
|
||||||
} else {
|
} else {
|
||||||
db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation))
|
db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Error == nil {
|
|
||||||
preload(db, rels, db.Statement.Preloads[name])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue