diff --git a/callbacks/preload.go b/callbacks/preload.go index aaac31b5..9f23a2ca 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -9,102 +9,6 @@ import ( "github.com/jinzhu/gorm/utils" ) -// getRelationsValue get relations's value from a reflect value -func getRelationsValue(reflectValue reflect.Value, rels []*schema.Relationship) (reflectResults reflect.Value) { - for _, rel := range rels { - reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0) - - appendToResults := func(value reflect.Value) { - if _, isZero := rel.Field.ValueOf(value); !isZero { - result := reflect.Indirect(rel.Field.ReflectValueOf(value)) - switch result.Kind() { - case reflect.Struct: - reflectResults = reflect.Append(reflectResults, result) - case reflect.Slice, reflect.Array: - for i := 0; i < value.Len(); i++ { - reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i))) - } - } - } - } - - switch reflectValue.Kind() { - case reflect.Struct: - appendToResults(reflectValue) - case reflect.Slice: - for i := 0; i < reflectValue.Len(); i++ { - appendToResults(reflectValue.Index(i)) - } - } - - reflectValue = reflectResults - } - - return -} - -func getIdentityFieldValuesMap(reflectValue reflect.Value, fields []*schema.Field) (map[string][]reflect.Value, [][]interface{}) { - var ( - fieldValues = make([]reflect.Value, len(fields)) - results = [][]interface{}{} - dataResults = map[string][]reflect.Value{} - ) - - switch reflectValue.Kind() { - case reflect.Struct: - results = [][]interface{}{make([]interface{}, len(fields))} - - for idx, field := range fields { - fieldValues[idx] = field.ReflectValueOf(reflectValue) - results[0][idx] = fieldValues[idx].Interface() - } - - dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue} - case reflect.Slice, reflect.Array: - for i := 0; i < reflectValue.Len(); i++ { - for idx, field := range fields { - fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i)) - } - - dataKey := utils.ToStringKey(fieldValues...) - if _, ok := dataResults[dataKey]; !ok { - result := make([]interface{}, len(fieldValues)) - for idx, fieldValue := range fieldValues { - result[idx] = fieldValue.Interface() - } - results = append(results, result) - - dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} - } else { - dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) - } - } - } - - return dataResults, results -} - -func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}, conds []interface{}) reflect.Value { - slice := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0) - results := reflect.New(slice.Type()) - results.Elem().Set(slice) - - queryValues := make([]interface{}, len(foreignValues)) - if len(foreignKeys) == 1 { - for idx, r := range foreignValues { - queryValues[idx] = r[0] - } - tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Interface(), conds...) - } else { - for idx, r := range foreignValues { - queryValues[idx] = r - } - tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Interface(), conds...) - } - - return results.Elem() -} - func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { var ( reflectValue = db.Statement.ReflectValue @@ -118,7 +22,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { ) if len(rels) > 1 { - reflectValue = getRelationsValue(reflectValue, rels[:len(rels)]) + reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)]) } if rel.JoinTable != nil { @@ -138,8 +42,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - joinIdentityMap, joinForeignValues := getIdentityFieldValuesMap(reflectValue, joinForeignFields) - joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues, nil) + joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, joinForeignFields) + + joinResults := rel.JoinTable.MakeSlice().Elem() + column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues) + tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) // convert join identity map to relation identity map fieldValues := make([]reflect.Value, len(foreignFields)) @@ -158,7 +65,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - _, foreignValues = getIdentityFieldValuesMap(joinResults, joinRelForeignFields) + _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { @@ -174,10 +81,12 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - identityMap, foreignValues = getIdentityFieldValuesMap(reflectValue, foreignFields) + identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields) } - reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues, conds) + reflectResults := rel.FieldSchema.MakeSlice().Elem() + column, values := schema.ToQueryValues(relForeignKeys, foreignValues) + tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...) fieldValues := make([]reflect.Value, len(foreignFields)) for i := 0; i < reflectResults.Len(); i++ { diff --git a/schema/schema.go b/schema/schema.go index 3abac2ba..5a28797b 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -43,6 +43,13 @@ func (schema Schema) String() string { return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) } +func (schema Schema) MakeSlice() reflect.Value { + slice := reflect.MakeSlice(reflect.SliceOf(schema.ModelType), 0, 0) + results := reflect.New(slice.Type()) + results.Elem().Set(slice) + return results +} + func (schema Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field diff --git a/schema/utils.go b/schema/utils.go index 7be78bc5..7a26332d 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -4,6 +4,8 @@ import ( "reflect" "regexp" "strings" + + "github.com/jinzhu/gorm/utils" ) func ParseTagSetting(str string, sep string) map[string]string { @@ -49,3 +51,96 @@ func toColumns(val string) (results []string) { func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}")) } + +// GetRelationsValues get relations's values from a reflect value +func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { + for _, rel := range rels { + reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0) + + appendToResults := func(value reflect.Value) { + if _, isZero := rel.Field.ValueOf(value); !isZero { + result := reflect.Indirect(rel.Field.ReflectValueOf(value)) + switch result.Kind() { + case reflect.Struct: + reflectResults = reflect.Append(reflectResults, result) + case reflect.Slice, reflect.Array: + for i := 0; i < value.Len(); i++ { + reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i))) + } + } + } + } + + switch reflectValue.Kind() { + case reflect.Struct: + appendToResults(reflectValue) + case reflect.Slice: + for i := 0; i < reflectValue.Len(); i++ { + appendToResults(reflectValue.Index(i)) + } + } + + reflectValue = reflectResults + } + + return +} + +// GetIdentityFieldValuesMap get identity map from fields +func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { + var ( + fieldValues = make([]reflect.Value, len(fields)) + results = [][]interface{}{} + dataResults = map[string][]reflect.Value{} + ) + + switch reflectValue.Kind() { + case reflect.Struct: + results = [][]interface{}{make([]interface{}, len(fields))} + + for idx, field := range fields { + fieldValues[idx] = field.ReflectValueOf(reflectValue) + results[0][idx] = fieldValues[idx].Interface() + } + + dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue} + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + for idx, field := range fields { + fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i)) + } + + dataKey := utils.ToStringKey(fieldValues...) + if _, ok := dataResults[dataKey]; !ok { + result := make([]interface{}, len(fieldValues)) + for idx, fieldValue := range fieldValues { + result[idx] = fieldValue.Interface() + } + results = append(results, result) + + dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} + } else { + dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) + } + } + } + + return dataResults, results +} + +// ToQueryValues to query values +func ToQueryValues(foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { + queryValues := make([]interface{}, len(foreignValues)) + if len(foreignKeys) == 1 { + for idx, r := range foreignValues { + queryValues[idx] = r[0] + } + + return foreignKeys[0], queryValues + } else { + for idx, r := range foreignValues { + queryValues[idx] = r + } + } + return foreignKeys, queryValues +}