Test Field Valuer, Setter

This commit is contained in:
Jinzhu 2020-02-15 19:45:27 +08:00
parent 2cb88dc7c5
commit faee069a9f
6 changed files with 225 additions and 90 deletions

View File

@ -42,6 +42,7 @@ type Field struct {
Size int Size int
Precision int Precision int
FieldType reflect.Type FieldType reflect.Type
IndirectFieldType reflect.Type
StructField reflect.StructField StructField reflect.StructField
Tag reflect.StructTag Tag reflect.StructTag
TagSettings map[string]string TagSettings map[string]string
@ -57,6 +58,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
Name: fieldStruct.Name, Name: fieldStruct.Name,
BindNames: []string{fieldStruct.Name}, BindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type, FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct, StructField: fieldStruct,
Creatable: true, Creatable: true,
Updatable: true, Updatable: true,
@ -65,12 +67,11 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
Schema: schema, Schema: schema,
} }
for field.FieldType.Kind() == reflect.Ptr { for field.IndirectFieldType.Kind() == reflect.Ptr {
field.FieldType = field.FieldType.Elem() field.IndirectFieldType = field.IndirectFieldType.Elem()
} }
fieldValue := reflect.New(field.FieldType) fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first fields as data type // if field is valuer, used its value or first fields as data type
if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer { if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer {
var overrideFieldValue bool var overrideFieldValue bool
@ -79,10 +80,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
fieldValue = reflect.ValueOf(v) fieldValue = reflect.ValueOf(v)
} }
if field.FieldType.Kind() == reflect.Struct { if field.IndirectFieldType.Kind() == reflect.Struct {
for i := 0; i < field.FieldType.NumField(); i++ { for i := 0; i < field.IndirectFieldType.NumField(); i++ {
if !overrideFieldValue { if !overrideFieldValue {
newFieldType := field.FieldType.Field(i).Type newFieldType := field.IndirectFieldType.Field(i).Type
for newFieldType.Kind() == reflect.Ptr { for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem() newFieldType = newFieldType.Elem()
} }
@ -92,7 +93,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
// copy tag settings from valuer // copy tag settings from valuer
for key, value := range ParseTagSetting(field.FieldType.Field(i).Tag) { for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag) {
if _, ok := field.TagSettings[key]; !ok { if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value field.TagSettings[key] = value
} }
@ -197,7 +198,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
if field.FieldType.Kind() == reflect.Struct { if field.FieldType.Kind() == reflect.Struct {
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
} else { } else {
ef.StructField.Index = append([]int{-fieldStruct.Index[0]}, ef.StructField.Index...) ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
} }
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok {
@ -235,26 +236,29 @@ func (field *Field) setupValuerAndSetter() {
switch { switch {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
field.Valuer = func(value reflect.Value) interface{} { field.Valuer = func(value reflect.Value) interface{} {
return value.Field(field.StructField.Index[0]).Interface() return reflect.Indirect(value).Field(field.StructField.Index[0]).Interface()
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.Valuer = func(value reflect.Value) interface{} { field.Valuer = func(value reflect.Value) interface{} {
return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface()
} }
default: default:
field.Valuer = func(value reflect.Value) interface{} { field.Valuer = func(value reflect.Value) interface{} {
v := value.Field(field.StructField.Index[0]) v := reflect.Indirect(value)
for _, idx := range field.StructField.Index[1:] {
if v.Kind() == reflect.Ptr { for _, idx := range field.StructField.Index {
if idx >= 0 {
v = v.Field(idx)
} else {
v = v.Field(-idx - 1)
if v.Type().Elem().Kind() == reflect.Struct { if v.Type().Elem().Kind() == reflect.Struct {
if !v.IsNil() { if !v.IsNil() {
v = v.Elem().Field(-idx) v = v.Elem()
continue
} }
}
return nil
} else { } else {
v = v.Field(idx) return nil
}
} }
} }
return v.Interface() return v.Interface()
@ -266,7 +270,7 @@ func (field *Field) setupValuerAndSetter() {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
if field.FieldType.Kind() == reflect.Ptr { if field.FieldType.Kind() == reflect.Ptr {
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValuer = func(value reflect.Value) reflect.Value {
fieldValue := value.Field(field.StructField.Index[0]) fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
if fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
} }
@ -274,32 +278,34 @@ func (field *Field) setupValuerAndSetter() {
} }
} else { } else {
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValuer = func(value reflect.Value) reflect.Value {
return value.Field(field.StructField.Index[0]) return reflect.Indirect(value).Field(field.StructField.Index[0])
} }
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
field.Valuer = func(value reflect.Value) interface{} { field.ReflectValuer = func(value reflect.Value) reflect.Value {
return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
} }
default: default:
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValuer = func(value reflect.Value) reflect.Value {
v := value.Field(field.StructField.Index[0]) v := reflect.Indirect(value)
for _, idx := range field.StructField.Index[1:] { for _, idx := range field.StructField.Index {
if idx >= 0 {
v = v.Field(idx)
} else {
v = v.Field(-idx - 1)
}
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
if v.Type().Elem().Kind() == reflect.Struct { if v.Type().Elem().Kind() == reflect.Struct {
if v.IsNil() { if v.IsNil() {
v.Set(reflect.New(v.Type().Elem())) v.Set(reflect.New(v.Type().Elem()))
} }
}
if idx >= 0 { if idx < len(field.StructField.Index)-1 {
v = v.Elem().Field(idx) v = v.Elem()
} else {
v = v.Elem().Field(-idx)
} }
} }
} else {
v = v.Field(idx)
}
} }
return v return v
} }
@ -490,7 +496,7 @@ func (field *Field) setupValuerAndSetter() {
} }
default: default:
fieldValue := reflect.New(field.FieldType) fieldValue := reflect.New(field.FieldType)
switch fieldValue.Interface().(type) { switch fieldValue.Elem().Interface().(type) {
case time.Time: case time.Time:
field.Setter = func(value reflect.Value, v interface{}) error { field.Setter = func(value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
@ -528,6 +534,20 @@ func (field *Field) setupValuerAndSetter() {
return nil return nil
} }
default: default:
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
}
} else {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
}
return
}
return
}
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Setter = func(value reflect.Value, v interface{}) (err error) {
@ -544,6 +564,19 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
if field.FieldType.Kind() == reflect.Ptr {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
return nil
}
} else {
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Setter = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
@ -556,3 +589,4 @@ func (field *Field) setupValuerAndSetter() {
} }
} }
} }
}

View File

@ -24,10 +24,12 @@ func TestFieldValuerAndSetter(t *testing.T) {
Name: "valuer_and_setter", Name: "valuer_and_setter",
Age: 18, Age: 18,
Birthday: tests.Now(), Birthday: tests.Now(),
Active: true,
} }
reflectValue = reflect.ValueOf(user) reflectValue = reflect.ValueOf(&user)
) )
// test valuer
values := map[string]interface{}{ values := map[string]interface{}{
"name": user.Name, "name": user.Name,
"id": user.ID, "id": user.ID,
@ -35,30 +37,95 @@ func TestFieldValuerAndSetter(t *testing.T) {
"deleted_at": user.DeletedAt, "deleted_at": user.DeletedAt,
"age": user.Age, "age": user.Age,
"birthday": user.Birthday, "birthday": user.Birthday,
"active": true,
} }
checkField(t, userSchema, reflectValue, values)
for k, v := range values { // test setter
if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v {
t.Errorf("user's %v value should equal %+v, but got %+v", k, v, rv)
}
}
newValues := map[string]interface{}{ newValues := map[string]interface{}{
"name": "valuer_and_setter_2", "name": "valuer_and_setter_2",
"id": "2", "id": 2,
"created_at": time.Now(), "created_at": time.Now(),
"deleted_at": tests.Now(), "deleted_at": tests.Now(),
"age": 20, "age": 20,
"birthday": time.Now(), "birthday": time.Now(),
"active": false,
} }
for k, v := range newValues { for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v", k) t.Errorf("no error should happen when assign value to field %v", k)
} }
}
checkField(t, userSchema, reflectValue, newValues)
}
if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v { func TestPointerFieldValuerAndSetter(t *testing.T) {
t.Errorf("user's %v value should equal %+v after assign new value, but got %+v", k, v, rv) var (
cacheMap = sync.Map{}
userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{})
name = "pointer_field_valuer_and_setter"
age = 18
active = true
user = User{
Model: &gorm.Model{
ID: 10,
CreatedAt: time.Now(),
DeletedAt: tests.Now(),
},
Name: &name,
Age: &age,
Birthday: tests.Now(),
Active: &active,
}
reflectValue = reflect.ValueOf(&user)
)
// test valuer
values := map[string]interface{}{
"name": user.Name,
"id": user.ID,
"created_at": user.CreatedAt,
"deleted_at": user.DeletedAt,
"age": user.Age,
"birthday": user.Birthday,
"active": true,
}
checkField(t, userSchema, reflectValue, values)
// test setter
newValues := map[string]interface{}{
"name": "valuer_and_setter_2",
"id": 2,
"created_at": time.Now(),
"deleted_at": tests.Now(),
"age": 20,
"birthday": time.Now(),
"active": false,
}
for k, v := range newValues {
if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil {
t.Errorf("no error should happen when assign value to field %v, but got %v", k, err)
} }
} }
checkField(t, userSchema, reflectValue, newValues)
}
type User struct {
*gorm.Model
Name *string
Age *int
Birthday *time.Time
Account *tests.Account
Pets []*tests.Pet
Toys []tests.Toy `gorm:"polymorphic:Owner"`
CompanyID *int
Company *tests.Company
ManagerID *int
Manager *User
Team []User `gorm:"foreignkey:ManagerID"`
Languages []tests.Language `gorm:"many2many:UserSpeak"`
Friends []*User `gorm:"many2many:user_friends"`
Active *bool
} }

View File

@ -54,7 +54,7 @@ type Reference struct {
func (schema *Schema) parseRelation(field *Field) { func (schema *Schema) parseRelation(field *Field) {
var ( var (
err error err error
fieldValue = reflect.New(field.FieldType).Interface() fieldValue = reflect.New(field.IndirectFieldType).Interface()
relation = &Relationship{ relation = &Relationship{
Name: field.Name, Name: field.Name,
Field: field, Field: field,
@ -74,7 +74,7 @@ func (schema *Schema) parseRelation(field *Field) {
} else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" { } else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many) schema.buildMany2ManyRelation(relation, field, many2many)
} else { } else {
switch field.FieldType.Kind() { switch field.IndirectFieldType.Kind() {
case reflect.Struct, reflect.Slice: case reflect.Struct, reflect.Slice:
schema.guessRelation(relation, field, true) schema.guessRelation(relation, field, true)
default: default:
@ -83,7 +83,7 @@ func (schema *Schema) parseRelation(field *Field) {
} }
if relation.Type == "has" { if relation.Type == "has" {
switch field.FieldType.Kind() { switch field.IndirectFieldType.Kind() {
case reflect.Struct: case reflect.Struct:
relation.Type = HasOne relation.Type = HasOne
case reflect.Slice: case reflect.Slice:

View File

@ -2,6 +2,7 @@ package schema_test
import ( import (
"fmt" "fmt"
"reflect"
"strings" "strings"
"testing" "testing"
@ -189,3 +190,34 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) {
} }
}) })
} }
func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) {
for k, v := range values {
t.Run("CheckField/"+k, func(t *testing.T) {
field := s.FieldsByDBName[k]
fv := field.ValueOf(value)
if reflect.ValueOf(fv).Kind() == reflect.Ptr {
if reflect.ValueOf(v).Kind() == reflect.Ptr {
if fv != v {
t.Errorf("pointer expects: %p, but got %p", v, fv)
}
} else if fv == nil {
if v != nil {
t.Errorf("expects: %+v, but got nil", v)
}
} else if reflect.ValueOf(fv).Elem().Interface() != v {
t.Errorf("expects: %+v, but got %+v", v, fv)
}
} else if reflect.ValueOf(v).Kind() == reflect.Ptr {
if reflect.ValueOf(v).Elem().Interface() != fv {
t.Errorf("expects: %+v, but got %+v", v, fv)
}
} else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) {
if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv {
t.Errorf("expects: %+v, but got %+v", v, fv)
}
}
})
}
}

View File

@ -29,7 +29,8 @@ func TestParseSchema(t *testing.T) {
{Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint},
{Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time},
{Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int},
{Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Int},
{Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool},
} }
for _, f := range fields { for _, f := range fields {

View File

@ -21,11 +21,12 @@ type User struct {
Toys []Toy `gorm:"polymorphic:Owner"` Toys []Toy `gorm:"polymorphic:Owner"`
CompanyID *int CompanyID *int
Company Company Company Company
ManagerID uint ManagerID int
Manager *User Manager *User
Team []User `gorm:"foreignkey:ManagerID"` Team []User `gorm:"foreignkey:ManagerID"`
Languages []Language `gorm:"many2many:UserSpeak"` Languages []Language `gorm:"many2many:UserSpeak"`
Friends []*User `gorm:"many2many:user_friends"` Friends []*User `gorm:"many2many:user_friends"`
Active bool
} }
type Account struct { type Account struct {