From 3326a4e69dd09e830f69155f2e45f295f9cba13d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 15 Jan 2016 15:53:53 +0800 Subject: [PATCH] Refactor Preload --- preload.go | 224 ++++++++++++++++++++--------------------------- utils_private.go | 34 +++++++ 2 files changed, 127 insertions(+), 131 deletions(-) diff --git a/preload.go b/preload.go index f415386d..cfc65380 100644 --- a/preload.go +++ b/preload.go @@ -1,139 +1,109 @@ package gorm import ( - "database/sql/driver" "errors" "fmt" "reflect" "strings" ) -func getRealValue(value reflect.Value, columns []string) (results []interface{}) { - // If value is a nil pointer, Indirect returns a zero Value! - // Therefor we need to check for a zero value, - // as FieldByName could panic - if pointedValue := reflect.Indirect(value); pointedValue.IsValid() { - for _, column := range columns { - if pointedValue.FieldByName(column).IsValid() { - result := pointedValue.FieldByName(column).Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() - } - results = append(results, result) - } - } - } - return -} - -func equalAsString(a interface{}, b interface{}) bool { - return fmt.Sprintf("%v", a) == fmt.Sprintf("%v", b) -} - +// Preload preload relations callback func Preload(scope *Scope) { if scope.Search.preload == nil || scope.HasError() { return } - preloadMap := map[string]bool{} - fields := scope.Fields() + var ( + preloadedMap = map[string]bool{} + fields = scope.Fields() + ) + for _, preload := range scope.Search.preload { - schema, conditions := preload.schema, preload.conditions - keys := strings.Split(schema, ".") - currentScope := scope - currentFields := fields - originalConditions := conditions - conditions = []interface{}{} - for i, key := range keys { - var found bool - if preloadMap[strings.Join(keys[:i+1], ".")] { - goto nextLoop - } + var ( + preloadFields = strings.Split(preload.schema, ".") + currentScope = scope + currentFields = fields + ) - if i == len(keys)-1 { - conditions = originalConditions - } + for idx, preloadField := range preloadFields { + var currentPreloadConditions []interface{} - for _, field := range currentFields { - if field.Name != key || field.Relationship == nil { - continue + // if not preloaded + if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { + + // assign search conditions to last preload + if idx == len(preloadFields)-1 { + currentPreloadConditions = preload.conditions } - found = true - switch field.Relationship.Kind { - case "has_one": - currentScope.handleHasOnePreload(field, conditions) - case "has_many": - currentScope.handleHasManyPreload(field, conditions) - case "belongs_to": - currentScope.handleBelongsToPreload(field, conditions) - case "many_to_many": - currentScope.handleManyToManyPreload(field, conditions) - default: - currentScope.Err(errors.New("not supported relation")) + for _, field := range currentFields { + if field.Name != preloadField || field.Relationship == nil { + continue + } + + switch field.Relationship.Kind { + case "has_one": + currentScope.handleHasOnePreload(field, currentPreloadConditions) + case "has_many": + currentScope.handleHasManyPreload(field, currentPreloadConditions) + case "belongs_to": + currentScope.handleBelongsToPreload(field, currentPreloadConditions) + case "many_to_many": + currentScope.handleManyToManyPreload(field, currentPreloadConditions) + default: + scope.Err(errors.New("unsupported relation")) + } + + preloadedMap[preloadKey] = true + break + } + + if !preloadedMap[preloadKey] { + scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) + return } - break } - if !found { - value := reflect.ValueOf(currentScope.Value) - if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { - value = value.Index(0).Elem() - } - scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type())) - return - } - - preloadMap[strings.Join(keys[:i+1], ".")] = true - - nextLoop: - if i < len(keys)-1 { - currentScope = currentScope.getColumnsAsScope(key) + // preload next level + if idx < len(preloadFields)-1 { + currentScope = currentScope.getColumnAsScope(preloadField) currentFields = currentScope.Fields() } } } - -} - -func makeSlice(typ reflect.Type) interface{} { - if typ.Kind() == reflect.Slice { - typ = typ.Elem() - } - sliceType := reflect.SliceOf(typ) - slice := reflect.New(sliceType) - slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) - return slice.Interface() } func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { relation := field.Relationship + // get relations's primary keys primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) if len(primaryKeys) == 0 { return } + // find relations results := makeSlice(field.Struct.Type) scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, relation.ForeignFieldNames) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) { - reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) + // assign find results + var ( + resultsValue = reflect.Indirect(reflect.ValueOf(results)) + indirectScopeValue = scope.IndirectValue() + ) + + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + if indirectScopeValue.Kind() == reflect.Slice { + value := getValueFromFields(result, relation.ForeignFieldNames) + for j := 0; j < indirectScopeValue.Len(); j++ { + if equalAsString(getValueFromFields(indirectScopeValue.Index(j), relation.AssociationForeignFieldNames), value) { + reflect.Indirect(indirectScopeValue.Index(j)).FieldByName(field.Name).Set(result) break } } } else { - if err := scope.SetColumn(field, result); err != nil { - scope.Err(err) - return - } + scope.Err(field.Set(result)) } } } @@ -152,11 +122,11 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) if scope.IndirectValue().Kind() == reflect.Slice { for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldNames) + value := getValueFromFields(result, relation.ForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) { + if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), value) { f := object.FieldByName(field.Name) f.Set(reflect.Append(f, result)) break @@ -182,11 +152,11 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, relation.AssociationForeignFieldNames) + value := getValueFromFields(result, relation.AssociationForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) { + if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { object.FieldByName(field.Name).Set(result) } } @@ -276,10 +246,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface if indirectScopeValue.Kind() == reflect.Slice { for j := 0; j < indirectScopeValue.Len(); j++ { object := reflect.Indirect(indirectScopeValue.Index(j)) - fieldsSourceMap[toString(getRealValue(object, foreignFieldNames))] = object.FieldByName(field.Name) + fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name) } } else if indirectScopeValue.IsValid() { - fieldsSourceMap[toString(getRealValue(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name) + fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name) } for source, link := range linkHash { @@ -308,46 +278,38 @@ func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) return } -func (scope *Scope) getColumnsAsScope(column string) *Scope { - values := scope.IndirectValue() - switch values.Kind() { +func (scope *Scope) getColumnAsScope(column string) *Scope { + indirectScopeValue := scope.IndirectValue() + + switch indirectScopeValue.Kind() { case reflect.Slice: - modelType := values.Type().Elem() - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - fieldStruct, _ := modelType.FieldByName(column) - var columns reflect.Value - if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr { - columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem() - } else { - columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem() - } - for i := 0; i < values.Len(); i++ { - column := reflect.Indirect(values.Index(i)).FieldByName(column) - if column.Kind() == reflect.Ptr { - column = column.Elem() + if fieldStruct, ok := scope.GetModelStruct().ModelType.FieldByName(column); ok { + fieldType := fieldStruct.Type + if fieldType.Kind() == reflect.Slice || fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() } - if column.Kind() == reflect.Slice { - for i := 0; i < column.Len(); i++ { - elem := column.Index(i) - if elem.CanAddr() { - columns = reflect.Append(columns, elem.Addr()) + + results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() + + for i := 0; i < indirectScopeValue.Len(); i++ { + result := reflect.Indirect(reflect.Indirect(indirectScopeValue.Index(i)).FieldByName(column)) + + if result.Kind() == reflect.Slice { + for j := 0; j < result.Len(); j++ { + if elem := result.Index(j); elem.CanAddr() { + results = reflect.Append(results, elem.Addr()) + } } - } - } else { - if column.CanAddr() { - columns = reflect.Append(columns, column.Addr()) + } else if result.CanAddr() { + results = reflect.Append(results, result.Addr()) } } + return scope.New(results.Interface()) } - return scope.New(columns.Interface()) case reflect.Struct: - field := values.FieldByName(column) - if !field.CanAddr() { - return nil + if field := indirectScopeValue.FieldByName(column); field.CanAddr() { + return scope.New(field.Addr().Interface()) } - return scope.New(field.Addr().Interface()) } return nil } diff --git a/utils_private.go b/utils_private.go index 50549857..f8f918fb 100644 --- a/utils_private.go +++ b/utils_private.go @@ -1,6 +1,7 @@ package gorm import ( + "database/sql/driver" "fmt" "reflect" "regexp" @@ -73,6 +74,10 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { return attrs } +func equalAsString(a interface{}, b interface{}) bool { + return toString(a) == toString(b) +} + func toString(str interface{}) string { if values, ok := str.([]interface{}); ok { var results []string @@ -87,6 +92,16 @@ func toString(str interface{}) string { } } +func makeSlice(elemType reflect.Type) interface{} { + if elemType.Kind() == reflect.Slice { + elemType = elemType.Elem() + } + sliceType := reflect.SliceOf(elemType) + slice := reflect.New(sliceType) + slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0)) + return slice.Interface() +} + func strInSlice(a string, list []string) bool { for _, b := range list { if b == a { @@ -95,3 +110,22 @@ func strInSlice(a string, list []string) bool { } return false } + +// getValueFromFields return given fields's value +func getValueFromFields(value reflect.Value, fieldNames []string) (results []interface{}) { + // If value is a nil pointer, Indirect returns a zero Value! + // Therefor we need to check for a zero value, + // as FieldByName could panic + if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { + for _, fieldName := range fieldNames { + if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + result := fieldValue.Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) + } + } + } + return +}