diff --git a/schema/field.go b/schema/field.go index 15e94279..b4610436 100644 --- a/schema/field.go +++ b/schema/field.go @@ -25,52 +25,53 @@ const ( ) type Field struct { - Name string - DBName string - BindNames []string - DataType DataType - DBDataType string - PrimaryKey bool - AutoIncrement bool - Creatable bool - Updatable bool - HasDefaultValue bool - DefaultValue string - NotNull bool - Unique bool - Comment string - Size int - Precision int - FieldType reflect.Type - StructField reflect.StructField - Tag reflect.StructTag - TagSettings map[string]string - Schema *Schema - EmbeddedSchema *Schema - ReflectValuer func(reflect.Value) reflect.Value - Valuer func(reflect.Value) interface{} - Setter func(reflect.Value, interface{}) error + Name string + DBName string + BindNames []string + DataType DataType + DBDataType string + PrimaryKey bool + AutoIncrement bool + Creatable bool + Updatable bool + HasDefaultValue bool + DefaultValue string + NotNull bool + Unique bool + Comment string + Size int + Precision int + FieldType reflect.Type + IndirectFieldType reflect.Type + StructField reflect.StructField + Tag reflect.StructTag + TagSettings map[string]string + Schema *Schema + EmbeddedSchema *Schema + ReflectValuer func(reflect.Value) reflect.Value + Valuer func(reflect.Value) interface{} + Setter func(reflect.Value, interface{}) error } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field := &Field{ - Name: fieldStruct.Name, - BindNames: []string{fieldStruct.Name}, - FieldType: fieldStruct.Type, - StructField: fieldStruct, - Creatable: true, - Updatable: true, - Tag: fieldStruct.Tag, - TagSettings: ParseTagSetting(fieldStruct.Tag), - Schema: schema, + Name: fieldStruct.Name, + BindNames: []string{fieldStruct.Name}, + FieldType: fieldStruct.Type, + IndirectFieldType: fieldStruct.Type, + StructField: fieldStruct, + Creatable: true, + Updatable: true, + Tag: fieldStruct.Tag, + TagSettings: ParseTagSetting(fieldStruct.Tag), + Schema: schema, } - for field.FieldType.Kind() == reflect.Ptr { - field.FieldType = field.FieldType.Elem() + for field.IndirectFieldType.Kind() == reflect.Ptr { + field.IndirectFieldType = field.IndirectFieldType.Elem() } - fieldValue := reflect.New(field.FieldType) - + fieldValue := reflect.New(field.IndirectFieldType) // if field is valuer, used its value or first fields as data type if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer { var overrideFieldValue bool @@ -79,10 +80,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { fieldValue = reflect.ValueOf(v) } - if field.FieldType.Kind() == reflect.Struct { - for i := 0; i < field.FieldType.NumField(); i++ { + if field.IndirectFieldType.Kind() == reflect.Struct { + for i := 0; i < field.IndirectFieldType.NumField(); i++ { if !overrideFieldValue { - newFieldType := field.FieldType.Field(i).Type + newFieldType := field.IndirectFieldType.Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } @@ -92,7 +93,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } // copy tag settings from valuer - for key, value := range ParseTagSetting(field.FieldType.Field(i).Tag) { + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag) { if _, ok := field.TagSettings[key]; !ok { field.TagSettings[key] = value } @@ -197,7 +198,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if field.FieldType.Kind() == reflect.Struct { ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) } else { - ef.StructField.Index = append([]int{-fieldStruct.Index[0]}, ef.StructField.Index...) + ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) } if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { @@ -235,26 +236,29 @@ func (field *Field) setupValuerAndSetter() { switch { case len(field.StructField.Index) == 1: field.Valuer = func(value reflect.Value) interface{} { - return value.Field(field.StructField.Index[0]).Interface() + return reflect.Indirect(value).Field(field.StructField.Index[0]).Interface() } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: field.Valuer = func(value reflect.Value) interface{} { - return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() + return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() } default: field.Valuer = func(value reflect.Value) interface{} { - v := value.Field(field.StructField.Index[0]) - for _, idx := range field.StructField.Index[1:] { - if v.Kind() == reflect.Ptr { + v := reflect.Indirect(value) + + for _, idx := range field.StructField.Index { + if idx >= 0 { + v = v.Field(idx) + } else { + v = v.Field(-idx - 1) + if v.Type().Elem().Kind() == reflect.Struct { if !v.IsNil() { - v = v.Elem().Field(-idx) - continue + v = v.Elem() } + } else { + return nil } - return nil - } else { - v = v.Field(idx) } } return v.Interface() @@ -266,7 +270,7 @@ func (field *Field) setupValuerAndSetter() { case len(field.StructField.Index) == 1: if field.FieldType.Kind() == reflect.Ptr { field.ReflectValuer = func(value reflect.Value) reflect.Value { - fieldValue := value.Field(field.StructField.Index[0]) + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } @@ -274,31 +278,33 @@ func (field *Field) setupValuerAndSetter() { } } else { field.ReflectValuer = func(value reflect.Value) reflect.Value { - return value.Field(field.StructField.Index[0]) + return reflect.Indirect(value).Field(field.StructField.Index[0]) } } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: - field.Valuer = func(value reflect.Value) interface{} { - return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() + field.ReflectValuer = func(value reflect.Value) reflect.Value { + return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) } default: field.ReflectValuer = func(value reflect.Value) reflect.Value { - v := value.Field(field.StructField.Index[0]) - for _, idx := range field.StructField.Index[1:] { + v := reflect.Indirect(value) + for _, idx := range field.StructField.Index { + if idx >= 0 { + v = v.Field(idx) + } else { + v = v.Field(-idx - 1) + } + if v.Kind() == reflect.Ptr { if v.Type().Elem().Kind() == reflect.Struct { if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } - - if idx >= 0 { - v = v.Elem().Field(idx) - } else { - v = v.Elem().Field(-idx) - } } - } else { - v = v.Field(idx) + + if idx < len(field.StructField.Index)-1 { + v = v.Elem() + } } } return v @@ -490,7 +496,7 @@ func (field *Field) setupValuerAndSetter() { } default: fieldValue := reflect.New(field.FieldType) - switch fieldValue.Interface().(type) { + switch fieldValue.Elem().Interface().(type) { case time.Time: field.Setter = func(value reflect.Value, v interface{}) error { switch data := v.(type) { @@ -528,6 +534,20 @@ func (field *Field) setupValuerAndSetter() { return nil } default: + if _, ok := fieldValue.Interface().(sql.Scanner); ok { + field.Setter = func(value reflect.Value, v interface{}) (err error) { + if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) + } + } else { + err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) + } + return + } + return + } + if fieldValue.CanAddr() { if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { field.Setter = func(value reflect.Value, v interface{}) (err error) { @@ -544,14 +564,28 @@ func (field *Field) setupValuerAndSetter() { } } - field.Setter = func(value reflect.Value, v interface{}) (err error) { - reflectV := reflect.ValueOf(v) - if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) - } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + if field.FieldType.Kind() == reflect.Ptr { + field.Setter = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { + field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + return nil + } + } else { + field.Setter = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + return nil } - return nil } } } diff --git a/schema/field_test.go b/schema/field_test.go index c7814fbf..065d6d05 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -24,10 +24,12 @@ func TestFieldValuerAndSetter(t *testing.T) { Name: "valuer_and_setter", Age: 18, Birthday: tests.Now(), + Active: true, } - reflectValue = reflect.ValueOf(user) + reflectValue = reflect.ValueOf(&user) ) + // test valuer values := map[string]interface{}{ "name": user.Name, "id": user.ID, @@ -35,30 +37,95 @@ func TestFieldValuerAndSetter(t *testing.T) { "deleted_at": user.DeletedAt, "age": user.Age, "birthday": user.Birthday, + "active": true, } + checkField(t, userSchema, reflectValue, values) - for k, v := range values { - if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v { - t.Errorf("user's %v value should equal %+v, but got %+v", k, v, rv) - } - } - + // test setter newValues := map[string]interface{}{ "name": "valuer_and_setter_2", - "id": "2", + "id": 2, "created_at": time.Now(), "deleted_at": tests.Now(), "age": 20, "birthday": time.Now(), + "active": false, } for k, v := range newValues { if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v", k) } + } + checkField(t, userSchema, reflectValue, newValues) +} - if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v { - t.Errorf("user's %v value should equal %+v after assign new value, but got %+v", k, v, rv) +func TestPointerFieldValuerAndSetter(t *testing.T) { + var ( + cacheMap = sync.Map{} + userSchema, _ = schema.Parse(&User{}, &cacheMap, schema.NamingStrategy{}) + name = "pointer_field_valuer_and_setter" + age = 18 + active = true + user = User{ + Model: &gorm.Model{ + ID: 10, + CreatedAt: time.Now(), + DeletedAt: tests.Now(), + }, + Name: &name, + Age: &age, + Birthday: tests.Now(), + Active: &active, + } + reflectValue = reflect.ValueOf(&user) + ) + + // test valuer + values := map[string]interface{}{ + "name": user.Name, + "id": user.ID, + "created_at": user.CreatedAt, + "deleted_at": user.DeletedAt, + "age": user.Age, + "birthday": user.Birthday, + "active": true, + } + checkField(t, userSchema, reflectValue, values) + + // test setter + newValues := map[string]interface{}{ + "name": "valuer_and_setter_2", + "id": 2, + "created_at": time.Now(), + "deleted_at": tests.Now(), + "age": 20, + "birthday": time.Now(), + "active": false, + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } + checkField(t, userSchema, reflectValue, newValues) +} + +type User struct { + *gorm.Model + Name *string + Age *int + Birthday *time.Time + Account *tests.Account + Pets []*tests.Pet + Toys []tests.Toy `gorm:"polymorphic:Owner"` + CompanyID *int + Company *tests.Company + ManagerID *int + Manager *User + Team []User `gorm:"foreignkey:ManagerID"` + Languages []tests.Language `gorm:"many2many:UserSpeak"` + Friends []*User `gorm:"many2many:user_friends"` + Active *bool } diff --git a/schema/relationship.go b/schema/relationship.go index b6aaefbd..671371fe 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -54,7 +54,7 @@ type Reference struct { func (schema *Schema) parseRelation(field *Field) { var ( err error - fieldValue = reflect.New(field.FieldType).Interface() + fieldValue = reflect.New(field.IndirectFieldType).Interface() relation = &Relationship{ Name: field.Name, Field: field, @@ -74,7 +74,7 @@ func (schema *Schema) parseRelation(field *Field) { } else if many2many, _ := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) } else { - switch field.FieldType.Kind() { + switch field.IndirectFieldType.Kind() { case reflect.Struct, reflect.Slice: schema.guessRelation(relation, field, true) default: @@ -83,7 +83,7 @@ func (schema *Schema) parseRelation(field *Field) { } if relation.Type == "has" { - switch field.FieldType.Kind() { + switch field.IndirectFieldType.Kind() { case reflect.Struct: relation.Type = HasOne case reflect.Slice: diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index db38355d..4af0fc89 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -2,6 +2,7 @@ package schema_test import ( "fmt" + "reflect" "strings" "testing" @@ -189,3 +190,34 @@ func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { } }) } + +func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { + for k, v := range values { + t.Run("CheckField/"+k, func(t *testing.T) { + field := s.FieldsByDBName[k] + fv := field.ValueOf(value) + + if reflect.ValueOf(fv).Kind() == reflect.Ptr { + if reflect.ValueOf(v).Kind() == reflect.Ptr { + if fv != v { + t.Errorf("pointer expects: %p, but got %p", v, fv) + } + } else if fv == nil { + if v != nil { + t.Errorf("expects: %+v, but got nil", v) + } + } else if reflect.ValueOf(fv).Elem().Interface() != v { + t.Errorf("expects: %+v, but got %+v", v, fv) + } + } else if reflect.ValueOf(v).Kind() == reflect.Ptr { + if reflect.ValueOf(v).Elem().Interface() != fv { + t.Errorf("expects: %+v, but got %+v", v, fv) + } + } else if reflect.ValueOf(v).Type().ConvertibleTo(field.FieldType) { + if reflect.ValueOf(v).Convert(field.FieldType).Interface() != fv { + t.Errorf("expects: %+v, but got %+v", v, fv) + } + } + }) + } +} diff --git a/schema/schema_test.go b/schema/schema_test.go index 526a98bd..97da1d5d 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -29,7 +29,8 @@ func TestParseSchema(t *testing.T) { {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int}, - {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint}, + {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Int}, + {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, } for _, f := range fields { diff --git a/tests/model.go b/tests/model.go index 62000352..ac2156c7 100644 --- a/tests/model.go +++ b/tests/model.go @@ -21,11 +21,12 @@ type User struct { Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company - ManagerID uint + ManagerID int Manager *User Team []User `gorm:"foreignkey:ManagerID"` Languages []Language `gorm:"many2many:UserSpeak"` Friends []*User `gorm:"many2many:user_friends"` + Active bool } type Account struct {