forked from mirror/gorm
feat: Convert SQL nulls to zero values (ConvertNullToZeroValues)
Makes it the default behavior to convert SQL null values to zero values for model fields which are not pointers.
This commit is contained in:
parent
696092e287
commit
19cf645dbd
|
@ -159,6 +159,7 @@ func CreateWithReturning(db *gorm.DB) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resetFields := map[int]*schema.Field{}
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
fieldValue := field.ReflectValueOf(reflectValue)
|
fieldValue := field.ReflectValueOf(reflectValue)
|
||||||
|
|
||||||
|
@ -172,22 +173,47 @@ func CreateWithReturning(db *gorm.DB) {
|
||||||
goto BEGIN
|
goto BEGIN
|
||||||
}
|
}
|
||||||
|
|
||||||
values[idx] = fieldValue.Addr().Interface()
|
if field.FieldType.Kind() == reflect.Ptr {
|
||||||
|
values[idx] = fieldValue.Addr().Interface()
|
||||||
|
} else {
|
||||||
|
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
|
||||||
|
reflectValue.Elem().Set(fieldValue.Addr())
|
||||||
|
values[idx] = reflectValue.Interface()
|
||||||
|
resetFields[idx] = field
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
if err := rows.Scan(values...); err != nil {
|
if err := rows.Scan(values...); err != nil {
|
||||||
db.AddError(err)
|
db.AddError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for idx, field := range resetFields {
|
||||||
|
if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() {
|
||||||
|
field.ReflectValueOf(reflectValue).Set(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
|
resetFields := map[int]*schema.Field{}
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
if field.FieldType.Kind() == reflect.Ptr {
|
||||||
|
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||||
|
} else {
|
||||||
|
reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
|
||||||
|
reflectValue.Elem().Set(field.ReflectValueOf(db.Statement.ReflectValue).Addr())
|
||||||
|
values[idx] = reflectValue.Interface()
|
||||||
|
resetFields[idx] = field
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.Next() {
|
if rows.Next() {
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
for idx, field := range resetFields {
|
||||||
|
if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() {
|
||||||
|
field.ReflectValueOf(db.Statement.ReflectValue).Set(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
package tests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/callbacks"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReturningWithNullToZeroValues(t *testing.T) {
|
||||||
|
dialect := DB.Dialector.Name()
|
||||||
|
switch dialect {
|
||||||
|
case "mysql", "sqlserver":
|
||||||
|
// these dialects do not support the "returning" clause
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// This user struct will leverage the existing users table, but override
|
||||||
|
// the Name field to default to null.
|
||||||
|
type user struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string `gorm:"default:null"`
|
||||||
|
}
|
||||||
|
u1 := user{}
|
||||||
|
c := DB.Callback().Create().Get("gorm:create")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
DB.Callback().Create().Replace("gorm:create", c)
|
||||||
|
})
|
||||||
|
DB.Callback().Create().Replace("gorm:create", callbacks.Create(&callbacks.Config{WithReturning: true}))
|
||||||
|
|
||||||
|
if results := DB.Create(&u1); results.Error != nil {
|
||||||
|
t.Fatalf("errors happened on create: %v", results.Error)
|
||||||
|
} else if results.RowsAffected != 1 {
|
||||||
|
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
|
||||||
|
} else if u1.ID == 0 {
|
||||||
|
t.Fatalf("ID expects : not equal 0, got %v", u1.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := user{}
|
||||||
|
results := DB.First(&got, "id = ?", u1.ID)
|
||||||
|
if results.Error != nil {
|
||||||
|
t.Fatalf("errors happened on first: %v", results.Error)
|
||||||
|
} else if results.RowsAffected != 1 {
|
||||||
|
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
|
||||||
|
} else if got.ID != u1.ID {
|
||||||
|
t.Fatalf("first expects: %v, got %v", u1, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
results = DB.Select("id, name").Find(&got)
|
||||||
|
if results.Error != nil {
|
||||||
|
t.Fatalf("errors happened on first: %v", results.Error)
|
||||||
|
} else if results.RowsAffected != 1 {
|
||||||
|
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
|
||||||
|
} else if got.ID != u1.ID {
|
||||||
|
t.Fatalf("select expects: %v, got %v", u1, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
u1.Name = "jinzhu"
|
||||||
|
if results := DB.Save(&u1); results.Error != nil {
|
||||||
|
t.Fatalf("errors happened on update: %v", results.Error)
|
||||||
|
} else if results.RowsAffected != 1 {
|
||||||
|
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
u1 = user{} // important to reinitialize this before creating it again
|
||||||
|
u2 := user{}
|
||||||
|
db := DB.Session(&gorm.Session{CreateBatchSize: 10})
|
||||||
|
|
||||||
|
if results := db.Create([]*user{&u1, &u2}); results.Error != nil {
|
||||||
|
t.Fatalf("errors happened on create: %v", results.Error)
|
||||||
|
} else if results.RowsAffected != 2 {
|
||||||
|
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
|
||||||
|
} else if u1.ID == 0 {
|
||||||
|
t.Fatalf("ID expects : not equal 0, got %v", u1.ID)
|
||||||
|
} else if u2.ID == 0 {
|
||||||
|
t.Fatalf("ID expects : not equal 0, got %v", u2.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotUsers []user
|
||||||
|
results = DB.Where("id in (?, ?)", u1.ID, u2.ID).Order("id asc").Select("id, name").Find(&gotUsers)
|
||||||
|
if results.Error != nil {
|
||||||
|
t.Fatalf("errors happened on first: %v", results.Error)
|
||||||
|
} else if results.RowsAffected != 2 {
|
||||||
|
t.Fatalf("rows affected expects: %v, got %v", 2, results.RowsAffected)
|
||||||
|
} else if gotUsers[0].ID != u1.ID {
|
||||||
|
t.Fatalf("select expects: %v, got %v", u1.ID, gotUsers[0].ID)
|
||||||
|
} else if gotUsers[1].ID != u2.ID {
|
||||||
|
t.Fatalf("select expects: %v, got %v", u2.ID, gotUsers[1].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
u1.Name = "Jinzhu"
|
||||||
|
u2.Name = "Zhang"
|
||||||
|
if results := DB.Save([]*user{&u1, &u2}); results.Error != nil {
|
||||||
|
t.Fatalf("errors happened on update: %v", results.Error)
|
||||||
|
} else if results.RowsAffected != 2 {
|
||||||
|
t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue