diff --git a/callbacks/preload.go b/callbacks/preload.go index 3614346f..27e3c3dd 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -9,10 +9,9 @@ import ( "gorm.io/gorm/utils" ) -func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { +func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) { var ( reflectValue = db.Statement.ReflectValue - rel = rels[len(rels)-1] tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) relForeignKeys []string relForeignFields []*schema.Field @@ -27,10 +26,6 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { return true }) - if len(rels) > 1 { - reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)-1]) - } - if rel.JoinTable != nil { var joinForeignFields, joinRelForeignFields []*schema.Field var joinForeignKeys []string @@ -97,6 +92,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } + // nested preload + for p, pvs := range preloads { + tx = tx.Preload(p, pvs...) + } + reflectResults := rel.FieldSchema.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) diff --git a/callbacks/query.go b/callbacks/query.go index ebb09d6b..fff46d57 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -8,7 +8,6 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" - "gorm.io/gorm/schema" ) func Query(db *gorm.DB) { @@ -168,48 +167,37 @@ func BuildQuerySQL(db *gorm.DB) { func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { - preloadMap := map[string][]string{} + preloadMap := map[string]map[string][]interface{}{} for name := range db.Statement.Preloads { if name == clause.Associations { for _, rel := range db.Statement.Schema.Relationships.Relations { if rel.Schema == db.Statement.Schema { - preloadMap[rel.Name] = []string{rel.Name} + preloadMap[rel.Name] = nil } } } else { preloadFields := strings.Split(name, ".") - for idx := range preloadFields { - preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + if _, ok := preloadMap[preloadFields[0]]; !ok { + preloadMap[preloadFields[0]] = map[string][]interface{}{} + } + + if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { + preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] } } } - preloadNames := make([]string, len(preloadMap)) - idx := 0 + preloadNames := make([]string, 0, len(preloadMap)) for key := range preloadMap { - preloadNames[idx] = key - idx++ + preloadNames = append(preloadNames, key) } sort.Strings(preloadNames) for _, name := range preloadNames { - var ( - curSchema = db.Statement.Schema - preloadFields = preloadMap[name] - rels = make([]*schema.Relationship, len(preloadFields)) - ) - - for idx, preloadField := range preloadFields { - if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { - rels[idx] = rel - curSchema = rel.FieldSchema - } else { - db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) - } - } - - if db.Error == nil { - preload(db, rels, db.Statement.Preloads[name]) + if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { + preload(db, rel, db.Statement.Preloads[name], preloadMap[name]) + } else { + db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } } }