From 308d22b166eb3b71d2a3374bfc565be29ed88eda Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 13:48:37 +0800 Subject: [PATCH] Clean up associations before Preload, close #3345 --- callbacks/preload.go | 10 ++++++++++ tests/helper_test.go | 10 +++++----- tests/preload_test.go | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 25b8cb2b..9b8f762a 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -107,6 +107,16 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { fieldValues := make([]interface{}, len(relForeignFields)) + // clean up old values before preloading + switch reflectValue.Kind() { + case reflect.Struct: + rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + } + } + for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { diff --git a/tests/helper_test.go b/tests/helper_test.go index cc0d808c..eee34e99 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -115,7 +115,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Pets", func(t *testing.T) { if len(user.Pets) != len(expect.Pets) { - t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) + t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) } sort.Slice(user.Pets, func(i, j int) bool { @@ -137,7 +137,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Toys", func(t *testing.T) { if len(user.Toys) != len(expect.Toys) { - t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) + t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) } sort.Slice(user.Toys, func(i, j int) bool { @@ -177,7 +177,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Team", func(t *testing.T) { if len(user.Team) != len(expect.Team) { - t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) + t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) } sort.Slice(user.Team, func(i, j int) bool { @@ -195,7 +195,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Languages", func(t *testing.T) { if len(user.Languages) != len(expect.Languages) { - t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) + t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) } sort.Slice(user.Languages, func(i, j int) bool { @@ -212,7 +212,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Friends", func(t *testing.T) { if len(user.Friends) != len(expect.Friends) { - t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) + t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) } sort.Slice(user.Friends, func(i, j int) bool { diff --git a/tests/preload_test.go b/tests/preload_test.go index 7e5d2622..76b72f14 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -31,6 +31,20 @@ func TestPreloadWithAssociations(t *testing.T) { var user2 User DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + var user3 = *GetUser("preload_with_associations_new", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) } func TestNestedPreload(t *testing.T) {