Fix can't scan null value into normal data types

This commit is contained in:
Jinzhu 2020-06-03 08:44:13 +08:00
parent 94685d1024
commit b32658358c
3 changed files with 73 additions and 62 deletions

74
scan.go
View File

@ -87,6 +87,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
joinFields[idx] = [2]*schema.Field{rel.Field, field} joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue continue
} }
@ -98,52 +99,41 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
for initialized || rows.Next() { for initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
elem := reflect.New(reflectValueType).Elem() elem := reflect.New(reflectValueType).Elem()
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
// pluck // pluck
values[0] = elem.Addr().Interface() values[0] = elem.Addr().Interface()
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
} else { } else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
}
}
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
for idx, field := range fields { for idx, field := range fields {
if v, ok := values[idx].(*interface{}); ok { if joinFields[idx][0] != nil {
if field != nil { value := reflect.ValueOf(values[idx]).Elem()
if v == nil {
field.Set(elem, v)
} else {
field.Set(elem, *v)
}
} else if joinFields[idx][0] != nil {
relValue := joinFields[idx][0].ReflectValueOf(elem) relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if v == nil { if value.IsNil() {
continue continue
} }
relValue.Set(reflect.New(relValue.Type().Elem())) relValue.Set(reflect.New(relValue.Type().Elem()))
} }
if v == nil { field.Set(relValue, values[idx])
joinFields[idx][1].Set(relValue, nil) } else if field != nil {
} else { field.Set(elem, values[idx])
joinFields[idx][1].Set(relValue, *v)
} }
} }
} }
}
for idx := range columns {
values[idx] = new(interface{})
}
}
if isPtr { if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
@ -153,8 +143,20 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
case reflect.Struct: case reflect.Struct:
if initialized || rows.Next() { if initialized || rows.Next() {
for idx := range columns { for idx, column := range columns {
values[idx] = new(interface{}) if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
} }
db.RowsAffected++ db.RowsAffected++
@ -162,31 +164,21 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
for idx, column := range columns { for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
if v, ok := values[idx].(*interface{}); ok { field.Set(db.Statement.ReflectValue, values[idx])
if v == nil {
field.Set(db.Statement.ReflectValue, v)
} else {
field.Set(db.Statement.ReflectValue, *v)
}
}
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
if v, ok := values[idx].(*interface{}); ok { value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if v == nil { if value.IsNil() {
continue continue
} }
relValue.Set(reflect.New(relValue.Type().Elem())) relValue.Set(reflect.New(relValue.Type().Elem()))
} }
if v == nil { field.Set(relValue, values[idx])
field.Set(relValue, nil)
} else {
field.Set(relValue, *v)
}
}
} }
} }
} }

View File

@ -247,7 +247,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) { if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" { if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond field.AutoCreateTime = UnixNanosecond
} else { } else {
@ -255,7 +255,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
} }
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) { if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" { if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond field.AutoUpdateTime = UnixNanosecond
} else { } else {
@ -407,6 +407,7 @@ func (field *Field) setupValuerAndSetter() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().AssignableTo(field.FieldType) { if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV) field.ReflectValueOf(value).Set(reflectV)
return return
@ -437,7 +438,11 @@ func (field *Field) setupValuerAndSetter() {
setter(value, v) setter(value, v)
} }
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
setter(value, reflectV.Elem().Interface()) setter(value, reflectV.Elem().Interface())
}
} else { } else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} }
@ -680,8 +685,14 @@ func (field *Field) setupValuerAndSetter() {
} }
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
}
} else { } else {
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
} }
@ -691,14 +702,22 @@ func (field *Field) setupValuerAndSetter() {
// pointer scanner // pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
if valuer, ok := v.(driver.Valuer); ok { if valuer, ok := v.(driver.Valuer); ok {
if valuer == nil {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
v, _ = valuer.Value() v, _ = valuer.Value()
} }
}
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { } else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
field.Set(value, reflectV.Elem().Interface())
}
} else { } else {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() { if fieldValue.IsNil() {

View File

@ -7,8 +7,8 @@ require (
gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0
gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286
gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8 gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8
gorm.io/driver/sqlserver v0.0.0-20200602015206-ef9f739c6a30 gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2
gorm.io/gorm v1.9.12 gorm.io/gorm v0.0.0-00010101000000-000000000000
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../