Refactor DataTypeOf

This commit is contained in:
Jinzhu 2016-01-19 20:58:38 +08:00
parent d92c5db9e7
commit 2dfd76d22b
8 changed files with 73 additions and 66 deletions

View File

@ -12,7 +12,7 @@ type Dialect interface {
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
Quote(key string) string Quote(key string) string
// DataTypeOf return data's sql type // DataTypeOf return data's sql type
DataTypeOf(value reflect.Value, size int, autoIncrease bool) string DataTypeOf(value reflect.Value, tagSettings map[string]string) string
// HasIndex check has index or not // HasIndex check has index or not
HasIndex(scope *Scope, tableName string, indexName string) bool HasIndex(scope *Scope, tableName string, indexName string) bool

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"time" "time"
) )
@ -16,17 +17,22 @@ func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key) return fmt.Sprintf(`"%s"`, key)
} }
func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { func (commonDialect) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
switch value.Kind() { var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
switch dataValue.Kind() {
case reflect.Bool: case reflect.Bool:
return "BOOLEAN" return "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "INTEGER AUTO_INCREMENT" return "INTEGER AUTO_INCREMENT"
} }
return "INTEGER" return "INTEGER"
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "BIGINT AUTO_INCREMENT" return "BIGINT AUTO_INCREMENT"
} }
return "BIGINT" return "BIGINT"
@ -38,18 +44,18 @@ func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool
} }
return "VARCHAR(65532)" return "VARCHAR(65532)"
case reflect.Struct: case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok { if _, ok := dataValue.Interface().(time.Time); ok {
return "TIMESTAMP" return "TIMESTAMP"
} }
default: default:
if _, ok := value.Interface().([]byte); ok { if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("BINARY(%d)", size) return fmt.Sprintf("BINARY(%d)", size)
} }
return "BINARY(65532)" return "BINARY(65532)"
} }
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
} }
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"time" "time"
) )
@ -10,17 +11,22 @@ type mssql struct {
commonDialect commonDialect
} }
func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { func (mssql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
switch value.Kind() { var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
switch dataValue.Kind() {
case reflect.Bool: case reflect.Bool:
return "bit" return "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "int IDENTITY(1,1)" return "int IDENTITY(1,1)"
} }
return "int" return "int"
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigint IDENTITY(1,1)" return "bigint IDENTITY(1,1)"
} }
return "bigint" return "bigint"
@ -32,18 +38,18 @@ func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
} }
return "text" return "text"
case reflect.Struct: case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok { if _, ok := dataValue.Interface().(time.Time); ok {
return "datetime2" return "datetime2"
} }
default: default:
if _, ok := value.Interface().([]byte); ok { if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size) return fmt.Sprintf("varchar(%d)", size)
} }
return "text" return "text"
} }
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
} }
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"time" "time"
) )
@ -14,27 +15,32 @@ func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key) return fmt.Sprintf("`%s`", key)
} }
func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { func (mysql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
switch value.Kind() { var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
switch dataValue.Kind() {
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "int AUTO_INCREMENT" return "int AUTO_INCREMENT"
} }
return "int" return "int"
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "int unsigned AUTO_INCREMENT" return "int unsigned AUTO_INCREMENT"
} }
return "int unsigned" return "int unsigned"
case reflect.Int64: case reflect.Int64:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigint AUTO_INCREMENT" return "bigint AUTO_INCREMENT"
} }
return "bigint" return "bigint"
case reflect.Uint64: case reflect.Uint64:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigint unsigned AUTO_INCREMENT" return "bigint unsigned AUTO_INCREMENT"
} }
return "bigint unsigned" return "bigint unsigned"
@ -46,18 +52,18 @@ func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
} }
return "longtext" return "longtext"
case reflect.Struct: case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok { if _, ok := dataValue.Interface().(time.Time); ok {
return "timestamp NULL" return "timestamp NULL"
} }
default: default:
if _, ok := value.Interface().([]byte); ok { if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 { if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size) return fmt.Sprintf("varbinary(%d)", size)
} }
return "longblob" return "longblob"
} }
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
} }
func (s mysql) currentDatabase(scope *Scope) (name string) { func (s mysql) currentDatabase(scope *Scope) (name string) {

View File

@ -5,6 +5,7 @@ import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"strings" "strings"
"time" "time"
@ -19,17 +20,22 @@ func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i) return fmt.Sprintf("$%v", i)
} }
func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { func (postgres) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
switch value.Kind() { var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
switch dataValue.Kind() {
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "serial" return "serial"
} }
return "integer" return "integer"
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigserial" return "bigserial"
} }
return "bigint" return "bigint"
@ -41,21 +47,21 @@ func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) str
} }
return "text" return "text"
case reflect.Struct: case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok { if _, ok := dataValue.Interface().(time.Time); ok {
return "timestamp with time zone" return "timestamp with time zone"
} }
case reflect.Map: case reflect.Map:
if value.Type() == hstoreType { if dataValue.Type() == hstoreType {
return "hstore" return "hstore"
} }
default: default:
if isByteArrayOrSlice(value) { if isByteArrayOrSlice(dataValue) {
return "bytea" return "bytea"
} else if isUUID(value) { } else if isUUID(dataValue) {
return "uuid" return "uuid"
} }
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
} }
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"time" "time"
) )
@ -10,17 +11,22 @@ type sqlite3 struct {
commonDialect commonDialect
} }
func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { func (sqlite3) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
switch value.Kind() { var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
switch dataValue.Kind() {
case reflect.Bool: case reflect.Bool:
return "bool" return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "integer primary key autoincrement" return "integer primary key autoincrement"
} }
return "integer" return "integer"
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
if autoIncrease { if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "integer primary key autoincrement" return "integer primary key autoincrement"
} }
return "bigint" return "bigint"
@ -32,15 +38,15 @@ func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) stri
} }
return "text" return "text"
case reflect.Struct: case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok { if _, ok := dataValue.Interface().(time.Time); ok {
return "datetime" return "datetime"
} }
default: default:
if _, ok := value.Interface().([]byte); ok { if _, ok := dataValue.Interface().([]byte); ok {
return "blob" return "blob"
} }
} }
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
} }
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"go/ast" "go/ast"
"reflect" "reflect"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -541,21 +540,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
} }
if sqlType == "" { if sqlType == "" {
var size = 255 sqlType = scope.Dialect().DataTypeOf(reflectValue, field.TagSettings)
if value, ok := field.TagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(value)
}
v, autoIncrease := field.TagSettings["AUTO_INCREMENT"]
if field.IsPrimaryKey {
autoIncrease = true
}
if v == "FALSE" {
autoIncrease = false
}
sqlType = scope.Dialect().DataTypeOf(reflectValue, size, autoIncrease)
} }
if strings.TrimSpace(additionalType) == "" { if strings.TrimSpace(additionalType) == "" {

View File

@ -516,11 +516,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
for idx, fieldName := range relationship.ForeignFieldNames { for idx, fieldName := range relationship.ForeignFieldNames {
if field, ok := scope.Fields()[fieldName]; ok { if field, ok := scope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type)) value := reflect.Indirect(reflect.New(field.Struct.Type))
primaryKeySqlType := field.TagSettings["TYPE"] sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
if primaryKeySqlType == "" {
primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false)
}
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
} }
} }
@ -528,11 +524,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
for idx, fieldName := range relationship.AssociationForeignFieldNames { for idx, fieldName := range relationship.AssociationForeignFieldNames {
if field, ok := toScope.Fields()[fieldName]; ok { if field, ok := toScope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type)) value := reflect.Indirect(reflect.New(field.Struct.Type))
primaryKeySqlType := field.TagSettings["TYPE"] sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
if primaryKeySqlType == "" {
primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false)
}
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
} }
} }