From d8ddccf1478bf1aaf3726f2301c08fe6a9ca4183 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 4 Sep 2020 19:02:37 +0800 Subject: [PATCH] Don't marshal to null for associations after preloading, close #3395 --- callbacks/preload.go | 14 ++++++++++++-- tests/preload_test.go | 24 ++++++++++++++++++++++++ tests/scan_test.go | 8 ++++++-- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 9b8f762a..aec10ec5 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -110,10 +110,20 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { // clean up old values before preloading switch reflectValue.Kind() { case reflect.Struct: - rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + default: + rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + default: + rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + } } } diff --git a/tests/preload_test.go b/tests/preload_test.go index 76b72f14..d9035661 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -1,6 +1,8 @@ package tests_test import ( + "encoding/json" + "regexp" "sort" "strconv" "testing" @@ -188,3 +190,25 @@ func TestNestedPreloadWithConds(t *testing.T) { CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2]) } } + +func TestPreloadEmptyData(t *testing.T) { + var user = *GetUser("user_without_associations", Config{}) + DB.Create(&user) + + DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name) + + if r, err := json.Marshal(&user); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } + + var results []User + DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name) + + if r, err := json.Marshal(&results); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } +} diff --git a/tests/scan_test.go b/tests/scan_test.go index 3e66a25a..92e89521 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -51,11 +51,11 @@ func TestScan(t *testing.T) { DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) sort.Slice(results, func(i, j int) bool { - return strings.Compare(results[i].Name, results[j].Name) < -1 + return strings.Compare(results[i].Name, results[j].Name) <= -1 }) if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { - t.Errorf("Scan into struct map") + t.Errorf("Scan into struct map, got %#v", results) } } @@ -84,6 +84,10 @@ func TestScanRows(t *testing.T) { results = append(results, result) } + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) <= -1 + }) + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { t.Errorf("Should find expected results") }