diff --git a/preload.go b/preload.go index 69efc01b..2c981a79 100644 --- a/preload.go +++ b/preload.go @@ -295,11 +295,12 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } } } else { - object := scope.IndirectValue() - source := getRealValue(object, associationForeignStructFieldNames) - field := object.FieldByName(field.Name) - for _, link := range linkHash[toString(source)] { - field.Set(reflect.Append(field, link)) + if object := scope.IndirectValue(); object.IsValid() { + source := getRealValue(object, associationForeignStructFieldNames) + field := object.FieldByName(field.Name) + for _, link := range linkHash[toString(source)] { + field.Set(reflect.Append(field, link)) + } } } } diff --git a/preload_test.go b/preload_test.go index 36bfaae5..b5188628 100644 --- a/preload_test.go +++ b/preload_test.go @@ -771,6 +771,9 @@ func TestManyToManyPreloadForPointer(t *testing.T) { panic(err) } + var got5 Level2 + DB.Preload("Level1s").First(&got5, "value = ?", "bogus") + var ruLevel1 Level1 var zhLevel1 Level1 DB.First(&ruLevel1, "value = ?", "ru") @@ -783,6 +786,109 @@ func TestManyToManyPreloadForPointer(t *testing.T) { } } +func TestManyToManyPreloadForNestedPointer(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:levels;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) + + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + DB.Table("levels").DropTableIfExists("levels") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Value: "Bob", + Level2: &Level2{ + Value: "Foo", + Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }, + }, + } + if err := DB.Save(&want).Error; err != nil { + panic(err) + } + + want2 := Level3{ + Value: "Tom", + Level2: &Level2{ + Value: "Bar", + Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }, + }, + } + if err := DB.Save(&want2).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + var got2 Level3 + if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got3, []Level3{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2})) + } + + var got4 []Level3 + if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + panic(err) + } + + var got5 Level3 + DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus") + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level2.Level1s = []*Level1{&ruLevel1} + got2.Level2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level3{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2})) + } +} + func TestNestedManyToManyPreload(t *testing.T) { type ( Level1 struct {