diff --git a/preload.go b/preload.go index f3f2df12..ebbbeb32 100644 --- a/preload.go +++ b/preload.go @@ -186,6 +186,9 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) + if object.Kind() == reflect.Ptr { + object = reflect.Indirect(objects.Index(j).Elem()) + } if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) { object.FieldByName(field.Name).Set(result) } @@ -312,7 +315,11 @@ func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) for i := 0; i < values.Len(); i++ { var result []interface{} for _, column := range columns { - result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) + value := reflect.Indirect(values.Index(i)) + if value.Kind() == reflect.Ptr { + value = reflect.Indirect(values.Index(i).Elem()) + } + result = append(result, value.FieldByName(column).Interface()) } results = append(results, result) } diff --git a/preload_test.go b/preload_test.go index 29ea39a7..010682fd 100644 --- a/preload_test.go +++ b/preload_test.go @@ -611,6 +611,70 @@ func TestNestedPreload9(t *testing.T) { } } +type Level1A struct { + ID uint + Value string +} + +type Level1B struct { + ID uint + Value string + Level2s []*Level2 +} + +type Level2 struct { + ID uint + Value string + Level1AID sql.NullInt64 + Level1A *Level1A + Level1BID sql.NullInt64 + Level1B *Level1B +} + +func TestNestedPreload10(t *testing.T) { + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1B{}) + DB.DropTableIfExists(&Level1A{}) + + if err := DB.AutoMigrate(&Level1A{}, &Level1B{}, &Level2{}).Error; err != nil { + t.Error(err) + } + + level1A := &Level1A{Value: "foo"} + if err := DB.Save(&level1A).Error; err != nil { + t.Error(err) + } + + want := []*Level1B{ + &Level1B{ + Value: "bar", + Level2s: []*Level2{ + &Level2{ + Value: "qux", + Level1A: level1A, + }, + }, + }, + &Level1B{ + Value: "bar 2", + }, + } + for _, level1B := range want { + if err := DB.Save(level1B).Error; err != nil { + t.Error(err) + } + } + + var got []*Level1B + if err := DB.Preload("Level2s.Level1A").Find(&got).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { if dialect := os.Getenv("GORM_DIALECT"); dialect == "" || dialect == "sqlite" { return