mirror of https://github.com/go-gorm/gorm.git
Fix driver.Valuer interface returns nil, close #3248
This commit is contained in:
parent
a3dda47afa
commit
7d45833f3e
|
@ -731,40 +731,10 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
||||||
// struct scanner
|
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
|
||||||
if valuer, ok := v.(driver.Valuer); ok {
|
|
||||||
v, _ = valuer.Value()
|
|
||||||
}
|
|
||||||
|
|
||||||
reflectV := reflect.ValueOf(v)
|
|
||||||
if !reflectV.IsValid() {
|
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
|
||||||
if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() {
|
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
|
||||||
} else {
|
|
||||||
return field.Set(value, reflectV.Elem().Interface())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
|
||||||
// pointer scanner
|
// pointer scanner
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
reflectV := reflect.ValueOf(v)
|
reflectV := reflect.ValueOf(v)
|
||||||
|
|
||||||
if valuer, ok := v.(driver.Valuer); ok {
|
|
||||||
if valuer == nil || reflectV.IsNil() {
|
|
||||||
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
|
||||||
} else {
|
|
||||||
v, _ = valuer.Value()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if reflectV.Type().AssignableTo(field.FieldType) {
|
if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
field.ReflectValueOf(value).Set(reflectV)
|
field.ReflectValueOf(value).Set(reflectV)
|
||||||
} else if reflectV.Kind() == reflect.Ptr {
|
} else if reflectV.Kind() == reflect.Ptr {
|
||||||
|
@ -778,10 +748,38 @@ func (field *Field) setupValuerAndSetter() {
|
||||||
if fieldValue.IsNil() {
|
if fieldValue.IsNil() {
|
||||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
v, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
|
||||||
err = fieldValue.Interface().(sql.Scanner).Scan(v)
|
err = fieldValue.Interface().(sql.Scanner).Scan(v)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||||
|
// struct scanner
|
||||||
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
|
reflectV := reflect.ValueOf(v)
|
||||||
|
if !reflectV.IsValid() {
|
||||||
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||||
|
field.ReflectValueOf(value).Set(reflectV)
|
||||||
|
} else if reflectV.Kind() == reflect.Ptr {
|
||||||
|
if reflectV.IsNil() || !reflectV.IsValid() {
|
||||||
|
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
|
||||||
|
} else {
|
||||||
|
return field.Set(value, reflectV.Elem().Interface())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if valuer, ok := v.(driver.Valuer); ok {
|
||||||
|
v, _ = valuer.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
field.Set = func(value reflect.Value, v interface{}) (err error) {
|
||||||
return fallbackSetter(value, v, field.Set)
|
return fallbackSetter(value, v, field.Set)
|
||||||
|
|
|
@ -36,8 +36,8 @@ func TestScannerValuer(t *testing.T) {
|
||||||
{"name2", "value2"},
|
{"name2", "value2"},
|
||||||
},
|
},
|
||||||
Role: Role{Name: "admin"},
|
Role: Role{Name: "admin"},
|
||||||
ExampleStruct: ExampleStruct1{"name", "value"},
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
||||||
ExampleStructPtr: &ExampleStruct1{"name", "value"},
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&data).Error; err != nil {
|
if err := DB.Create(&data).Error; err != nil {
|
||||||
|
@ -46,19 +46,18 @@ func TestScannerValuer(t *testing.T) {
|
||||||
|
|
||||||
var result ScannerValuerStruct
|
var result ScannerValuerStruct
|
||||||
|
|
||||||
if err := DB.Find(&result).Error; err != nil {
|
if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil {
|
||||||
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
|
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if result.ExampleStructPtr.Val != "value2" {
|
||||||
|
t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ExampleStruct.Val != "value1" {
|
||||||
|
t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct)
|
||||||
|
}
|
||||||
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
|
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
|
||||||
|
|
||||||
if result.ExampleStructPtr.Val != "value" {
|
|
||||||
t.Errorf(`ExampleStructPtr.Val should equal to "value", but got %v`, result.ExampleStructPtr.Val)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.ExampleStruct.Val != "value" {
|
|
||||||
t.Errorf(`ExampleStruct.Val should equal to "value", but got %v`, result.ExampleStruct.Val)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScannerValuerWithFirstOrCreate(t *testing.T) {
|
func TestScannerValuerWithFirstOrCreate(t *testing.T) {
|
||||||
|
@ -71,6 +70,8 @@ func TestScannerValuerWithFirstOrCreate(t *testing.T) {
|
||||||
Name: sql.NullString{String: "name", Valid: true},
|
Name: sql.NullString{String: "name", Valid: true},
|
||||||
Gender: &sql.NullString{String: "M", Valid: true},
|
Gender: &sql.NullString{String: "M", Valid: true},
|
||||||
Age: sql.NullInt64{Int64: 18, Valid: true},
|
Age: sql.NullInt64{Int64: 18, Valid: true},
|
||||||
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
||||||
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
var result ScannerValuerStruct
|
var result ScannerValuerStruct
|
||||||
|
@ -110,6 +111,8 @@ func TestInvalidValuer(t *testing.T) {
|
||||||
|
|
||||||
data := ScannerValuerStruct{
|
data := ScannerValuerStruct{
|
||||||
Password: EncryptedData("xpass1"),
|
Password: EncryptedData("xpass1"),
|
||||||
|
ExampleStruct: ExampleStruct{"name", "value1"},
|
||||||
|
ExampleStructPtr: &ExampleStruct{"name", "value2"},
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := DB.Create(&data).Error; err == nil {
|
if err := DB.Create(&data).Error; err == nil {
|
||||||
|
@ -149,8 +152,8 @@ type ScannerValuerStruct struct {
|
||||||
UserID *sql.NullInt64
|
UserID *sql.NullInt64
|
||||||
User User
|
User User
|
||||||
EmptyTime EmptyTime
|
EmptyTime EmptyTime
|
||||||
ExampleStruct ExampleStruct1
|
ExampleStruct ExampleStruct
|
||||||
ExampleStructPtr *ExampleStruct1
|
ExampleStructPtr *ExampleStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
type EncryptedData []byte
|
type EncryptedData []byte
|
||||||
|
@ -216,24 +219,23 @@ func (l *StringsSlice) Scan(input interface{}) error {
|
||||||
|
|
||||||
type ExampleStruct struct {
|
type ExampleStruct struct {
|
||||||
Name string
|
Name string
|
||||||
Value string
|
Val string
|
||||||
}
|
}
|
||||||
|
|
||||||
type ExampleStruct1 struct {
|
func (ExampleStruct) GormDataType() string {
|
||||||
Name string `json:"name,omitempty"`
|
return "bytes"
|
||||||
Val string `json:"val,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s ExampleStruct1) Value() (driver.Value, error) {
|
func (s ExampleStruct) Value() (driver.Value, error) {
|
||||||
if len(s.Name) == 0 {
|
if len(s.Name) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
//for test, has no practical meaning
|
// for test, has no practical meaning
|
||||||
s.Name = ""
|
s.Name = ""
|
||||||
return json.Marshal(s)
|
return json.Marshal(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ExampleStruct1) Scan(src interface{}) error {
|
func (s *ExampleStruct) Scan(src interface{}) error {
|
||||||
switch value := src.(type) {
|
switch value := src.(type) {
|
||||||
case string:
|
case string:
|
||||||
return json.Unmarshal([]byte(value), s)
|
return json.Unmarshal([]byte(value), s)
|
||||||
|
|
Loading…
Reference in New Issue