Test Field Valuer, Setter

This commit is contained in:
Jinzhu 2020-02-15 19:45:27 +08:00
parent 2cb88dc7c5
commit faee069a9f
6 changed files with 225 additions and 90 deletions

View File

@ -25,52 +25,53 @@ const (
) )
type Field struct { type Field struct {
Name string Name string
DBName string DBName string
BindNames []string BindNames []string
DataType DataType DataType DataType
DBDataType string DBDataType string
PrimaryKey bool PrimaryKey bool
AutoIncrement bool AutoIncrement bool
Creatable bool Creatable bool
Updatable bool Updatable bool
HasDefaultValue bool HasDefaultValue bool
DefaultValue string DefaultValue string
NotNull bool NotNull bool
Unique bool Unique bool
Comment string Comment string
Size int Size int
Precision int Precision int
FieldType reflect.Type FieldType reflect.Type
StructField reflect.StructField IndirectFieldType reflect.Type
Tag reflect.StructTag StructField reflect.StructField
TagSettings map[string]string Tag reflect.StructTag
Schema *Schema TagSettings map[string]string
EmbeddedSchema *Schema Schema *Schema
ReflectValuer func(reflect.Value) reflect.Value EmbeddedSchema *Schema
Valuer func(reflect.Value) interface{} ReflectValuer func(reflect.Value) reflect.Value
Setter func(reflect.Value, interface{}) error Valuer func(reflect.Value) interface{}
Setter func(reflect.Value, interface{}) error
} }
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field := &Field{ field := &Field{
Name: fieldStruct.Name, Name: fieldStruct.Name,
BindNames: []string{fieldStruct.Name}, BindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type, FieldType: fieldStruct.Type,
StructField: fieldStruct, IndirectFieldType: fieldStruct.Type,
Creatable: true, StructField: fieldStruct,
Updatable: true, Creatable: true,
Tag: fieldStruct.Tag, Updatable: true,
TagSettings: ParseTagSetting(fieldStruct.Tag), Tag: fieldStruct.Tag,
Schema: schema, TagSettings: ParseTagSetting(fieldStruct.Tag),
Schema: schema,
} }
for field.FieldType.Kind() == reflect.Ptr { for field.IndirectFieldType.Kind() == reflect.Ptr {
field.FieldType = field.FieldType.Elem() field.IndirectFieldType = field.IndirectFieldType.Elem()
} }
fieldValue := reflect.New(field.FieldType) fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first fields as data type // if field is valuer, used its value or first fields as data type
if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer { if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer {
var overrideFieldValue bool var overrideFieldValue bool
@ -79,10 +80,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
fieldValue = reflect.ValueOf(v) fieldValue = reflect.ValueOf(v)
} }
if field.FieldType.Kind() == reflect.Struct { if field.IndirectFieldType.Kind() == reflect.Struct {
for i := 0; i < field.FieldType.NumField(); i++ { for i := 0; i < field.IndirectFieldType.NumField(); i++ {
if !overrideFieldValue { if !overrideFieldValue {
newFieldType := field.FieldType.Field(i).Type newFieldType := field.IndirectFieldType.Field(i).Type
for newFieldType.Kind() == reflect.Ptr { for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem() newFieldType = newFieldType.Elem()
} }
@ -92,7 +93,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
// copy tag settings from valuer // copy tag settings from valuer
for key, value := range ParseTagSetting(field.FieldType.Field(i).Tag) { for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag) {
if _, ok := field.TagSettings[key]; !ok { if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value field.TagSettings[key] = value
} }
@ -197,7 +198,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
if field.FieldType.Kind() == reflect.Struct { if field.FieldType.Kind() == reflect.Struct {
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
} else { } else {
ef.StructField.Index = append([]int{-fieldStruct.Index[0]}, ef.StructField.Index...) ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
} }
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok {
@ -235,26 +236,29 @@ func (field *Field) setupValuerAndSetter() {
switch { switch {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
field.Valuer = func(value reflect.Value) interface{} { field.Valuer = func(value reflect.Value) interface{} {
return value.Field(field.StructField.Index[0]).Interface() return reflect.Indirect(value).Field(field.StructField.Index[0]).Interface()
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.Valuer = func(value reflect.Value) interface{} { field.Valuer = func(value reflect.Value) interface{} {
return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface()
} }
default: default:
field.Valuer = func(value reflect.Value) interface{} { field.Valuer = func(value reflect.Value) interface{} {
v := value.Field(field.StructField.Index[0]) v := reflect.Indirect(value)
for _, idx := range field.StructField.Index[1:] {
if v.Kind() == reflect.Ptr { for _, idx := range field.StructField.Index {
if idx >= 0 {
v = v.Field(idx)
} else {
v = v.Field(-idx - 1)
if v.Type().Elem().Kind() == reflect.Struct { if v.Type().Elem().Kind() == reflect.Struct {
if !v.IsNil() { if !v.IsNil() {
v = v.Elem().Field(-idx) v = v.Elem()
continue
} }
} else {
return nil
} }
return nil
} else {
v = v.Field(idx)
} }
} }
return v.Interface() return v.Interface()
@ -266,7 +270,7 @@ func (field *Field) setupValuerAndSetter() {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
if field.FieldType.Kind() == reflect.Ptr { if field.FieldType.Kind() == reflect.Ptr {
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValuer = func(value reflect.Value) reflect.Value {
fieldValue := value.Field(field.StructField.Index[0]) fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
if fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
} }
@ -274,31 +278,33 @@ func (field *Field) setupValuerAndSetter() {
} }
} else { } else {
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValuer = func(value reflect.Value) reflect.Value {
return value.Field(field.StructField.Index[0]) return reflect.Indirect(value).Field(field.StructField.Index[0])
} }
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
field.Valuer = func(value reflect.Value) interface{} { field.ReflectValuer = func(value reflect.Value) reflect.Value {
return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
} }
default: default:
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValuer = func(value reflect.Value) reflect.Value {
v := value.Field(field.StructField.Index[0]) v := reflect.Indirect(value)
for _, idx := range field.StructField.Index[1:] { for _, idx := range field.StructField.Index {
if idx >= 0 {
v = v.Field(idx)
} else {
v = v.Field(-idx - 1)
}
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
if v.Type().Elem().Kind() == reflect.Struct { if v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() { if v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) v.Set(reflect.New(v.Type().Elem()))
} }
if idx >= 0 {
v = v.Elem().Field(idx)
} else {
v = v.Elem().Field(-idx)
}
} }
} else {
v = v.Field(idx) if idx < len(field.StructField.Index)-1 {
v = v.Elem()
}
} }
} }
return v return v
@ -490,7 +496,7 @@ func (field *Field) setupValuerAndSetter() {
} }
default: default:
fieldValue := reflect.New(field.FieldType) fieldValue := reflect.New(field.FieldType)
switch fieldValue.Interface().(type) { switch fieldValue.Elem().Interface().(type) {
case time.Time: case time.Time:
field.Setter = func(value reflect.Value, v interface{}) error { field.Setter = func(value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
@ -528,6 +534,20 @@ func (field *Field) setupValuerAndSetter() {
return nil return nil
} }
default: default:
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
}
} else {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
}
return
}
return
}
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Setter = func(value reflect.Value, v interface{}) (err error) {
@ -544,14 +564,28 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
field.Setter = func(value reflect.Value, v interface{}) (err error) { if field.FieldType.Kind() == reflect.Ptr {
reflectV := reflect.ValueOf(v) field.Setter = func(value reflect.Value, v interface{}) (err error) {
if reflectV.Type().ConvertibleTo(field.FieldType) { reflectV := reflect.ValueOf(v)
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) if reflectV.Type().ConvertibleTo(field.FieldType) {
} else { field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
return nil
}
} else {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
return nil
} }
return nil
} }
} }
} }

View File

@ -24,10 +24,12 @@ func TestFieldValuerAndSetter(t *testing.T) {
Name: "valuer_and_setter", Name: "valuer_and_setter",
Age: 18, Age: 18,
Birthday: tests.Now(), Birthday: tests.Now(),
Active: true,
} }
reflectValue = reflect.ValueOf(user) reflectValue = reflect.ValueOf(&user)
) )
// test valuer
values := map[string]interface{}{ values := map[string]interface{}{
"name": user.Name, "name": user.Name,
"id": user.ID, "id": user.ID,
@ -35,30 +37,95 @@ func TestFieldValuerAndSetter(t *testing.T) {
"deleted_at": user.DeletedAt, "deleted_at": user.DeletedAt,
"age": user.Age, "age": user.Age,
"birthday": user.Birthday, "birthday": user.Birthday,
"active": true,
} }
checkField(t, userSchema, reflectValue, values)
for k, v := range values { // test setter
if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v {
t.Errorf("user's %v value should equal %+v, but got %+v", k, v, rv)
}
}
newValues := map[string]interface{}{ newValues := map[string]interface{}{
"name": "valuer_and_setter_2", "name": "valuer_and_setter_2",
"id": "2", "id": 2,
"created_at": time.Now(), "created_at": time.Now(),
"deleted_at": tests.Now(), "deleted_at": tests.Now(),
"age": 20, "age": 20,
"birthday": time.Now(), "birthday": time.Now(),
"active": false,
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v", k) t.Errorf("no error should happen when assign value to field %v", k)
} }
}
checkField(t, userSchema, reflectValue, newValues)
}
if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v { func TestPointerFieldValuerAndSetter(t *testing.T) {
t.Errorf("user's %v value should equal %+v after assign new value, but got %+v", k, v, rv) var (
cacheMap = sync.Map{}
userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{})
name = "pointer_field_valuer_and_setter"
age = 18
active = true
user = User{
Model: &gorm.Model{
ID: 10,
CreatedAt: time.Now(),
DeletedAt: tests.Now(),
},
Name: &name,
Age: &age,
Birthday: tests.Now(),
Active: &active,
}
reflectValue = reflect.ValueOf(&user)
)
// test valuer
values := map[string]interface{}{
"name": user.Name,
"id": user.ID,
"created_at": user.CreatedAt,
"deleted_at": user.DeletedAt,
"age": user.Age,
"birthday": user.Birthday,
"active": true,
}
checkField(t, userSchema, reflectValue, values)
// test setter
newValues := map[string]interface{}{
"name": "valuer_and_setter_2",
"id": 2,
"created_at": time.Now(),
"deleted_at": tests.Now(),
"age": 20,
"birthday": time.Now(),
"active": false,
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
checkField(t, userSchema, reflectValue, newValues)
}
type User struct {
*gorm.Model
Name *string
Age *int
Birthday *time.Time
Account *tests.Account
Pets []*tests.Pet
Toys []tests.Toy `gorm:"polymorphic:Owner"`
CompanyID *int
Company *tests.Company
ManagerID *int
Manager *User
Team []User `gorm:"foreignkey:ManagerID"`
Languages []tests.Language `gorm:"many2many:UserSpeak"`
Friends []*User `gorm:"many2many:user_friends"`
Active *bool
} }

View File

@ -54,7 +54,7 @@ type Reference struct {
func (schema *Schema) parseRelation(field *Field) { func (schema *Schema) parseRelation(field *Field) {
var ( var (
err error err error
fieldValue = reflect.New(field.FieldType).Interface() fieldValue = reflect.New(field.IndirectFieldType).Interface()
relation = &Relationship{ relation = &Relationship{
Name: field.Name, Name: field.Name,
Field: field, Field: field,
@ -74,7 +74,7 @@ func (schema *Schema) parseRelation(field *Field) {
} else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" { } else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many) schema.buildMany2ManyRelation(relation, field, many2many)
} else { } else {
switch field.FieldType.Kind() { switch field.IndirectFieldType.Kind() {
case reflect.Struct, reflect.Slice: case reflect.Struct, reflect.Slice:
schema.guessRelation(relation, field, true) schema.guessRelation(relation, field, true)
default: default:
@ -83,7 +83,7 @@ func (schema *Schema) parseRelation(field *Field) {
} }
if relation.Type == "has" { if relation.Type == "has" {
switch field.FieldType.Kind() { switch field.IndirectFieldType.Kind() {
case reflect.Struct: case reflect.Struct:
relation.Type = HasOne relation.Type = HasOne
case reflect.Slice: case reflect.Slice:

View File

@ -2,6 +2,7 @@ package schema_test
import ( import (
"fmt" "fmt"
"reflect"
"strings" "strings"
"testing" "testing"
@ -189,3 +190,34 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
} }
}) })
} }
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) {
field := s.FieldsByDBName[k]
fv := field.ValueOf(value)
if reflect.ValueOf(fv).Kind() == reflect.Ptr {
if reflect.ValueOf(v).Kind() == reflect.Ptr {
if fv != v {
t.Errorf("pointer expects: %p, but got %p", v, fv)
}
} else if fv == nil {
if v != nil {
t.Errorf("expects: %+v, but got nil", v)
}
} else if reflect.ValueOf(fv).Elem().Interface() != v {
t.Errorf("expects: %+v, but got %+v", v, fv)
}
} else if reflect.ValueOf(v).Kind() == reflect.Ptr {
if reflect.ValueOf(v).Elem().Interface() != fv {
t.Errorf("expects: %+v, but got %+v", v, fv)
}
} else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) {
if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv {
t.Errorf("expects: %+v, but got %+v", v, fv)
}
}
})
}
}

View File

@ -29,7 +29,8 @@ func TestParseSchema(t *testing.T) {
{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint},
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int},
{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Int},
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
} }
for _, f := range fields { for _, f := range fields {

View File

@ -21,11 +21,12 @@ type User struct {
Toys []Toy `gorm:"polymorphic:Owner"` Toys []Toy `gorm:"polymorphic:Owner"`
CompanyID *int CompanyID *int
Company Company Company Company
ManagerID uint ManagerID int
Manager *User Manager *User
Team []User `gorm:"foreignkey:ManagerID"` Team []User `gorm:"foreignkey:ManagerID"`
Languages []Language `gorm:"many2many:UserSpeak"` Languages []Language `gorm:"many2many:UserSpeak"`
Friends []*User `gorm:"many2many:user_friends"` Friends []*User `gorm:"many2many:user_friends"`
Active bool
} }
type Account struct { type Account struct {