From ca46038cb43072306bca032b73ba22d873fe1afc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 10 Jul 2016 21:34:37 +0800 Subject: [PATCH] Fix preload duplicates has many related objects --- callback_query_preload.go | 7 ------- preload_test.go | 6 +++++- scope.go | 7 +++++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/callback_query_preload.go b/callback_query_preload.go index c9bfa866..d9ec8bdd 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -186,13 +186,6 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) - if j > 0 { - prevObject := indirect(indirectScopeValue.Index(j - 1)) - prevObjectRealValue := getValueFromFields(prevObject, relation.AssociationForeignFieldNames) - if toString(prevObjectRealValue) == toString(objectRealValue) { - continue - } - } if results, ok := preloadMap[toString(objectRealValue)]; ok { f := object.FieldByName(field.Name) f.Set(reflect.Append(f, results...)) diff --git a/preload_test.go b/preload_test.go index fd5b3af6..8c56a8ac 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1513,19 +1513,23 @@ func TestPrefixedPreloadDuplication(t *testing.T) { type ( Level4 struct { ID uint + Name string Level3ID uint } Level3 struct { ID uint + Name string Level4s []*Level4 } Level2 struct { ID uint + Name string Level3ID sql.NullInt64 `sql:"index"` Level3 *Level3 } Level1 struct { ID uint + Name string Level2ID sql.NullInt64 `sql:"index"` Level2 *Level2 } @@ -1540,7 +1544,7 @@ func TestPrefixedPreloadDuplication(t *testing.T) { t.Error(err) } - lvl := new(Level3) + lvl := &Level3{} if err := DB.Save(lvl).Error; err != nil { t.Error(err) } diff --git a/scope.go b/scope.go index 0ecf43df..974ff035 100644 --- a/scope.go +++ b/scope.go @@ -1237,6 +1237,7 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { fieldType = fieldType.Elem() } + resultsMap := map[interface{}]bool{} results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() for i := 0; i < indirectScopeValue.Len(); i++ { @@ -1244,11 +1245,13 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { if result.Kind() == reflect.Slice { for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() { + if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { + resultsMap[elem.Addr()] = true results = reflect.Append(results, elem.Addr()) } } - } else if result.CanAddr() { + } else if result.CanAddr() && resultsMap[result.Addr()] != true { + resultsMap[result.Addr()] = true results = reflect.Append(results, result.Addr()) } }