Refactor Preload

This commit is contained in:
Jinzhu 2016-01-15 15:53:53 +08:00
parent 41620f3d6c
commit 3326a4e69d
2 changed files with 127 additions and 131 deletions

View File

@ -1,139 +1,109 @@
package gorm package gorm
import ( import (
"database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
) )
func getRealValue(value reflect.Value, columns []string) (results []interface{}) { // Preload preload relations callback
// If value is a nil pointer, Indirect returns a zero Value!
// Therefor we need to check for a zero value,
// as FieldByName could panic
if pointedValue := reflect.Indirect(value); pointedValue.IsValid() {
for _, column := range columns {
if pointedValue.FieldByName(column).IsValid() {
result := pointedValue.FieldByName(column).Interface()
if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value()
}
results = append(results, result)
}
}
}
return
}
func equalAsString(a interface{}, b interface{}) bool {
return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b)
}
func Preload(scope *Scope) { func Preload(scope *Scope) {
if scope.Search.preload == nil || scope.HasError() { if scope.Search.preload == nil || scope.HasError() {
return return
} }
preloadMap := map[string]bool{} var (
fields := scope.Fields() preloadedMap = map[string]bool{}
fields = scope.Fields()
)
for _, preload := range scope.Search.preload { for _, preload := range scope.Search.preload {
schema, conditions := preload.schema, preload.conditions var (
keys := strings.Split(schema, ".") preloadFields = strings.Split(preload.schema, ".")
currentScope := scope currentScope = scope
currentFields := fields currentFields = fields
originalConditions := conditions )
conditions = []interface{}{}
for i, key := range keys {
var found bool
if preloadMap[strings.Join(keys[:i+1], ".")] {
goto nextLoop
}
if i == len(keys)-1 { for idx, preloadField := range preloadFields {
conditions = originalConditions var currentPreloadConditions []interface{}
}
for _, field := range currentFields { // if not preloaded
if field.Name != key || field.Relationship == nil { if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] {
continue
// assign search conditions to last preload
if idx == len(preloadFields)-1 {
currentPreloadConditions = preload.conditions
} }
found = true for _, field := range currentFields {
switch field.Relationship.Kind { if field.Name != preloadField || field.Relationship == nil {
case "has_one": continue
currentScope.handleHasOnePreload(field, conditions) }
case "has_many":
currentScope.handleHasManyPreload(field, conditions) switch field.Relationship.Kind {
case "belongs_to": case "has_one":
currentScope.handleBelongsToPreload(field, conditions) currentScope.handleHasOnePreload(field, currentPreloadConditions)
case "many_to_many": case "has_many":
currentScope.handleManyToManyPreload(field, conditions) currentScope.handleHasManyPreload(field, currentPreloadConditions)
default: case "belongs_to":
currentScope.Err(errors.New("not supported relation")) currentScope.handleBelongsToPreload(field, currentPreloadConditions)
case "many_to_many":
currentScope.handleManyToManyPreload(field, currentPreloadConditions)
default:
scope.Err(errors.New("unsupported relation"))
}
preloadedMap[preloadKey] = true
break
}
if !preloadedMap[preloadKey] {
scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType))
return
} }
break
} }
if !found { // preload next level
value := reflect.ValueOf(currentScope.Value) if idx < len(preloadFields)-1 {
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { currentScope = currentScope.getColumnAsScope(preloadField)
value = value.Index(0).Elem()
}
scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
return
}
preloadMap[strings.Join(keys[:i+1], ".")] = true
nextLoop:
if i < len(keys)-1 {
currentScope = currentScope.getColumnsAsScope(key)
currentFields = currentScope.Fields() currentFields = currentScope.Fields()
} }
} }
} }
}
func makeSlice(typ reflect.Type) interface{} {
if typ.Kind() == reflect.Slice {
typ = typ.Elem()
}
sliceType := reflect.SliceOf(typ)
slice := reflect.New(sliceType)
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
return slice.Interface()
} }
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
relation := field.Relationship relation := field.Relationship
// get relations's primary keys
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }
// find relations
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ { // assign find results
result := resultValues.Index(i) var (
if scope.IndirectValue().Kind() == reflect.Slice { resultsValue = reflect.Indirect(reflect.ValueOf(results))
value := getRealValue(result, relation.ForeignFieldNames) indirectScopeValue = scope.IndirectValue()
objects := scope.IndirectValue() )
for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) { for i := 0; i < resultsValue.Len(); i++ {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) result := resultsValue.Index(i)
if indirectScopeValue.Kind() == reflect.Slice {
value := getValueFromFields(result, relation.ForeignFieldNames)
for j := 0; j < indirectScopeValue.Len(); j++ {
if equalAsString(getValueFromFields(indirectScopeValue.Index(j), relation.AssociationForeignFieldNames), value) {
reflect.Indirect(indirectScopeValue.Index(j)).FieldByName(field.Name).Set(result)
break break
} }
} }
} else { } else {
if err := scope.SetColumn(field, result); err != nil { scope.Err(field.Set(result))
scope.Err(err)
return
}
} }
} }
} }
@ -152,11 +122,11 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
if scope.IndirectValue().Kind() == reflect.Slice { if scope.IndirectValue().Kind() == reflect.Slice {
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i) result := resultValues.Index(i)
value := getRealValue(result, relation.ForeignFieldNames) value := getValueFromFields(result, relation.ForeignFieldNames)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j)) object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) { if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), value) {
f := object.FieldByName(field.Name) f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result)) f.Set(reflect.Append(f, result))
break break
@ -182,11 +152,11 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i) result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice { if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, relation.AssociationForeignFieldNames) value := getValueFromFields(result, relation.AssociationForeignFieldNames)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j)) object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) { if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) {
object.FieldByName(field.Name).Set(result) object.FieldByName(field.Name).Set(result)
} }
} }
@ -276,10 +246,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
if indirectScopeValue.Kind() == reflect.Slice { if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ { for j := 0; j < indirectScopeValue.Len(); j++ {
object := reflect.Indirect(indirectScopeValue.Index(j)) object := reflect.Indirect(indirectScopeValue.Index(j))
fieldsSourceMap[toString(getRealValue(object, foreignFieldNames))] = object.FieldByName(field.Name) fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name)
} }
} else if indirectScopeValue.IsValid() { } else if indirectScopeValue.IsValid() {
fieldsSourceMap[toString(getRealValue(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name) fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name)
} }
for source, link := range linkHash { for source, link := range linkHash {
@ -308,46 +278,38 @@ func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{})
return return
} }
func (scope *Scope) getColumnsAsScope(column string) *Scope { func (scope *Scope) getColumnAsScope(column string) *Scope {
values := scope.IndirectValue() indirectScopeValue := scope.IndirectValue()
switch values.Kind() {
switch indirectScopeValue.Kind() {
case reflect.Slice: case reflect.Slice:
modelType := values.Type().Elem() if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok {
if modelType.Kind() == reflect.Ptr { fieldType := fieldStruct.Type
modelType = modelType.Elem() if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr {
} fieldType = fieldType.Elem()
fieldStruct, _ := modelType.FieldByName(column)
var columns reflect.Value
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
} else {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
}
for i := 0; i < values.Len(); i++ {
column := reflect.Indirect(values.Index(i)).FieldByName(column)
if column.Kind() == reflect.Ptr {
column = column.Elem()
} }
if column.Kind() == reflect.Slice {
for i := 0; i < column.Len(); i++ { results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem()
elem := column.Index(i)
if elem.CanAddr() { for i := 0; i < indirectScopeValue.Len(); i++ {
columns = reflect.Append(columns, elem.Addr()) result := reflect.Indirect(reflect.Indirect(indirectScopeValue.Index(i)).FieldByName(column))
if result.Kind() == reflect.Slice {
for j := 0; j < result.Len(); j++ {
if elem := result.Index(j); elem.CanAddr() {
results = reflect.Append(results, elem.Addr())
}
} }
} } else if result.CanAddr() {
} else { results = reflect.Append(results, result.Addr())
if column.CanAddr() {
columns = reflect.Append(columns, column.Addr())
} }
} }
return scope.New(results.Interface())
} }
return scope.New(columns.Interface())
case reflect.Struct: case reflect.Struct:
field := values.FieldByName(column) if field := indirectScopeValue.FieldByName(column); field.CanAddr() {
if !field.CanAddr() { return scope.New(field.Addr().Interface())
return nil
} }
return scope.New(field.Addr().Interface())
} }
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package gorm package gorm
import ( import (
"database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
@ -73,6 +74,10 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} {
return attrs return attrs
} }
func equalAsString(a interface{}, b interface{}) bool {
return toString(a) == toString(b)
}
func toString(str interface{}) string { func toString(str interface{}) string {
if values, ok := str.([]interface{}); ok { if values, ok := str.([]interface{}); ok {
var results []string var results []string
@ -87,6 +92,16 @@ func toString(str interface{}) string {
} }
} }
func makeSlice(elemType reflect.Type) interface{} {
if elemType.Kind() == reflect.Slice {
elemType = elemType.Elem()
}
sliceType := reflect.SliceOf(elemType)
slice := reflect.New(sliceType)
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
return slice.Interface()
}
func strInSlice(a string, list []string) bool { func strInSlice(a string, list []string) bool {
for _, b := range list { for _, b := range list {
if b == a { if b == a {
@ -95,3 +110,22 @@ func strInSlice(a string, list []string) bool {
} }
return false return false
} }
// getValueFromFields return given fields's value
func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) {
// If value is a nil pointer, Indirect returns a zero Value!
// Therefor we need to check for a zero value,
// as FieldByName could panic
if indirectValue := reflect.Indirect(value); indirectValue.IsValid() {
for _, fieldName := range fieldNames {
if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() {
result := fieldValue.Interface()
if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value()
}
results = append(results, result)
}
}
}
return
}