diff --git a/preload_test.go b/preload_test.go index aec6d816..db37f37d 100644 --- a/preload_test.go +++ b/preload_test.go @@ -624,21 +624,47 @@ func TestManyToManyPreload(t *testing.T) { } want := Level2{Value: "Bob", Level1s: []Level1{ - Level1{Value: "ru"}, - Level1{Value: "en"}, + {Value: "ru"}, + {Value: "en"}, }} if err := DB.Save(&want).Error; err != nil { panic(err) } + want2 := Level2{Value: "Tom", Level1s: []Level1{ + {Value: "zh"}, + {Value: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + panic(err) + } + var got Level2 - if err := DB.Preload("Level1s").Find(&got).Error; err != nil { + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { panic(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } + + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } } func toJSONString(v interface{}) []byte {