diff --git a/callback_query_preload.go b/callback_query_preload.go index 13a109dd..16498ba7 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -308,7 +308,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface // assign find results var ( indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string]reflect.Value{} + fieldsSourceMap = map[string][]reflect.Value{} foreignFieldNames = []string{} ) @@ -321,13 +321,21 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface if indirectScopeValue.Kind() == reflect.Slice { for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) - fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name) + key := toString(getValueFromFields(object, foreignFieldNames)) + fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) } } else if indirectScopeValue.IsValid() { - fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name) + key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) + fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) } - for source, link := range linkHash { - fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...)) + for i, field := range fieldsSourceMap[source] { + //If not 0 this means Value is a pointer and we already added preloaded models to it + if fieldsSourceMap[source][i].Len() != 0 { + continue + } + field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) + } + } } diff --git a/preload_test.go b/preload_test.go index e2fb35ff..144d08ad 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1118,6 +1118,81 @@ func TestNestedManyToManyPreload3(t *testing.T) { } } +func TestNestedManyToManyPreload3ForStruct(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 Level2 + } + ) + + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + level1Zh := Level1{Value: "zh"} + level1Ru := Level1{Value: "ru"} + level1En := Level1{Value: "en"} + + level21 := Level2{ + Value: "Level2-1", + Level1s: []Level1{level1Zh, level1Ru}, + } + + level22 := Level2{ + Value: "Level2-2", + Level1s: []Level1{level1Zh, level1En}, + } + + wants := []*Level3{ + { + Value: "Level3-1", + Level2: level21, + }, + { + Value: "Level3-2", + Level2: level22, + }, + { + Value: "Level3-3", + Level2: level21, + }, + } + + for _, want := range wants { + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + } + + var gots []*Level3 + if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { + return db.Order("level1.id ASC") + }).Find(&gots).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(gots, wants) { + t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) + } +} + func TestNestedManyToManyPreload4(t *testing.T) { type ( Level4 struct {