Fix can't scan null value into normal data types

This commit is contained in:
Jinzhu 2020-06-02 22:13:53 +08:00
parent 2218e32999
commit 94685d1024
11 changed files with 226 additions and 122 deletions

View File

@ -168,7 +168,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.Create(dest) return tx.Create(dest)
} else if len(tx.Statement.assigns) > 0 { } else if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
assigns := map[string]interface{}{} assigns := map[string]interface{}{}
for _, expr := range exprs { for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok { if eq, ok := expr.(clause.Eq); ok {

102
scan.go
View File

@ -14,40 +14,53 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
switch dest := db.Statement.Dest.(type) { switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}: case map[string]interface{}, *map[string]interface{}:
for idx, _ := range columns { if initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{}) values[idx] = new(interface{})
} }
if initialized || rows.Next() {
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
}
mapValue, ok := dest.(map[string]interface{}) mapValue, ok := dest.(map[string]interface{})
if ok { if !ok {
if v, ok := dest.(*map[string]interface{}); ok { if v, ok := dest.(*map[string]interface{}); ok {
mapValue = *v mapValue = *v
} }
} }
for idx, column := range columns { for idx, column := range columns {
mapValue[column] = *(values[idx].(*interface{})) if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
}
}
} }
case *[]map[string]interface{}: case *[]map[string]interface{}:
for idx, _ := range columns { for initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{}) values[idx] = new(interface{})
} }
for initialized || rows.Next() {
initialized = false initialized = false
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
v := map[string]interface{}{} mapValue := map[string]interface{}{}
for idx, column := range columns { for idx, column := range columns {
v[column] = *(values[idx].(*interface{})) if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
} }
*dest = append(*dest, v) }
}
*dest = append(*dest, mapValue)
} }
case *int, *int64, *uint, *uint64: case *int, *int64, *uint, *uint64:
for initialized || rows.Next() { for initialized || rows.Next() {
@ -85,28 +98,52 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
for initialized || rows.Next() { for initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
initialized = false initialized = false
db.RowsAffected++
elem := reflect.New(reflectValueType).Elem() elem := reflect.New(reflectValueType).Elem()
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
// pluck
values[0] = elem.Addr().Interface() values[0] = elem.Addr().Interface()
db.AddError(rows.Scan(values...))
} else { } else {
db.AddError(rows.Scan(values...))
for idx, field := range fields { for idx, field := range fields {
if v, ok := values[idx].(*interface{}); ok {
if field != nil { if field != nil {
values[idx] = field.ReflectValueOf(elem).Addr().Interface() if v == nil {
field.Set(elem, v)
} else {
field.Set(elem, *v)
}
} else if joinFields[idx][0] != nil { } else if joinFields[idx][0] != nil {
relValue := joinFields[idx][0].ReflectValueOf(elem) relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if v == nil {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem())) relValue.Set(reflect.New(relValue.Type().Elem()))
} }
values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() if v == nil {
joinFields[idx][1].Set(relValue, nil)
} else {
joinFields[idx][1].Set(relValue, *v)
}
} }
} }
} }
db.RowsAffected++ for idx := range columns {
db.AddError(rows.Scan(values...)) values[idx] = new(interface{})
}
}
if isPtr { if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
@ -115,30 +152,45 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} }
} }
case reflect.Struct: case reflect.Struct:
if initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
for idx, column := range columns { for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() if v, ok := values[idx].(*interface{}); ok {
if v == nil {
field.Set(db.Statement.ReflectValue, v)
} else {
field.Set(db.Statement.ReflectValue, *v)
}
}
} else if names := strings.Split(column, "__"); len(names) > 1 { } else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
if v, ok := values[idx].(*interface{}); ok {
if relValue.Kind() == reflect.Ptr && relValue.IsNil() { if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if v == nil {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem())) relValue.Set(reflect.New(relValue.Type().Elem()))
} }
values[idx] = field.ReflectValueOf(relValue).Addr().Interface() if v == nil {
continue field.Set(relValue, nil)
}
}
values[idx] = &sql.RawBytes{}
} else { } else {
values[idx] = &sql.RawBytes{} field.Set(relValue, *v)
}
}
}
}
} }
} }
if initialized || rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
} }
} }
} }

View File

@ -402,34 +402,48 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
if v == nil { if v == nil {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().AssignableTo(field.FieldType) {
if reflectV.Type().ConvertibleTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV)
return
} else if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok { return
if v, err = valuer.Value(); err == nil { } else if field.FieldType.Kind() == reflect.Ptr {
return setter(value, v)
}
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(value)
if reflectV.Type().AssignableTo(field.FieldType.Elem()) {
if fieldValue.IsNil() { if fieldValue.IsNil() {
if v == nil {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
} }
fieldValue.Elem().Set(reflectV)
return
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
return
}
}
if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
setter(value, v)
}
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
return field.Set(value, reflectV.Elem().Interface()) setter(value, reflectV.Elem().Interface())
} else { } else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} }
} }
return err
return
} }
// Set // Set
@ -441,8 +455,17 @@ func (field *Field) setupValuerAndSetter() {
field.ReflectValueOf(value).SetBool(data) field.ReflectValueOf(value).SetBool(data)
case *bool: case *bool:
field.ReflectValueOf(value).SetBool(*data) field.ReflectValueOf(value).SetBool(*data)
case int64:
if data > 0 {
field.ReflectValueOf(value).SetBool(true)
} else {
field.ReflectValueOf(value).SetBool(false)
}
case string:
b, _ := strconv.ParseBool(data)
field.ReflectValueOf(value).SetBool(b)
default: default:
return recoverFunc(value, v, field.Set) return fallbackSetter(value, v, field.Set)
} }
return nil return nil
} }
@ -498,7 +521,7 @@ func (field *Field) setupValuerAndSetter() {
field.ReflectValueOf(value).SetInt(0) field.ReflectValueOf(value).SetInt(0)
} }
default: default:
return recoverFunc(value, v, field.Set) return fallbackSetter(value, v, field.Set)
} }
return err return err
} }
@ -538,7 +561,7 @@ func (field *Field) setupValuerAndSetter() {
return err return err
} }
default: default:
return recoverFunc(value, v, field.Set) return fallbackSetter(value, v, field.Set)
} }
return err return err
} }
@ -578,7 +601,7 @@ func (field *Field) setupValuerAndSetter() {
return err return err
} }
default: default:
return recoverFunc(value, v, field.Set) return fallbackSetter(value, v, field.Set)
} }
return err return err
} }
@ -594,7 +617,7 @@ func (field *Field) setupValuerAndSetter() {
case float64, float32: case float64, float32:
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default: default:
return recoverFunc(value, v, field.Set) return fallbackSetter(value, v, field.Set)
} }
return err return err
} }
@ -615,7 +638,7 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return recoverFunc(value, v, field.Set) return fallbackSetter(value, v, field.Set)
} }
return nil return nil
} }
@ -625,9 +648,6 @@ func (field *Field) setupValuerAndSetter() {
case time.Time: case time.Time:
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() { if fieldValue.IsNil() {
if v == nil {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
} }
fieldValue.Elem().Set(reflect.ValueOf(v)) fieldValue.Elem().Set(reflect.ValueOf(v))
@ -647,7 +667,7 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return recoverFunc(value, v, field.Set) return fallbackSetter(value, v, field.Set)
} }
return nil return nil
} }
@ -655,53 +675,42 @@ func (field *Field) setupValuerAndSetter() {
if _, ok := fieldValue.Interface().(sql.Scanner); ok { if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner // struct scanner
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
if v == nil { if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
reflectV := reflect.ValueOf(v)
if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else { } else {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
} }
} else {
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
}
return return
} }
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner // pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
if v == nil { if valuer, ok := v.(driver.Valuer); ok {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) v, _ = valuer.Value()
} else { }
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
fieldValue := field.ReflectValueOf(value) fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() { if fieldValue.IsNil() {
if v == nil {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
} }
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) err = fieldValue.Interface().(sql.Scanner).Scan(v)
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
}
} else {
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
}
} }
return return
} }
} else { } else {
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
return recoverFunc(value, v, field.Set) return fallbackSetter(value, v, field.Set)
} }
} }
} }

View File

@ -146,8 +146,16 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
case clause.Column, clause.Table: case clause.Column, clause.Table:
stmt.QuoteTo(writer, v) stmt.QuoteTo(writer, v)
case clause.Expr: case clause.Expr:
writer.WriteString(v.SQL) var varStr strings.Builder
stmt.Vars = append(stmt.Vars, v.Vars...) var sql = v.SQL
for _, arg := range v.Vars {
stmt.Vars = append(stmt.Vars, arg)
stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg)
sql = strings.Replace(sql, "?", varStr.String(), 1)
varStr.Reset()
}
writer.WriteString(sql)
case driver.Valuer: case driver.Valuer:
stmt.Vars = append(stmt.Vars, v) stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v)

View File

@ -6,11 +6,6 @@ import (
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
func TestMain(m *testing.M) {
RunMigrations()
m.Run()
}
func TestExceptionsWithInvalidSql(t *testing.T) { func TestExceptionsWithInvalidSql(t *testing.T) {
var columns []string var columns []string
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {

View File

@ -1299,7 +1299,6 @@ func TestNilPointerSlice(t *testing.T) {
) )
DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil {
t.Error(err) t.Error(err)
} }

View File

@ -435,3 +435,35 @@ func TestSubQueryWithHaving(t *testing.T) {
t.Errorf("Two user group should be found, instead found %d", len(results)) t.Errorf("Two user group should be found, instead found %d", len(results))
} }
} }
func TestScanNullValue(t *testing.T) {
user := GetUser("scan_null_value", Config{})
DB.Create(&user)
if err := DB.Model(&user).Update("age", nil).Error; err != nil {
t.Fatalf("failed to update column age for struct, got error %v", err)
}
var result User
if err := DB.First(&result, "id = ?", user.ID).Error; err != nil {
t.Fatalf("failed to query struct data with null age, got error %v", err)
}
AssertEqual(t, result, user)
users := []User{
*GetUser("scan_null_value_for_slice_1", Config{}),
*GetUser("scan_null_value_for_slice_2", Config{}),
*GetUser("scan_null_value_for_slice_3", Config{}),
}
DB.Create(&users)
if err := DB.Model(&users[0]).Update("age", nil).Error; err != nil {
t.Fatalf("failed to update column age for struct, got error %v", err)
}
var results []User
if err := DB.Find(&results, "name like ?", "scan_null_value_for_slice%").Error; err != nil {
t.Fatalf("failed to query slice data with null age, got error %v", err)
}
}

View File

@ -1,4 +1,4 @@
dialects=("sqlite" "mysql" "postgres" "mssql") dialects=("sqlite" "mysql" "postgres" "sqlserver")
if [[ $(pwd) == *"gorm/tests"* ]]; then if [[ $(pwd) == *"gorm/tests"* ]]; then
cd .. cd ..
@ -10,7 +10,7 @@ for dialect in "${dialects[@]}" ; do
echo "testing ${dialect}..." echo "testing ${dialect}..."
race="" race=""
if [ "$GORM_VERBOSE" = "" ] if [ "$GORM_DIALECT" = "sqlserver" ]
then then
race="-race" race="-race"
fi fi

View File

@ -23,6 +23,8 @@ func init() {
if DB, err = OpenTestConnection(); err != nil { if DB, err = OpenTestConnection(); err != nil {
log.Printf("failed to connect database, got error %v\n", err) log.Printf("failed to connect database, got error %v\n", err)
os.Exit(1) os.Exit(1)
} else {
RunMigrations()
} }
} }

View File

@ -155,12 +155,14 @@ func TestUpdates(t *testing.T) {
AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt) AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt)
// update with gorm exprs // update with gorm exprs
DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}) if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
}
var user4 User var user4 User
DB.First(&user4, user3.ID) DB.First(&user4, user3.ID)
user3.Age += 100 user3.Age += 100
AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt) AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
} }
func TestUpdateColumn(t *testing.T) { func TestUpdateColumn(t *testing.T) {

View File

@ -121,6 +121,11 @@ func TestFindOrCreate(t *testing.T) {
updatedAt1 := user4.UpdatedAt updatedAt1 := user4.UpdatedAt
DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4)
if user4.Age != 55 {
t.Errorf("Failed to set change to 55, got %v", user4.Age)
}
if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("UpdateAt should be changed when update values with assign") t.Errorf("UpdateAt should be changed when update values with assign")
} }