Update schema

This commit is contained in:
Jinzhu 2020-02-18 22:56:37 +08:00
parent 98ad29f2c2
commit cbbf8f3d49
3 changed files with 188 additions and 140 deletions

View File

@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"time"
@ -14,6 +15,13 @@ import (
type DataType string
type TimeType int64
const (
UnixSecond TimeType = 1
UnixNanosecond TimeType = 2
)
const (
Bool DataType = "bool"
Int = "int"
@ -25,32 +33,35 @@ const (
)
type Field struct {
Name string
DBName string
BindNames []string
DataType DataType
DBDataType string
PrimaryKey bool
AutoIncrement bool
Creatable bool
Updatable bool
HasDefaultValue bool
DefaultValue string
NotNull bool
Unique bool
Comment string
Size int
Precision int
FieldType reflect.Type
IndirectFieldType reflect.Type
StructField reflect.StructField
Tag reflect.StructTag
TagSettings map[string]string
Schema *Schema
EmbeddedSchema *Schema
ReflectValuer func(reflect.Value) reflect.Value
Valuer func(reflect.Value) interface{}
Setter func(reflect.Value, interface{}) error
Name string
DBName string
BindNames []string
DataType DataType
DBDataType string
PrimaryKey bool
AutoIncrement bool
Creatable bool
Updatable bool
HasDefaultValue bool
AutoCreateTime TimeType
AutoUpdateTime TimeType
DefaultValue string
DefaultValueInterface interface{}
NotNull bool
Unique bool
Comment string
Size int
Precision int
FieldType reflect.Type
IndirectFieldType reflect.Type
StructField reflect.StructField
Tag reflect.StructTag
TagSettings map[string]string
Schema *Schema
EmbeddedSchema *Schema
ReflectValueOf func(reflect.Value) reflect.Value
ValueOf func(reflect.Value) (value interface{}, zero bool)
Set func(reflect.Value, interface{}) error
}
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
@ -73,7 +84,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first fields as data type
if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer {
if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf {
var overrideFieldValue bool
if v, err := valuer.Value(); v != nil && err == nil {
overrideFieldValue = true
@ -150,17 +161,48 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DBDataType = val
}
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) {
if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond
} else {
field.AutoCreateTime = UnixSecond
}
}
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) {
if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond
} else {
field.AutoUpdateTime = UnixSecond
}
}
switch fieldValue.Elem().Kind() {
case reflect.Bool:
field.DataType = Bool
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.DataType = Int
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.DataType = Uint
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64)
}
case reflect.Float32, reflect.Float64:
field.DataType = Float
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64)
}
case reflect.String:
field.DataType = String
if field.HasDefaultValue {
field.DefaultValueInterface = field.DefaultValue
}
case reflect.Struct:
if _, ok := fieldValue.Interface().(*time.Time); ok {
field.DataType = Time
@ -216,36 +258,22 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
return field
}
// ValueOf field value of
func (field *Field) ValueOf(value reflect.Value) interface{} {
if field != nil {
return field.Valuer(value)
}
return nil
}
func (field *Field) Set(value reflect.Value, v interface{}) error {
if field != nil {
return field.Setter(value, v)
}
return fmt.Errorf("failed to set field value: %v", field.Name)
}
// create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() {
// Valuer
// ValueOf
switch {
case len(field.StructField.Index) == 1:
field.Valuer = func(value reflect.Value) interface{} {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Interface()
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
return fieldValue.Interface(), fieldValue.IsZero()
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.Valuer = func(value reflect.Value) interface{} {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface()
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
return fieldValue.Interface(), fieldValue.IsZero()
}
default:
field.Valuer = func(value reflect.Value) interface{} {
field.ValueOf = func(value reflect.Value) (interface{}, bool) {
v := reflect.Indirect(value)
for _, idx := range field.StructField.Index {
@ -259,19 +287,19 @@ func (field *Field) setupValuerAndSetter() {
v = v.Elem()
}
} else {
return nil
return nil, true
}
}
}
return v.Interface()
return v.Interface(), v.IsZero()
}
}
// ReflectValuer
// ReflectValueOf
switch {
case len(field.StructField.Index) == 1:
if field.FieldType.Kind() == reflect.Ptr {
field.ReflectValuer = func(value reflect.Value) reflect.Value {
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
@ -279,16 +307,16 @@ func (field *Field) setupValuerAndSetter() {
return fieldValue
}
} else {
field.ReflectValuer = func(value reflect.Value) reflect.Value {
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0])
}
}
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
field.ReflectValuer = func(value reflect.Value) reflect.Value {
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
}
default:
field.ReflectValuer = func(value reflect.Value) reflect.Value {
field.ReflectValueOf = func(value reflect.Value) reflect.Value {
v := reflect.Indirect(value)
for _, idx := range field.StructField.Index {
if idx >= 0 {
@ -316,168 +344,184 @@ func (field *Field) setupValuerAndSetter() {
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
return setter(value, v)
}
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface())
return field.Set(value, reflectV.Elem().Interface())
} else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
}
return err
}
// Setter
// Set
switch field.FieldType.Kind() {
case reflect.Bool:
field.Setter = func(value reflect.Value, v interface{}) error {
field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) {
case bool:
field.ReflectValuer(value).SetBool(data)
field.ReflectValueOf(value).SetBool(data)
case *bool:
field.ReflectValuer(value).SetBool(*data)
field.ReflectValueOf(value).SetBool(*data)
default:
return recoverFunc(value, v, field.Setter)
return recoverFunc(value, v, field.Set)
}
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Setter = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case int64:
field.ReflectValuer(value).SetInt(data)
field.ReflectValueOf(value).SetInt(data)
case int:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case int8:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case int16:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case int32:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case uint:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case uint8:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case uint16:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case uint32:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case uint64:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case float32:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case float64:
field.ReflectValuer(value).SetInt(int64(data))
field.ReflectValueOf(value).SetInt(int64(data))
case []byte:
return field.Setter(value, string(data))
return field.Set(value, string(data))
case string:
if i, err := strconv.ParseInt(data, 0, 64); err == nil {
field.ReflectValuer(value).SetInt(i)
field.ReflectValueOf(value).SetInt(i)
} else {
return err
}
case time.Time:
if field.AutoCreateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano())
} else {
field.ReflectValueOf(value).SetInt(data.Unix())
}
case *time.Time:
if data != nil {
if field.AutoCreateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano())
} else {
field.ReflectValueOf(value).SetInt(data.Unix())
}
} else {
field.ReflectValueOf(value).SetInt(0)
}
default:
return recoverFunc(value, v, field.Setter)
return recoverFunc(value, v, field.Set)
}
return err
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Setter = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case uint64:
field.ReflectValuer(value).SetUint(data)
field.ReflectValueOf(value).SetUint(data)
case uint:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case uint8:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case uint16:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case uint32:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case int64:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case int:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case int8:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case int16:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case int32:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case float32:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case float64:
field.ReflectValuer(value).SetUint(uint64(data))
field.ReflectValueOf(value).SetUint(uint64(data))
case []byte:
return field.Setter(value, string(data))
return field.Set(value, string(data))
case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValuer(value).SetUint(i)
field.ReflectValueOf(value).SetUint(i)
} else {
return err
}
default:
return recoverFunc(value, v, field.Setter)
return recoverFunc(value, v, field.Set)
}
return err
}
case reflect.Float32, reflect.Float64:
field.Setter = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case float64:
field.ReflectValuer(value).SetFloat(data)
field.ReflectValueOf(value).SetFloat(data)
case float32:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case int64:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case int:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case int8:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case int16:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case int32:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case uint:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case uint8:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case uint16:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case uint32:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case uint64:
field.ReflectValuer(value).SetFloat(float64(data))
field.ReflectValueOf(value).SetFloat(float64(data))
case []byte:
return field.Setter(value, string(data))
return field.Set(value, string(data))
case string:
if i, err := strconv.ParseFloat(data, 64); err == nil {
field.ReflectValuer(value).SetFloat(i)
field.ReflectValueOf(value).SetFloat(i)
} else {
return err
}
default:
return recoverFunc(value, v, field.Setter)
return recoverFunc(value, v, field.Set)
}
return err
}
case reflect.String:
field.Setter = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case string:
field.ReflectValuer(value).SetString(data)
field.ReflectValueOf(value).SetString(data)
case []byte:
field.ReflectValuer(value).SetString(string(data))
field.ReflectValueOf(value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValuer(value).SetString(fmt.Sprint(data))
field.ReflectValueOf(value).SetString(fmt.Sprint(data))
case float64, float32:
field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default:
return recoverFunc(value, v, field.Setter)
return recoverFunc(value, v, field.Set)
}
return err
}
@ -485,77 +529,77 @@ func (field *Field) setupValuerAndSetter() {
fieldValue := reflect.New(field.FieldType)
switch fieldValue.Elem().Interface().(type) {
case time.Time:
field.Setter = func(value reflect.Value, v interface{}) error {
field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) {
case time.Time:
field.ReflectValuer(value).Set(reflect.ValueOf(v))
field.ReflectValueOf(value).Set(reflect.ValueOf(v))
case *time.Time:
field.ReflectValuer(value).Set(reflect.ValueOf(v).Elem())
field.ReflectValueOf(value).Set(reflect.ValueOf(v).Elem())
case string:
if t, err := now.Parse(data); err == nil {
field.ReflectValuer(value).Set(reflect.ValueOf(t))
field.ReflectValueOf(value).Set(reflect.ValueOf(t))
} else {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return recoverFunc(value, v, field.Setter)
return recoverFunc(value, v, field.Set)
}
return nil
}
case *time.Time:
field.Setter = func(value reflect.Value, v interface{}) error {
field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) {
case time.Time:
field.ReflectValuer(value).Elem().Set(reflect.ValueOf(v))
field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(v))
case *time.Time:
field.ReflectValuer(value).Set(reflect.ValueOf(v))
field.ReflectValueOf(value).Set(reflect.ValueOf(v))
case string:
if t, err := now.Parse(data); err == nil {
field.ReflectValuer(value).Elem().Set(reflect.ValueOf(t))
field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(t))
} else {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return recoverFunc(value, v, field.Setter)
return recoverFunc(value, v, field.Set)
}
return nil
}
default:
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Setter = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
} else {
err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v)
err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
}
return
}
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner
field.Setter = func(value reflect.Value, v interface{}) (err error) {
field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType))
field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
}
} else {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v)
err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
}
return
}
} else {
field.Setter = func(value reflect.Value, v interface{}) (err error) {
return recoverFunc(value, v, field.Setter)
field.Set = func(value reflect.Value, v interface{}) (err error) {
return recoverFunc(value, v, field.Set)
}
}
}

View File

@ -18,6 +18,7 @@ type Schema struct {
ModelType reflect.Type
Table string
PrioritizedPrimaryField *Field
DBNames []string
PrimaryFields []*Field
Fields []*Field
FieldsByName map[string]*Field
@ -99,6 +100,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName)
}
schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field

View File

@ -198,7 +198,7 @@ func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[
var (
checker func(fv interface{}, v interface{})
field = s.FieldsByDBName[k]
fv = field.ValueOf(value)
fv, _ = field.ValueOf(value)
)
checker = func(fv interface{}, v interface{}) {