Add more tests for setter, valuer

This commit is contained in:
Jinzhu 2020-02-16 00:37:59 +08:00
parent faee069a9f
commit 18236fa3d7
5 changed files with 281 additions and 133 deletions

View File

@ -164,6 +164,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
case reflect.Struct:
if _, ok := fieldValue.Interface().(*time.Time); ok {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) {
field.DataType = Time
}
case reflect.Array, reflect.Slice:
if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) {
@ -311,6 +313,24 @@ func (field *Field) setupValuerAndSetter() {
}
}
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
return setter(value, v)
}
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
return err
}
// Setter
switch field.FieldType.Kind() {
case reflect.Bool:
@ -321,17 +341,12 @@ func (field *Field) setupValuerAndSetter() {
case *bool:
field.ReflectValuer(value).SetBool(*data)
default:
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else {
field.ReflectValuer(value).SetBool(!reflect.ValueOf(v).IsZero())
}
return recoverFunc(value, v, field.Setter)
}
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Setter = func(value reflect.Value, v interface{}) error {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case int64:
field.ReflectValuer(value).SetInt(data)
@ -366,19 +381,12 @@ func (field *Field) setupValuerAndSetter() {
return err
}
default:
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
return recoverFunc(value, v, field.Setter)
}
}
return nil
return err
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Setter = func(value reflect.Value, v interface{}) error {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case uint64:
field.ReflectValuer(value).SetUint(data)
@ -413,19 +421,12 @@ func (field *Field) setupValuerAndSetter() {
return err
}
default:
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
return recoverFunc(value, v, field.Setter)
}
}
return nil
return err
}
case reflect.Float32, reflect.Float64:
field.Setter = func(value reflect.Value, v interface{}) error {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case float64:
field.ReflectValuer(value).SetFloat(data)
@ -460,19 +461,12 @@ func (field *Field) setupValuerAndSetter() {
return err
}
default:
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
return recoverFunc(value, v, field.Setter)
}
}
return nil
return err
}
case reflect.String:
field.Setter = func(value reflect.Value, v interface{}) error {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case string:
field.ReflectValuer(value).SetString(data)
@ -483,16 +477,9 @@ func (field *Field) setupValuerAndSetter() {
case float64, float32:
field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default:
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
return recoverFunc(value, v, field.Setter)
}
}
return nil
return err
}
default:
fieldValue := reflect.New(field.FieldType)
@ -511,7 +498,7 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name)
return recoverFunc(value, v, field.Setter)
}
return nil
}
@ -529,29 +516,18 @@ func (field *Field) setupValuerAndSetter() {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name)
return recoverFunc(value, v, field.Setter)
}
return nil
}
default:
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Setter = func(value reflect.Value, v interface{}) (err error) {
if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
}
} else {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
}
return
}
return
}
if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
if valuer, ok := v.(driver.Valuer); ok {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
}
@ -560,31 +536,26 @@ func (field *Field) setupValuerAndSetter() {
}
return
}
return
}
}
if field.FieldType.Kind() == reflect.Ptr {
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner
field.Setter = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
}
return nil
} else {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
}
return
}
} else {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
return nil
return recoverFunc(value, v, field.Setter)
}
}
}

View File

@ -1,6 +1,7 @@
package schema_test
import (
"database/sql"
"reflect"
"sync"
"testing"
@ -13,8 +14,7 @@ import (
func TestFieldValuerAndSetter(t *testing.T) {
var (
cacheMap = sync.Map{}
userSchema, _ = schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
user = tests.User{
Model: gorm.Model{
ID: 10,
@ -54,18 +54,36 @@ func TestFieldValuerAndSetter(t *testing.T) {
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v", k)
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
checkField(t, userSchema, reflectValue, newValues)
// test valuer and other type
age := myint(10)
newValues2 := map[string]interface{}{
"name": sql.NullString{String: "valuer_and_setter_3", Valid: true},
"id": &sql.NullInt64{Int64: 3, Valid: true},
"created_at": tests.Now(),
"deleted_at": time.Now(),
"age": &age,
"birthday": mytime(time.Now()),
"active": mybool(true),
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
checkField(t, userSchema, reflectValue, newValues2)
}
func TestPointerFieldValuerAndSetter(t *testing.T) {
var (
cacheMap = sync.Map{}
userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{})
userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
name = "pointer_field_valuer_and_setter"
age = 18
age uint = 18
active = true
user = User{
Model: &gorm.Model{
@ -110,22 +128,91 @@ func TestPointerFieldValuerAndSetter(t *testing.T) {
}
}
checkField(t, userSchema, reflectValue, newValues)
// test valuer and other type
age2 := myint(10)
newValues2 := map[string]interface{}{
"name": sql.NullString{String: "valuer_and_setter_3", Valid: true},
"id": &sql.NullInt64{Int64: 3, Valid: true},
"created_at": tests.Now(),
"deleted_at": time.Now(),
"age": &age2,
"birthday": mytime(time.Now()),
"active": mybool(true),
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
checkField(t, userSchema, reflectValue, newValues2)
}
type User struct {
*gorm.Model
Name *string
Age *int
Birthday *time.Time
Account *tests.Account
Pets []*tests.Pet
Toys []tests.Toy `gorm:"polymorphic:Owner"`
CompanyID *int
Company *tests.Company
ManagerID *int
Manager *User
Team []User `gorm:"foreignkey:ManagerID"`
Languages []tests.Language `gorm:"many2many:UserSpeak"`
Friends []*User `gorm:"many2many:user_friends"`
Active *bool
func TestAdvancedDataTypeValuerAndSetter(t *testing.T) {
var (
userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
name = "advanced_data_type_valuer_and_setter"
deletedAt = mytime(time.Now())
isAdmin = mybool(false)
user = AdvancedDataTypeUser{
ID: sql.NullInt64{Int64: 10, Valid: true},
Name: &sql.NullString{String: name, Valid: true},
Birthday: sql.NullTime{Time: time.Now(), Valid: true},
RegisteredAt: mytime(time.Now()),
DeletedAt: &deletedAt,
Active: mybool(true),
Admin: &isAdmin,
}
reflectValue = reflect.ValueOf(&user)
)
// test valuer
values := map[string]interface{}{
"id": user.ID,
"name": user.Name,
"birthday": user.Birthday,
"registered_at": user.RegisteredAt,
"deleted_at": user.DeletedAt,
"active": user.Active,
"admin": user.Admin,
}
checkField(t, userSchema, reflectValue, values)
// test setter
newDeletedAt := mytime(time.Now())
newIsAdmin := mybool(true)
newValues := map[string]interface{}{
"id": sql.NullInt64{Int64: 1, Valid: true},
"name": &sql.NullString{String: name + "rename", Valid: true},
"birthday": time.Now(),
"registered_at": mytime(time.Now()),
"deleted_at": &newDeletedAt,
"active": mybool(false),
"admin": &newIsAdmin,
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
checkField(t, userSchema, reflectValue, newValues)
newValues2 := map[string]interface{}{
"id": 5,
"name": name + "rename2",
"birthday": time.Now(),
"registered_at": time.Now(),
"deleted_at": time.Now(),
"active": true,
"admin": false,
}
for k, v := range newValues2 {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
}
}
checkField(t, userSchema, reflectValue, newValues2)
}

41
schema/model_test.go Normal file
View File

@ -0,0 +1,41 @@
package schema_test
import (
"database/sql"
"time"
"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/tests"
)
type User struct {
*gorm.Model
Name *string
Age *uint
Birthday *time.Time
Account *tests.Account
Pets []*tests.Pet
Toys []*tests.Toy `gorm:"polymorphic:Owner"`
CompanyID *int
Company *tests.Company
ManagerID *int
Manager *User
Team []*User `gorm:"foreignkey:ManagerID"`
Languages []*tests.Language `gorm:"many2many:UserSpeak"`
Friends []*User `gorm:"many2many:user_friends"`
Active *bool
}
type mytime time.Time
type myint int
type mybool = bool
type AdvancedDataTypeUser struct {
ID sql.NullInt64
Name *sql.NullString
Birthday sql.NullTime
RegisteredAt mytime
DeletedAt *mytime
Active mybool
Admin *mybool
}

View File

@ -1,6 +1,7 @@
package schema_test
import (
"database/sql/driver"
"fmt"
"reflect"
"strings"
@ -194,30 +195,39 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) {
field := s.FieldsByDBName[k]
fv := field.ValueOf(value)
var (
checker func(fv interface{}, v interface{})
field = s.FieldsByDBName[k]
fv = field.ValueOf(value)
)
if reflect.ValueOf(fv).Kind() == reflect.Ptr {
if reflect.ValueOf(v).Kind() == reflect.Ptr {
if fv != v {
t.Errorf("pointer expects: %p, but got %p", v, fv)
checker = func(fv interface{}, v interface{}) {
if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v {
t.Errorf("expects: %p, but got %p", v, fv)
} else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) {
if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv {
t.Errorf("expects: %p, but got %p", v, fv)
}
} else if fv == nil {
if v != nil {
t.Errorf("expects: %+v, but got nil", v)
}
} else if reflect.ValueOf(fv).Elem().Interface() != v {
t.Errorf("expects: %+v, but got %+v", v, fv)
} else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) {
if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v {
t.Errorf("expects: %p, but got %p", v, fv)
}
} else if valuer, isValuer := fv.(driver.Valuer); isValuer {
valuerv, _ := valuer.Value()
checker(valuerv, v)
} else if valuer, isValuer := v.(driver.Valuer); isValuer {
valuerv, _ := valuer.Value()
checker(fv, valuerv)
} else if reflect.ValueOf(fv).Kind() == reflect.Ptr {
checker(reflect.ValueOf(fv).Elem().Interface(), v)
} else if reflect.ValueOf(v).Kind() == reflect.Ptr {
if reflect.ValueOf(v).Elem().Interface() != fv {
t.Errorf("expects: %+v, but got %+v", v, fv)
}
} else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) {
if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv {
checker(fv, reflect.ValueOf(v).Elem().Interface())
} else {
t.Errorf("expects: %+v, but got %+v", v, fv)
}
}
checker(fv, v)
})
}
}

View File

@ -9,13 +9,24 @@ import (
)
func TestParseSchema(t *testing.T) {
cacheMap := sync.Map{}
user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{})
user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user, got error %v", err)
}
checkUserSchema(t, user)
}
func TestParseSchemaWithPointerFields(t *testing.T) {
user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err)
}
checkUserSchema(t, user)
}
func checkUserSchema(t *testing.T, user *schema.Schema) {
// check schema
checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"})
@ -101,3 +112,31 @@ func TestParseSchema(t *testing.T) {
checkSchemaRelation(t, user, relation)
}
}
func TestParseSchemaWithAdvancedDataType(t *testing.T) {
user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse pointer user, got error %v", err)
}
// check schema
checkSchema(t, user, schema.Schema{Name: "AdvancedDataTypeUser", Table: "advanced_data_type_users"}, []string{"ID"})
// check fields
fields := []schema.Field{
{Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true},
{Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String},
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
{Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time},
{Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"DeletedAt"}, DataType: schema.Time},
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
{Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool},
}
for _, f := range fields {
checkSchemaField(t, user, &f, func(f *schema.Field) {
f.Creatable = true
f.Updatable = true
})
}
}