mirror of https://github.com/go-gorm/gorm.git
Fix can't scan null value into normal data types
This commit is contained in:
parent
2218e32999
commit
94685d1024
|
@ -168,7 +168,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
|
|
||||||
return tx.Create(dest)
|
return tx.Create(dest)
|
||||||
} else if len(tx.Statement.assigns) > 0 {
|
} else if len(tx.Statement.assigns) > 0 {
|
||||||
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:])
|
exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
|
||||||
assigns := map[string]interface{}{}
|
assigns := map[string]interface{}{}
|
||||||
for _, expr := range exprs {
|
for _, expr := range exprs {
|
||||||
if eq, ok := expr.(clause.Eq); ok {
|
if eq, ok := expr.(clause.Eq); ok {
|
||||||
|
|
102
scan.go
102
scan.go
|
@ -14,40 +14,53 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
|
|
||||||
switch dest := db.Statement.Dest.(type) {
|
switch dest := db.Statement.Dest.(type) {
|
||||||
case map[string]interface{}, *map[string]interface{}:
|
case map[string]interface{}, *map[string]interface{}:
|
||||||
for idx, _ := range columns {
|
if initialized || rows.Next() {
|
||||||
|
for idx := range columns {
|
||||||
values[idx] = new(interface{})
|
values[idx] = new(interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
if initialized || rows.Next() {
|
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
}
|
|
||||||
|
|
||||||
mapValue, ok := dest.(map[string]interface{})
|
mapValue, ok := dest.(map[string]interface{})
|
||||||
if ok {
|
if !ok {
|
||||||
if v, ok := dest.(*map[string]interface{}); ok {
|
if v, ok := dest.(*map[string]interface{}); ok {
|
||||||
mapValue = *v
|
mapValue = *v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
mapValue[column] = *(values[idx].(*interface{}))
|
if v, ok := values[idx].(*interface{}); ok {
|
||||||
|
if v == nil {
|
||||||
|
mapValue[column] = nil
|
||||||
|
} else {
|
||||||
|
mapValue[column] = *v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case *[]map[string]interface{}:
|
case *[]map[string]interface{}:
|
||||||
for idx, _ := range columns {
|
for initialized || rows.Next() {
|
||||||
|
for idx := range columns {
|
||||||
values[idx] = new(interface{})
|
values[idx] = new(interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
for initialized || rows.Next() {
|
|
||||||
initialized = false
|
initialized = false
|
||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
v := map[string]interface{}{}
|
mapValue := map[string]interface{}{}
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
v[column] = *(values[idx].(*interface{}))
|
if v, ok := values[idx].(*interface{}); ok {
|
||||||
|
if v == nil {
|
||||||
|
mapValue[column] = nil
|
||||||
|
} else {
|
||||||
|
mapValue[column] = *v
|
||||||
}
|
}
|
||||||
*dest = append(*dest, v)
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*dest = append(*dest, mapValue)
|
||||||
}
|
}
|
||||||
case *int, *int64, *uint, *uint64:
|
case *int, *int64, *uint, *uint64:
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
|
@ -85,28 +98,52 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
|
for idx := range columns {
|
||||||
|
values[idx] = new(interface{})
|
||||||
|
}
|
||||||
|
|
||||||
initialized = false
|
initialized = false
|
||||||
|
db.RowsAffected++
|
||||||
|
|
||||||
elem := reflect.New(reflectValueType).Elem()
|
elem := reflect.New(reflectValueType).Elem()
|
||||||
|
|
||||||
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
|
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
|
||||||
|
// pluck
|
||||||
values[0] = elem.Addr().Interface()
|
values[0] = elem.Addr().Interface()
|
||||||
|
db.AddError(rows.Scan(values...))
|
||||||
} else {
|
} else {
|
||||||
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
|
if v, ok := values[idx].(*interface{}); ok {
|
||||||
if field != nil {
|
if field != nil {
|
||||||
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
|
if v == nil {
|
||||||
|
field.Set(elem, v)
|
||||||
|
} else {
|
||||||
|
field.Set(elem, *v)
|
||||||
|
}
|
||||||
} else if joinFields[idx][0] != nil {
|
} else if joinFields[idx][0] != nil {
|
||||||
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
||||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||||
|
if v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||||
}
|
}
|
||||||
|
|
||||||
values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface()
|
if v == nil {
|
||||||
|
joinFields[idx][1].Set(relValue, nil)
|
||||||
|
} else {
|
||||||
|
joinFields[idx][1].Set(relValue, *v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RowsAffected++
|
for idx := range columns {
|
||||||
db.AddError(rows.Scan(values...))
|
values[idx] = new(interface{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if isPtr {
|
if isPtr {
|
||||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
|
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr()))
|
||||||
|
@ -115,30 +152,45 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case reflect.Struct:
|
case reflect.Struct:
|
||||||
|
if initialized || rows.Next() {
|
||||||
|
for idx := range columns {
|
||||||
|
values[idx] = new(interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
db.RowsAffected++
|
||||||
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable {
|
||||||
values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface()
|
if v, ok := values[idx].(*interface{}); ok {
|
||||||
|
if v == nil {
|
||||||
|
field.Set(db.Statement.ReflectValue, v)
|
||||||
|
} else {
|
||||||
|
field.Set(db.Statement.ReflectValue, *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok {
|
||||||
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||||
|
if v, ok := values[idx].(*interface{}); ok {
|
||||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||||
|
if v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||||
}
|
}
|
||||||
|
|
||||||
values[idx] = field.ReflectValueOf(relValue).Addr().Interface()
|
if v == nil {
|
||||||
continue
|
field.Set(relValue, nil)
|
||||||
}
|
|
||||||
}
|
|
||||||
values[idx] = &sql.RawBytes{}
|
|
||||||
} else {
|
} else {
|
||||||
values[idx] = &sql.RawBytes{}
|
field.Set(relValue, *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if initialized || rows.Next() {
|
|
||||||
db.RowsAffected++
|
|
||||||
db.AddError(rows.Scan(values...))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
107
schema/field.go
107
schema/field.go
|
@ -402,34 +402,48 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
|
fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
} else {
|
} else {
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
|
if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
field.ReflectValueOf(value).Set(reflectV)
|
||||||
|
return
|
||||||
|
} else if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
||||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
return
|
||||||
if v, err = valuer.Value(); err == nil {
|
} else if field.FieldType.Kind() == reflect.Ptr {
|
||||||
return setter(value, v)
|
|
||||||
}
|
|
||||||
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
|
|
||||||
fieldValue := field.ReflectValueOf(value)
|
fieldValue := field.ReflectValueOf(value)
|
||||||
|
|
||||||
|
if reflectV.Type().AssignableTo(field.FieldType.Elem()) {
|
||||||
if fieldValue.IsNil() {
|
if fieldValue.IsNil() {
|
||||||
if v == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||||
}
|
}
|
||||||
|
fieldValue.Elem().Set(reflectV)
|
||||||
|
return
|
||||||
|
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
|
||||||
|
if fieldValue.IsNil() {
|
||||||
|
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||||
|
}
|
||||||
|
|
||||||
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
|
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
if v, err = valuer.Value(); err == nil {
|
||||||
|
setter(value, v)
|
||||||
|
}
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
} else if reflectV.Kind() == reflect.Ptr {
|
||||||
return field.Set(value, reflectV.Elem().Interface())
|
setter(value, reflectV.Elem().Interface())
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set
|
// Set
|
||||||
|
@ -441,8 +455,17 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
field.ReflectValueOf(value).SetBool(data)
|
field.ReflectValueOf(value).SetBool(data)
|
||||||
case *bool:
|
case *bool:
|
||||||
field.ReflectValueOf(value).SetBool(*data)
|
field.ReflectValueOf(value).SetBool(*data)
|
||||||
|
case int64:
|
||||||
|
if data > 0 {
|
||||||
|
field.ReflectValueOf(value).SetBool(true)
|
||||||
|
} else {
|
||||||
|
field.ReflectValueOf(value).SetBool(false)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
b, _ := strconv.ParseBool(data)
|
||||||
|
field.ReflectValueOf(value).SetBool(b)
|
||||||
default:
|
default:
|
||||||
return recoverFunc(value, v, field.Set)
|
return fallbackSetter(value, v, field.Set)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -498,7 +521,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
field.ReflectValueOf(value).SetInt(0)
|
field.ReflectValueOf(value).SetInt(0)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return recoverFunc(value, v, field.Set)
|
return fallbackSetter(value, v, field.Set)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -538,7 +561,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return recoverFunc(value, v, field.Set)
|
return fallbackSetter(value, v, field.Set)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -578,7 +601,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return recoverFunc(value, v, field.Set)
|
return fallbackSetter(value, v, field.Set)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -594,7 +617,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
case float64, float32:
|
case float64, float32:
|
||||||
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
||||||
default:
|
default:
|
||||||
return recoverFunc(value, v, field.Set)
|
return fallbackSetter(value, v, field.Set)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -615,7 +638,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return recoverFunc(value, v, field.Set)
|
return fallbackSetter(value, v, field.Set)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -625,9 +648,6 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
case time.Time:
|
case time.Time:
|
||||||
fieldValue := field.ReflectValueOf(value)
|
fieldValue := field.ReflectValueOf(value)
|
||||||
if fieldValue.IsNil() {
|
if fieldValue.IsNil() {
|
||||||
if v == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||||
}
|
}
|
||||||
fieldValue.Elem().Set(reflect.ValueOf(v))
|
fieldValue.Elem().Set(reflect.ValueOf(v))
|
||||||
|
@ -647,7 +667,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return recoverFunc(value, v, field.Set)
|
return fallbackSetter(value, v, field.Set)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -655,53 +675,42 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||||
// struct scanner
|
// struct scanner
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
if v == nil {
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
v, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
reflectV := reflect.ValueOf(v)
|
||||||
|
if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
} else {
|
} else {
|
||||||
reflectV := reflect.ValueOf(v)
|
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
|
||||||
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
|
||||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
|
||||||
if v, err = valuer.Value(); err == nil {
|
|
||||||
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
||||||
// pointer scanner
|
// pointer scanner
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
if v == nil {
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
v, _ = valuer.Value()
|
||||||
} else {
|
}
|
||||||
|
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
if reflectV.Type().ConvertibleTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
|
||||||
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
|
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||||
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else {
|
||||||
fieldValue := field.ReflectValueOf(value)
|
fieldValue := field.ReflectValueOf(value)
|
||||||
if fieldValue.IsNil() {
|
if fieldValue.IsNil() {
|
||||||
if v == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||||
}
|
}
|
||||||
fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem()))
|
err = fieldValue.Interface().(sql.Scanner).Scan(v)
|
||||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
|
||||||
if v, err = valuer.Value(); err == nil {
|
|
||||||
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
return recoverFunc(value, v, field.Set)
|
return fallbackSetter(value, v, field.Set)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
12
statement.go
12
statement.go
|
@ -146,8 +146,16 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||||
case clause.Column, clause.Table:
|
case clause.Column, clause.Table:
|
||||||
stmt.QuoteTo(writer, v)
|
stmt.QuoteTo(writer, v)
|
||||||
case clause.Expr:
|
case clause.Expr:
|
||||||
writer.WriteString(v.SQL)
|
var varStr strings.Builder
|
||||||
stmt.Vars = append(stmt.Vars, v.Vars...)
|
var sql = v.SQL
|
||||||
|
for _, arg := range v.Vars {
|
||||||
|
stmt.Vars = append(stmt.Vars, arg)
|
||||||
|
stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg)
|
||||||
|
sql = strings.Replace(sql, "?", varStr.String(), 1)
|
||||||
|
varStr.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.WriteString(sql)
|
||||||
case driver.Valuer:
|
case driver.Valuer:
|
||||||
stmt.Vars = append(stmt.Vars, v)
|
stmt.Vars = append(stmt.Vars, v)
|
||||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||||
|
|
|
@ -6,11 +6,6 @@ import (
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
RunMigrations()
|
|
||||||
m.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExceptionsWithInvalidSql(t *testing.T) {
|
func TestExceptionsWithInvalidSql(t *testing.T) {
|
||||||
var columns []string
|
var columns []string
|
||||||
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||||
|
|
|
@ -1299,7 +1299,6 @@ func TestNilPointerSlice(t *testing.T) {
|
||||||
)
|
)
|
||||||
|
|
||||||
DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{})
|
DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{})
|
||||||
|
|
||||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil {
|
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -435,3 +435,35 @@ func TestSubQueryWithHaving(t *testing.T) {
|
||||||
t.Errorf("Two user group should be found, instead found %d", len(results))
|
t.Errorf("Two user group should be found, instead found %d", len(results))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScanNullValue(t *testing.T) {
|
||||||
|
user := GetUser("scan_null_value", Config{})
|
||||||
|
DB.Create(&user)
|
||||||
|
|
||||||
|
if err := DB.Model(&user).Update("age", nil).Error; err != nil {
|
||||||
|
t.Fatalf("failed to update column age for struct, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result User
|
||||||
|
if err := DB.First(&result, "id = ?", user.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to query struct data with null age, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result, user)
|
||||||
|
|
||||||
|
users := []User{
|
||||||
|
*GetUser("scan_null_value_for_slice_1", Config{}),
|
||||||
|
*GetUser("scan_null_value_for_slice_2", Config{}),
|
||||||
|
*GetUser("scan_null_value_for_slice_3", Config{}),
|
||||||
|
}
|
||||||
|
DB.Create(&users)
|
||||||
|
|
||||||
|
if err := DB.Model(&users[0]).Update("age", nil).Error; err != nil {
|
||||||
|
t.Fatalf("failed to update column age for struct, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []User
|
||||||
|
if err := DB.Find(&results, "name like ?", "scan_null_value_for_slice%").Error; err != nil {
|
||||||
|
t.Fatalf("failed to query slice data with null age, got error %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
dialects=("sqlite" "mysql" "postgres" "mssql")
|
dialects=("sqlite" "mysql" "postgres" "sqlserver")
|
||||||
|
|
||||||
if [[ $(pwd) == *"gorm/tests"* ]]; then
|
if [[ $(pwd) == *"gorm/tests"* ]]; then
|
||||||
cd ..
|
cd ..
|
||||||
|
@ -10,7 +10,7 @@ for dialect in "${dialects[@]}" ; do
|
||||||
echo "testing ${dialect}..."
|
echo "testing ${dialect}..."
|
||||||
|
|
||||||
race=""
|
race=""
|
||||||
if [ "$GORM_VERBOSE" = "" ]
|
if [ "$GORM_DIALECT" = "sqlserver" ]
|
||||||
then
|
then
|
||||||
race="-race"
|
race="-race"
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -23,6 +23,8 @@ func init() {
|
||||||
if DB, err = OpenTestConnection(); err != nil {
|
if DB, err = OpenTestConnection(); err != nil {
|
||||||
log.Printf("failed to connect database, got error %v\n", err)
|
log.Printf("failed to connect database, got error %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
} else {
|
||||||
|
RunMigrations()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -155,12 +155,14 @@ func TestUpdates(t *testing.T) {
|
||||||
AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt)
|
AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt)
|
||||||
|
|
||||||
// update with gorm exprs
|
// update with gorm exprs
|
||||||
DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)})
|
if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
|
||||||
|
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
|
||||||
|
}
|
||||||
var user4 User
|
var user4 User
|
||||||
DB.First(&user4, user3.ID)
|
DB.First(&user4, user3.ID)
|
||||||
|
|
||||||
user3.Age += 100
|
user3.Age += 100
|
||||||
AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt)
|
AssertObjEqual(t, user4, user3, "UpdatedAt", "Age")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateColumn(t *testing.T) {
|
func TestUpdateColumn(t *testing.T) {
|
||||||
|
|
|
@ -121,6 +121,11 @@ func TestFindOrCreate(t *testing.T) {
|
||||||
|
|
||||||
updatedAt1 := user4.UpdatedAt
|
updatedAt1 := user4.UpdatedAt
|
||||||
DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4)
|
DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4)
|
||||||
|
|
||||||
|
if user4.Age != 55 {
|
||||||
|
t.Errorf("Failed to set change to 55, got %v", user4.Age)
|
||||||
|
}
|
||||||
|
|
||||||
if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) {
|
if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) {
|
||||||
t.Errorf("UpdateAt should be changed when update values with assign")
|
t.Errorf("UpdateAt should be changed when update values with assign")
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue