feat: Convert SQL nulls to zero values (ConvertNullToZeroValues)

Makes it the default behavior to convert SQL null values to zero
values for model fields which are not pointers.
This commit is contained in:
Jim 2021-09-12 06:42:48 -04:00
parent 696092e287
commit 19cf645dbd
2 changed files with 127 additions and 3 deletions

View File

@ -159,6 +159,7 @@ func CreateWithReturning(db *gorm.DB) {
break
}
resetFields := map[int]*schema.Field{}
for idx, field := range fields {
fieldValue := field.ReflectValueOf(reflectValue)
@ -172,22 +173,47 @@ func CreateWithReturning(db *gorm.DB) {
goto BEGIN
}
if field.FieldType.Kind() == reflect.Ptr {
values[idx] = fieldValue.Addr().Interface()
} else {
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
reflectValue.Elem().Set(fieldValue.Addr())
values[idx] = reflectValue.Interface()
resetFields[idx] = field
}
}
db.RowsAffected++
if err := rows.Scan(values...); err != nil {
db.AddError(err)
}
for idx, field := range resetFields {
if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() {
field.ReflectValueOf(reflectValue).Set(v)
}
}
}
case reflect.Struct:
resetFields := map[int]*schema.Field{}
for idx, field := range fields {
if field.FieldType.Kind() == reflect.Ptr {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
} else {
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
reflectValue.Elem().Set(field.ReflectValueOf(db.Statement.ReflectValue).Addr())
values[idx] = reflectValue.Interface()
resetFields[idx] = field
}
}
if rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
for idx, field := range resetFields {
if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() {
field.ReflectValueOf(db.Statement.ReflectValue).Set(v)
}
}
}
}
} else {

98
tests/gorm_test.go Normal file
View File

@ -0,0 +1,98 @@
package tests_test
import (
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"testing"
)
func TestReturningWithNullToZeroValues(t *testing.T) {
dialect := DB.Dialector.Name()
switch dialect {
case "mysql", "sqlserver":
// these dialects do not support the "returning" clause
return
default:
// This user struct will leverage the existing users table, but override
// the Name field to default to null.
type user struct {
gorm.Model
Name string `gorm:"default:null"`
}
u1 := user{}
c := DB.Callback().Create().Get("gorm:create")
t.Cleanup(func() {
DB.Callback().Create().Replace("gorm:create", c)
})
DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true}))
if results := DB.Create(&u1); results.Error != nil {
t.Fatalf("errors happened on create: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
} else if u1.ID == 0 {
t.Fatalf("ID expects : not equal 0, got %v", u1.ID)
}
got := user{}
results := DB.First(&got, "id = ?", u1.ID)
if results.Error != nil {
t.Fatalf("errors happened on first: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
} else if got.ID != u1.ID {
t.Fatalf("first expects: %v, got %v", u1, got)
}
results = DB.Select("id, name").Find(&got)
if results.Error != nil {
t.Fatalf("errors happened on first: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
} else if got.ID != u1.ID {
t.Fatalf("select expects: %v, got %v", u1, got)
}
u1.Name = "jinzhu"
if results := DB.Save(&u1); results.Error != nil {
t.Fatalf("errors happened on update: %v", results.Error)
} else if results.RowsAffected != 1 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}
u1 = user{} // important to reinitialize this before creating it again
u2 := user{}
db := DB.Session(&gorm.Session{CreateBatchSize: 10})
if results := db.Create([]*user{&u1, &u2}); results.Error != nil {
t.Fatalf("errors happened on create: %v", results.Error)
} else if results.RowsAffected != 2 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
} else if u1.ID == 0 {
t.Fatalf("ID expects : not equal 0, got %v", u1.ID)
} else if u2.ID == 0 {
t.Fatalf("ID expects : not equal 0, got %v", u2.ID)
}
var gotUsers []user
results = DB.Where("id in (?, ?)", u1.ID, u2.ID).Order("id asc").Select("id, name").Find(&gotUsers)
if results.Error != nil {
t.Fatalf("errors happened on first: %v", results.Error)
} else if results.RowsAffected != 2 {
t.Fatalf("rows affected expects: %v, got %v", 2, results.RowsAffected)
} else if gotUsers[0].ID != u1.ID {
t.Fatalf("select expects: %v, got %v", u1.ID, gotUsers[0].ID)
} else if gotUsers[1].ID != u2.ID {
t.Fatalf("select expects: %v, got %v", u2.ID, gotUsers[1].ID)
}
u1.Name = "Jinzhu"
u2.Name = "Zhang"
if results := DB.Save([]*user{&u1, &u2}); results.Error != nil {
t.Fatalf("errors happened on update: %v", results.Error)
} else if results.RowsAffected != 2 {
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
}
}
}