diff --git a/callbacks/preload.go b/callbacks/preload.go index c8dcd05e..112f67f7 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -1,9 +1,196 @@ package callbacks import ( + "reflect" + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/utils" ) -func preload(db *gorm.DB, preloadFields []string, rel *schema.Relationship) { +// getRelationsValue get relations's value from a reflect value +func getRelationsValue(reflectValue reflect.Value, rels []*schema.Relationship) (reflectResults reflect.Value) { + for _, rel := range rels { + reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0) + + appendToResults := func(value reflect.Value) { + if _, isZero := rel.Field.ValueOf(value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + switch result.Kind() { + case reflect.Struct: + reflectResults = reflect.Append(reflectResults, result) + case reflect.Slice, reflect.Array: + for i := 0; i < value.Len(); i++ { + reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i))) + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Struct: + appendToResults(reflectValue) + case reflect.Slice: + for i := 0; i < reflectValue.Len(); i++ { + appendToResults(reflectValue.Index(i)) + } + } + + reflectValue = reflectResults + } + + return +} + +func getIdentityFieldValuesMap(reflectValue reflect.Value, fields []*schema.Field) (map[string][]reflect.Value, [][]interface{}) { + var ( + fieldValues = make([]reflect.Value, len(fields)) + results = [][]interface{}{} + dataResults = map[string][]reflect.Value{} + ) + + switch reflectValue.Kind() { + case reflect.Struct: + results = [][]interface{}{make([]interface{}, len(fields))} + + for idx, field := range fields { + fieldValues[idx] = field.ReflectValueOf(reflectValue) + results[0][idx] = fieldValues[idx].Interface() + } + + dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue} + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + for idx, field := range fields { + fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i)) + } + + dataKey := utils.ToStringKey(fieldValues...) + if _, ok := dataResults[dataKey]; !ok { + result := make([]interface{}, len(fieldValues)) + for idx, fieldValue := range fieldValues { + result[idx] = fieldValue.Interface() + } + results = append(results, result) + + dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} + } else { + dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) + } + } + } + + return dataResults, results +} + +func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}) reflect.Value { + results := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0) + queryValues := make([]interface{}, len(foreignValues)) + if len(foreignKeys) == 1 { + for idx, r := range foreignValues { + queryValues[idx] = r[0] + } + tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Addr().Interface()) + } else { + for idx, r := range foreignValues { + queryValues[idx] = r + } + tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Addr().Interface()) + } + + return results +} + +func preload(tx *gorm.DB, rels []*schema.Relationship, conds []interface{}) { + var ( + reflectValue = tx.Statement.ReflectValue + rel = rels[len(rels)-1] + relForeignKeys []string + relForeignFields []*schema.Field + foreignFields []*schema.Field + foreignValues [][]interface{} + identityMap = map[string][]reflect.Value{} + ) + + if len(rels) > 1 { + reflectValue = getRelationsValue(reflectValue, rels[:len(rels)]) + } + + if rel.JoinTable != nil { + var joinForeignFields, joinRelForeignFields []*schema.Field + var joinForeignKeys []string + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) + joinForeignFields = append(joinForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey) + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + } + } + + joinIdentityMap, joinForeignValues := getIdentityFieldValuesMap(reflectValue, joinForeignFields) + joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues) + + // convert join identity map to relation identity map + fieldValues := make([]reflect.Value, len(foreignFields)) + joinFieldValues := make([]reflect.Value, len(joinForeignFields)) + for i := 0; i < joinResults.Len(); i++ { + for idx, field := range foreignFields { + fieldValues[idx] = field.ReflectValueOf(joinResults.Index(i)) + } + + for idx, field := range joinForeignFields { + joinFieldValues[idx] = field.ReflectValueOf(joinResults.Index(i)) + } + + if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { + identityMap[utils.ToStringKey(joinFieldValues...)] = results + } + } + + _, foreignValues = getIdentityFieldValuesMap(joinResults, joinRelForeignFields) + } else { + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + relForeignFields = append(relForeignFields, ref.ForeignKey) + foreignFields = append(foreignFields, ref.PrimaryKey) + } else if ref.PrimaryValue != "" { + tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) + } else { + relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) + relForeignFields = append(relForeignFields, ref.PrimaryKey) + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + + identityMap, foreignValues = getIdentityFieldValuesMap(reflectValue, foreignFields) + } + + reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues) + + fieldValues := make([]reflect.Value, len(foreignFields)) + for i := 0; i < reflectResults.Len(); i++ { + for idx, field := range foreignFields { + fieldValues[idx] = field.ReflectValueOf(reflectResults.Index(i)) + } + + for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { + reflectFieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) + switch reflectFieldValue.Kind() { + case reflect.Struct: + elem := reflectResults.Index(i).Convert(reflectFieldValue.Type().Elem()) + rel.Field.Set(data, elem.Interface()) + case reflect.Slice, reflect.Array: + elem := reflectResults.Index(i).Convert(reflectFieldValue.Type().Elem()) + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + } + } + } } diff --git a/callbacks/query.go b/callbacks/query.go index ca9e84a9..2c187868 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -25,6 +25,7 @@ func Query(db *gorm.DB) { } } + // inline joins if len(db.Statement.Joins) != 0 { joins := []clause.Join{} @@ -101,7 +102,6 @@ func Query(db *gorm.DB) { func Preload(db *gorm.DB) { if len(db.Statement.Preloads) > 0 { preloadMap := map[string][]string{} - for name := range db.Statement.Preloads { preloadFields := strings.Split(name, ".") for idx := range preloadFields { @@ -118,27 +118,22 @@ func Preload(db *gorm.DB) { sort.Strings(preloadNames) for _, name := range preloadNames { - curSchema := db.Statement.Schema - preloadFields := preloadMap[name] + var ( + curSchema = db.Statement.Schema + preloadFields = preloadMap[name] + rels = make([]*schema.Relationship, len(preloadFields)) + ) for idx, preloadField := range preloadFields { if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { - if idx == len(preloadFields)-1 { - conds := db.Statement.Preloads[strings.Join(preloadFields[:idx+1], ".")] - - switch rel.Type { - case schema.HasOne: - case schema.HasMany: - case schema.BelongsTo: - case schema.Many2Many: - } - } else { - curSchema = rel.FieldSchema - } + rels[idx] = rel + curSchema = rel.FieldSchema } else { db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) } } + + preload(db.Session(&gorm.Session{}), rels, db.Statement.Preloads[name]) } } } diff --git a/statement.go b/statement.go index 3f2ceca3..f3090eb7 100644 --- a/statement.go +++ b/statement.go @@ -95,6 +95,15 @@ func (stmt Statement) QuoteTo(writer clause.Writer, field interface{}) { } case string: stmt.DB.Dialector.QuoteTo(writer, v) + case []string: + writer.WriteByte('(') + for idx, d := range v { + if idx != 0 { + writer.WriteString(",") + } + stmt.DB.Dialector.QuoteTo(writer, d) + } + writer.WriteByte(')') default: stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) } diff --git a/utils/utils.go b/utils/utils.go index 8dd500a5..f3dedec2 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -5,6 +5,7 @@ import ( "reflect" "regexp" "runtime" + "strconv" "strings" "unicode" ) @@ -38,3 +39,24 @@ func CheckTruth(val interface{}) bool { return !reflect.ValueOf(val).IsZero() } + +func ToStringKey(values ...reflect.Value) string { + results := make([]string, len(values)) + + for idx, value := range values { + rv := reflect.Indirect(value).Interface() + + switch v := rv.(type) { + case string: + results[idx] = v + case []byte: + results[idx] = string(v) + case uint: + results[idx] = strconv.FormatUint(uint64(v), 10) + default: + results[idx] = fmt.Sprint(v) + } + } + + return strings.Join(results, "_") +}