package callbacks import ( "fmt" "reflect" "sort" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) // parsePreloadMap extracts nested preloads. e.g. // // // schema has a "k0" relation and a "k7.k8" embedded relation // parsePreloadMap(schema, map[string][]interface{}{ // clause.Associations: {"arg1"}, // "k1": {"arg2"}, // "k2.k3": {"arg3"}, // "k4.k5.k6": {"arg4"}, // }) // // preloadMap is // map[string]map[string][]interface{}{ // "k0": {}, // "k7": { // "k8": {}, // }, // "k1": {}, // "k2": { // "k3": {"arg3"}, // }, // "k4": { // "k5.k6": {"arg4"}, // }, // } func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { preloadMap := map[string]map[string][]interface{}{} setPreloadMap := func(name, value string, args []interface{}) { if _, ok := preloadMap[name]; !ok { preloadMap[name] = map[string][]interface{}{} } if value != "" { preloadMap[name][value] = args } } for name, args := range preloads { preloadFields := strings.Split(name, ".") value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".") if preloadFields[0] == clause.Associations { for _, relation := range s.Relationships.Relations { if relation.Schema == s { setPreloadMap(relation.Name, value, args) } } for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations { for _, value := range embeddedValues(embeddedRelations) { setPreloadMap(embedded, value, args) } } } else { setPreloadMap(preloadFields[0], value, args) } } return preloadMap } func embeddedValues(embeddedRelations *schema.Relationships) []string { if embeddedRelations == nil { return nil } names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) for _, relation := range embeddedRelations.Relations { // skip first struct name names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], ".")) } for _, relations := range embeddedRelations.EmbeddedRelations { names = append(names, embeddedValues(relations)...) } return names } // preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point. // If the current relationship is embedded or joined, current query will be ignored. // //nolint:cyclop func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { preloadMap := parsePreloadMap(db.Statement.Schema, preloads) // avoid random traversal of the map preloadNames := make([]string, 0, len(preloadMap)) for key := range preloadMap { preloadNames = append(preloadNames, key) } sort.Strings(preloadNames) isJoined := func(name string) (joined bool, nestedJoins []string) { for _, join := range joins { if _, ok := relationships.Relations[join]; ok && name == join { joined = true continue } joinNames := strings.SplitN(join, ".", 2) if len(joinNames) == 2 { if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] { joined = true nestedJoins = append(nestedJoins, joinNames[1]) } } } return joined, nestedJoins } for _, name := range preloadNames { if relations := relationships.EmbeddedRelations[name]; relations != nil { if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { return err } } else if rel := relationships.Relations[name]; rel != nil { if joined, nestedJoins := isJoined(name); joined { switch rv := db.Statement.ReflectValue; rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() > 0 { reflectValue := rel.FieldSchema.MakeSlice().Elem() for i := 0; i < rv.Len(); i++ { frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) if frv.Kind() != reflect.Ptr { reflectValue = reflect.Append(reflectValue, frv.Addr()) } else { if frv.IsNil() { continue } reflectValue = reflect.Append(reflectValue, frv) } } tx := preloadDB(db, reflectValue, reflectValue.Interface()) if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { return err } } case reflect.Struct, reflect.Pointer: reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) tx := preloadDB(db, reflectValue, reflectValue.Interface()) if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { return err } default: return gorm.ErrInvalidData } } else { tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) tx.Statement.ReflectValue = db.Statement.ReflectValue tx.Statement.Unscoped = db.Statement.Unscoped if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { return err } } } else { return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name) } } return nil } func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB { tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) db.Statement.Settings.Range(func(k, v interface{}) bool { tx.Statement.Settings.Store(k, v) return true }) if err := tx.Statement.Parse(dest); err != nil { tx.AddError(err) return tx } tx.Statement.ReflectValue = reflectValue tx.Statement.Unscoped = db.Statement.Unscoped return tx } func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( reflectValue = tx.Statement.ReflectValue relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field foreignValues [][]interface{} identityMap = map[string][]reflect.Value{} inlineConds []interface{} ) if rel.JoinTable != nil { var ( joinForeignFields = make([]*schema.Field, 0, len(rel.References)) joinRelForeignFields = make([]*schema.Field, 0, len(rel.References)) joinForeignKeys = make([]string, 0, len(rel.References)) ) for _, ref := range rel.References { if ref.OwnPrimaryKey { joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) joinForeignFields = append(joinForeignFields, ref.ForeignKey) foreignFields = append(foreignFields, ref.PrimaryKey) } else if ref.PrimaryValue != "" { tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } else { joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey) relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) relForeignFields = append(relForeignFields, ref.PrimaryKey) } } joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(joinForeignValues) == 0 { return nil } joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil { return err } // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) joinFieldValues := make([]interface{}, len(joinRelForeignFields)) for i := 0; i < joinResults.Len(); i++ { joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } for idx, field := range joinRelForeignFields { joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { joinKey := utils.ToStringKey(joinFieldValues...) identityMap[joinKey] = append(identityMap[joinKey], results...) } } _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) relForeignFields = append(relForeignFields, ref.ForeignKey) foreignFields = append(foreignFields, ref.PrimaryKey) } else if ref.PrimaryValue != "" { tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } else { relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) relForeignFields = append(relForeignFields, ref.PrimaryKey) foreignFields = append(foreignFields, ref.ForeignKey) } } identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(foreignValues) == 0 { return nil } } // nested preload for p, pvs := range preloads { tx = tx.Preload(p, pvs...) } reflectResults := rel.FieldSchema.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) if len(values) != 0 { for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { tx = fc(tx) } else { inlineConds = append(inlineConds, cond) } } if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil { return err } } fieldValues := make([]interface{}, len(relForeignFields)) // clean up old values before preloading switch reflectValue.Kind() { case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())) } } } for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem) } datas, ok := identityMap[utils.ToStringKey(fieldValues...)] if !ok { return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()) } for _, data := range datas { reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data) if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface())) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())) } else { tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())) } } } } return tx.Error }