Fix driver.Valuer interface returns nil, close #3248

This commit is contained in:
Jinzhu 2020-08-13 12:05:55 +08:00
parent a3dda47afa
commit 7d45833f3e
2 changed files with 56 additions and 56 deletions

View File

@ -731,40 +731,10 @@ func (field *Field) setupValuerAndSetter() {
return nil return nil
} }
default: default:
if _, ok := fieldValue.Interface().(sql.Scanner); ok { if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// struct scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
}
} else {
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
return
}
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner // pointer scanner
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if valuer, ok := v.(driver.Valuer); ok {
if valuer == nil || reflectV.IsNil() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
v, _ = valuer.Value()
}
}
if reflectV.Type().AssignableTo(field.FieldType) { if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV) field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
@ -778,10 +748,38 @@ func (field *Field) setupValuerAndSetter() {
if fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
} }
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = fieldValue.Interface().(sql.Scanner).Scan(v) err = fieldValue.Interface().(sql.Scanner).Scan(v)
} }
return return
} }
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() || !reflectV.IsValid() {
field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem())
} else {
return field.Set(value, reflectV.Elem().Interface())
}
} else {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
return
}
} else { } else {
field.Set = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
return fallbackSetter(value, v, field.Set) return fallbackSetter(value, v, field.Set)

View File

@ -36,8 +36,8 @@ func TestScannerValuer(t *testing.T) {
{"name2", "value2"}, {"name2", "value2"},
}, },
Role: Role{Name: "admin"}, Role: Role{Name: "admin"},
ExampleStruct: ExampleStruct1{"name", "value"}, ExampleStruct: ExampleStruct{"name", "value1"},
ExampleStructPtr: &ExampleStruct1{"name", "value"}, ExampleStructPtr: &ExampleStruct{"name", "value2"},
} }
if err := DB.Create(&data).Error; err != nil { if err := DB.Create(&data).Error; err != nil {
@ -46,19 +46,18 @@ func TestScannerValuer(t *testing.T) {
var result ScannerValuerStruct var result ScannerValuerStruct
if err := DB.Find(&result).Error; err != nil { if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil {
t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err)
} }
if result.ExampleStructPtr.Val != "value2" {
t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val)
}
if result.ExampleStruct.Val != "value1" {
t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct)
}
AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs")
if result.ExampleStructPtr.Val != "value" {
t.Errorf(`ExampleStructPtr.Val should equal to "value", but got %v`, result.ExampleStructPtr.Val)
}
if result.ExampleStruct.Val != "value" {
t.Errorf(`ExampleStruct.Val should equal to "value", but got %v`, result.ExampleStruct.Val)
}
} }
func TestScannerValuerWithFirstOrCreate(t *testing.T) { func TestScannerValuerWithFirstOrCreate(t *testing.T) {
@ -68,9 +67,11 @@ func TestScannerValuerWithFirstOrCreate(t *testing.T) {
} }
data := ScannerValuerStruct{ data := ScannerValuerStruct{
Name: sql.NullString{String: "name", Valid: true}, Name: sql.NullString{String: "name", Valid: true},
Gender: &sql.NullString{String: "M", Valid: true}, Gender: &sql.NullString{String: "M", Valid: true},
Age: sql.NullInt64{Int64: 18, Valid: true}, Age: sql.NullInt64{Int64: 18, Valid: true},
ExampleStruct: ExampleStruct{"name", "value1"},
ExampleStructPtr: &ExampleStruct{"name", "value2"},
} }
var result ScannerValuerStruct var result ScannerValuerStruct
@ -109,7 +110,9 @@ func TestInvalidValuer(t *testing.T) {
} }
data := ScannerValuerStruct{ data := ScannerValuerStruct{
Password: EncryptedData("xpass1"), Password: EncryptedData("xpass1"),
ExampleStruct: ExampleStruct{"name", "value1"},
ExampleStructPtr: &ExampleStruct{"name", "value2"},
} }
if err := DB.Create(&data).Error; err == nil { if err := DB.Create(&data).Error; err == nil {
@ -149,8 +152,8 @@ type ScannerValuerStruct struct {
UserID *sql.NullInt64 UserID *sql.NullInt64
User User User User
EmptyTime EmptyTime EmptyTime EmptyTime
ExampleStruct ExampleStruct1 ExampleStruct ExampleStruct
ExampleStructPtr *ExampleStruct1 ExampleStructPtr *ExampleStruct
} }
type EncryptedData []byte type EncryptedData []byte
@ -215,25 +218,24 @@ func (l *StringsSlice) Scan(input interface{}) error {
} }
type ExampleStruct struct { type ExampleStruct struct {
Name string Name string
Value string Val string
} }
type ExampleStruct1 struct { func (ExampleStruct) GormDataType() string {
Name string `json:"name,omitempty"` return "bytes"
Val string `json:"val,omitempty"`
} }
func (s ExampleStruct1) Value() (driver.Value, error) { func (s ExampleStruct) Value() (driver.Value, error) {
if len(s.Name) == 0 { if len(s.Name) == 0 {
return nil, nil return nil, nil
} }
//for test, has no practical meaning // for test, has no practical meaning
s.Name = "" s.Name = ""
return json.Marshal(s) return json.Marshal(s)
} }
func (s *ExampleStruct1) Scan(src interface{}) error { func (s *ExampleStruct) Scan(src interface{}) error {
switch value := src.(type) { switch value := src.(type) {
case string: case string:
return json.Unmarshal([]byte(value), s) return json.Unmarshal([]byte(value), s)