package schema import ( "database/sql" "database/sql/driver" "fmt" "reflect" "strconv" "strings" "sync" "time" "github.com/jinzhu/gorm/utils" "github.com/jinzhu/now" ) type DataType string type TimeType int64 const ( UnixSecond TimeType = 1 UnixNanosecond TimeType = 2 ) const ( Bool DataType = "bool" Int = "int" Uint = "uint" Float = "float" String = "string" Time = "time" Bytes = "bytes" ) type Field struct { Name string DBName string BindNames []string DataType DataType DBDataType string PrimaryKey bool AutoIncrement bool Creatable bool Updatable bool Readable bool HasDefaultValue bool AutoCreateTime TimeType AutoUpdateTime TimeType DefaultValue string DefaultValueInterface interface{} NotNull bool Unique bool Comment string Size int Precision int FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField Tag reflect.StructTag TagSettings map[string]string Schema *Schema EmbeddedSchema *Schema ReflectValueOf func(reflect.Value) reflect.Value ValueOf func(reflect.Value) (value interface{}, zero bool) Set func(reflect.Value, interface{}) error } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field := &Field{ Name: fieldStruct.Name, BindNames: []string{fieldStruct.Name}, FieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type, StructField: fieldStruct, Creatable: true, Updatable: true, Readable: true, Tag: fieldStruct.Tag, TagSettings: ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";"), Schema: schema, } for field.IndirectFieldType.Kind() == reflect.Ptr { field.IndirectFieldType = field.IndirectFieldType.Elem() } fieldValue := reflect.New(field.IndirectFieldType) // if field is valuer, used its value or first fields as data type if valuer, isValueOf := fieldValue.Interface().(driver.Valuer); isValueOf { var overrideFieldValue bool if v, err := valuer.Value(); v != nil && err == nil { overrideFieldValue = true fieldValue = reflect.ValueOf(v) } if field.IndirectFieldType.Kind() == reflect.Struct { for i := 0; i < field.IndirectFieldType.NumField(); i++ { if !overrideFieldValue { newFieldType := field.IndirectFieldType.Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) overrideFieldValue = true } // copy tag settings from valuer for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { if _, ok := field.TagSettings[key]; !ok { field.TagSettings[key] = value } } } } } // setup permission if _, ok := field.TagSettings["-"]; ok { field.Creatable = false field.Updatable = false field.Readable = false } if v, ok := field.TagSettings["<-"]; ok { if v != "<-" { if !strings.Contains(v, "create") { field.Creatable = false } if !strings.Contains(v, "update") { field.Updatable = false } } field.Readable = false } if _, ok := field.TagSettings["->"]; ok { field.Creatable = false field.Updatable = false field.Readable = true } if dbName, ok := field.TagSettings["COLUMN"]; ok { field.DBName = dbName } if val, ok := field.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { field.PrimaryKey = true } else if val, ok := field.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { field.PrimaryKey = true } if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { field.AutoIncrement = true field.HasDefaultValue = true } if v, ok := field.TagSettings["DEFAULT"]; ok { field.HasDefaultValue = true field.DefaultValue = v } if num, ok := field.TagSettings["SIZE"]; ok { var err error if field.Size, err = strconv.Atoi(num); err != nil { field.Size = -1 } } if p, ok := field.TagSettings["PRECISION"]; ok { field.Precision, _ = strconv.Atoi(p) } if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { field.NotNull = true } if val, ok := field.TagSettings["UNIQUE"]; ok && utils.CheckTruth(val) { field.Unique = true } if val, ok := field.TagSettings["COMMENT"]; ok { field.Comment = val } if val, ok := field.TagSettings["TYPE"]; ok { field.DBDataType = val } switch fieldValue.Elem().Kind() { case reflect.Bool: field.DataType = Bool if field.HasDefaultValue { field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int if field.HasDefaultValue { field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint if field.HasDefaultValue { field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64) } case reflect.Float32, reflect.Float64: field.DataType = Float if field.HasDefaultValue { field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64) } case reflect.String: field.DataType = String if field.HasDefaultValue { field.DefaultValueInterface = field.DefaultValue } case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { field.DataType = Time } case reflect.Array, reflect.Slice: if fieldValue.Type().Elem() == reflect.TypeOf(uint8(0)) { field.DataType = Bytes } } if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond } else { field.AutoCreateTime = UnixSecond } } if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int)) { if strings.ToUpper(v) == "NANO" { field.AutoUpdateTime = UnixNanosecond } else { field.AutoUpdateTime = UnixSecond } } if field.Size == 0 { switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: field.Size = 64 case reflect.Int8, reflect.Uint8: field.Size = 8 case reflect.Int16, reflect.Uint16: field.Size = 16 case reflect.Int32, reflect.Uint32, reflect.Float32: field.Size = 32 } } if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { var err error field.Creatable = false field.Updatable = false field.Readable = false if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) // index is negative means is pointer if field.FieldType.Kind() == reflect.Struct { ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) } else { ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) } if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { ef.DBName = prefix + ef.DBName } for k, v := range field.TagSettings { ef.TagSettings[k] = v } } } return field } // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter() { // ValueOf switch { case len(field.StructField.Index) == 1: field.ValueOf = func(value reflect.Value) (interface{}, bool) { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) return fieldValue.Interface(), fieldValue.IsZero() } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0: field.ValueOf = func(value reflect.Value) (interface{}, bool) { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) return fieldValue.Interface(), fieldValue.IsZero() } default: field.ValueOf = func(value reflect.Value) (interface{}, bool) { v := reflect.Indirect(value) for _, idx := range field.StructField.Index { if idx >= 0 { v = v.Field(idx) } else { v = v.Field(-idx - 1) if v.Type().Elem().Kind() == reflect.Struct { if !v.IsNil() { v = v.Elem() } } else { return nil, true } } } return v.Interface(), v.IsZero() } } // ReflectValueOf switch { case len(field.StructField.Index) == 1: if field.FieldType.Kind() == reflect.Ptr { field.ReflectValueOf = func(value reflect.Value) reflect.Value { fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } return fieldValue } } else { field.ReflectValueOf = func(value reflect.Value) reflect.Value { return reflect.Indirect(value).Field(field.StructField.Index[0]) } } case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr: field.ReflectValueOf = func(value reflect.Value) reflect.Value { return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1]) } default: field.ReflectValueOf = func(value reflect.Value) reflect.Value { v := reflect.Indirect(value) for _, idx := range field.StructField.Index { if idx >= 0 { v = v.Field(idx) } else { v = v.Field(-idx - 1) } if v.Kind() == reflect.Ptr { if v.Type().Elem().Kind() == reflect.Struct { if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } } if idx < len(field.StructField.Index)-1 { v = v.Elem() } } } return v } } recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { reflectV := reflect.ValueOf(v) if reflectV.Type().ConvertibleTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { return setter(value, v) } } else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if reflectV.Kind() == reflect.Ptr { return field.Set(value, reflectV.Elem().Interface()) } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } return err } // Set switch field.FieldType.Kind() { case reflect.Bool: field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case bool: field.ReflectValueOf(value).SetBool(data) case *bool: field.ReflectValueOf(value).SetBool(*data) default: return recoverFunc(value, v, field.Set) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case int64: field.ReflectValueOf(value).SetInt(data) case int: field.ReflectValueOf(value).SetInt(int64(data)) case int8: field.ReflectValueOf(value).SetInt(int64(data)) case int16: field.ReflectValueOf(value).SetInt(int64(data)) case int32: field.ReflectValueOf(value).SetInt(int64(data)) case uint: field.ReflectValueOf(value).SetInt(int64(data)) case uint8: field.ReflectValueOf(value).SetInt(int64(data)) case uint16: field.ReflectValueOf(value).SetInt(int64(data)) case uint32: field.ReflectValueOf(value).SetInt(int64(data)) case uint64: field.ReflectValueOf(value).SetInt(int64(data)) case float32: field.ReflectValueOf(value).SetInt(int64(data)) case float64: field.ReflectValueOf(value).SetInt(int64(data)) case []byte: return field.Set(value, string(data)) case string: if i, err := strconv.ParseInt(data, 0, 64); err == nil { field.ReflectValueOf(value).SetInt(i) } else { return err } case time.Time: if field.AutoCreateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) } else { field.ReflectValueOf(value).SetInt(data.Unix()) } case *time.Time: if data != nil { if field.AutoCreateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) } else { field.ReflectValueOf(value).SetInt(data.Unix()) } } else { field.ReflectValueOf(value).SetInt(0) } default: return recoverFunc(value, v, field.Set) } return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case uint64: field.ReflectValueOf(value).SetUint(data) case uint: field.ReflectValueOf(value).SetUint(uint64(data)) case uint8: field.ReflectValueOf(value).SetUint(uint64(data)) case uint16: field.ReflectValueOf(value).SetUint(uint64(data)) case uint32: field.ReflectValueOf(value).SetUint(uint64(data)) case int64: field.ReflectValueOf(value).SetUint(uint64(data)) case int: field.ReflectValueOf(value).SetUint(uint64(data)) case int8: field.ReflectValueOf(value).SetUint(uint64(data)) case int16: field.ReflectValueOf(value).SetUint(uint64(data)) case int32: field.ReflectValueOf(value).SetUint(uint64(data)) case float32: field.ReflectValueOf(value).SetUint(uint64(data)) case float64: field.ReflectValueOf(value).SetUint(uint64(data)) case []byte: return field.Set(value, string(data)) case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { field.ReflectValueOf(value).SetUint(i) } else { return err } default: return recoverFunc(value, v, field.Set) } return err } case reflect.Float32, reflect.Float64: field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case float64: field.ReflectValueOf(value).SetFloat(data) case float32: field.ReflectValueOf(value).SetFloat(float64(data)) case int64: field.ReflectValueOf(value).SetFloat(float64(data)) case int: field.ReflectValueOf(value).SetFloat(float64(data)) case int8: field.ReflectValueOf(value).SetFloat(float64(data)) case int16: field.ReflectValueOf(value).SetFloat(float64(data)) case int32: field.ReflectValueOf(value).SetFloat(float64(data)) case uint: field.ReflectValueOf(value).SetFloat(float64(data)) case uint8: field.ReflectValueOf(value).SetFloat(float64(data)) case uint16: field.ReflectValueOf(value).SetFloat(float64(data)) case uint32: field.ReflectValueOf(value).SetFloat(float64(data)) case uint64: field.ReflectValueOf(value).SetFloat(float64(data)) case []byte: return field.Set(value, string(data)) case string: if i, err := strconv.ParseFloat(data, 64); err == nil { field.ReflectValueOf(value).SetFloat(i) } else { return err } default: return recoverFunc(value, v, field.Set) } return err } case reflect.String: field.Set = func(value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case string: field.ReflectValueOf(value).SetString(data) case []byte: field.ReflectValueOf(value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: field.ReflectValueOf(value).SetString(fmt.Sprint(data)) case float64, float32: field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: return recoverFunc(value, v, field.Set) } return err } default: fieldValue := reflect.New(field.FieldType) switch fieldValue.Elem().Interface().(type) { case time.Time: field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case *time.Time: field.ReflectValueOf(value).Set(reflect.ValueOf(v).Elem()) case string: if t, err := now.Parse(data); err == nil { field.ReflectValueOf(value).Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: return recoverFunc(value, v, field.Set) } return nil } case *time.Time: field.Set = func(value reflect.Value, v interface{}) error { switch data := v.(type) { case time.Time: field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(v)) case *time.Time: field.ReflectValueOf(value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { field.ReflectValueOf(value).Elem().Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) } default: return recoverFunc(value, v, field.Set) } return nil } default: if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if reflectV.Type().ConvertibleTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } } else { err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if reflectV.Type().ConvertibleTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) } else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { field.ReflectValueOf(value).Elem().Set(reflectV.Convert(field.FieldType.Elem())) } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) } } else { err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) } return } } else { field.Set = func(value reflect.Value, v interface{}) (err error) { return recoverFunc(value, v, field.Set) } } } } }