Refacotr Preload

This commit is contained in:
Jinzhu 2020-05-18 13:07:11 +08:00
parent f999240e10
commit 59365b776b
3 changed files with 113 additions and 102 deletions

View File

@ -9,102 +9,6 @@ import (
"github.com/jinzhu/gorm/utils" "github.com/jinzhu/gorm/utils"
) )
// getRelationsValue get relations's value from a reflect value
func getRelationsValue(reflectValue reflect.Value, rels []*schema.Relationship) (reflectResults reflect.Value) {
for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0)
appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(value))
switch result.Kind() {
case reflect.Struct:
reflectResults = reflect.Append(reflectResults, result)
case reflect.Slice, reflect.Array:
for i := 0; i < value.Len(); i++ {
reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i)))
}
}
}
}
switch reflectValue.Kind() {
case reflect.Struct:
appendToResults(reflectValue)
case reflect.Slice:
for i := 0; i < reflectValue.Len(); i++ {
appendToResults(reflectValue.Index(i))
}
}
reflectValue = reflectResults
}
return
}
func getIdentityFieldValuesMap(reflectValue reflect.Value, fields []*schema.Field) (map[string][]reflect.Value, [][]interface{}) {
var (
fieldValues = make([]reflect.Value, len(fields))
results = [][]interface{}{}
dataResults = map[string][]reflect.Value{}
)
switch reflectValue.Kind() {
case reflect.Struct:
results = [][]interface{}{make([]interface{}, len(fields))}
for idx, field := range fields {
fieldValues[idx] = field.ReflectValueOf(reflectValue)
results[0][idx] = fieldValues[idx].Interface()
}
dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
for idx, field := range fields {
fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i))
}
dataKey := utils.ToStringKey(fieldValues...)
if _, ok := dataResults[dataKey]; !ok {
result := make([]interface{}, len(fieldValues))
for idx, fieldValue := range fieldValues {
result[idx] = fieldValue.Interface()
}
results = append(results, result)
dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)}
} else {
dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i))
}
}
}
return dataResults, results
}
func preloadData(tx *gorm.DB, resultSchema *schema.Schema, foreignKeys []string, foreignValues [][]interface{}, conds []interface{}) reflect.Value {
slice := reflect.MakeSlice(reflect.SliceOf(resultSchema.ModelType), 0, 0)
results := reflect.New(slice.Type())
results.Elem().Set(slice)
queryValues := make([]interface{}, len(foreignValues))
if len(foreignKeys) == 1 {
for idx, r := range foreignValues {
queryValues[idx] = r[0]
}
tx.Where(clause.IN{Column: foreignKeys[0], Values: queryValues}).Find(results.Interface(), conds...)
} else {
for idx, r := range foreignValues {
queryValues[idx] = r
}
tx.Where(clause.IN{Column: foreignKeys, Values: queryValues}).Find(results.Interface(), conds...)
}
return results.Elem()
}
func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
var ( var (
reflectValue = db.Statement.ReflectValue reflectValue = db.Statement.ReflectValue
@ -118,7 +22,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
) )
if len(rels) > 1 { if len(rels) > 1 {
reflectValue = getRelationsValue(reflectValue, rels[:len(rels)]) reflectValue = schema.GetRelationsValues(reflectValue, rels[:len(rels)])
} }
if rel.JoinTable != nil { if rel.JoinTable != nil {
@ -138,8 +42,11 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
} }
joinIdentityMap, joinForeignValues := getIdentityFieldValuesMap(reflectValue, joinForeignFields) joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, joinForeignFields)
joinResults := preloadData(tx, rel.JoinTable, joinForeignKeys, joinForeignValues, nil)
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues)
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface())
// convert join identity map to relation identity map // convert join identity map to relation identity map
fieldValues := make([]reflect.Value, len(foreignFields)) fieldValues := make([]reflect.Value, len(foreignFields))
@ -158,7 +65,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
} }
_, foreignValues = getIdentityFieldValuesMap(joinResults, joinRelForeignFields) _, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
} else { } else {
for _, ref := range rel.References { for _, ref := range rel.References {
if ref.OwnPrimaryKey { if ref.OwnPrimaryKey {
@ -174,10 +81,12 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
} }
identityMap, foreignValues = getIdentityFieldValuesMap(reflectValue, foreignFields) identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
} }
reflectResults := preloadData(tx, rel.FieldSchema, relForeignKeys, foreignValues, conds) reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(relForeignKeys, foreignValues)
tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...)
fieldValues := make([]reflect.Value, len(foreignFields)) fieldValues := make([]reflect.Value, len(foreignFields))
for i := 0; i < reflectResults.Len(); i++ { for i := 0; i < reflectResults.Len(); i++ {

View File

@ -43,6 +43,13 @@ func (schema Schema) String() string {
return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name())
} }
func (schema Schema) MakeSlice() reflect.Value {
slice := reflect.MakeSlice(reflect.SliceOf(schema.ModelType), 0, 0)
results := reflect.New(slice.Type())
results.Elem().Set(slice)
return results
}
func (schema Schema) LookUpField(name string) *Field { func (schema Schema) LookUpField(name string) *Field {
if field, ok := schema.FieldsByDBName[name]; ok { if field, ok := schema.FieldsByDBName[name]; ok {
return field return field

View File

@ -4,6 +4,8 @@ import (
"reflect" "reflect"
"regexp" "regexp"
"strings" "strings"
"github.com/jinzhu/gorm/utils"
) )
func ParseTagSetting(str string, sep string) map[string]string { func ParseTagSetting(str string, sep string) map[string]string {
@ -49,3 +51,96 @@ func toColumns(val string) (results []string) {
func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag {
return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}")) return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}"))
} }
// GetRelationsValues get relations's values from a reflect value
func GetRelationsValues(reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(rel.FieldSchema.ModelType), 0, 0)
appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(value))
switch result.Kind() {
case reflect.Struct:
reflectResults = reflect.Append(reflectResults, result)
case reflect.Slice, reflect.Array:
for i := 0; i < value.Len(); i++ {
reflectResults = reflect.Append(reflectResults, reflect.Indirect(value.Index(i)))
}
}
}
}
switch reflectValue.Kind() {
case reflect.Struct:
appendToResults(reflectValue)
case reflect.Slice:
for i := 0; i < reflectValue.Len(); i++ {
appendToResults(reflectValue.Index(i))
}
}
reflectValue = reflectResults
}
return
}
// GetIdentityFieldValuesMap get identity map from fields
func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
var (
fieldValues = make([]reflect.Value, len(fields))
results = [][]interface{}{}
dataResults = map[string][]reflect.Value{}
)
switch reflectValue.Kind() {
case reflect.Struct:
results = [][]interface{}{make([]interface{}, len(fields))}
for idx, field := range fields {
fieldValues[idx] = field.ReflectValueOf(reflectValue)
results[0][idx] = fieldValues[idx].Interface()
}
dataResults[utils.ToStringKey(fieldValues...)] = []reflect.Value{reflectValue}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
for idx, field := range fields {
fieldValues[idx] = field.ReflectValueOf(reflectValue.Index(i))
}
dataKey := utils.ToStringKey(fieldValues...)
if _, ok := dataResults[dataKey]; !ok {
result := make([]interface{}, len(fieldValues))
for idx, fieldValue := range fieldValues {
result[idx] = fieldValue.Interface()
}
results = append(results, result)
dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)}
} else {
dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i))
}
}
}
return dataResults, results
}
// ToQueryValues to query values
func ToQueryValues(foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) {
queryValues := make([]interface{}, len(foreignValues))
if len(foreignKeys) == 1 {
for idx, r := range foreignValues {
queryValues[idx] = r[0]
}
return foreignKeys[0], queryValues
} else {
for idx, r := range foreignValues {
queryValues[idx] = r
}
}
return foreignKeys, queryValues
}