diff --git a/callback_query_preload.go b/callback_query_preload.go index ff99fea9..75225c27 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -89,7 +89,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) // assign find results var ( - resultsValue = reflect.Indirect(reflect.ValueOf(results)) + resultsValue = indirect(reflect.ValueOf(results)) indirectScopeValue = scope.IndirectValue() ) @@ -98,7 +98,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) if indirectScopeValue.Kind() == reflect.Slice { foreignValues := getValueFromFields(result, relation.ForeignFieldNames) for j := 0; j < indirectScopeValue.Len(); j++ { - if indirectValue := reflect.Indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { + if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { indirectValue.FieldByName(field.Name).Set(result) break } @@ -125,7 +125,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) // assign find results var ( - resultsValue = reflect.Indirect(reflect.ValueOf(results)) + resultsValue = indirect(reflect.ValueOf(results)) indirectScopeValue = scope.IndirectValue() ) @@ -134,7 +134,7 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) result := resultsValue.Index(i) foreignValues := getValueFromFields(result, relation.ForeignFieldNames) for j := 0; j < indirectScopeValue.Len(); j++ { - object := reflect.Indirect(indirectScopeValue.Index(j)) + object := indirect(indirectScopeValue.Index(j)) if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) { objectField := object.FieldByName(field.Name) objectField.Set(reflect.Append(objectField, result)) @@ -163,7 +163,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ // assign find results var ( - resultsValue = reflect.Indirect(reflect.ValueOf(results)) + resultsValue = indirect(reflect.ValueOf(results)) indirectScopeValue = scope.IndirectValue() ) @@ -172,7 +172,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ if indirectScopeValue.Kind() == reflect.Slice { value := getValueFromFields(result, relation.AssociationForeignFieldNames) for j := 0; j < indirectScopeValue.Len(); j++ { - object := reflect.Indirect(indirectScopeValue.Index(j)) + object := indirect(indirectScopeValue.Index(j)) if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { object.FieldByName(field.Name).Set(result) } @@ -265,7 +265,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface if indirectScopeValue.Kind() == reflect.Slice { for j := 0; j < indirectScopeValue.Len(); j++ { - object := reflect.Indirect(indirectScopeValue.Index(j)) + object := indirect(indirectScopeValue.Index(j)) fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name) } } else if indirectScopeValue.IsValid() { diff --git a/preload_test.go b/preload_test.go index 7f4e7d50..2fc441a9 100644 --- a/preload_test.go +++ b/preload_test.go @@ -611,6 +611,70 @@ func TestNestedPreload9(t *testing.T) { } } +type Level1A struct { + ID uint + Value string +} + +type Level1B struct { + ID uint + Value string + Level2s []*Level2 +} + +type Level2 struct { + ID uint + Value string + Level1AID sql.NullInt64 + Level1A *Level1A + Level1BID sql.NullInt64 + Level1B *Level1B +} + +func TestNestedPreload10(t *testing.T) { + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1B{}) + DB.DropTableIfExists(&Level1A{}) + + if err := DB.AutoMigrate(&Level1A{}, &Level1B{}, &Level2{}).Error; err != nil { + t.Error(err) + } + + level1A := &Level1A{Value: "foo"} + if err := DB.Save(&level1A).Error; err != nil { + t.Error(err) + } + + want := []*Level1B{ + &Level1B{ + Value: "bar", + Level2s: []*Level2{ + &Level2{ + Value: "qux", + Level1A: level1A, + }, + }, + }, + &Level1B{ + Value: "bar 2", + }, + } + for _, level1B := range want { + if err := DB.Save(level1B).Error; err != nil { + t.Error(err) + } + } + + var got []*Level1B + if err := DB.Preload("Level2s.Level1A").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { return diff --git a/scope.go b/scope.go index 09da8e11..1608a99b 100644 --- a/scope.go +++ b/scope.go @@ -16,7 +16,6 @@ type Scope struct { Sql string SqlVars []interface{} db *DB - indirectValue *reflect.Value instanceID string primaryKeyField *Field skipLeft bool @@ -25,14 +24,7 @@ type Scope struct { } func (scope *Scope) IndirectValue() reflect.Value { - if scope.indirectValue == nil { - value := reflect.Indirect(reflect.ValueOf(scope.Value)) - if value.Kind() == reflect.Ptr { - value = value.Elem() - } - scope.indirectValue = &value - } - return *scope.indirectValue + return indirect(reflect.ValueOf(scope.Value)) } // New create a new Scope without search information diff --git a/scope_utils.go b/scope_utils.go index ffaa99b4..2d914314 100644 --- a/scope_utils.go +++ b/scope_utils.go @@ -13,7 +13,7 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r case reflect.Slice: for i := 0; i < indirectValue.Len(); i++ { var result []interface{} - var object = reflect.Indirect(indirectValue.Index(i)) + var object = indirect(indirectValue.Index(i)) for _, column := range columns { result = append(result, object.FieldByName(column).Interface()) } @@ -44,7 +44,7 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() for i := 0; i < indirectScopeValue.Len(); i++ { - result := reflect.Indirect(reflect.Indirect(indirectScopeValue.Index(i)).FieldByName(column)) + result := indirect(indirect(indirectScopeValue.Index(i)).FieldByName(column)) if result.Kind() == reflect.Slice { for j := 0; j < result.Len(); j++ { diff --git a/utils.go b/utils.go index 43d0031c..9d2bb075 100644 --- a/utils.go +++ b/utils.go @@ -3,6 +3,7 @@ package gorm import ( "bytes" "fmt" + "reflect" "strings" "sync" ) @@ -102,6 +103,13 @@ func Expr(expression string, args ...interface{}) *expr { return &expr{expr: expression, args: args} } +func indirect(reflectValue reflect.Value) reflect.Value { + for reflectValue.Kind() == reflect.Ptr { + reflectValue = reflectValue.Elem() + } + return reflectValue +} + func toQueryMarks(primaryValues [][]interface{}) string { var results []string