From b32658358cd0bd5ee76f1229dfaa4613c0045fee Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 3 Jun 2020 08:44:13 +0800 Subject: [PATCH] Fix can't scan null value into normal data types --- scan.go | 94 ++++++++++++++++++++++--------------------------- schema/field.go | 37 ++++++++++++++----- tests/go.mod | 4 +-- 3 files changed, 73 insertions(+), 62 deletions(-) diff --git a/scan.go b/scan.go index 14a4699d..acba4e9f 100644 --- a/scan.go +++ b/scan.go @@ -87,6 +87,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + fields[idx] = field joinFields[idx] = [2]*schema.Field{rel.Field, field} continue } @@ -98,50 +99,39 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } for initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) - } - initialized = false db.RowsAffected++ elem := reflect.New(reflectValueType).Elem() - if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { // pluck values[0] = elem.Addr().Interface() db.AddError(rows.Scan(values...)) } else { - db.AddError(rows.Scan(values...)) - for idx, field := range fields { - if v, ok := values[idx].(*interface{}); ok { - if field != nil { - if v == nil { - field.Set(elem, v) - } else { - field.Set(elem, *v) - } - } else if joinFields[idx][0] != nil { - relValue := joinFields[idx][0].ReflectValueOf(elem) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if v == nil { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - if v == nil { - joinFields[idx][1].Set(relValue, nil) - } else { - joinFields[idx][1].Set(relValue, *v) - } - } + if field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() } } - for idx := range columns { - values[idx] = new(interface{}) + db.AddError(rows.Scan(values...)) + + for idx, field := range fields { + if joinFields[idx][0] != nil { + value := reflect.ValueOf(values[idx]).Elem() + relValue := joinFields[idx][0].ReflectValueOf(elem) + + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + field.Set(relValue, values[idx]) + } else if field != nil { + field.Set(elem, values[idx]) + } } } @@ -153,8 +143,20 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } case reflect.Struct: if initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + continue + } + } + values[idx] = &sql.RawBytes{} + } else { + values[idx] = &sql.RawBytes{} + } } db.RowsAffected++ @@ -162,31 +164,21 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { for idx, column := range columns { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { - if v, ok := values[idx].(*interface{}); ok { - if v == nil { - field.Set(db.Statement.ReflectValue, v) - } else { - field.Set(db.Statement.ReflectValue, *v) - } - } + field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - if v, ok := values[idx].(*interface{}); ok { - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - if v == nil { - continue - } - relValue.Set(reflect.New(relValue.Type().Elem())) - } + value := reflect.ValueOf(values[idx]).Elem() - if v == nil { - field.Set(relValue, nil) - } else { - field.Set(relValue, *v) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if value.IsNil() { + continue } + relValue.Set(reflect.New(relValue.Type().Elem())) } + + field.Set(relValue, values[idx]) } } } diff --git a/schema/field.go b/schema/field.go index 8861a00d..a27fdd87 100644 --- a/schema/field.go +++ b/schema/field.go @@ -247,7 +247,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) { + if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond } else { @@ -255,7 +255,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) { + if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoUpdateTime = UnixNanosecond } else { @@ -407,6 +407,7 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) + if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) return @@ -437,7 +438,11 @@ func (field *Field) setupValuerAndSetter() { setter(value, v) } } else if reflectV.Kind() == reflect.Ptr { - setter(value, reflectV.Elem().Interface()) + if reflectV.IsNil() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + setter(value, reflectV.Elem().Interface()) + } } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } @@ -680,8 +685,14 @@ func (field *Field) setupValuerAndSetter() { } reflectV := reflect.ValueOf(v) - if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + if !reflectV.IsValid() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } } else { err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } @@ -691,14 +702,22 @@ func (field *Field) setupValuerAndSetter() { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() + if valuer == nil { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + v, _ = valuer.Value() + } } reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + field.Set(value, reflectV.Elem().Interface()) + } } else { fieldValue := field.ReflectValueOf(value) if fieldValue.IsNil() { diff --git a/tests/go.mod b/tests/go.mod index 3954c442..3401b9b2 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,8 +7,8 @@ require ( gorm.io/driver/mysql v0.0.0-20200602015408-0407d0c21cf0 gorm.io/driver/postgres v0.0.0-20200602015520-15fcc29eb286 gorm.io/driver/sqlite v0.0.0-20200602015323-284b563f81c8 - gorm.io/driver/sqlserver v0.0.0-20200602015206-ef9f739c6a30 - gorm.io/gorm v1.9.12 + gorm.io/driver/sqlserver v0.0.0-20200602144728-79c224f6c1a2 + gorm.io/gorm v0.0.0-00010101000000-000000000000 ) replace gorm.io/gorm => ../