mirror of https://github.com/go-gorm/gorm.git
Refactor DataTypeOf
This commit is contained in:
parent
d92c5db9e7
commit
2dfd76d22b
|
@ -12,7 +12,7 @@ type Dialect interface {
|
|||
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||
Quote(key string) string
|
||||
// DataTypeOf return data's sql type
|
||||
DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
|
||||
DataTypeOf(value reflect.Value, tagSettings map[string]string) string
|
||||
|
||||
// HasIndex check has index or not
|
||||
HasIndex(scope *Scope, tableName string, indexName string) bool
|
||||
|
|
|
@ -3,6 +3,7 @@ package gorm
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -16,17 +17,22 @@ func (commonDialect) Quote(key string) string {
|
|||
return fmt.Sprintf(`"%s"`, key)
|
||||
}
|
||||
|
||||
func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
func (commonDialect) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
||||
var size int
|
||||
if num, ok := tagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
}
|
||||
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
return "BOOLEAN"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "INTEGER AUTO_INCREMENT"
|
||||
}
|
||||
return "INTEGER"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "BIGINT AUTO_INCREMENT"
|
||||
}
|
||||
return "BIGINT"
|
||||
|
@ -38,18 +44,18 @@ func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool
|
|||
}
|
||||
return "VARCHAR(65532)"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
return "TIMESTAMP"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("BINARY(%d)", size)
|
||||
}
|
||||
return "BINARY(65532)"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
|
|
|
@ -3,6 +3,7 @@ package gorm
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -10,17 +11,22 @@ type mssql struct {
|
|||
commonDialect
|
||||
}
|
||||
|
||||
func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
func (mssql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
||||
var size int
|
||||
if num, ok := tagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
}
|
||||
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bit"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "int IDENTITY(1,1)"
|
||||
}
|
||||
return "int"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "bigint IDENTITY(1,1)"
|
||||
}
|
||||
return "bigint"
|
||||
|
@ -32,18 +38,18 @@ func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
|
|||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
return "datetime2"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
}
|
||||
return "text"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
|
|
|
@ -3,6 +3,7 @@ package gorm
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -14,27 +15,32 @@ func (mysql) Quote(key string) string {
|
|||
return fmt.Sprintf("`%s`", key)
|
||||
}
|
||||
|
||||
func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
func (mysql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
||||
var size int
|
||||
if num, ok := tagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
}
|
||||
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "int AUTO_INCREMENT"
|
||||
}
|
||||
return "int"
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "int unsigned AUTO_INCREMENT"
|
||||
}
|
||||
return "int unsigned"
|
||||
case reflect.Int64:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "bigint AUTO_INCREMENT"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Uint64:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "bigint unsigned AUTO_INCREMENT"
|
||||
}
|
||||
return "bigint unsigned"
|
||||
|
@ -46,18 +52,18 @@ func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
|
|||
}
|
||||
return "longtext"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
return "timestamp NULL"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varbinary(%d)", size)
|
||||
}
|
||||
return "longblob"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
func (s mysql) currentDatabase(scope *Scope) (name string) {
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -19,17 +20,22 @@ func (postgres) BindVar(i int) string {
|
|||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
func (postgres) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
||||
var size int
|
||||
if num, ok := tagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
}
|
||||
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "serial"
|
||||
}
|
||||
return "integer"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "bigserial"
|
||||
}
|
||||
return "bigint"
|
||||
|
@ -41,21 +47,21 @@ func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) str
|
|||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
return "timestamp with time zone"
|
||||
}
|
||||
case reflect.Map:
|
||||
if value.Type() == hstoreType {
|
||||
if dataValue.Type() == hstoreType {
|
||||
return "hstore"
|
||||
}
|
||||
default:
|
||||
if isByteArrayOrSlice(value) {
|
||||
if isByteArrayOrSlice(dataValue) {
|
||||
return "bytea"
|
||||
} else if isUUID(value) {
|
||||
} else if isUUID(dataValue) {
|
||||
return "uuid"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
|
|
|
@ -3,6 +3,7 @@ package gorm
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -10,17 +11,22 @@ type sqlite3 struct {
|
|||
commonDialect
|
||||
}
|
||||
|
||||
func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
func (sqlite3) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
||||
var size int
|
||||
if num, ok := tagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
}
|
||||
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bool"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "integer primary key autoincrement"
|
||||
}
|
||||
return "integer"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "integer primary key autoincrement"
|
||||
}
|
||||
return "bigint"
|
||||
|
@ -32,15 +38,15 @@ func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) stri
|
|||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if _, ok := value.Interface().(time.Time); ok {
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
return "datetime"
|
||||
}
|
||||
default:
|
||||
if _, ok := value.Interface().([]byte); ok {
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
return "blob"
|
||||
}
|
||||
}
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -541,21 +540,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
|
|||
}
|
||||
|
||||
if sqlType == "" {
|
||||
var size = 255
|
||||
|
||||
if value, ok := field.TagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(value)
|
||||
}
|
||||
|
||||
v, autoIncrease := field.TagSettings["AUTO_INCREMENT"]
|
||||
if field.IsPrimaryKey {
|
||||
autoIncrease = true
|
||||
}
|
||||
if v == "FALSE" {
|
||||
autoIncrease = false
|
||||
}
|
||||
|
||||
sqlType = scope.Dialect().DataTypeOf(reflectValue, size, autoIncrease)
|
||||
sqlType = scope.Dialect().DataTypeOf(reflectValue, field.TagSettings)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
|
|
|
@ -516,11 +516,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
if field, ok := scope.Fields()[fieldName]; ok {
|
||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
primaryKeySqlType := field.TagSettings["TYPE"]
|
||||
if primaryKeySqlType == "" {
|
||||
primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false)
|
||||
}
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
|
||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||
}
|
||||
}
|
||||
|
@ -528,11 +524,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := toScope.Fields()[fieldName]; ok {
|
||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
primaryKeySqlType := field.TagSettings["TYPE"]
|
||||
if primaryKeySqlType == "" {
|
||||
primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false)
|
||||
}
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
|
||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue