Fix can't scan null value into normal data types

This commit is contained in:
Jinzhu 2020-06-03 08:44:13 +08:00
parent 94685d1024
commit b32658358c
3 changed files with 73 additions and 62 deletions

74
scan.go
View File

@ -87,6 +87,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
joinFields[idx] = [2]*schema.Field{rel.Field, field}
continue
}
@ -98,52 +99,41 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
for initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
initialized = false
db.RowsAffected++
elem := reflect.New(reflectValueType).Elem()
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
// pluck
values[0] = elem.Addr().Interface()
db.AddError(rows.Scan(values...))
} else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
}
}
db.AddError(rows.Scan(values...))
for idx, field := range fields {
if v, ok := values[idx].(*interface{}); ok {
if field != nil {
if v == nil {
field.Set(elem, v)
} else {
field.Set(elem, *v)
}
} else if joinFields[idx][0] != nil {
if joinFields[idx][0] != nil {
value := reflect.ValueOf(values[idx]).Elem()
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if v == nil {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
if v == nil {
joinFields[idx][1].Set(relValue, nil)
} else {
joinFields[idx][1].Set(relValue, *v)
field.Set(relValue, values[idx])
} else if field != nil {
field.Set(elem, values[idx])
}
}
}
}
for idx := range columns {
values[idx] = new(interface{})
}
}
if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
@ -153,8 +143,20 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
case reflect.Struct:
if initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
db.RowsAffected++
@ -162,31 +164,21 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
field.Set(db.Statement.ReflectValue, v)
} else {
field.Set(db.Statement.ReflectValue, *v)
}
}
field.Set(db.Statement.ReflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
if v, ok := values[idx].(*interface{}); ok {
value := reflect.ValueOf(values[idx]).Elem()
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if v == nil {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
if v == nil {
field.Set(relValue, nil)
} else {
field.Set(relValue, *v)
}
}
field.Set(relValue, values[idx])
}
}
}

View File

@ -247,7 +247,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) {
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond
} else {
@ -255,7 +255,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
}
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) {
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond
} else {
@ -407,6 +407,7 @@ func (field *Field) setupValuerAndSetter() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
reflectV := reflect.ValueOf(v)
if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
return
@ -437,7 +438,11 @@ func (field *Field) setupValuerAndSetter() {
setter(value, v)
}
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
setter(value, reflectV.Elem().Interface())
}
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
@ -680,8 +685,14 @@ func (field *Field) setupValuerAndSetter() {
}
reflectV := reflect.ValueOf(v)
if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
}
} else {
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
@ -691,14 +702,22 @@ func (field *Field) setupValuerAndSetter() {
// pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
if valuer, ok := v.(driver.Valuer); ok {
if valuer == nil {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
v, _ = valuer.Value()
}
}
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
field.Set(value, reflectV.Elem().Interface())
}
} else {
fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() {

View File

@ -7,8 +7,8 @@ require (
gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0
gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286
gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8
gorm.io/driver/sqlserver v0.0.0-20200602015206-ef9f739c6a30
gorm.io/gorm v1.9.12
gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2
gorm.io/gorm v0.0.0-00010101000000-000000000000
)
replace gorm.io/gorm => ../