From 94685d102430d8549aa60180dff83e3970e2fb91 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 2 Jun 2020 22:13:53 +0800 Subject: [PATCH] Fix can't scan null value into normal data types --- finisher_api.go | 2 +- scan.go | 158 ++++++++++++++++++++++++------------ schema/field.go | 121 ++++++++++++++------------- statement.go | 12 ++- tests/main_test.go | 5 -- tests/preload_suits_test.go | 1 - tests/query_test.go | 32 ++++++++ tests/tests_all.sh | 4 +- tests/tests_test.go | 2 + tests/update_test.go | 6 +- tests/upsert_test.go | 5 ++ 11 files changed, 226 insertions(+), 122 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 5023150c..b97f2301 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -168,7 +168,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx.Create(dest) } else if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) + exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) assigns := map[string]interface{}{} for _, expr := range exprs { if eq, ok := expr.(clause.Eq); ok { diff --git a/scan.go b/scan.go index fc6b211b..14a4699d 100644 --- a/scan.go +++ b/scan.go @@ -14,40 +14,53 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: - for idx, _ := range columns { - values[idx] = new(interface{}) - } - if initialized || rows.Next() { + for idx := range columns { + values[idx] = new(interface{}) + } + db.RowsAffected++ db.AddError(rows.Scan(values...)) - } - mapValue, ok := dest.(map[string]interface{}) - if ok { - if v, ok := dest.(*map[string]interface{}); ok { - mapValue = *v + mapValue, ok := dest.(map[string]interface{}) + if !ok { + if v, ok := dest.(*map[string]interface{}); ok { + mapValue = *v + } + } + + for idx, column := range columns { + if v, ok := values[idx].(*interface{}); ok { + if v == nil { + mapValue[column] = nil + } else { + mapValue[column] = *v + } + } } } - - for idx, column := range columns { - mapValue[column] = *(values[idx].(*interface{})) - } case *[]map[string]interface{}: - for idx, _ := range columns { - values[idx] = new(interface{}) - } - for initialized || rows.Next() { + for idx := range columns { + values[idx] = new(interface{}) + } + initialized = false db.RowsAffected++ db.AddError(rows.Scan(values...)) - v := map[string]interface{}{} + mapValue := map[string]interface{}{} for idx, column := range columns { - v[column] = *(values[idx].(*interface{})) + if v, ok := values[idx].(*interface{}); ok { + if v == nil { + mapValue[column] = nil + } else { + mapValue[column] = *v + } + } } - *dest = append(*dest, v) + + *dest = append(*dest, mapValue) } case *int, *int64, *uint, *uint64: for initialized || rows.Next() { @@ -85,28 +98,52 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } for initialized || rows.Next() { + for idx := range columns { + values[idx] = new(interface{}) + } + initialized = false + db.RowsAffected++ + elem := reflect.New(reflectValueType).Elem() if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { + // pluck values[0] = elem.Addr().Interface() + db.AddError(rows.Scan(values...)) } else { - for idx, field := range fields { - if field != nil { - values[idx] = field.ReflectValueOf(elem).Addr().Interface() - } else if joinFields[idx][0] != nil { - relValue := joinFields[idx][0].ReflectValueOf(elem) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - relValue.Set(reflect.New(relValue.Type().Elem())) - } + db.AddError(rows.Scan(values...)) - values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() + for idx, field := range fields { + if v, ok := values[idx].(*interface{}); ok { + if field != nil { + if v == nil { + field.Set(elem, v) + } else { + field.Set(elem, *v) + } + } else if joinFields[idx][0] != nil { + relValue := joinFields[idx][0].ReflectValueOf(elem) + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if v == nil { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + if v == nil { + joinFields[idx][1].Set(relValue, nil) + } else { + joinFields[idx][1].Set(relValue, *v) + } + } } } - } - db.RowsAffected++ - db.AddError(rows.Scan(values...)) + for idx := range columns { + values[idx] = new(interface{}) + } + } if isPtr { db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) @@ -115,30 +152,45 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } } case reflect.Struct: - for idx, column := range columns { - if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { - relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { - relValue.Set(reflect.New(relValue.Type().Elem())) - } - - values[idx] = field.ReflectValueOf(relValue).Addr().Interface() - continue - } - } - values[idx] = &sql.RawBytes{} - } else { - values[idx] = &sql.RawBytes{} - } - } - if initialized || rows.Next() { + for idx := range columns { + values[idx] = new(interface{}) + } + db.RowsAffected++ db.AddError(rows.Scan(values...)) + + for idx, column := range columns { + if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { + if v, ok := values[idx].(*interface{}); ok { + if v == nil { + field.Set(db.Statement.ReflectValue, v) + } else { + field.Set(db.Statement.ReflectValue, *v) + } + } + } else if names := strings.Split(column, "__"); len(names) > 1 { + if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { + relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + if v, ok := values[idx].(*interface{}); ok { + if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if v == nil { + continue + } + relValue.Set(reflect.New(relValue.Type().Elem())) + } + + if v == nil { + field.Set(relValue, nil) + } else { + field.Set(relValue, *v) + } + } + } + } + } + } } } } diff --git a/schema/field.go b/schema/field.go index 4f92aae7..8861a00d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -402,34 +402,48 @@ func (field *Field) setupValuerAndSetter() { } } - recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { if v == nil { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) - - if reflectV.Type().ConvertibleTo(field.FieldType) { + if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + return + } else if reflectV.Type().ConvertibleTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - return setter(value, v) - } - } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + return + } else if field.FieldType.Kind() == reflect.Ptr { fieldValue := field.ReflectValueOf(value) - if fieldValue.IsNil() { - if v == nil { - return nil + + if reflectV.Type().AssignableTo(field.FieldType.Elem()) { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) } - fieldValue.Set(reflect.New(field.FieldType.Elem())) + fieldValue.Elem().Set(reflectV) + return + } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + + fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) + return + } + } + + if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + setter(value, v) } - fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if reflectV.Kind() == reflect.Ptr { - return field.Set(value, reflectV.Elem().Interface()) + setter(value, reflectV.Elem().Interface()) } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } } - return err + + return } // Set @@ -441,8 +455,17 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).SetBool(data) case *bool: field.ReflectValueOf(value).SetBool(*data) + case int64: + if data > 0 { + field.ReflectValueOf(value).SetBool(true) + } else { + field.ReflectValueOf(value).SetBool(false) + } + case string: + b, _ := strconv.ParseBool(data) + field.ReflectValueOf(value).SetBool(b) default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return nil } @@ -498,7 +521,7 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).SetInt(0) } default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return err } @@ -538,7 +561,7 @@ func (field *Field) setupValuerAndSetter() { return err } default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return err } @@ -578,7 +601,7 @@ func (field *Field) setupValuerAndSetter() { return err } default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return err } @@ -594,7 +617,7 @@ func (field *Field) setupValuerAndSetter() { case float64, float32: field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return err } @@ -615,7 +638,7 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return nil } @@ -625,9 +648,6 @@ func (field *Field) setupValuerAndSetter() { case time.Time: fieldValue := field.ReflectValueOf(value) if fieldValue.IsNil() { - if v == nil { - return nil - } fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(v)) @@ -647,7 +667,7 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } return nil } @@ -655,53 +675,42 @@ func (field *Field) setupValuerAndSetter() { if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner field.Set = func(value reflect.Value, v interface{}) (err error) { - if v == nil { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + reflectV := reflect.ValueOf(v) + if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) - } - } else { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) - } + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { - if v == nil { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - fieldValue := field.ReflectValueOf(value) - if fieldValue.IsNil() { - if v == nil { - return nil - } - fieldValue.Set(reflect.New(field.FieldType.Elem())) - } - fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) - } else if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) - } - } else { - err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) + fieldValue := field.ReflectValueOf(value) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) } + err = fieldValue.Interface().(sql.Scanner).Scan(v) } return } } else { field.Set = func(value reflect.Value, v interface{}) (err error) { - return recoverFunc(value, v, field.Set) + return fallbackSetter(value, v, field.Set) } } } diff --git a/statement.go b/statement.go index ebd6e234..ffe3c75b 100644 --- a/statement.go +++ b/statement.go @@ -146,8 +146,16 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { case clause.Column, clause.Table: stmt.QuoteTo(writer, v) case clause.Expr: - writer.WriteString(v.SQL) - stmt.Vars = append(stmt.Vars, v.Vars...) + var varStr strings.Builder + var sql = v.SQL + for _, arg := range v.Vars { + stmt.Vars = append(stmt.Vars, arg) + stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg) + sql = strings.Replace(sql, "?", varStr.String(), 1) + varStr.Reset() + } + + writer.WriteString(sql) case driver.Valuer: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) diff --git a/tests/main_test.go b/tests/main_test.go index ff293e6e..9d933caf 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -6,11 +6,6 @@ import ( . "gorm.io/gorm/utils/tests" ) -func TestMain(m *testing.M) { - RunMigrations() - m.Run() -} - func TestExceptionsWithInvalidSql(t *testing.T) { var columns []string if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { diff --git a/tests/preload_suits_test.go b/tests/preload_suits_test.go index 98f24daf..8f678b21 100644 --- a/tests/preload_suits_test.go +++ b/tests/preload_suits_test.go @@ -1299,7 +1299,6 @@ func TestNilPointerSlice(t *testing.T) { ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } diff --git a/tests/query_test.go b/tests/query_test.go index f6fb1081..18ffb3fb 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -435,3 +435,35 @@ func TestSubQueryWithHaving(t *testing.T) { t.Errorf("Two user group should be found, instead found %d", len(results)) } } + +func TestScanNullValue(t *testing.T) { + user := GetUser("scan_null_value", Config{}) + DB.Create(&user) + + if err := DB.Model(&user).Update("age", nil).Error; err != nil { + t.Fatalf("failed to update column age for struct, got error %v", err) + } + + var result User + if err := DB.First(&result, "id = ?", user.ID).Error; err != nil { + t.Fatalf("failed to query struct data with null age, got error %v", err) + } + + AssertEqual(t, result, user) + + users := []User{ + *GetUser("scan_null_value_for_slice_1", Config{}), + *GetUser("scan_null_value_for_slice_2", Config{}), + *GetUser("scan_null_value_for_slice_3", Config{}), + } + DB.Create(&users) + + if err := DB.Model(&users[0]).Update("age", nil).Error; err != nil { + t.Fatalf("failed to update column age for struct, got error %v", err) + } + + var results []User + if err := DB.Find(&results, "name like ?", "scan_null_value_for_slice%").Error; err != nil { + t.Fatalf("failed to query slice data with null age, got error %v", err) + } +} diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 95245804..92a28f3b 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -1,4 +1,4 @@ -dialects=("sqlite" "mysql" "postgres" "mssql") +dialects=("sqlite" "mysql" "postgres" "sqlserver") if [[ $(pwd) == *"gorm/tests"* ]]; then cd .. @@ -10,7 +10,7 @@ for dialect in "${dialects[@]}" ; do echo "testing ${dialect}..." race="" - if [ "$GORM_VERBOSE" = "" ] + if [ "$GORM_DIALECT" = "sqlserver" ] then race="-race" fi diff --git a/tests/tests_test.go b/tests/tests_test.go index 40816c3c..09850003 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -23,6 +23,8 @@ func init() { if DB, err = OpenTestConnection(); err != nil { log.Printf("failed to connect database, got error %v\n", err) os.Exit(1) + } else { + RunMigrations() } } diff --git a/tests/update_test.go b/tests/update_test.go index 524e9ea6..220d3e76 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -155,12 +155,14 @@ func TestUpdates(t *testing.T) { AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt) // update with gorm exprs - DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}) + if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { + t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) + } var user4 User DB.First(&user4, user3.ID) user3.Age += 100 - AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt) + AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") } func TestUpdateColumn(t *testing.T) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 412be305..f132a7da 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -121,6 +121,11 @@ func TestFindOrCreate(t *testing.T) { updatedAt1 := user4.UpdatedAt DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) + + if user4.Age != 55 { + t.Errorf("Failed to set change to 55, got %v", user4.Age) + } + if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { t.Errorf("UpdateAt should be changed when update values with assign") }