Fix can't scan null value into normal data types

This commit is contained in:
Jinzhu 2020-06-02 22:13:53 +08:00
parent 2218e32999
commit 94685d1024
11 changed files with 226 additions and 122 deletions

View File

@ -168,7 +168,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.Create(dest)
} else if len(tx.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:])
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
assigns := map[string]interface{}{}
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {

102
scan.go
View File

@ -14,40 +14,53 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
for idx, _ := range columns {
if initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
if initialized || rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
mapValue, ok := dest.(map[string]interface{})
if ok {
if !ok {
if v, ok := dest.(*map[string]interface{}); ok {
mapValue = *v
}
}
for idx, column := range columns {
mapValue[column] = *(values[idx].(*interface{}))
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
}
}
}
case *[]map[string]interface{}:
for idx, _ := range columns {
for initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(values...))
v := map[string]interface{}{}
mapValue := map[string]interface{}{}
for idx, column := range columns {
v[column] = *(values[idx].(*interface{}))
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
mapValue[column] = nil
} else {
mapValue[column] = *v
}
*dest = append(*dest, v)
}
}
*dest = append(*dest, mapValue)
}
case *int, *int64, *uint, *uint64:
for initialized || rows.Next() {
@ -85,28 +98,52 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
for initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
initialized = false
db.RowsAffected++
elem := reflect.New(reflectValueType).Elem()
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
// pluck
values[0] = elem.Addr().Interface()
db.AddError(rows.Scan(values...))
} else {
db.AddError(rows.Scan(values...))
for idx, field := range fields {
if v, ok := values[idx].(*interface{}); ok {
if field != nil {
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
if v == nil {
field.Set(elem, v)
} else {
field.Set(elem, *v)
}
} else if joinFields[idx][0] != nil {
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if v == nil {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface()
if v == nil {
joinFields[idx][1].Set(relValue, nil)
} else {
joinFields[idx][1].Set(relValue, *v)
}
}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
for idx := range columns {
values[idx] = new(interface{})
}
}
if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
@ -115,30 +152,45 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
}
case reflect.Struct:
if initialized || rows.Next() {
for idx := range columns {
values[idx] = new(interface{})
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
for idx, column := range columns {
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
if v, ok := values[idx].(*interface{}); ok {
if v == nil {
field.Set(db.Statement.ReflectValue, v)
} else {
field.Set(db.Statement.ReflectValue, *v)
}
}
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
if v, ok := values[idx].(*interface{}); ok {
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if v == nil {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
}
values[idx] = field.ReflectValueOf(relValue).Addr().Interface()
continue
}
}
values[idx] = &sql.RawBytes{}
if v == nil {
field.Set(relValue, nil)
} else {
values[idx] = &sql.RawBytes{}
field.Set(relValue, *v)
}
}
}
}
}
}
if initialized || rows.Next() {
db.RowsAffected++
db.AddError(rows.Scan(values...))
}
}
}

View File

@ -402,34 +402,48 @@ func (field *Field) setupValuerAndSetter() {
}
}
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
if v == nil {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
return
} else if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
return setter(value, v)
}
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
return
} else if field.FieldType.Kind() == reflect.Ptr {
fieldValue := field.ReflectValueOf(value)
if reflectV.Type().AssignableTo(field.FieldType.Elem()) {
if fieldValue.IsNil() {
if v == nil {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflectV)
return
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
return
}
}
if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
setter(value, v)
}
} else if reflectV.Kind() == reflect.Ptr {
return field.Set(value, reflectV.Elem().Interface())
setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
}
return err
return
}
// Set
@ -441,8 +455,17 @@ func (field *Field) setupValuerAndSetter() {
field.ReflectValueOf(value).SetBool(data)
case *bool:
field.ReflectValueOf(value).SetBool(*data)
case int64:
if data > 0 {
field.ReflectValueOf(value).SetBool(true)
} else {
field.ReflectValueOf(value).SetBool(false)
}
case string:
b, _ := strconv.ParseBool(data)
field.ReflectValueOf(value).SetBool(b)
default:
return recoverFunc(value, v, field.Set)
return fallbackSetter(value, v, field.Set)
}
return nil
}
@ -498,7 +521,7 @@ func (field *Field) setupValuerAndSetter() {
field.ReflectValueOf(value).SetInt(0)
}
default:
return recoverFunc(value, v, field.Set)
return fallbackSetter(value, v, field.Set)
}
return err
}
@ -538,7 +561,7 @@ func (field *Field) setupValuerAndSetter() {
return err
}
default:
return recoverFunc(value, v, field.Set)
return fallbackSetter(value, v, field.Set)
}
return err
}
@ -578,7 +601,7 @@ func (field *Field) setupValuerAndSetter() {
return err
}
default:
return recoverFunc(value, v, field.Set)
return fallbackSetter(value, v, field.Set)
}
return err
}
@ -594,7 +617,7 @@ func (field *Field) setupValuerAndSetter() {
case float64, float32:
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default:
return recoverFunc(value, v, field.Set)
return fallbackSetter(value, v, field.Set)
}
return err
}
@ -615,7 +638,7 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return recoverFunc(value, v, field.Set)
return fallbackSetter(value, v, field.Set)
}
return nil
}
@ -625,9 +648,6 @@ func (field *Field) setupValuerAndSetter() {
case time.Time:
fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() {
if v == nil {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflect.ValueOf(v))
@ -647,7 +667,7 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return recoverFunc(value, v, field.Set)
return fallbackSetter(value, v, field.Set)
}
return nil
}
@ -655,53 +675,42 @@ func (field *Field) setupValuerAndSetter() {
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
if v == nil {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
reflectV := reflect.ValueOf(v)
if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
} else {
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
}
return
}
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
if v == nil {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
fieldValue := field.ReflectValueOf(value)
if fieldValue.IsNil() {
if v == nil {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
}
} else {
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
}
err = fieldValue.Interface().(sql.Scanner).Scan(v)
}
return
}
} else {
field.Set = func(value reflect.Value, v interface{}) (err error) {
return recoverFunc(value, v, field.Set)
return fallbackSetter(value, v, field.Set)
}
}
}

View File

@ -146,8 +146,16 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
case clause.Column, clause.Table:
stmt.QuoteTo(writer, v)
case clause.Expr:
writer.WriteString(v.SQL)
stmt.Vars = append(stmt.Vars, v.Vars...)
var varStr strings.Builder
var sql = v.SQL
for _, arg := range v.Vars {
stmt.Vars = append(stmt.Vars, arg)
stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg)
sql = strings.Replace(sql, "?", varStr.String(), 1)
varStr.Reset()
}
writer.WriteString(sql)
case driver.Valuer:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)

View File

@ -6,11 +6,6 @@ import (
. "gorm.io/gorm/utils/tests"
)
func TestMain(m *testing.M) {
RunMigrations()
m.Run()
}
func TestExceptionsWithInvalidSql(t *testing.T) {
var columns []string
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {

View File

@ -1299,7 +1299,6 @@ func TestNilPointerSlice(t *testing.T) {
)
DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil {
t.Error(err)
}

View File

@ -435,3 +435,35 @@ func TestSubQueryWithHaving(t *testing.T) {
t.Errorf("Two user group should be found, instead found %d", len(results))
}
}
func TestScanNullValue(t *testing.T) {
user := GetUser("scan_null_value", Config{})
DB.Create(&user)
if err := DB.Model(&user).Update("age", nil).Error; err != nil {
t.Fatalf("failed to update column age for struct, got error %v", err)
}
var result User
if err := DB.First(&result, "id = ?", user.ID).Error; err != nil {
t.Fatalf("failed to query struct data with null age, got error %v", err)
}
AssertEqual(t, result, user)
users := []User{
*GetUser("scan_null_value_for_slice_1", Config{}),
*GetUser("scan_null_value_for_slice_2", Config{}),
*GetUser("scan_null_value_for_slice_3", Config{}),
}
DB.Create(&users)
if err := DB.Model(&users[0]).Update("age", nil).Error; err != nil {
t.Fatalf("failed to update column age for struct, got error %v", err)
}
var results []User
if err := DB.Find(&results, "name like ?", "scan_null_value_for_slice%").Error; err != nil {
t.Fatalf("failed to query slice data with null age, got error %v", err)
}
}

View File

@ -1,4 +1,4 @@
dialects=("sqlite" "mysql" "postgres" "mssql")
dialects=("sqlite" "mysql" "postgres" "sqlserver")
if [[ $(pwd) == *"gorm/tests"* ]]; then
cd ..
@ -10,7 +10,7 @@ for dialect in "${dialects[@]}" ; do
echo "testing ${dialect}..."
race=""
if [ "$GORM_VERBOSE" = "" ]
if [ "$GORM_DIALECT" = "sqlserver" ]
then
race="-race"
fi

View File

@ -23,6 +23,8 @@ func init() {
if DB, err = OpenTestConnection(); err != nil {
log.Printf("failed to connect database, got error %v\n", err)
os.Exit(1)
} else {
RunMigrations()
}
}

View File

@ -155,12 +155,14 @@ func TestUpdates(t *testing.T) {
AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt)
// update with gorm exprs
DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)})
if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
}
var user4 User
DB.First(&user4, user3.ID)
user3.Age += 100
AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt)
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
}
func TestUpdateColumn(t *testing.T) {

View File

@ -121,6 +121,11 @@ func TestFindOrCreate(t *testing.T) {
updatedAt1 := user4.UpdatedAt
DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4)
if user4.Age != 55 {
t.Errorf("Failed to set change to 55, got %v", user4.Age)
}
if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) {
t.Errorf("UpdateAt should be changed when update values with assign")
}