From 7693c093a96ef1bd0c8312b938201997dfe99f46 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Apr 2015 15:36:10 +0800 Subject: [PATCH] Refactor Preload --- preload.go | 30 ++++++++++++++---------------- preload_test.go | 4 ++-- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/preload.go b/preload.go index 42836067..add077ab 100644 --- a/preload.go +++ b/preload.go @@ -106,10 +106,9 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) results := makeSlice(field.Struct.Type) relation := field.Relationship condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - // TODO: handle error? - scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) @@ -123,8 +122,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) } } } else { - err := scope.SetColumn(field, result) - if err != nil { + if err := scope.SetColumn(field, result); err != nil { scope.Err(err) return } @@ -142,9 +140,9 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) results := makeSlice(field.Struct.Type) relation := field.Relationship condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + resultValues := reflect.Indirect(reflect.ValueOf(results)) if scope.IndirectValue().Kind() == reflect.Slice { for i := 0; i < resultValues.Len(); i++ { @@ -173,10 +171,10 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } results := makeSlice(field.Struct.Type) - resultValues := reflect.Indirect(reflect.ValueOf(results)) associationPrimaryKey := scope.New(results).PrimaryField().Name - scope.NewDB().Where(primaryKeys).Find(results, conditions...) + scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error) + resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) @@ -212,16 +210,16 @@ func (scope *Scope) getColumnsAsScope(column string) *Scope { values := scope.IndirectValue() switch values.Kind() { case reflect.Slice: - model := values.Type().Elem() - if model.Kind() == reflect.Ptr { - model = model.Elem() + modelType := values.Type().Elem() + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() } - fieldType, _ := model.FieldByName(column) + fieldStruct, _ := modelType.FieldByName(column) var columns reflect.Value - if fieldType.Type.Kind() == reflect.Slice { - columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type.Elem()))).Elem() + if fieldStruct.Type.Kind() == reflect.Slice { + columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem() } else { - columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type))).Elem() + columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem() } for i := 0; i < values.Len(); i++ { column := reflect.Indirect(values.Index(i)).FieldByName(column) diff --git a/preload_test.go b/preload_test.go index 7fbba5c0..a6647bbd 100644 --- a/preload_test.go +++ b/preload_test.go @@ -156,7 +156,6 @@ func TestNestedPreload2(t *testing.T) { } want := Level3{ - Name: "name", Level2s: []Level2{ { Level1s: []*Level1{ @@ -211,7 +210,6 @@ func TestNestedPreload3(t *testing.T) { } want := Level3{ - Name: "name", Level2s: []Level2{ {Level1: Level1{Value: "value1"}}, {Level1: Level1{Value: "value2"}}, @@ -368,6 +366,7 @@ func TestNestedPreload6(t *testing.T) { if err := DB.Create(&want[0]).Error; err != nil { panic(err) } + want[1] = Level3{ Level2s: []Level2{ { @@ -432,6 +431,7 @@ func TestNestedPreload7(t *testing.T) { if err := DB.Create(&want[0]).Error; err != nil { panic(err) } + want[1] = Level3{ Level2s: []Level2{ {Level1: Level1{Value: "value3"}},