diff --git a/schema/field.go b/schema/field.go index 76f459ec..e4c80734 100644 --- a/schema/field.go +++ b/schema/field.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "sync" "time" @@ -14,6 +15,13 @@ import ( type DataType string +type TimeType int64 + +const ( + UnixSecond TimeType = 1 + UnixNanosecond TimeType = 2 +) + const ( Bool DataType = "bool" Int = "int" @@ -25,32 +33,35 @@ const ( ) type Field struct { - Name string - DBName string - BindNames []string - DataType DataType - DBDataType string - PrimaryKey bool - AutoIncrement bool - Creatable bool - Updatable bool - HasDefaultValue bool - DefaultValue string - NotNull bool - Unique bool - Comment string - Size int - Precision int - FieldType reflect.Type - IndirectFieldType reflect.Type - StructField reflect.StructField - Tag reflect.StructTag - TagSettings map[string]string - Schema *Schema - EmbeddedSchema *Schema - ReflectValuer func(reflect.Value) reflect.Value - Valuer func(reflect.Value) interface{} - Setter func(reflect.Value, interface{}) error + Name string + DBName string + BindNames []string + DataType DataType + DBDataType string + PrimaryKey bool + AutoIncrement bool + Creatable bool + Updatable bool + HasDefaultValue bool + AutoCreateTime TimeType + AutoUpdateTime TimeType + DefaultValue string + DefaultValueInterface interface{} + NotNull bool + Unique bool + Comment string + Size int + Precision int + FieldType reflect.Type + IndirectFieldType reflect.Type + StructField reflect.StructField + Tag reflect.StructTag + TagSettings map[string]string + Schema *Schema + EmbeddedSchema *Schema + ReflectValueOf func(reflect.Value) reflect.Value + ValueOf func(reflect.Value) (value interface{}, zero bool) + Set func(reflect.Value, interface{}) error } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { @@ -73,7 +84,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { fieldValue := reflect.New(field.IndirectFieldType) // if field is valuer, used its value or first fields as data type - if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer { + if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf { var overrideFieldValue bool if v, err := valuer.Value(); v != nil && err == nil { overrideFieldValue = true @@ -150,17 +161,48 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DBDataType = val } + if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) { + if strings.ToUpper(v) == "NANO" { + field.AutoCreateTime = UnixNanosecond + } else { + field.AutoCreateTime = UnixSecond + } + } + + if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) { + if strings.ToUpper(v) == "NANO" { + field.AutoUpdateTime = UnixNanosecond + } else { + field.AutoUpdateTime = UnixSecond + } + } + switch fieldValue.Elem().Kind() { case reflect.Bool: field.DataType = Bool + if field.HasDefaultValue { + field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue) + } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int + if field.HasDefaultValue { + field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64) + } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint + if field.HasDefaultValue { + field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64) + } case reflect.Float32, reflect.Float64: field.DataType = Float + if field.HasDefaultValue { + field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64) + } case reflect.String: field.DataType = String + if field.HasDefaultValue { + field.DefaultValueInterface = field.DefaultValue + } case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time @@ -216,36 +258,22 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { return field } -// ValueOf field value of -func (field *Field) ValueOf(value reflect.Value) interface{} { - if field != nil { - return field.Valuer(value) - } - return nil -} - -func (field *Field) Set(value reflect.Value, v interface{}) error { - if field != nil { - return field.Setter(value, v) - } - - return fmt.Errorf("failed to set field value: %v", field.Name) -} - // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { - // Valuer + // ValueOf switch { case len(field.StructField.Index) == 1: - field.Valuer = func(value reflect.Value) interface{} { - return reflect.Indirect(value).Field(field.StructField.Index[0]).Interface() + field.ValueOf = func(value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) + return fieldValue.Interface(), fieldValue.IsZero() } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: - field.Valuer = func(value reflect.Value) interface{} { - return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() + field.ValueOf = func(value reflect.Value) (interface{}, bool) { + fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) + return fieldValue.Interface(), fieldValue.IsZero() } default: - field.Valuer = func(value reflect.Value) interface{} { + field.ValueOf = func(value reflect.Value) (interface{}, bool) { v := reflect.Indirect(value) for _, idx := range field.StructField.Index { @@ -259,19 +287,19 @@ func (field *Field) setupValuerAndSetter() { v = v.Elem() } } else { - return nil + return nil, true } } } - return v.Interface() + return v.Interface(), v.IsZero() } } - // ReflectValuer + // ReflectValueOf switch { case len(field.StructField.Index) == 1: if field.FieldType.Kind() == reflect.Ptr { - field.ReflectValuer = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(value reflect.Value) reflect.Value { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) @@ -279,16 +307,16 @@ func (field *Field) setupValuerAndSetter() { return fieldValue } } else { - field.ReflectValuer = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(value reflect.Value) reflect.Value { return reflect.Indirect(value).Field(field.StructField.Index[0]) } } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: - field.ReflectValuer = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(value reflect.Value) reflect.Value { return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) } default: - field.ReflectValuer = func(value reflect.Value) reflect.Value { + field.ReflectValueOf = func(value reflect.Value) reflect.Value { v := reflect.Indirect(value) for _, idx := range field.StructField.Index { if idx >= 0 { @@ -316,168 +344,184 @@ func (field *Field) setupValuerAndSetter() { recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { reflectV := reflect.ValueOf(v) if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { return setter(value, v) } } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if reflectV.Kind() == reflect.Ptr { - return field.Setter(value, reflectV.Elem().Interface()) + return field.Set(value, reflectV.Elem().Interface()) } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } return err } - // Setter + // Set switch field.FieldType.Kind() { case reflect.Bool: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case bool: - field.ReflectValuer(value).SetBool(data) + field.ReflectValueOf(value).SetBool(data) case *bool: - field.ReflectValuer(value).SetBool(*data) + field.ReflectValueOf(value).SetBool(*data) default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case int64: - field.ReflectValuer(value).SetInt(data) + field.ReflectValueOf(value).SetInt(data) case int: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case int8: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case int16: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case int32: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case uint: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case uint8: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case uint16: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case uint32: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case uint64: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case float32: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case float64: - field.ReflectValuer(value).SetInt(int64(data)) + field.ReflectValueOf(value).SetInt(int64(data)) case []byte: - return field.Setter(value, string(data)) + return field.Set(value, string(data)) case string: if i, err := strconv.ParseInt(data, 0, 64); err == nil { - field.ReflectValuer(value).SetInt(i) + field.ReflectValueOf(value).SetInt(i) } else { return err } + case time.Time: + if field.AutoCreateTime == UnixNanosecond { + field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else { + field.ReflectValueOf(value).SetInt(data.Unix()) + } + case *time.Time: + if data != nil { + if field.AutoCreateTime == UnixNanosecond { + field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else { + field.ReflectValueOf(value).SetInt(data.Unix()) + } + } else { + field.ReflectValueOf(value).SetInt(0) + } default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case uint64: - field.ReflectValuer(value).SetUint(data) + field.ReflectValueOf(value).SetUint(data) case uint: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case uint8: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case uint16: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case uint32: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case int64: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case int: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case int8: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case int16: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case int32: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case float32: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case float64: - field.ReflectValuer(value).SetUint(uint64(data)) + field.ReflectValueOf(value).SetUint(uint64(data)) case []byte: - return field.Setter(value, string(data)) + return field.Set(value, string(data)) case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { - field.ReflectValuer(value).SetUint(i) + field.ReflectValueOf(value).SetUint(i) } else { return err } default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return err } case reflect.Float32, reflect.Float64: - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case float64: - field.ReflectValuer(value).SetFloat(data) + field.ReflectValueOf(value).SetFloat(data) case float32: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case int64: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case int: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case int8: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case int16: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case int32: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case uint: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case uint8: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case uint16: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case uint32: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case uint64: - field.ReflectValuer(value).SetFloat(float64(data)) + field.ReflectValueOf(value).SetFloat(float64(data)) case []byte: - return field.Setter(value, string(data)) + return field.Set(value, string(data)) case string: if i, err := strconv.ParseFloat(data, 64); err == nil { - field.ReflectValuer(value).SetFloat(i) + field.ReflectValueOf(value).SetFloat(i) } else { return err } default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return err } case reflect.String: - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case string: - field.ReflectValuer(value).SetString(data) + field.ReflectValueOf(value).SetString(data) case []byte: - field.ReflectValuer(value).SetString(string(data)) + field.ReflectValueOf(value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - field.ReflectValuer(value).SetString(fmt.Sprint(data)) + field.ReflectValueOf(value).SetString(fmt.Sprint(data)) case float64, float32: - field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) + field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return err } @@ -485,77 +529,77 @@ func (field *Field) setupValuerAndSetter() { fieldValue := reflect.New(field.FieldType) switch fieldValue.Elem().Interface().(type) { case time.Time: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: - field.ReflectValuer(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case *time.Time: - field.ReflectValuer(value).Set(reflect.ValueOf(v).Elem()) + field.ReflectValueOf(value).Set(reflect.ValueOf(v).Elem()) case string: if t, err := now.Parse(data); err == nil { - field.ReflectValuer(value).Set(reflect.ValueOf(t)) + field.ReflectValueOf(value).Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return nil } case *time.Time: - field.Setter = func(value reflect.Value, v interface{}) error { + field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: - field.ReflectValuer(value).Elem().Set(reflect.ValueOf(v)) + field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(v)) case *time.Time: - field.ReflectValuer(value).Set(reflect.ValueOf(v)) + field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { - field.ReflectValuer(value).Elem().Set(reflect.ValueOf(t)) + field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: - return recoverFunc(value, v, field.Setter) + return recoverFunc(value, v, field.Set) } return nil } default: if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } } else { - err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner - field.Setter = func(value reflect.Value, v interface{}) (err error) { + field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if reflectV.Type().ConvertibleTo(field.FieldType) { - field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) + field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { - field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) + field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { - err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) } } else { - err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) + err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) } return } } else { - field.Setter = func(value reflect.Value, v interface{}) (err error) { - return recoverFunc(value, v, field.Setter) + field.Set = func(value reflect.Value, v interface{}) (err error) { + return recoverFunc(value, v, field.Set) } } } diff --git a/schema/schema.go b/schema/schema.go index 2f3cdf88..63e388f5 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -18,6 +18,7 @@ type Schema struct { ModelType reflect.Type Table string PrioritizedPrimaryField *Field + DBNames []string PrimaryFields []*Field Fields []*Field FieldsByName map[string]*Field @@ -99,6 +100,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { + if _, ok := schema.FieldsByDBName[field.DBName]; !ok { + schema.DBNames = append(schema.DBNames, field.DBName) + } schema.FieldsByDBName[field.DBName] = field schema.FieldsByName[field.Name] = field diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 8ac2f002..60e51543 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -198,7 +198,7 @@ func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[ var ( checker func(fv interface{}, v interface{}) field = s.FieldsByDBName[k] - fv = field.ValueOf(value) + fv, _ = field.ValueOf(value) ) checker = func(fv interface{}, v interface{}) {