From 13ac0839c45fe0f213b223973eed4be2f030987f Mon Sep 17 00:00:00 2001 From: Igor Noskov Date: Tue, 10 May 2016 12:43:50 +0600 Subject: [PATCH] fix m2m nested preload --- callback_query_preload.go | 18 +++++++--- preload_test.go | 75 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index a4b1cf1f..305ac4f3 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -288,7 +288,7 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface // assign find results var ( indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string]reflect.Value{} + fieldsSourceMap = map[string][]reflect.Value{} foreignFieldNames = []string{} ) @@ -301,13 +301,21 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface if indirectScopeValue.Kind() == reflect.Slice { for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) - fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name) + key := toString(getValueFromFields(object, foreignFieldNames)) + fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) } } else if indirectScopeValue.IsValid() { - fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name) + key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) + fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) } - for source, link := range linkHash { - fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...)) + for i, field := range fieldsSourceMap[source] { + //If not 0 this means Value is a pointer and we already added preloaded models to it + if fieldsSourceMap[source][i].Len() != 0 { + continue + } + field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) + } + } } diff --git a/preload_test.go b/preload_test.go index 5c49ecc2..144e3820 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1118,6 +1118,81 @@ func TestNestedManyToManyPreload3(t *testing.T) { } } +func TestNestedManyToManyPreload3ForStruct(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + } + Level2 struct { + ID uint + Value string + Level1s []Level1 `gorm:"many2many:level1_level2;"` + } + Level3 struct { + ID uint + Value string + Level2ID sql.NullInt64 + Level2 Level2 + } + ) + + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists("level1_level2") + + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + t.Error(err) + } + + level1Zh := Level1{Value: "zh"} + level1Ru := Level1{Value: "ru"} + level1En := Level1{Value: "en"} + + level21 := Level2{ + Value: "Level2-1", + Level1s: []Level1{level1Zh, level1Ru}, + } + + level22 := Level2{ + Value: "Level2-2", + Level1s: []Level1{level1Zh, level1En}, + } + + wants := []*Level3{ + { + Value: "Level3-1", + Level2: level21, + }, + { + Value: "Level3-2", + Level2: level22, + }, + { + Value: "Level3-3", + Level2: level21, + }, + } + + for _, want := range wants { + if err := DB.Save(&want).Error; err != nil { + t.Error(err) + } + } + + var gots []*Level3 + if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { + return db.Order("level1.id ASC") + }).Find(&gots).Error; err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(gots, wants) { + t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) + } +} + func TestNestedManyToManyPreload4(t *testing.T) { type ( Level4 struct {