Add more tests for setter, valuer

This commit is contained in:
Jinzhu 2020-02-16 00:37:59 +08:00
parent faee069a9f
commit 18236fa3d7
5 changed files with 281 additions and 133 deletions

View File

@ -164,6 +164,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
case reflect.Struct: case reflect.Struct:
if _, ok := fieldValue.Interface().(*time.Time); ok { if _, ok := fieldValue.Interface().(*time.Time); ok {
field.DataType = Time field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
field.DataType = Time
} }
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) { if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) {
@ -311,6 +313,24 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
return setter(value, v)
}
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
return err
}
// Setter // Setter
switch field.FieldType.Kind() { switch field.FieldType.Kind() {
case reflect.Bool: case reflect.Bool:
@ -321,17 +341,12 @@ func (field *Field) setupValuerAndSetter() {
case *bool: case *bool:
field.ReflectValuer(value).SetBool(*data) field.ReflectValuer(value).SetBool(*data)
default: default:
reflectV := reflect.ValueOf(v) return recoverFunc(value, v, field.Setter)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else {
field.ReflectValuer(value).SetBool(!reflect.ValueOf(v).IsZero())
}
} }
return nil return nil
} }
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Setter = func(value reflect.Value, v interface{}) error { field.Setter = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case int64: case int64:
field.ReflectValuer(value).SetInt(data) field.ReflectValuer(value).SetInt(data)
@ -366,19 +381,12 @@ func (field *Field) setupValuerAndSetter() {
return err return err
} }
default: default:
reflectV := reflect.ValueOf(v) return recoverFunc(value, v, field.Setter)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} }
} return err
return nil
} }
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Setter = func(value reflect.Value, v interface{}) error { field.Setter = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case uint64: case uint64:
field.ReflectValuer(value).SetUint(data) field.ReflectValuer(value).SetUint(data)
@ -413,19 +421,12 @@ func (field *Field) setupValuerAndSetter() {
return err return err
} }
default: default:
reflectV := reflect.ValueOf(v) return recoverFunc(value, v, field.Setter)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} }
} return err
return nil
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
field.Setter = func(value reflect.Value, v interface{}) error { field.Setter = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case float64: case float64:
field.ReflectValuer(value).SetFloat(data) field.ReflectValuer(value).SetFloat(data)
@ -460,19 +461,12 @@ func (field *Field) setupValuerAndSetter() {
return err return err
} }
default: default:
reflectV := reflect.ValueOf(v) return recoverFunc(value, v, field.Setter)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} }
} return err
return nil
} }
case reflect.String: case reflect.String:
field.Setter = func(value reflect.Value, v interface{}) error { field.Setter = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case string: case string:
field.ReflectValuer(value).SetString(data) field.ReflectValuer(value).SetString(data)
@ -483,16 +477,9 @@ func (field *Field) setupValuerAndSetter() {
case float64, float32: case float64, float32:
field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default: default:
reflectV := reflect.ValueOf(v) return recoverFunc(value, v, field.Setter)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} }
} return err
return nil
} }
default: default:
fieldValue := reflect.New(field.FieldType) fieldValue := reflect.New(field.FieldType)
@ -511,7 +498,7 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name) return recoverFunc(value, v, field.Setter)
} }
return nil return nil
} }
@ -529,29 +516,18 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name) return recoverFunc(value, v, field.Setter)
} }
return nil return nil
} }
default: default:
if _, ok := fieldValue.Interface().(sql.Scanner); ok { if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Setter = func(value reflect.Value, v interface{}) (err error) {
if valuer, ok := v.(driver.Valuer); ok { reflectV := reflect.ValueOf(v)
if v, err = valuer.Value(); err == nil { if reflectV.Type().ConvertibleTo(field.FieldType) {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} } else if valuer, ok := v.(driver.Valuer); ok {
} else {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
}
return
}
return
}
if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
} }
@ -560,31 +536,26 @@ func (field *Field) setupValuerAndSetter() {
} }
return return
} }
return } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
} // pointer scanner
}
if field.FieldType.Kind() == reflect.Ptr {
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Setter = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else { } else if valuer, ok := v.(driver.Valuer); ok {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
} }
return nil } else {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
}
return
} }
} else { } else {
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Setter = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) return recoverFunc(value, v, field.Setter)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
return nil
} }
} }
} }

View File

@ -1,6 +1,7 @@
package schema_test package schema_test
import ( import (
"database/sql"
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
@ -13,8 +14,7 @@ import (
func TestFieldValuerAndSetter(t *testing.T) { func TestFieldValuerAndSetter(t *testing.T) {
var ( var (
cacheMap = sync.Map{} userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
userSchema, _ = schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
user = tests.User{ user = tests.User{
Model: gorm.Model{ Model: gorm.Model{
ID: 10, ID: 10,
@ -54,18 +54,36 @@ func TestFieldValuerAndSetter(t *testing.T) {
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v", k) t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
checkField(t, userSchema, reflectValue, newValues) checkField(t, userSchema, reflectValue, newValues)
// test valuer and other type
age := myint(10)
newValues2 := map[string]interface{}{
"name": sql.NullString{String: "valuer_and_setter_3", Valid: true},
"id": &sql.NullInt64{Int64: 3, Valid: true},
"created_at": tests.Now(),
"deleted_at": time.Now(),
"age": &age,
"birthday": mytime(time.Now()),
"active": mybool(true),
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
checkField(t, userSchema, reflectValue, newValues2)
} }
func TestPointerFieldValuerAndSetter(t *testing.T) { func TestPointerFieldValuerAndSetter(t *testing.T) {
var ( var (
cacheMap = sync.Map{} userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{})
name = "pointer_field_valuer_and_setter" name = "pointer_field_valuer_and_setter"
age = 18 age uint = 18
active = true active = true
user = User{ user = User{
Model: &gorm.Model{ Model: &gorm.Model{
@ -110,22 +128,91 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
} }
} }
checkField(t, userSchema, reflectValue, newValues) checkField(t, userSchema, reflectValue, newValues)
// test valuer and other type
age2 := myint(10)
newValues2 := map[string]interface{}{
"name": sql.NullString{String: "valuer_and_setter_3", Valid: true},
"id": &sql.NullInt64{Int64: 3, Valid: true},
"created_at": tests.Now(),
"deleted_at": time.Now(),
"age": &age2,
"birthday": mytime(time.Now()),
"active": mybool(true),
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
checkField(t, userSchema, reflectValue, newValues2)
} }
type User struct { func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
*gorm.Model var (
Name *string userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
Age *int name = "advanced_data_type_valuer_and_setter"
Birthday *time.Time deletedAt = mytime(time.Now())
Account *tests.Account isAdmin = mybool(false)
Pets []*tests.Pet user = AdvancedDataTypeUser{
Toys []tests.Toy `gorm:"polymorphic:Owner"` ID: sql.NullInt64{Int64: 10, Valid: true},
CompanyID *int Name: &sql.NullString{String: name, Valid: true},
Company *tests.Company Birthday: sql.NullTime{Time: time.Now(), Valid: true},
ManagerID *int RegisteredAt: mytime(time.Now()),
Manager *User DeletedAt: &deletedAt,
Team []User `gorm:"foreignkey:ManagerID"` Active: mybool(true),
Languages []tests.Language `gorm:"many2many:UserSpeak"` Admin: &isAdmin,
Friends []*User `gorm:"many2many:user_friends"` }
Active *bool reflectValue = reflect.ValueOf(&user)
)
// test valuer
values := map[string]interface{}{
"id": user.ID,
"name": user.Name,
"birthday": user.Birthday,
"registered_at": user.RegisteredAt,
"deleted_at": user.DeletedAt,
"active": user.Active,
"admin": user.Admin,
}
checkField(t, userSchema, reflectValue, values)
// test setter
newDeletedAt := mytime(time.Now())
newIsAdmin := mybool(true)
newValues := map[string]interface{}{
"id": sql.NullInt64{Int64: 1, Valid: true},
"name": &sql.NullString{String: name + "rename", Valid: true},
"birthday": time.Now(),
"registered_at": mytime(time.Now()),
"deleted_at": &newDeletedAt,
"active": mybool(false),
"admin": &newIsAdmin,
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
checkField(t, userSchema, reflectValue, newValues)
newValues2 := map[string]interface{}{
"id": 5,
"name": name + "rename2",
"birthday": time.Now(),
"registered_at": time.Now(),
"deleted_at": time.Now(),
"active": true,
"admin": false,
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
checkField(t, userSchema, reflectValue, newValues2)
} }

41
schema/model_test.go Normal file
View File

@ -0,0 +1,41 @@
package schema_test
import (
"database/sql"
"time"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/tests"
)
type User struct {
*gorm.Model
Name *string
Age *uint
Birthday *time.Time
Account *tests.Account
Pets []*tests.Pet
Toys []*tests.Toy `gorm:"polymorphic:Owner"`
CompanyID *int
Company *tests.Company
ManagerID *int
Manager *User
Team []*User `gorm:"foreignkey:ManagerID"`
Languages []*tests.Language `gorm:"many2many:UserSpeak"`
Friends []*User `gorm:"many2many:user_friends"`
Active *bool
}
type mytime time.Time
type myint int
type mybool = bool
type AdvancedDataTypeUser struct {
ID sql.NullInt64
Name *sql.NullString
Birthday sql.NullTime
RegisteredAt mytime
DeletedAt *mytime
Active mybool
Admin *mybool
}

View File

@ -1,6 +1,7 @@
package schema_test package schema_test
import ( import (
"database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
@ -194,30 +195,39 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values { for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) { t.Run("CheckField/"+k, func(t *testing.T) {
field := s.FieldsByDBName[k] var (
fv := field.ValueOf(value) checker func(fv interface{}, v interface{})
field = s.FieldsByDBName[k]
fv = field.ValueOf(value)
)
if reflect.ValueOf(fv).Kind() == reflect.Ptr { checker = func(fv interface{}, v interface{}) {
if reflect.ValueOf(v).Kind() == reflect.Ptr { if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v {
if fv != v { t.Errorf("expects: %p, but got %p", v, fv)
t.Errorf("pointer expects: %p, but got %p", v, fv) } else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) {
if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv {
t.Errorf("expects: %p, but got %p", v, fv)
} }
} else if fv == nil { } else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) {
if v != nil { if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v {
t.Errorf("expects: %+v, but got nil", v) t.Errorf("expects: %p, but got %p", v, fv)
}
} else if reflect.ValueOf(fv).Elem().Interface() != v {
t.Errorf("expects: %+v, but got %+v", v, fv)
} }
} else if valuer, isValuer := fv.(driver.Valuer); isValuer {
valuerv, _ := valuer.Value()
checker(valuerv, v)
} else if valuer, isValuer := v.(driver.Valuer); isValuer {
valuerv, _ := valuer.Value()
checker(fv, valuerv)
} else if reflect.ValueOf(fv).Kind() == reflect.Ptr {
checker(reflect.ValueOf(fv).Elem().Interface(), v)
} else if reflect.ValueOf(v).Kind() == reflect.Ptr { } else if reflect.ValueOf(v).Kind() == reflect.Ptr {
if reflect.ValueOf(v).Elem().Interface() != fv { checker(fv, reflect.ValueOf(v).Elem().Interface())
t.Errorf("expects: %+v, but got %+v", v, fv) } else {
}
} else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) {
if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv {
t.Errorf("expects: %+v, but got %+v", v, fv) t.Errorf("expects: %+v, but got %+v", v, fv)
} }
} }
checker(fv, v)
}) })
} }
} }

View File

@ -9,13 +9,24 @@ import (
) )
func TestParseSchema(t *testing.T) { func TestParseSchema(t *testing.T) {
cacheMap := sync.Map{} user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
if err != nil { if err != nil {
t.Fatalf("failed to parse user, got error %v", err) t.Fatalf("failed to parse user, got error %v", err)
} }
checkUserSchema(t, user)
}
func TestParseSchemaWithPointerFields(t *testing.T) {
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err)
}
checkUserSchema(t, user)
}
func checkUserSchema(t *testing.T, user *schema.Schema) {
// check schema // check schema
checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"}) checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"})
@ -101,3 +112,31 @@ func TestParseSchema(t *testing.T) {
checkSchemaRelation(t, user, relation) checkSchemaRelation(t, user, relation)
} }
} }
func TestParseSchemaWithAdvancedDataType(t *testing.T) {
user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err)
}
// check schema
checkSchema(t, user, schema.Schema{Name: "AdvancedDataTypeUser", Table: "advanced_data_type_users"}, []string{"ID"})
// check fields
fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true},
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time},
{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"DeletedAt"}, DataType: schema.Time},
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
}
for _, f := range fields {
checkSchemaField(t, user, &f, func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
})
}
}