forked from mirror/gorm
Fix can't scan null value into normal data types
This commit is contained in:
parent
2218e32999
commit
94685d1024
|
@ -168,7 +168,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||
|
||||
return tx.Create(dest)
|
||||
} else if len(tx.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:])
|
||||
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
||||
assigns := map[string]interface{}{}
|
||||
for _, expr := range exprs {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
|
|
102
scan.go
102
scan.go
|
@ -14,40 +14,53 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
|
||||
switch dest := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}, *map[string]interface{}:
|
||||
for idx, _ := range columns {
|
||||
if initialized || rows.Next() {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
|
||||
if initialized || rows.Next() {
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
}
|
||||
|
||||
mapValue, ok := dest.(map[string]interface{})
|
||||
if ok {
|
||||
if !ok {
|
||||
if v, ok := dest.(*map[string]interface{}); ok {
|
||||
mapValue = *v
|
||||
}
|
||||
}
|
||||
|
||||
for idx, column := range columns {
|
||||
mapValue[column] = *(values[idx].(*interface{}))
|
||||
if v, ok := values[idx].(*interface{}); ok {
|
||||
if v == nil {
|
||||
mapValue[column] = nil
|
||||
} else {
|
||||
mapValue[column] = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case *[]map[string]interface{}:
|
||||
for idx, _ := range columns {
|
||||
for initialized || rows.Next() {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
|
||||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
v := map[string]interface{}{}
|
||||
mapValue := map[string]interface{}{}
|
||||
for idx, column := range columns {
|
||||
v[column] = *(values[idx].(*interface{}))
|
||||
if v, ok := values[idx].(*interface{}); ok {
|
||||
if v == nil {
|
||||
mapValue[column] = nil
|
||||
} else {
|
||||
mapValue[column] = *v
|
||||
}
|
||||
*dest = append(*dest, v)
|
||||
}
|
||||
}
|
||||
|
||||
*dest = append(*dest, mapValue)
|
||||
}
|
||||
case *int, *int64, *uint, *uint64:
|
||||
for initialized || rows.Next() {
|
||||
|
@ -85,28 +98,52 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
}
|
||||
|
||||
for initialized || rows.Next() {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
|
||||
elem := reflect.New(reflectValueType).Elem()
|
||||
|
||||
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
|
||||
// pluck
|
||||
values[0] = elem.Addr().Interface()
|
||||
db.AddError(rows.Scan(values...))
|
||||
} else {
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
for idx, field := range fields {
|
||||
if v, ok := values[idx].(*interface{}); ok {
|
||||
if field != nil {
|
||||
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
|
||||
if v == nil {
|
||||
field.Set(elem, v)
|
||||
} else {
|
||||
field.Set(elem, *v)
|
||||
}
|
||||
} else if joinFields[idx][0] != nil {
|
||||
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
}
|
||||
|
||||
values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface()
|
||||
if v == nil {
|
||||
joinFields[idx][1].Set(relValue, nil)
|
||||
} else {
|
||||
joinFields[idx][1].Set(relValue, *v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
|
||||
|
@ -115,30 +152,45 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
|||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if initialized || rows.Next() {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
for idx, column := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
||||
if v, ok := values[idx].(*interface{}); ok {
|
||||
if v == nil {
|
||||
field.Set(db.Statement.ReflectValue, v)
|
||||
} else {
|
||||
field.Set(db.Statement.ReflectValue, *v)
|
||||
}
|
||||
}
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
||||
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
if v, ok := values[idx].(*interface{}); ok {
|
||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
}
|
||||
|
||||
values[idx] = field.ReflectValueOf(relValue).Addr().Interface()
|
||||
continue
|
||||
}
|
||||
}
|
||||
values[idx] = &sql.RawBytes{}
|
||||
if v == nil {
|
||||
field.Set(relValue, nil)
|
||||
} else {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
field.Set(relValue, *v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if initialized || rows.Next() {
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
107
schema/field.go
107
schema/field.go
|
@ -402,34 +402,48 @@ func (field *Field) setupValuerAndSetter() {
|
|||
}
|
||||
}
|
||||
|
||||
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
|
||||
fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
|
||||
if v == nil {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
|
||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||
if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(value).Set(reflectV)
|
||||
return
|
||||
} else if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||
if v, err = valuer.Value(); err == nil {
|
||||
return setter(value, v)
|
||||
}
|
||||
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
|
||||
return
|
||||
} else if field.FieldType.Kind() == reflect.Ptr {
|
||||
fieldValue := field.ReflectValueOf(value)
|
||||
|
||||
if reflectV.Type().AssignableTo(field.FieldType.Elem()) {
|
||||
if fieldValue.IsNil() {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||
}
|
||||
fieldValue.Elem().Set(reflectV)
|
||||
return
|
||||
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||
}
|
||||
|
||||
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
if v, err = valuer.Value(); err == nil {
|
||||
setter(value, v)
|
||||
}
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
return field.Set(value, reflectV.Elem().Interface())
|
||||
setter(value, reflectV.Elem().Interface())
|
||||
} else {
|
||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
||||
}
|
||||
}
|
||||
return err
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Set
|
||||
|
@ -441,8 +455,17 @@ func (field *Field) setupValuerAndSetter() {
|
|||
field.ReflectValueOf(value).SetBool(data)
|
||||
case *bool:
|
||||
field.ReflectValueOf(value).SetBool(*data)
|
||||
case int64:
|
||||
if data > 0 {
|
||||
field.ReflectValueOf(value).SetBool(true)
|
||||
} else {
|
||||
field.ReflectValueOf(value).SetBool(false)
|
||||
}
|
||||
case string:
|
||||
b, _ := strconv.ParseBool(data)
|
||||
field.ReflectValueOf(value).SetBool(b)
|
||||
default:
|
||||
return recoverFunc(value, v, field.Set)
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -498,7 +521,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||
field.ReflectValueOf(value).SetInt(0)
|
||||
}
|
||||
default:
|
||||
return recoverFunc(value, v, field.Set)
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -538,7 +561,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||
return err
|
||||
}
|
||||
default:
|
||||
return recoverFunc(value, v, field.Set)
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -578,7 +601,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||
return err
|
||||
}
|
||||
default:
|
||||
return recoverFunc(value, v, field.Set)
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -594,7 +617,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||
case float64, float32:
|
||||
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
||||
default:
|
||||
return recoverFunc(value, v, field.Set)
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -615,7 +638,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
||||
}
|
||||
default:
|
||||
return recoverFunc(value, v, field.Set)
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -625,9 +648,6 @@ func (field *Field) setupValuerAndSetter() {
|
|||
case time.Time:
|
||||
fieldValue := field.ReflectValueOf(value)
|
||||
if fieldValue.IsNil() {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||
}
|
||||
fieldValue.Elem().Set(reflect.ValueOf(v))
|
||||
|
@ -647,7 +667,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
||||
}
|
||||
default:
|
||||
return recoverFunc(value, v, field.Set)
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -655,53 +675,42 @@ func (field *Field) setupValuerAndSetter() {
|
|||
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||
// struct scanner
|
||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||
if v == nil {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||
if v, err = valuer.Value(); err == nil {
|
||||
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||
}
|
||||
} else {
|
||||
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
||||
// pointer scanner
|
||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||
if v == nil {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
||||
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
fieldValue := field.ReflectValueOf(value)
|
||||
if fieldValue.IsNil() {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||
}
|
||||
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
|
||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||
if v, err = valuer.Value(); err == nil {
|
||||
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
|
||||
}
|
||||
} else {
|
||||
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
|
||||
}
|
||||
err = fieldValue.Interface().(sql.Scanner).Scan(v)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else {
|
||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||
return recoverFunc(value, v, field.Set)
|
||||
return fallbackSetter(value, v, field.Set)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
12
statement.go
12
statement.go
|
@ -146,8 +146,16 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
|||
case clause.Column, clause.Table:
|
||||
stmt.QuoteTo(writer, v)
|
||||
case clause.Expr:
|
||||
writer.WriteString(v.SQL)
|
||||
stmt.Vars = append(stmt.Vars, v.Vars...)
|
||||
var varStr strings.Builder
|
||||
var sql = v.SQL
|
||||
for _, arg := range v.Vars {
|
||||
stmt.Vars = append(stmt.Vars, arg)
|
||||
stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg)
|
||||
sql = strings.Replace(sql, "?", varStr.String(), 1)
|
||||
varStr.Reset()
|
||||
}
|
||||
|
||||
writer.WriteString(sql)
|
||||
case driver.Valuer:
|
||||
stmt.Vars = append(stmt.Vars, v)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||
|
|
|
@ -6,11 +6,6 @@ import (
|
|||
. "gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
RunMigrations()
|
||||
m.Run()
|
||||
}
|
||||
|
||||
func TestExceptionsWithInvalidSql(t *testing.T) {
|
||||
var columns []string
|
||||
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||
|
|
|
@ -1299,7 +1299,6 @@ func TestNilPointerSlice(t *testing.T) {
|
|||
)
|
||||
|
||||
DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{})
|
||||
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
|
|
@ -435,3 +435,35 @@ func TestSubQueryWithHaving(t *testing.T) {
|
|||
t.Errorf("Two user group should be found, instead found %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanNullValue(t *testing.T) {
|
||||
user := GetUser("scan_null_value", Config{})
|
||||
DB.Create(&user)
|
||||
|
||||
if err := DB.Model(&user).Update("age", nil).Error; err != nil {
|
||||
t.Fatalf("failed to update column age for struct, got error %v", err)
|
||||
}
|
||||
|
||||
var result User
|
||||
if err := DB.First(&result, "id = ?", user.ID).Error; err != nil {
|
||||
t.Fatalf("failed to query struct data with null age, got error %v", err)
|
||||
}
|
||||
|
||||
AssertEqual(t, result, user)
|
||||
|
||||
users := []User{
|
||||
*GetUser("scan_null_value_for_slice_1", Config{}),
|
||||
*GetUser("scan_null_value_for_slice_2", Config{}),
|
||||
*GetUser("scan_null_value_for_slice_3", Config{}),
|
||||
}
|
||||
DB.Create(&users)
|
||||
|
||||
if err := DB.Model(&users[0]).Update("age", nil).Error; err != nil {
|
||||
t.Fatalf("failed to update column age for struct, got error %v", err)
|
||||
}
|
||||
|
||||
var results []User
|
||||
if err := DB.Find(&results, "name like ?", "scan_null_value_for_slice%").Error; err != nil {
|
||||
t.Fatalf("failed to query slice data with null age, got error %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
dialects=("sqlite" "mysql" "postgres" "mssql")
|
||||
dialects=("sqlite" "mysql" "postgres" "sqlserver")
|
||||
|
||||
if [[ $(pwd) == *"gorm/tests"* ]]; then
|
||||
cd ..
|
||||
|
@ -10,7 +10,7 @@ for dialect in "${dialects[@]}" ; do
|
|||
echo "testing ${dialect}..."
|
||||
|
||||
race=""
|
||||
if [ "$GORM_VERBOSE" = "" ]
|
||||
if [ "$GORM_DIALECT" = "sqlserver" ]
|
||||
then
|
||||
race="-race"
|
||||
fi
|
||||
|
|
|
@ -23,6 +23,8 @@ func init() {
|
|||
if DB, err = OpenTestConnection(); err != nil {
|
||||
log.Printf("failed to connect database, got error %v\n", err)
|
||||
os.Exit(1)
|
||||
} else {
|
||||
RunMigrations()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -155,12 +155,14 @@ func TestUpdates(t *testing.T) {
|
|||
AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt)
|
||||
|
||||
// update with gorm exprs
|
||||
DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)})
|
||||
if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||
}
|
||||
var user4 User
|
||||
DB.First(&user4, user3.ID)
|
||||
|
||||
user3.Age += 100
|
||||
AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt)
|
||||
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
|
||||
}
|
||||
|
||||
func TestUpdateColumn(t *testing.T) {
|
||||
|
|
|
@ -121,6 +121,11 @@ func TestFindOrCreate(t *testing.T) {
|
|||
|
||||
updatedAt1 := user4.UpdatedAt
|
||||
DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4)
|
||||
|
||||
if user4.Age != 55 {
|
||||
t.Errorf("Failed to set change to 55, got %v", user4.Age)
|
||||
}
|
||||
|
||||
if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||
t.Errorf("UpdateAt should be changed when update values with assign")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue