diff --git a/preload.go b/preload.go index dd85c327..4621dd91 100644 --- a/preload.go +++ b/preload.go @@ -337,15 +337,24 @@ func (scope *Scope) getColumnsAsScope(column string) *Scope { } if column.Kind() == reflect.Slice { for i := 0; i < column.Len(); i++ { - columns = reflect.Append(columns, column.Index(i).Addr()) + elem := column.Index(i) + if elem.CanAddr() { + columns = reflect.Append(columns, elem.Addr()) + } } } else { - columns = reflect.Append(columns, column.Addr()) + if column.CanAddr() { + columns = reflect.Append(columns, column.Addr()) + } } } return scope.New(columns.Interface()) case reflect.Struct: - return scope.New(values.FieldByName(column).Addr().Interface()) + field := values.FieldByName(column) + if !field.CanAddr() { + return nil + } + return scope.New(field.Addr().Interface()) } return nil } diff --git a/preload_test.go b/preload_test.go index 3dcd325b..490c3134 100644 --- a/preload_test.go +++ b/preload_test.go @@ -772,6 +772,67 @@ func TestManyToManyPreloadForPointer(t *testing.T) { } } +func TestNilPointerSlice(t *testing.T) { + type ( + Level3 struct { + ID uint `gorm:"primary_key;"` + Value string + } + Level2 struct { + ID uint `gorm:"primary_key;"` + Value string + Level3ID uint + Level3 *Level3 + } + Level1 struct { + ID uint `gorm:"primary_key;"` + Value string + Level2ID uint + Level2 *Level2 + } + ) + + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level1{Value: "Bob", Level2: &Level2{ + Value: "en", + Level3: &Level3{ + Value: "native", + }, + }} + if err := DB.Save(&want).Error; err != nil { + panic(err) + } + + want2 := Level1{Value: "Tom", Level2: nil} + if err := DB.Save(&want2).Error; err != nil { + panic(err) + } + + var got []Level1 + if err := DB.Debug().Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { + panic(err) + } + + if len(got) != 2 { + t.Fatalf("got %v items, expected 2", len(got)) + } + + if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { + t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) + } + + if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { + t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) + } +} + func toJSONString(v interface{}) []byte { r, _ := json.MarshalIndent(v, "", " ") return r