diff --git a/schema/field.go b/schema/field.go index 570b3c50..15e94279 100644 --- a/schema/field.go +++ b/schema/field.go @@ -1,11 +1,15 @@ package schema import ( + "database/sql" "database/sql/driver" + "fmt" "reflect" "strconv" "sync" "time" + + "github.com/jinzhu/now" ) type DataType string @@ -43,6 +47,9 @@ type Field struct { TagSettings map[string]string Schema *Schema EmbeddedSchema *Schema + ReflectValuer func(reflect.Value) reflect.Value + Valuer func(reflect.Value) interface{} + Setter func(reflect.Value, interface{}) error } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { @@ -186,6 +193,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + // index is negative means is pointer + if field.FieldType.Kind() == reflect.Struct { + ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) + } else { + ef.StructField.Index = append([]int{-fieldStruct.Index[0]}, ef.StructField.Index...) + } if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { ef.DBName = prefix + ef.DBName @@ -199,3 +212,347 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { return field } + +// ValueOf field value of +func (field *Field) ValueOf(value reflect.Value) interface{} { + if field != nil { + return field.Valuer(value) + } + return nil +} + +func (field *Field) Set(value reflect.Value, v interface{}) error { + if field != nil { + return field.Setter(value, v) + } + + return fmt.Errorf("failed to set field value: %v", field.Name) +} + +// create valuer, setter when parse struct +func (field *Field) setupValuerAndSetter() { + // Valuer + switch { + case len(field.StructField.Index) == 1: + field.Valuer = func(value reflect.Value) interface{} { + return value.Field(field.StructField.Index[0]).Interface() + } + case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: + field.Valuer = func(value reflect.Value) interface{} { + return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() + } + default: + field.Valuer = func(value reflect.Value) interface{} { + v := value.Field(field.StructField.Index[0]) + for _, idx := range field.StructField.Index[1:] { + if v.Kind() == reflect.Ptr { + if v.Type().Elem().Kind() == reflect.Struct { + if !v.IsNil() { + v = v.Elem().Field(-idx) + continue + } + } + return nil + } else { + v = v.Field(idx) + } + } + return v.Interface() + } + } + + // ReflectValuer + switch { + case len(field.StructField.Index) == 1: + if field.FieldType.Kind() == reflect.Ptr { + field.ReflectValuer = func(value reflect.Value) reflect.Value { + fieldValue := value.Field(field.StructField.Index[0]) + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(field.FieldType.Elem())) + } + return fieldValue + } + } else { + field.ReflectValuer = func(value reflect.Value) reflect.Value { + return value.Field(field.StructField.Index[0]) + } + } + case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: + field.Valuer = func(value reflect.Value) interface{} { + return value.Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() + } + default: + field.ReflectValuer = func(value reflect.Value) reflect.Value { + v := value.Field(field.StructField.Index[0]) + for _, idx := range field.StructField.Index[1:] { + if v.Kind() == reflect.Ptr { + if v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + + if idx >= 0 { + v = v.Elem().Field(idx) + } else { + v = v.Elem().Field(-idx) + } + } + } else { + v = v.Field(idx) + } + } + return v + } + } + + // Setter + switch field.FieldType.Kind() { + case reflect.Bool: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case bool: + field.ReflectValuer(value).SetBool(data) + case *bool: + field.ReflectValuer(value).SetBool(*data) + default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else { + field.ReflectValuer(value).SetBool(!reflect.ValueOf(v).IsZero()) + } + } + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case int64: + field.ReflectValuer(value).SetInt(data) + case int: + field.ReflectValuer(value).SetInt(int64(data)) + case int8: + field.ReflectValuer(value).SetInt(int64(data)) + case int16: + field.ReflectValuer(value).SetInt(int64(data)) + case int32: + field.ReflectValuer(value).SetInt(int64(data)) + case uint: + field.ReflectValuer(value).SetInt(int64(data)) + case uint8: + field.ReflectValuer(value).SetInt(int64(data)) + case uint16: + field.ReflectValuer(value).SetInt(int64(data)) + case uint32: + field.ReflectValuer(value).SetInt(int64(data)) + case uint64: + field.ReflectValuer(value).SetInt(int64(data)) + case float32: + field.ReflectValuer(value).SetInt(int64(data)) + case float64: + field.ReflectValuer(value).SetInt(int64(data)) + case []byte: + return field.Setter(value, string(data)) + case string: + if i, err := strconv.ParseInt(data, 0, 64); err == nil { + field.ReflectValuer(value).SetInt(i) + } else { + return err + } + default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Kind() == reflect.Ptr { + return field.Setter(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + } + return nil + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case uint64: + field.ReflectValuer(value).SetUint(data) + case uint: + field.ReflectValuer(value).SetUint(uint64(data)) + case uint8: + field.ReflectValuer(value).SetUint(uint64(data)) + case uint16: + field.ReflectValuer(value).SetUint(uint64(data)) + case uint32: + field.ReflectValuer(value).SetUint(uint64(data)) + case int64: + field.ReflectValuer(value).SetUint(uint64(data)) + case int: + field.ReflectValuer(value).SetUint(uint64(data)) + case int8: + field.ReflectValuer(value).SetUint(uint64(data)) + case int16: + field.ReflectValuer(value).SetUint(uint64(data)) + case int32: + field.ReflectValuer(value).SetUint(uint64(data)) + case float32: + field.ReflectValuer(value).SetUint(uint64(data)) + case float64: + field.ReflectValuer(value).SetUint(uint64(data)) + case []byte: + return field.Setter(value, string(data)) + case string: + if i, err := strconv.ParseUint(data, 0, 64); err == nil { + field.ReflectValuer(value).SetUint(i) + } else { + return err + } + default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Kind() == reflect.Ptr { + return field.Setter(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + } + return nil + } + case reflect.Float32, reflect.Float64: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case float64: + field.ReflectValuer(value).SetFloat(data) + case float32: + field.ReflectValuer(value).SetFloat(float64(data)) + case int64: + field.ReflectValuer(value).SetFloat(float64(data)) + case int: + field.ReflectValuer(value).SetFloat(float64(data)) + case int8: + field.ReflectValuer(value).SetFloat(float64(data)) + case int16: + field.ReflectValuer(value).SetFloat(float64(data)) + case int32: + field.ReflectValuer(value).SetFloat(float64(data)) + case uint: + field.ReflectValuer(value).SetFloat(float64(data)) + case uint8: + field.ReflectValuer(value).SetFloat(float64(data)) + case uint16: + field.ReflectValuer(value).SetFloat(float64(data)) + case uint32: + field.ReflectValuer(value).SetFloat(float64(data)) + case uint64: + field.ReflectValuer(value).SetFloat(float64(data)) + case []byte: + return field.Setter(value, string(data)) + case string: + if i, err := strconv.ParseFloat(data, 64); err == nil { + field.ReflectValuer(value).SetFloat(i) + } else { + return err + } + default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Kind() == reflect.Ptr { + return field.Setter(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + } + return nil + } + case reflect.String: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case string: + field.ReflectValuer(value).SetString(data) + case []byte: + field.ReflectValuer(value).SetString(string(data)) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + field.ReflectValuer(value).SetString(fmt.Sprint(data)) + case float64, float32: + field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + default: + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else if reflectV.Kind() == reflect.Ptr { + return field.Setter(value, reflectV.Elem().Interface()) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + } + return nil + } + default: + fieldValue := reflect.New(field.FieldType) + switch fieldValue.Interface().(type) { + case time.Time: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case time.Time: + field.ReflectValuer(value).Set(reflect.ValueOf(v)) + case *time.Time: + field.ReflectValuer(value).Set(reflect.ValueOf(v).Elem()) + case string: + if t, err := now.Parse(data); err == nil { + field.ReflectValuer(value).Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name) + } + return nil + } + case *time.Time: + field.Setter = func(value reflect.Value, v interface{}) error { + switch data := v.(type) { + case time.Time: + field.ReflectValuer(value).Elem().Set(reflect.ValueOf(v)) + case *time.Time: + field.ReflectValuer(value).Set(reflect.ValueOf(v)) + case string: + if t, err := now.Parse(data); err == nil { + field.ReflectValuer(value).Elem().Set(reflect.ValueOf(t)) + } else { + return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + } + default: + return fmt.Errorf("failed to set value %+v to time.Time field %v", v, field.Name) + } + return nil + } + default: + if fieldValue.CanAddr() { + if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + field.Setter = func(value reflect.Value, v interface{}) (err error) { + if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + } + } else { + err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } + return + } + } + + field.Setter = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if reflectV.Type().ConvertibleTo(field.FieldType) { + field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + } else { + return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + } + return nil + } + } + } +} diff --git a/schema/field_test.go b/schema/field_test.go new file mode 100644 index 00000000..c7814fbf --- /dev/null +++ b/schema/field_test.go @@ -0,0 +1,64 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + "time" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/tests" +) + +func TestFieldValuerAndSetter(t *testing.T) { + var ( + cacheMap = sync.Map{} + userSchema, _ = schema.Parse(&tests.User{}, &cacheMap, schema.NamingStrategy{}) + user = tests.User{ + Model: gorm.Model{ + ID: 10, + CreatedAt: time.Now(), + DeletedAt: tests.Now(), + }, + Name: "valuer_and_setter", + Age: 18, + Birthday: tests.Now(), + } + reflectValue = reflect.ValueOf(user) + ) + + values := map[string]interface{}{ + "name": user.Name, + "id": user.ID, + "created_at": user.CreatedAt, + "deleted_at": user.DeletedAt, + "age": user.Age, + "birthday": user.Birthday, + } + + for k, v := range values { + if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v { + t.Errorf("user's %v value should equal %+v, but got %+v", k, v, rv) + } + } + + newValues := map[string]interface{}{ + "name": "valuer_and_setter_2", + "id": "2", + "created_at": time.Now(), + "deleted_at": tests.Now(), + "age": 20, + "birthday": time.Now(), + } + + for k, v := range newValues { + if err := userSchema.FieldsByDBName[k].Set(reflectValue, v); err != nil { + t.Errorf("no error should happen when assign value to field %v", k) + } + + if rv := userSchema.FieldsByDBName[k].ValueOf(reflectValue); rv != v { + t.Errorf("user's %v value should equal %+v after assign new value, but got %+v", k, v, rv) + } + } +} diff --git a/schema/schema.go b/schema/schema.go index 53170e18..2f3cdf88 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -128,6 +128,8 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if _, ok := schema.FieldsByName[field.Name]; !ok { schema.FieldsByName[field.Name] = field } + + field.setupValuerAndSetter() } if f := schema.LookUpField("id"); f != nil {