diff --git a/preload.go b/preload.go index d12995f3..15998c40 100644 --- a/preload.go +++ b/preload.go @@ -197,18 +197,20 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { - relation := field.Relationship - joinTableHandler := relation.JoinTableHandler - destType := field.StructField.Struct.Type.Elem() - var isPtr bool + var ( + relation = field.Relationship + joinTableHandler = relation.JoinTableHandler + destType = field.StructField.Struct.Type.Elem() + linkHash = make(map[string][]reflect.Value) + sourceKeys = []string{} + isPtr bool + ) + if destType.Kind() == reflect.Ptr { isPtr = true destType = destType.Elem() } - var sourceKeys []string - var linkHash = make(map[string][]reflect.Value) - for _, key := range joinTableHandler.SourceForeignKeys() { sourceKeys = append(sourceKeys, key.DBName) } @@ -217,9 +219,11 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value) + // preload inline conditions if len(conditions) > 0 { preloadJoinDB = preloadJoinDB.Where(conditions[0], conditions[1:]...) } + rows, err := preloadJoinDB.Rows() if scope.Err(err) != nil { diff --git a/preload_test.go b/preload_test.go index 29ea39a7..9e0716bd 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1011,6 +1011,79 @@ func TestNestedManyToManyPreload2(t *testing.T) { } } +func TestNestedManyToManyPreload3(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 *Level2 + } + ) + + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + level1Zh := &Level1{Value: "zh"} + level1Ru := &Level1{Value: "ru"} + level1En := &Level1{Value: "en"} + + level21 := &Level2{ + Value: "Level2-1", + Level1s: []*Level1{level1Zh, level1Ru}, + } + + level22 := &Level2{ + Value: "Level2-2", + Level1s: []*Level1{level1Zh, level1En}, + } + + wants := []*Level3{ + { + Value: "Level3-1", + Level2: level21, + }, + { + Value: "Level3-2", + Level2: level22, + }, + { + Value: "Level3-3", + Level2: level21, + }, + } + + for _, want := range wants { + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + } + + var gots []*Level3 + if err := DB.Preload("Level2.Level1s").Find(&gots).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(gots, wants) { + t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) + } +} + func TestNilPointerSlice(t *testing.T) { type ( Level3 struct {