Update schema

This commit is contained in:
Jinzhu 2020-02-18 22:56:37 +08:00
parent 98ad29f2c2
commit cbbf8f3d49
3 changed files with 188 additions and 140 deletions

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -14,6 +15,13 @@ import (
type DataType string type DataType string
type TimeType int64
const (
UnixSecond TimeType = 1
UnixNanosecond TimeType = 2
)
const ( const (
Bool DataType = "bool" Bool DataType = "bool"
Int = "int" Int = "int"
@ -35,7 +43,10 @@ type Field struct {
Creatable bool Creatable bool
Updatable bool Updatable bool
HasDefaultValue bool HasDefaultValue bool
AutoCreateTime TimeType
AutoUpdateTime TimeType
DefaultValue string DefaultValue string
DefaultValueInterface interface{}
NotNull bool NotNull bool
Unique bool Unique bool
Comment string Comment string
@ -48,9 +59,9 @@ type Field struct {
TagSettings map[string]string TagSettings map[string]string
Schema *Schema Schema *Schema
EmbeddedSchema *Schema EmbeddedSchema *Schema
ReflectValuer func(reflect.Value) reflect.Value ReflectValueOf func(reflect.Value) reflect.Value
Valuer func(reflect.Value) interface{} ValueOf func(reflect.Value) (value interface{}, zero bool)
Setter func(reflect.Value, interface{}) error Set func(reflect.Value, interface{}) error
} }
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
@ -73,7 +84,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
fieldValue := reflect.New(field.IndirectFieldType) fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first fields as data type // if field is valuer, used its value or first fields as data type
if valuer, isValuer := fieldValue.Interface().(driver.Valuer); isValuer { if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf {
var overrideFieldValue bool var overrideFieldValue bool
if v, err := valuer.Value(); v != nil && err == nil { if v, err := valuer.Value(); v != nil && err == nil {
overrideFieldValue = true overrideFieldValue = true
@ -150,17 +161,48 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
field.DBDataType = val field.DBDataType = val
} }
if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) {
if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond
} else {
field.AutoCreateTime = UnixSecond
}
}
if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) {
if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond
} else {
field.AutoUpdateTime = UnixSecond
}
}
switch fieldValue.Elem().Kind() { switch fieldValue.Elem().Kind() {
case reflect.Bool: case reflect.Bool:
field.DataType = Bool field.DataType = Bool
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.DataType = Int field.DataType = Int
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.DataType = Uint field.DataType = Uint
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64)
}
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
field.DataType = Float field.DataType = Float
if field.HasDefaultValue {
field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64)
}
case reflect.String: case reflect.String:
field.DataType = String field.DataType = String
if field.HasDefaultValue {
field.DefaultValueInterface = field.DefaultValue
}
case reflect.Struct: case reflect.Struct:
if _, ok := fieldValue.Interface().(*time.Time); ok { if _, ok := fieldValue.Interface().(*time.Time); ok {
field.DataType = Time field.DataType = Time
@ -216,36 +258,22 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
return field return field
} }
// ValueOf field value of
func (field *Field) ValueOf(value reflect.Value) interface{} {
if field != nil {
return field.Valuer(value)
}
return nil
}
func (field *Field) Set(value reflect.Value, v interface{}) error {
if field != nil {
return field.Setter(value, v)
}
return fmt.Errorf("failed to set field value: %v", field.Name)
}
// create valuer, setter when parse struct // create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() { func (field *Field) setupValuerAndSetter() {
// Valuer // ValueOf
switch { switch {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
field.Valuer = func(value reflect.Value) interface{} { field.ValueOf = func(value reflect.Value) (interface{}, bool) {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Interface() fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
return fieldValue.Interface(), fieldValue.IsZero()
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
field.Valuer = func(value reflect.Value) interface{} { field.ValueOf = func(value reflect.Value) (interface{}, bool) {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]).Interface() fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
return fieldValue.Interface(), fieldValue.IsZero()
} }
default: default:
field.Valuer = func(value reflect.Value) interface{} { field.ValueOf = func(value reflect.Value) (interface{}, bool) {
v := reflect.Indirect(value) v := reflect.Indirect(value)
for _, idx := range field.StructField.Index { for _, idx := range field.StructField.Index {
@ -259,19 +287,19 @@ func (field *Field) setupValuerAndSetter() {
v = v.Elem() v = v.Elem()
} }
} else { } else {
return nil return nil, true
} }
} }
} }
return v.Interface() return v.Interface(), v.IsZero()
} }
} }
// ReflectValuer // ReflectValueOf
switch { switch {
case len(field.StructField.Index) == 1: case len(field.StructField.Index) == 1:
if field.FieldType.Kind() == reflect.Ptr { if field.FieldType.Kind() == reflect.Ptr {
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(value reflect.Value) reflect.Value {
fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
if fieldValue.IsNil() { if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem())) fieldValue.Set(reflect.New(field.FieldType.Elem()))
@ -279,16 +307,16 @@ func (field *Field) setupValuerAndSetter() {
return fieldValue return fieldValue
} }
} else { } else {
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]) return reflect.Indirect(value).Field(field.StructField.Index[0])
} }
} }
case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
} }
default: default:
field.ReflectValuer = func(value reflect.Value) reflect.Value { field.ReflectValueOf = func(value reflect.Value) reflect.Value {
v := reflect.Indirect(value) v := reflect.Indirect(value)
for _, idx := range field.StructField.Index { for _, idx := range field.StructField.Index {
if idx >= 0 { if idx >= 0 {
@ -316,168 +344,184 @@ func (field *Field) setupValuerAndSetter() {
recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
return setter(value, v) return setter(value, v)
} }
} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if reflectV.Kind() == reflect.Ptr { } else if reflectV.Kind() == reflect.Ptr {
return field.Setter(value, reflectV.Elem().Interface()) return field.Set(value, reflectV.Elem().Interface())
} else { } else {
return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) return fmt.Errorf("failed to set value %+v to field %v", v, field.Name)
} }
return err return err
} }
// Setter // Set
switch field.FieldType.Kind() { switch field.FieldType.Kind() {
case reflect.Bool: case reflect.Bool:
field.Setter = func(value reflect.Value, v interface{}) error { field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case bool: case bool:
field.ReflectValuer(value).SetBool(data) field.ReflectValueOf(value).SetBool(data)
case *bool: case *bool:
field.ReflectValuer(value).SetBool(*data) field.ReflectValueOf(value).SetBool(*data)
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return nil return nil
} }
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case int64: case int64:
field.ReflectValuer(value).SetInt(data) field.ReflectValueOf(value).SetInt(data)
case int: case int:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case int8: case int8:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case int16: case int16:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case int32: case int32:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case uint: case uint:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case uint8: case uint8:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case uint16: case uint16:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case uint32: case uint32:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case uint64: case uint64:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case float32: case float32:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case float64: case float64:
field.ReflectValuer(value).SetInt(int64(data)) field.ReflectValueOf(value).SetInt(int64(data))
case []byte: case []byte:
return field.Setter(value, string(data)) return field.Set(value, string(data))
case string: case string:
if i, err := strconv.ParseInt(data, 0, 64); err == nil { if i, err := strconv.ParseInt(data, 0, 64); err == nil {
field.ReflectValuer(value).SetInt(i) field.ReflectValueOf(value).SetInt(i)
} else { } else {
return err return err
} }
case time.Time:
if field.AutoCreateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano())
} else {
field.ReflectValueOf(value).SetInt(data.Unix())
}
case *time.Time:
if data != nil {
if field.AutoCreateTime == UnixNanosecond {
field.ReflectValueOf(value).SetInt(data.UnixNano())
} else {
field.ReflectValueOf(value).SetInt(data.Unix())
}
} else {
field.ReflectValueOf(value).SetInt(0)
}
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return err return err
} }
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case uint64: case uint64:
field.ReflectValuer(value).SetUint(data) field.ReflectValueOf(value).SetUint(data)
case uint: case uint:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case uint8: case uint8:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case uint16: case uint16:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case uint32: case uint32:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case int64: case int64:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case int: case int:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case int8: case int8:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case int16: case int16:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case int32: case int32:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case float32: case float32:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case float64: case float64:
field.ReflectValuer(value).SetUint(uint64(data)) field.ReflectValueOf(value).SetUint(uint64(data))
case []byte: case []byte:
return field.Setter(value, string(data)) return field.Set(value, string(data))
case string: case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil { if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValuer(value).SetUint(i) field.ReflectValueOf(value).SetUint(i)
} else { } else {
return err return err
} }
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return err return err
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case float64: case float64:
field.ReflectValuer(value).SetFloat(data) field.ReflectValueOf(value).SetFloat(data)
case float32: case float32:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case int64: case int64:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case int: case int:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case int8: case int8:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case int16: case int16:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case int32: case int32:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case uint: case uint:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case uint8: case uint8:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case uint16: case uint16:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case uint32: case uint32:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case uint64: case uint64:
field.ReflectValuer(value).SetFloat(float64(data)) field.ReflectValueOf(value).SetFloat(float64(data))
case []byte: case []byte:
return field.Setter(value, string(data)) return field.Set(value, string(data))
case string: case string:
if i, err := strconv.ParseFloat(data, 64); err == nil { if i, err := strconv.ParseFloat(data, 64); err == nil {
field.ReflectValuer(value).SetFloat(i) field.ReflectValueOf(value).SetFloat(i)
} else { } else {
return err return err
} }
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return err return err
} }
case reflect.String: case reflect.String:
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
switch data := v.(type) { switch data := v.(type) {
case string: case string:
field.ReflectValuer(value).SetString(data) field.ReflectValueOf(value).SetString(data)
case []byte: case []byte:
field.ReflectValuer(value).SetString(string(data)) field.ReflectValueOf(value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValuer(value).SetString(fmt.Sprint(data)) field.ReflectValueOf(value).SetString(fmt.Sprint(data))
case float64, float32: case float64, float32:
field.ReflectValuer(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return err return err
} }
@ -485,77 +529,77 @@ func (field *Field) setupValuerAndSetter() {
fieldValue := reflect.New(field.FieldType) fieldValue := reflect.New(field.FieldType)
switch fieldValue.Elem().Interface().(type) { switch fieldValue.Elem().Interface().(type) {
case time.Time: case time.Time:
field.Setter = func(value reflect.Value, v interface{}) error { field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case time.Time: case time.Time:
field.ReflectValuer(value).Set(reflect.ValueOf(v)) field.ReflectValueOf(value).Set(reflect.ValueOf(v))
case *time.Time: case *time.Time:
field.ReflectValuer(value).Set(reflect.ValueOf(v).Elem()) field.ReflectValueOf(value).Set(reflect.ValueOf(v).Elem())
case string: case string:
if t, err := now.Parse(data); err == nil { if t, err := now.Parse(data); err == nil {
field.ReflectValuer(value).Set(reflect.ValueOf(t)) field.ReflectValueOf(value).Set(reflect.ValueOf(t))
} else { } else {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return nil return nil
} }
case *time.Time: case *time.Time:
field.Setter = func(value reflect.Value, v interface{}) error { field.Set = func(value reflect.Value, v interface{}) error {
switch data := v.(type) { switch data := v.(type) {
case time.Time: case time.Time:
field.ReflectValuer(value).Elem().Set(reflect.ValueOf(v)) field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(v))
case *time.Time: case *time.Time:
field.ReflectValuer(value).Set(reflect.ValueOf(v)) field.ReflectValueOf(value).Set(reflect.ValueOf(v))
case string: case string:
if t, err := now.Parse(data); err == nil { if t, err := now.Parse(data); err == nil {
field.ReflectValuer(value).Elem().Set(reflect.ValueOf(t)) field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(t))
} else { } else {
return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
} }
default: default:
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
return nil return nil
} }
default: default:
if _, ok := fieldValue.Interface().(sql.Scanner); ok { if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner // struct scanner
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
} }
} else { } else {
err = field.ReflectValuer(value).Addr().Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v)
} }
return return
} }
} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner // pointer scanner
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v) reflectV := reflect.ValueOf(v)
if reflectV.Type().ConvertibleTo(field.FieldType) { if reflectV.Type().ConvertibleTo(field.FieldType) {
field.ReflectValuer(value).Set(reflectV.Convert(field.FieldType)) field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType))
} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) {
field.ReflectValuer(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem()))
} else if valuer, ok := v.(driver.Valuer); ok { } else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil { if v, err = valuer.Value(); err == nil {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
} }
} else { } else {
err = field.ReflectValuer(value).Interface().(sql.Scanner).Scan(v) err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v)
} }
return return
} }
} else { } else {
field.Setter = func(value reflect.Value, v interface{}) (err error) { field.Set = func(value reflect.Value, v interface{}) (err error) {
return recoverFunc(value, v, field.Setter) return recoverFunc(value, v, field.Set)
} }
} }
} }

View File

@ -18,6 +18,7 @@ type Schema struct {
ModelType reflect.Type ModelType reflect.Type
Table string Table string
PrioritizedPrimaryField *Field PrioritizedPrimaryField *Field
DBNames []string
PrimaryFields []*Field PrimaryFields []*Field
Fields []*Field Fields []*Field
FieldsByName map[string]*Field FieldsByName map[string]*Field
@ -99,6 +100,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
if field.DBName != "" { if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission // nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName)
}
schema.FieldsByDBName[field.DBName] = field schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field schema.FieldsByName[field.Name] = field

View File

@ -198,7 +198,7 @@ func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[
var ( var (
checker func(fv interface{}, v interface{}) checker func(fv interface{}, v interface{})
field = s.FieldsByDBName[k] field = s.FieldsByDBName[k]
fv = field.ValueOf(value) fv, _ = field.ValueOf(value)
) )
checker = func(fv interface{}, v interface{}) { checker = func(fv interface{}, v interface{}) {