diff --git a/preload.go b/preload.go index 354ed7a8..f415386d 100644 --- a/preload.go +++ b/preload.go @@ -217,9 +217,8 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface sourceKeys = append(sourceKeys, key.DBName) } - db := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") - - preloadJoinDB := joinTableHandler.JoinWith(joinTableHandler, db, scope.Value) + preloadJoinDB := scope.NewDB().Table(scope.New(reflect.New(destType).Interface()).TableName()).Select("*") + preloadJoinDB = joinTableHandler.JoinWith(joinTableHandler, preloadJoinDB, scope.Value) // preload inline conditions if len(conditions) > 0 { @@ -261,31 +260,30 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } } - var foreignFieldNames []string + // assign find results + var ( + indirectScopeValue = scope.IndirectValue() + fieldsSourceMap = map[string]reflect.Value{} + foreignFieldNames = []string{} + ) + for _, dbName := range relation.ForeignFieldNames { if field, ok := scope.FieldByName(dbName); ok { foreignFieldNames = append(foreignFieldNames, field.Name) } } - if scope.IndirectValue().Kind() == reflect.Slice { - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - source := getRealValue(object, foreignFieldNames) - field := object.FieldByName(field.Name) - for _, link := range linkHash[toString(source)] { - field.Set(reflect.Append(field, link)) - } - } - } else { - if object := scope.IndirectValue(); object.IsValid() { - source := getRealValue(object, foreignFieldNames) - field := object.FieldByName(field.Name) - for _, link := range linkHash[toString(source)] { - field.Set(reflect.Append(field, link)) - } + if indirectScopeValue.Kind() == reflect.Slice { + for j := 0; j < indirectScopeValue.Len(); j++ { + object := reflect.Indirect(indirectScopeValue.Index(j)) + fieldsSourceMap[toString(getRealValue(object, foreignFieldNames))] = object.FieldByName(field.Name) } + } else if indirectScopeValue.IsValid() { + fieldsSourceMap[toString(getRealValue(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name) + } + + for source, link := range linkHash { + fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...)) } } diff --git a/preload_test.go b/preload_test.go index d2279e03..4f65d1d8 100644 --- a/preload_test.go +++ b/preload_test.go @@ -702,90 +702,6 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { } } -func TestManyToManyPreloadForPointer(t *testing.T) { - type ( - Level1 struct { - ID uint - Value string - } - Level2 struct { - ID uint - Value string - Level1s []*Level1 `gorm:"many2many:levels;"` - } - ) - - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - DB.DropTableIfExists("levels") - - if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { - t.Error(err) - } - - want := Level2{Value: "Bob", Level1s: []*Level1{ - {Value: "ru"}, - {Value: "en"}, - }} - if err := DB.Save(&want).Error; err != nil { - t.Error(err) - } - - want2 := Level2{Value: "Tom", Level1s: []*Level1{ - {Value: "zh"}, - {Value: "de"}, - }} - if err := DB.Save(&want2).Error; err != nil { - t.Error(err) - } - - var got Level2 - if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - - var got2 Level2 - if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got2, want2) { - t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) - } - - var got3 []Level2 - if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(got3, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) - } - - var got4 []Level2 - if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { - t.Error(err) - } - - var got5 Level2 - DB.Preload("Level1s").First(&got5, "value = ?", "bogus") - - var ruLevel1 Level1 - var zhLevel1 Level1 - DB.First(&ruLevel1, "value = ?", "ru") - DB.First(&zhLevel1, "value = ?", "zh") - - got.Level1s = []*Level1{&ruLevel1} - got2.Level1s = []*Level1{&zhLevel1} - if !reflect.DeepEqual(got4, []Level2{got, got2}) { - t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) - } -} - func TestManyToManyPreloadForNestedPointer(t *testing.T) { type ( Level1 struct { @@ -1012,8 +928,6 @@ func TestNestedManyToManyPreload2(t *testing.T) { } func TestNestedManyToManyPreload3(t *testing.T) { - t.Skip("not implemented") - type ( Level1 struct { ID uint @@ -1086,6 +1000,90 @@ func TestNestedManyToManyPreload3(t *testing.T) { } } +func TestManyToManyPreloadForPointer(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []*Level1 `gorm:"many2many:levels;"` + } + ) + + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists("levels") + + if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + want := Level2{Value: "Bob", Level1s: []*Level1{ + {Value: "ru"}, + {Value: "en"}, + }} + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + + want2 := Level2{Value: "Tom", Level1s: []*Level1{ + {Value: "zh"}, + {Value: "de"}, + }} + if err := DB.Save(&want2).Error; err != nil { + t.Error(err) + } + + var got Level2 + if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + + var got2 Level2 + if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got2, want2) { + t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) + } + + var got3 []Level2 + if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(got3, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) + } + + var got4 []Level2 + if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { + t.Error(err) + } + + var got5 Level2 + DB.Preload("Level1s").First(&got5, "value = ?", "bogus") + + var ruLevel1 Level1 + var zhLevel1 Level1 + DB.First(&ruLevel1, "value = ?", "ru") + DB.First(&zhLevel1, "value = ?", "zh") + + got.Level1s = []*Level1{&ruLevel1} + got2.Level1s = []*Level1{&zhLevel1} + if !reflect.DeepEqual(got4, []Level2{got, got2}) { + t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) + } +} + func TestNilPointerSlice(t *testing.T) { type ( Level3 struct {