From 7d45833f3e309f9c15bb9ca301c1782b23cb9f0e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 12:05:55 +0800 Subject: [PATCH] Fix driver.Valuer interface returns nil, close #3248 --- schema/field.go | 60 +++++++++++++++++------------------- tests/scanner_valuer_test.go | 52 ++++++++++++++++--------------- 2 files changed, 56 insertions(+), 56 deletions(-) diff --git a/schema/field.go b/schema/field.go index ea6364a4..84fdb695 100644 --- a/schema/field.go +++ b/schema/field.go @@ -731,40 +731,10 @@ func (field *Field) setupValuerAndSetter() { return nil } default: - if _, ok := fieldValue.Interface().(sql.Scanner); ok { - // struct scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { - if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - } - - reflectV := reflect.ValueOf(v) - if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else if reflectV.Kind() == reflect.Ptr { - if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(value, reflectV.Elem().Interface()) - } - } else { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) - } - return - } - } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) - - if valuer, ok := v.(driver.Valuer); ok { - if valuer == nil || reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else { - v, _ = valuer.Value() - } - } - if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { @@ -778,10 +748,38 @@ func (field *Field) setupValuerAndSetter() { if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } + + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + err = fieldValue.Interface().(sql.Scanner).Scan(v) } return } + } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner + field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } + } else { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } } else { field.Set = func(value reflect.Value, v interface{}) (err error) { return fallbackSetter(value, v, field.Set) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 6b8f086e..b8306af7 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -36,8 +36,8 @@ func TestScannerValuer(t *testing.T) { {"name2", "value2"}, }, Role: Role{Name: "admin"}, - ExampleStruct: ExampleStruct1{"name", "value"}, - ExampleStructPtr: &ExampleStruct1{"name", "value"}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err != nil { @@ -46,19 +46,18 @@ func TestScannerValuer(t *testing.T) { var result ScannerValuerStruct - if err := DB.Find(&result).Error; err != nil { + if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil { t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) } + if result.ExampleStructPtr.Val != "value2" { + t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val) + } + + if result.ExampleStruct.Val != "value1" { + t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct) + } AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") - - if result.ExampleStructPtr.Val != "value" { - t.Errorf(`ExampleStructPtr.Val should equal to "value", but got %v`, result.ExampleStructPtr.Val) - } - - if result.ExampleStruct.Val != "value" { - t.Errorf(`ExampleStruct.Val should equal to "value", but got %v`, result.ExampleStruct.Val) - } } func TestScannerValuerWithFirstOrCreate(t *testing.T) { @@ -68,9 +67,11 @@ func TestScannerValuerWithFirstOrCreate(t *testing.T) { } data := ScannerValuerStruct{ - Name: sql.NullString{String: "name", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: true}, + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } var result ScannerValuerStruct @@ -109,7 +110,9 @@ func TestInvalidValuer(t *testing.T) { } data := ScannerValuerStruct{ - Password: EncryptedData("xpass1"), + Password: EncryptedData("xpass1"), + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err == nil { @@ -149,8 +152,8 @@ type ScannerValuerStruct struct { UserID *sql.NullInt64 User User EmptyTime EmptyTime - ExampleStruct ExampleStruct1 - ExampleStructPtr *ExampleStruct1 + ExampleStruct ExampleStruct + ExampleStructPtr *ExampleStruct } type EncryptedData []byte @@ -215,25 +218,24 @@ func (l *StringsSlice) Scan(input interface{}) error { } type ExampleStruct struct { - Name string - Value string + Name string + Val string } -type ExampleStruct1 struct { - Name string `json:"name,omitempty"` - Val string `json:"val,omitempty"` +func (ExampleStruct) GormDataType() string { + return "bytes" } -func (s ExampleStruct1) Value() (driver.Value, error) { +func (s ExampleStruct) Value() (driver.Value, error) { if len(s.Name) == 0 { return nil, nil } - //for test, has no practical meaning + // for test, has no practical meaning s.Name = "" return json.Marshal(s) } -func (s *ExampleStruct1) Scan(src interface{}) error { +func (s *ExampleStruct) Scan(src interface{}) error { switch value := src.(type) { case string: return json.Unmarshal([]byte(value), s)