Refactor Convert SQL null values to zero values for model fields which are not pointers

This commit is contained in:
Jinzhu 2021-10-13 21:01:32 +08:00
parent 19cf645dbd
commit b27095e8a1
1 changed files with 16 additions and 13 deletions
callbacks

View File

@ -149,8 +149,11 @@ func CreateWithReturning(db *gorm.DB) {
switch db.Statement.ReflectValue.Kind() { switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
c := db.Statement.Clauses["ON CONFLICT"] var (
onConflict, _ := c.Expression.(clause.OnConflict) c = db.Statement.Clauses["ON CONFLICT"]
onConflict, _ = c.Expression.(clause.OnConflict)
resetFieldValues = map[int]reflect.Value{}
)
for rows.Next() { for rows.Next() {
BEGIN: BEGIN:
@ -159,7 +162,6 @@ func CreateWithReturning(db *gorm.DB) {
break break
} }
resetFields := map[int]*schema.Field{}
for idx, field := range fields { for idx, field := range fields {
fieldValue := field.ReflectValueOf(reflectValue) fieldValue := field.ReflectValueOf(reflectValue)
@ -179,7 +181,7 @@ func CreateWithReturning(db *gorm.DB) {
reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
reflectValue.Elem().Set(fieldValue.Addr()) reflectValue.Elem().Set(fieldValue.Addr())
values[idx] = reflectValue.Interface() values[idx] = reflectValue.Interface()
resetFields[idx] = field resetFieldValues[idx] = fieldValue
} }
} }
@ -188,30 +190,31 @@ func CreateWithReturning(db *gorm.DB) {
db.AddError(err) db.AddError(err)
} }
for idx, field := range resetFields { for idx, fv := range resetFieldValues {
if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() { if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() {
field.ReflectValueOf(reflectValue).Set(v) fv.Set(v.Elem())
} }
} }
} }
case reflect.Struct: case reflect.Struct:
resetFields := map[int]*schema.Field{} resetFieldValues := map[int]reflect.Value{}
for idx, field := range fields { for idx, field := range fields {
if field.FieldType.Kind() == reflect.Ptr { if field.FieldType.Kind() == reflect.Ptr {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
} else { } else {
reflectValue := reflect.New(reflect.PtrTo(field.FieldType)) reflectValue := reflect.New(reflect.PtrTo(field.FieldType))
reflectValue.Elem().Set(field.ReflectValueOf(db.Statement.ReflectValue).Addr()) fieldValue := field.ReflectValueOf(db.Statement.ReflectValue)
reflectValue.Elem().Set(fieldValue.Addr())
values[idx] = reflectValue.Interface() values[idx] = reflectValue.Interface()
resetFields[idx] = field resetFieldValues[idx] = fieldValue
} }
} }
if rows.Next() { if rows.Next() {
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
for idx, field := range resetFields { for idx, fv := range resetFieldValues {
if v := reflect.ValueOf(values[idx]).Elem().Elem(); v.IsValid() { if v := reflect.ValueOf(values[idx]).Elem(); !v.IsNil() {
field.ReflectValueOf(db.Statement.ReflectValue).Set(v) fv.Set(v.Elem())
} }
} }
} }