From 18236fa3d72c196d6a5c5ee4070626e305912645 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 16 Feb 2020 00:37:59 +0800 Subject: [PATCH] Add more tests for setter, valuer --- schema/field.go | 145 ++++++++++++++--------------------- schema/field_test.go | 137 +++++++++++++++++++++++++++------ schema/model_test.go | 41 ++++++++++ schema/schema_helper_test.go | 46 ++++++----- schema/schema_test.go | 45 ++++++++++- 5 files changed, 281 insertions(+), 133 deletions(-) create mode 100644 schema/model_test.go diff --git a/schema/field.go b/schema/field.go index b4610436..76f459ec 100644 --- a/schema/field.go +++ b/schema/field.go @@ -164,6 +164,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { + field.DataType = Time } case reflect.Array, reflect.Slice: if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) { @@ -311,6 +313,24 @@ func (field *Field) setupValuerAndSetter() { } } + recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + return setter(value, v) + } + } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + } else if reflectV.Kind() == reflect.Ptr { + return field.Setter(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + return err + } + // Setter switch field.FieldType.Kind() { case reflect.Bool: @@ -321,17 +341,12 @@ func (field *Field) setupValuerAndSetter() { case *bool: field.ReflectValuer(value).SetBool(*data) default: - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else { - field.ReflectValuer(value).SetBool(!reflect.ValueOf(v).IsZero()) - } + return recoverFunc(value, v, field.Setter) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Setter = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case int64: field.ReflectValuer(value).SetInt(data) @@ -366,19 +381,12 @@ func (field *Field) setupValuerAndSetter() { return err } default: - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Kind() == reflect.Ptr { - return field.Setter(value, reflectV.Elem().Interface()) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } + return recoverFunc(value, v, field.Setter) } - return nil + return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Setter = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case uint64: field.ReflectValuer(value).SetUint(data) @@ -413,19 +421,12 @@ func (field *Field) setupValuerAndSetter() { return err } default: - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Kind() == reflect.Ptr { - return field.Setter(value, reflectV.Elem().Interface()) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } + return recoverFunc(value, v, field.Setter) } - return nil + return err } case reflect.Float32, reflect.Float64: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Setter = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case float64: field.ReflectValuer(value).SetFloat(data) @@ -460,19 +461,12 @@ func (field *Field) setupValuerAndSetter() { return err } default: - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Kind() == reflect.Ptr { - return field.Setter(value, reflectV.Elem().Interface()) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } + return recoverFunc(value, v, field.Setter) } - return nil + return err } case reflect.String: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Setter = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case string: field.ReflectValuer(value).SetString(data) @@ -483,16 +477,9 @@ func (field *Field) setupValuerAndSetter() { case float64, float32: field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Kind() == reflect.Ptr { - return field.Setter(value, reflectV.Elem().Interface()) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } + return recoverFunc(value, v, field.Setter) } - return nil + return err } default: fieldValue := reflect.New(field.FieldType) @@ -511,7 +498,7 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name) + return recoverFunc(value, v, field.Setter) } return nil } @@ -529,14 +516,35 @@ func (field *Field) setupValuerAndSetter() { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name) + return recoverFunc(value, v, field.Setter) } return nil } default: if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner field.Setter = func(value reflect.Value, v interface{}) (err error) { - if valuer, ok := v.(driver.Valuer); ok { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + } + } else { + err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } + } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + // pointer scanner + field.Setter = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) } @@ -545,46 +553,9 @@ func (field *Field) setupValuerAndSetter() { } return } - return - } - - if fieldValue.CanAddr() { - if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - field.Setter = func(value reflect.Value, v interface{}) (err error) { - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) - } - } else { - err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) - } - return - } - return - } - } - - if field.FieldType.Kind() == reflect.Ptr { - field.Setter = func(value reflect.Value, v interface{}) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } - return nil - } } else { field.Setter = func(value reflect.Value, v interface{}) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) - } - return nil + return recoverFunc(value, v, field.Setter) } } } diff --git a/schema/field_test.go b/schema/field_test.go index 065d6d05..15dfa41d 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "database/sql" "reflect" "sync" "testing" @@ -13,8 +14,7 @@ import ( func TestFieldValuerAndSetter(t *testing.T) { var ( - cacheMap = sync.Map{} - userSchema, _ = schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) + userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) user = tests.User{ Model: gorm.Model{ ID: 10, @@ -54,20 +54,38 @@ func TestFieldValuerAndSetter(t *testing.T) { for k, v := range newValues { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { - t.Errorf("no error should happen when assign value to field %v", k) + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } checkField(t, userSchema, reflectValue, newValues) + + // test valuer and other type + age := myint(10) + newValues2 := map[string]interface{}{ + "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, + "id": &sql.NullInt64{Int64: 3, Valid: true}, + "created_at": tests.Now(), + "deleted_at": time.Now(), + "age": &age, + "birthday": mytime(time.Now()), + "active": mybool(true), + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues2) } func TestPointerFieldValuerAndSetter(t *testing.T) { var ( - cacheMap = sync.Map{} - userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{}) - name = "pointer_field_valuer_and_setter" - age = 18 - active = true - user = User{ + userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + name = "pointer_field_valuer_and_setter" + age uint = 18 + active = true + user = User{ Model: &gorm.Model{ ID: 10, CreatedAt: time.Now(), @@ -110,22 +128,91 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { } } checkField(t, userSchema, reflectValue, newValues) + + // test valuer and other type + age2 := myint(10) + newValues2 := map[string]interface{}{ + "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, + "id": &sql.NullInt64{Int64: 3, Valid: true}, + "created_at": tests.Now(), + "deleted_at": time.Now(), + "age": &age2, + "birthday": mytime(time.Now()), + "active": mybool(true), + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues2) } -type User struct { - *gorm.Model - Name *string - Age *int - Birthday *time.Time - Account *tests.Account - Pets []*tests.Pet - Toys []tests.Toy `gorm:"polymorphic:Owner"` - CompanyID *int - Company *tests.Company - ManagerID *int - Manager *User - Team []User `gorm:"foreignkey:ManagerID"` - Languages []tests.Language `gorm:"many2many:UserSpeak"` - Friends []*User `gorm:"many2many:user_friends"` - Active *bool +func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { + var ( + userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + name = "advanced_data_type_valuer_and_setter" + deletedAt = mytime(time.Now()) + isAdmin = mybool(false) + user = AdvancedDataTypeUser{ + ID: sql.NullInt64{Int64: 10, Valid: true}, + Name: &sql.NullString{String: name, Valid: true}, + Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + RegisteredAt: mytime(time.Now()), + DeletedAt: &deletedAt, + Active: mybool(true), + Admin: &isAdmin, + } + reflectValue = reflect.ValueOf(&user) + ) + + // test valuer + values := map[string]interface{}{ + "id": user.ID, + "name": user.Name, + "birthday": user.Birthday, + "registered_at": user.RegisteredAt, + "deleted_at": user.DeletedAt, + "active": user.Active, + "admin": user.Admin, + } + checkField(t, userSchema, reflectValue, values) + + // test setter + newDeletedAt := mytime(time.Now()) + newIsAdmin := mybool(true) + newValues := map[string]interface{}{ + "id": sql.NullInt64{Int64: 1, Valid: true}, + "name": &sql.NullString{String: name + "rename", Valid: true}, + "birthday": time.Now(), + "registered_at": mytime(time.Now()), + "deleted_at": &newDeletedAt, + "active": mybool(false), + "admin": &newIsAdmin, + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues) + + newValues2 := map[string]interface{}{ + "id": 5, + "name": name + "rename2", + "birthday": time.Now(), + "registered_at": time.Now(), + "deleted_at": time.Now(), + "active": true, + "admin": false, + } + + for k, v := range newValues2 { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) + } + } + checkField(t, userSchema, reflectValue, newValues2) } diff --git a/schema/model_test.go b/schema/model_test.go new file mode 100644 index 00000000..aca7e617 --- /dev/null +++ b/schema/model_test.go @@ -0,0 +1,41 @@ +package schema_test + +import ( + "database/sql" + "time" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/tests" +) + +type User struct { + *gorm.Model + Name *string + Age *uint + Birthday *time.Time + Account *tests.Account + Pets []*tests.Pet + Toys []*tests.Toy `gorm:"polymorphic:Owner"` + CompanyID *int + Company *tests.Company + ManagerID *int + Manager *User + Team []*User `gorm:"foreignkey:ManagerID"` + Languages []*tests.Language `gorm:"many2many:UserSpeak"` + Friends []*User `gorm:"many2many:user_friends"` + Active *bool +} + +type mytime time.Time +type myint int +type mybool = bool + +type AdvancedDataTypeUser struct { + ID sql.NullInt64 + Name *sql.NullString + Birthday sql.NullTime + RegisteredAt mytime + DeletedAt *mytime + Active mybool + Admin *mybool +} diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 4af0fc89..8ac2f002 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "database/sql/driver" "fmt" "reflect" "strings" @@ -194,30 +195,39 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { - field := s.FieldsByDBName[k] - fv := field.ValueOf(value) + var ( + checker func(fv interface{}, v interface{}) + field = s.FieldsByDBName[k] + fv = field.ValueOf(value) + ) - if reflect.ValueOf(fv).Kind() == reflect.Ptr { - if reflect.ValueOf(v).Kind() == reflect.Ptr { - if fv != v { - t.Errorf("pointer expects: %p, but got %p", v, fv) + checker = func(fv interface{}, v interface{}) { + if reflect.ValueOf(fv).Type() == reflect.ValueOf(v).Type() && fv != v { + t.Errorf("expects: %p, but got %p", v, fv) + } else if reflect.ValueOf(v).Type().ConvertibleTo(reflect.ValueOf(fv).Type()) { + if reflect.ValueOf(v).Convert(reflect.ValueOf(fv).Type()).Interface() != fv { + t.Errorf("expects: %p, but got %p", v, fv) } - } else if fv == nil { - if v != nil { - t.Errorf("expects: %+v, but got nil", v) + } else if reflect.ValueOf(fv).Type().ConvertibleTo(reflect.ValueOf(v).Type()) { + if reflect.ValueOf(fv).Convert(reflect.ValueOf(fv).Type()).Interface() != v { + t.Errorf("expects: %p, but got %p", v, fv) } - } else if reflect.ValueOf(fv).Elem().Interface() != v { - t.Errorf("expects: %+v, but got %+v", v, fv) - } - } else if reflect.ValueOf(v).Kind() == reflect.Ptr { - if reflect.ValueOf(v).Elem().Interface() != fv { - t.Errorf("expects: %+v, but got %+v", v, fv) - } - } else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) { - if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv { + } else if valuer, isValuer := fv.(driver.Valuer); isValuer { + valuerv, _ := valuer.Value() + checker(valuerv, v) + } else if valuer, isValuer := v.(driver.Valuer); isValuer { + valuerv, _ := valuer.Value() + checker(fv, valuerv) + } else if reflect.ValueOf(fv).Kind() == reflect.Ptr { + checker(reflect.ValueOf(fv).Elem().Interface(), v) + } else if reflect.ValueOf(v).Kind() == reflect.Ptr { + checker(fv, reflect.ValueOf(v).Elem().Interface()) + } else { t.Errorf("expects: %+v, but got %+v", v, fv) } } + + checker(fv, v) }) } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 97da1d5d..4134c966 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -9,13 +9,24 @@ import ( ) func TestParseSchema(t *testing.T) { - cacheMap := sync.Map{} - - user, err := schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) + user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user, got error %v", err) } + checkUserSchema(t, user) +} + +func TestParseSchemaWithPointerFields(t *testing.T) { + user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + checkUserSchema(t, user) +} + +func checkUserSchema(t *testing.T, user *schema.Schema) { // check schema checkSchema(t, user, schema.Schema{Name: "User", Table: "users"}, []string{"ID"}) @@ -101,3 +112,31 @@ func TestParseSchema(t *testing.T) { checkSchemaRelation(t, user, relation) } } + +func TestParseSchemaWithAdvancedDataType(t *testing.T) { + user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse pointer user, got error %v", err) + } + + // check schema + checkSchema(t, user, schema.Schema{Name: "AdvancedDataTypeUser", Table: "advanced_data_type_users"}, []string{"ID"}) + + // check fields + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true}, + {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, + {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, + {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, + {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"DeletedAt"}, DataType: schema.Time}, + {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, + {Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool}, + } + + for _, f := range fields { + checkSchemaField(t, user, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + }) + } +}