Don't marshal to null for associations after preloading, close #3395

This commit is contained in:
Jinzhu 2020-09-04 19:02:37 +08:00
parent f121622228
commit d8ddccf147
3 changed files with 42 additions and 4 deletions

View File

@ -110,12 +110,22 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
// clean up old values before preloading // clean up old values before preloading
switch reflectValue.Kind() { switch reflectValue.Kind() {
case reflect.Struct: case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface())
default:
rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
}
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ { for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface())
default:
rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
} }
} }
}
for i := 0; i < reflectResults.Len(); i++ { for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i) elem := reflectResults.Index(i)

View File

@ -1,6 +1,8 @@
package tests_test package tests_test
import ( import (
"encoding/json"
"regexp"
"sort" "sort"
"strconv" "strconv"
"testing" "testing"
@ -188,3 +190,25 @@ func TestNestedPreloadWithConds(t *testing.T) {
CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2]) CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2])
} }
} }
func TestPreloadEmptyData(t *testing.T) {
var user = *GetUser("user_without_associations", Config{})
DB.Create(&user)
DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name)
if r, err := json.Marshal(&user); err != nil {
t.Errorf("failed to marshal users, got error %v", err)
} else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) {
t.Errorf("json marshal is not empty slice, got %v", string(r))
}
var results []User
DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name)
if r, err := json.Marshal(&results); err != nil {
t.Errorf("failed to marshal users, got error %v", err)
} else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) {
t.Errorf("json marshal is not empty slice, got %v", string(r))
}
}

View File

@ -51,11 +51,11 @@ func TestScan(t *testing.T) {
DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results)
sort.Slice(results, func(i, j int) bool { sort.Slice(results, func(i, j int) bool {
return strings.Compare(results[i].Name, results[j].Name) < -1 return strings.Compare(results[i].Name, results[j].Name) <= -1
}) })
if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name {
t.Errorf("Scan into struct map") t.Errorf("Scan into struct map, got %#v", results)
} }
} }
@ -84,6 +84,10 @@ func TestScanRows(t *testing.T) {
results = append(results, result) results = append(results, result)
} }
sort.Slice(results, func(i, j int) bool {
return strings.Compare(results[i].Name, results[j].Name) <= -1
})
if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) {
t.Errorf("Should find expected results") t.Errorf("Should find expected results")
} }