Add more preload tests

This commit is contained in:
Jinzhu 2020-06-02 00:44:48 +08:00
parent bc01eb28ad
commit b71171dd92
3 changed files with 1544 additions and 8 deletions

View File

@ -19,6 +19,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
foreignFields []*schema.Field foreignFields []*schema.Field
foreignValues [][]interface{} foreignValues [][]interface{}
identityMap = map[string][]reflect.Value{} identityMap = map[string][]reflect.Value{}
inlineConds []interface{}
) )
if len(rels) > 1 { if len(rels) > 1 {
@ -64,7 +65,8 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
} }
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
identityMap[utils.ToStringKey(joinFieldValues...)] = results joinKey := utils.ToStringKey(joinFieldValues...)
identityMap[joinKey] = append(identityMap[joinKey], results...)
} }
} }
@ -92,12 +94,23 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
reflectResults := rel.FieldSchema.MakeSlice().Elem() reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(relForeignKeys, foreignValues) column, values := schema.ToQueryValues(relForeignKeys, foreignValues)
tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), conds...)
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
}
}
tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...)
fieldValues := make([]interface{}, len(relForeignFields)) fieldValues := make([]interface{}, len(relForeignFields))
for i := 0; i < reflectResults.Len(); i++ { for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i)
for idx, field := range relForeignFields { for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(reflectResults.Index(i)) fieldValues[idx], _ = field.ValueOf(elem)
} }
for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { for _, data := range identityMap[utils.ToStringKey(fieldValues...)] {
@ -105,15 +118,16 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
} }
reflectFieldValue = reflect.Indirect(reflectFieldValue) reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() { switch reflectFieldValue.Kind() {
case reflect.Struct: case reflect.Struct:
rel.Field.Set(data, reflectResults.Index(i).Interface()) rel.Field.Set(data, reflectResults.Index(i).Interface())
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i)).Interface()) rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
} else { } else {
rel.Field.Set(data, reflect.Append(reflectFieldValue, reflectResults.Index(i).Elem()).Interface()) rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
} }
} }
} }

View File

@ -95,6 +95,7 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
var ( var (
results = [][]interface{}{} results = [][]interface{}{}
dataResults = map[string][]reflect.Value{} dataResults = map[string][]reflect.Value{}
loaded = map[interface{}]bool{}
notZero, zero bool notZero, zero bool
) )
@ -114,10 +115,21 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue}
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
elem := reflectValue.Index(i)
elemKey := elem.Interface()
if elem.Kind() != reflect.Ptr {
elemKey = elem.Addr().Interface()
}
if _, ok := loaded[elemKey]; ok {
continue
}
loaded[elemKey] = true
fieldValues := make([]interface{}, len(fields)) fieldValues := make([]interface{}, len(fields))
notZero = false notZero = false
for idx, field := range fields { for idx, field := range fields {
fieldValues[idx], zero = field.ValueOf(reflectValue.Index(i)) fieldValues[idx], zero = field.ValueOf(elem)
notZero = notZero || !zero notZero = notZero || !zero
} }
@ -125,9 +137,9 @@ func GetIdentityFieldValuesMap(reflectValue reflect.Value, fields []*Field) (map
dataKey := utils.ToStringKey(fieldValues...) dataKey := utils.ToStringKey(fieldValues...)
if _, ok := dataResults[dataKey]; !ok { if _, ok := dataResults[dataKey]; !ok {
results = append(results, fieldValues[:]) results = append(results, fieldValues[:])
dataResults[dataKey] = []reflect.Value{reflectValue.Index(i)} dataResults[dataKey] = []reflect.Value{elem}
} else { } else {
dataResults[dataKey] = append(dataResults[dataKey], reflectValue.Index(i)) dataResults[dataKey] = append(dataResults[dataKey], elem)
} }
} }
} }

1510
tests/preload_suits_test.go Normal file

File diff suppressed because it is too large Load Diff