Update schema

This commit is contained in:
Jinzhu 2020-02-18 22:56:37 +08:00
parent 98ad29f2c2
commit cbbf8f3d49
3 changed files with 188 additions and 140 deletions

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -14,6 +15,13 @@ import (
type DataType string type DataType string
type TimeType int64
const (
UnixSecond TimeType = 1
UnixNanosecond TimeType = 2
)
const ( const (
Bool DataType = "bool" Bool DataType = "bool"
Int = "int" Int = "int"
@ -25,32 +33,35 @@ const (
) )
type Field struct { type Field struct {
Name string Name string
DBName string DBName string
BindNames []string BindNames []string
DataType DataType DataType DataType
DBDataType string DBDataType string
PrimaryKey bool PrimaryKey bool
AutoIncrement bool AutoIncrement bool
Creatable bool Creatable bool
Updatable bool Updatable bool
HasDefaultValue bool HasDefaultValue bool
DefaultValue string AutoCreateTime TimeType
NotNull bool AutoUpdateTime TimeType
Unique bool DefaultValue string
Comment string DefaultValueInterface interface{}
Size int NotNull bool
Precision int Unique bool
FieldType reflect.Type Comment string
IndirectFieldType reflect.Type Size int
StructField reflect.StructField Precision int
Tag reflect.StructTag FieldType reflect.Type
TagSettings map[string]string IndirectFieldType reflect.Type
Schema *Schema StructField reflect.StructField
EmbeddedSchema *Schema Tag reflect.StructTag
ReflectValuer func(reflect.Value) reflect.Value TagSettings map[string]string
Valuer func(reflect.Value) interface{} Schema *Schema
Setter func(reflect.Value, interface{}) error EmbeddedSchema *Schema
ReflectValueOf func(reflect.Value) reflect.Value
ValueOf func(reflect.Value) (value interface{}, zero bool)
Set func(reflect.Value, interface{}) error
} }
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
@ -73,7 +84,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
fieldValue := reflect.New(field.IndirectFieldType) fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first fields as data type // if field is valuer, used its value or first fields as data type
if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer { if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf {
var overrideFieldValue bool var overrideFieldValue bool
if v, err := valuer.Value(); v != nil && err == nil { if v, err := valuer.Value(); v != nil && err == nil {
overrideFieldValue = true overrideFieldValue = true
@ -150,17 +161,48 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DBDataType = val field.DBDataType = val
} }
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) {
if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond
} else {
field.AutoCreateTime = UnixSecond
}
}
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) {
if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond
} else {
field.AutoUpdateTime = UnixSecond
}
}
switch fieldValue.Elem().Kind() { switch fieldValue.Elem().Kind() {
case reflect.Bool: case reflect.Bool:
field.DataType = Bool field.DataType = Bool
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.DataType = Int field.DataType = Int
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.DataType = Uint field.DataType = Uint
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64)
}
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
field.DataType = Float field.DataType = Float
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64)
}
case reflect.String: case reflect.String:
field.DataType = String field.DataType = String
if field.HasDefaultValue {
field.DefaultValueInterface = field.DefaultValue
}
case reflect.Struct: case reflect.Struct:
if _, ok := fieldValue.Interface().(*time.Time); ok { if _, ok := fieldValue.Interface().(*time.Time); ok {
field.DataType = Time field.DataType = Time
@ -216,36 +258,22 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
return field return field
} }
// ValueOf field value of
func (field *Field) ValueOf(value reflect.Value) interface{} {
if field != nil {
return field.Valuer(value)
}
return nil
}
func (field *Field) Set(value reflect.Value, v interface{}) error {
if field != nil {
return field.Setter(value, v)
}
return fmt.Errorf("failed to set field value: %v", field.Name)
}
// create valuer, setter when parse struct // create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() { func (field *Field) setupValuerAndSetter() {
// Valuer // ValueOf
switch { switch {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
field.Valuer = func(value reflect.Value) interface{} { field.ValueOf = func(value reflect.Value) (interface{}, bool) {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Interface() fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
return fieldValue.Interface(), fieldValue.IsZero()
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.Valuer = func(value reflect.Value) interface{} { field.ValueOf = func(value reflect.Value) (interface{}, bool) {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
return fieldValue.Interface(), fieldValue.IsZero()
} }
default: default:
field.Valuer = func(value reflect.Value) interface{} { field.ValueOf = func(value reflect.Value) (interface{}, bool) {
v := reflect.Indirect(value) v := reflect.Indirect(value)
for _, idx := range field.StructField.Index { for _, idx := range field.StructField.Index {
@ -259,19 +287,19 @@ func (field *Field) setupValuerAndSetter() {
v = v.Elem() v = v.Elem()
} }
} else { } else {
return nil return nil, true
} }
} }
} }
return v.Interface() return v.Interface(), v.IsZero()
} }
} }
// ReflectValuer // ReflectValueOf
switch { switch {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
if field.FieldType.Kind() == reflect.Ptr { if field.FieldType.Kind() == reflect.Ptr {
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(value reflect.Value) reflect.Value {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
if fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
@ -279,16 +307,16 @@ func (field *Field) setupValuerAndSetter() {
return fieldValue return fieldValue
} }
} else { } else {
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]) return reflect.Indirect(value).Field(field.StructField.Index[0])
} }
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
} }
default: default:
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(value reflect.Value) reflect.Value {
v := reflect.Indirect(value) v := reflect.Indirect(value)
for _, idx := range field.StructField.Index { for _, idx := range field.StructField.Index {
if idx >= 0 { if idx >= 0 {
@ -316,168 +344,184 @@ func (field *Field) setupValuerAndSetter() {
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
return setter(value, v) return setter(value, v)
} }
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface()) return field.Set(value, reflectV.Elem().Interface())
} else { } else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} }
return err return err
} }
// Setter // Set
switch field.FieldType.Kind() { switch field.FieldType.Kind() {
case reflect.Bool: case reflect.Bool:
field.Setter = func(value reflect.Value, v interface{}) error { field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case bool: case bool:
field.ReflectValuer(value).SetBool(data) field.ReflectValueOf(value).SetBool(data)
case *bool: case *bool:
field.ReflectValuer(value).SetBool(*data) field.ReflectValueOf(value).SetBool(*data)
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return nil return nil
} }
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case int64: case int64:
field.ReflectValuer(value).SetInt(data) field.ReflectValueOf(value).SetInt(data)
case int: case int:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case int8: case int8:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case int16: case int16:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case int32: case int32:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case uint: case uint:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case uint8: case uint8:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case uint16: case uint16:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case uint32: case uint32:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case uint64: case uint64:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case float32: case float32:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case float64: case float64:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case []byte: case []byte:
return field.Setter(value, string(data)) return field.Set(value, string(data))
case string: case string:
if i, err := strconv.ParseInt(data, 0, 64); err == nil { if i, err := strconv.ParseInt(data, 0, 64); err == nil {
field.ReflectValuer(value).SetInt(i) field.ReflectValueOf(value).SetInt(i)
} else { } else {
return err return err
} }
case time.Time:
if field.AutoCreateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano())
} else {
field.ReflectValueOf(value).SetInt(data.Unix())
}
case *time.Time:
if data != nil {
if field.AutoCreateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano())
} else {
field.ReflectValueOf(value).SetInt(data.Unix())
}
} else {
field.ReflectValueOf(value).SetInt(0)
}
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return err return err
} }
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case uint64: case uint64:
field.ReflectValuer(value).SetUint(data) field.ReflectValueOf(value).SetUint(data)
case uint: case uint:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case uint8: case uint8:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case uint16: case uint16:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case uint32: case uint32:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case int64: case int64:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case int: case int:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case int8: case int8:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case int16: case int16:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case int32: case int32:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case float32: case float32:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case float64: case float64:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case []byte: case []byte:
return field.Setter(value, string(data)) return field.Set(value, string(data))
case string: case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil { if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValuer(value).SetUint(i) field.ReflectValueOf(value).SetUint(i)
} else { } else {
return err return err
} }
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return err return err
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case float64: case float64:
field.ReflectValuer(value).SetFloat(data) field.ReflectValueOf(value).SetFloat(data)
case float32: case float32:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case int64: case int64:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case int: case int:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case int8: case int8:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case int16: case int16:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case int32: case int32:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case uint: case uint:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case uint8: case uint8:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case uint16: case uint16:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case uint32: case uint32:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case uint64: case uint64:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case []byte: case []byte:
return field.Setter(value, string(data)) return field.Set(value, string(data))
case string: case string:
if i, err := strconv.ParseFloat(data, 64); err == nil { if i, err := strconv.ParseFloat(data, 64); err == nil {
field.ReflectValuer(value).SetFloat(i) field.ReflectValueOf(value).SetFloat(i)
} else { } else {
return err return err
} }
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return err return err
} }
case reflect.String: case reflect.String:
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case string: case string:
field.ReflectValuer(value).SetString(data) field.ReflectValueOf(value).SetString(data)
case []byte: case []byte:
field.ReflectValuer(value).SetString(string(data)) field.ReflectValueOf(value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValuer(value).SetString(fmt.Sprint(data)) field.ReflectValueOf(value).SetString(fmt.Sprint(data))
case float64, float32: case float64, float32:
field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return err return err
} }
@ -485,77 +529,77 @@ func (field *Field) setupValuerAndSetter() {
fieldValue := reflect.New(field.FieldType) fieldValue := reflect.New(field.FieldType)
switch fieldValue.Elem().Interface().(type) { switch fieldValue.Elem().Interface().(type) {
case time.Time: case time.Time:
field.Setter = func(value reflect.Value, v interface{}) error { field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case time.Time: case time.Time:
field.ReflectValuer(value).Set(reflect.ValueOf(v)) field.ReflectValueOf(value).Set(reflect.ValueOf(v))
case *time.Time: case *time.Time:
field.ReflectValuer(value).Set(reflect.ValueOf(v).Elem()) field.ReflectValueOf(value).Set(reflect.ValueOf(v).Elem())
case string: case string:
if t, err := now.Parse(data); err == nil { if t, err := now.Parse(data); err == nil {
field.ReflectValuer(value).Set(reflect.ValueOf(t)) field.ReflectValueOf(value).Set(reflect.ValueOf(t))
} else { } else {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return nil return nil
} }
case *time.Time: case *time.Time:
field.Setter = func(value reflect.Value, v interface{}) error { field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case time.Time: case time.Time:
field.ReflectValuer(value).Elem().Set(reflect.ValueOf(v)) field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(v))
case *time.Time: case *time.Time:
field.ReflectValuer(value).Set(reflect.ValueOf(v)) field.ReflectValueOf(value).Set(reflect.ValueOf(v))
case string: case string:
if t, err := now.Parse(data); err == nil { if t, err := now.Parse(data); err == nil {
field.ReflectValuer(value).Elem().Set(reflect.ValueOf(t)) field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(t))
} else { } else {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return nil return nil
} }
default: default:
if _, ok := fieldValue.Interface().(sql.Scanner); ok { if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner // struct scanner
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
} }
} else { } else {
err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
} }
return return
} }
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner // pointer scanner
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
} }
} else { } else {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
} }
return return
} }
} else { } else {
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
} }
} }

View File

@ -18,6 +18,7 @@ type Schema struct {
ModelType reflect.Type ModelType reflect.Type
Table string Table string
PrioritizedPrimaryField *Field PrioritizedPrimaryField *Field
DBNames []string
PrimaryFields []*Field PrimaryFields []*Field
Fields []*Field Fields []*Field
FieldsByName map[string]*Field FieldsByName map[string]*Field
@ -99,6 +100,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if field.DBName != "" { if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission // nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName)
}
schema.FieldsByDBName[field.DBName] = field schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field

View File

@ -198,7 +198,7 @@ func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[
var ( var (
checker func(fv interface{}, v interface{}) checker func(fv interface{}, v interface{})
field = s.FieldsByDBName[k] field = s.FieldsByDBName[k]
fv = field.ValueOf(value) fv, _ = field.ValueOf(value)
) )
checker = func(fv interface{}, v interface{}) { checker = func(fv interface{}, v interface{}) {