Refactor DataTypeOf

This commit is contained in:
Jinzhu 2016-01-19 20:58:38 +08:00
parent d92c5db9e7
commit 2dfd76d22b
8 changed files with 73 additions and 66 deletions

View File

@ -12,7 +12,7 @@ type Dialect interface {
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
Quote(key string) string
// DataTypeOf return data's sql type
DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
DataTypeOf(value reflect.Value, tagSettings map[string]string) string
// HasIndex check has index or not
HasIndex(scope *Scope, tableName string, indexName string) bool

View File

@ -3,6 +3,7 @@ package gorm
import (
"fmt"
"reflect"
"strconv"
"time"
)
@ -16,17 +17,22 @@ func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
func (commonDialect) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
switch dataValue.Kind() {
case reflect.Bool:
return "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "INTEGER AUTO_INCREMENT"
}
return "INTEGER"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "BIGINT AUTO_INCREMENT"
}
return "BIGINT"
@ -38,18 +44,18 @@ func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool
}
return "VARCHAR(65532)"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
if _, ok := dataValue.Interface().(time.Time); ok {
return "TIMESTAMP"
}
default:
if _, ok := value.Interface().([]byte); ok {
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("BINARY(%d)", size)
}
return "BINARY(65532)"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
}
func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {

View File

@ -3,6 +3,7 @@ package gorm
import (
"fmt"
"reflect"
"strconv"
"time"
)
@ -10,17 +11,22 @@ type mssql struct {
commonDialect
}
func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
func (mssql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
switch dataValue.Kind() {
case reflect.Bool:
return "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "int IDENTITY(1,1)"
}
return "int"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigint IDENTITY(1,1)"
}
return "bigint"
@ -32,18 +38,18 @@ func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
if _, ok := dataValue.Interface().(time.Time); ok {
return "datetime2"
}
default:
if _, ok := value.Interface().([]byte); ok {
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "text"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String()))
}
func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {

View File

@ -3,6 +3,7 @@ package gorm
import (
"fmt"
"reflect"
"strconv"
"time"
)
@ -14,27 +15,32 @@ func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
func (mysql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
switch dataValue.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "int AUTO_INCREMENT"
}
return "int"
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "int unsigned AUTO_INCREMENT"
}
return "int unsigned"
case reflect.Int64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigint AUTO_INCREMENT"
}
return "bigint"
case reflect.Uint64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigint unsigned AUTO_INCREMENT"
}
return "bigint unsigned"
@ -46,18 +52,18 @@ func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string
}
return "longtext"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
if _, ok := dataValue.Interface().(time.Time); ok {
return "timestamp NULL"
}
default:
if _, ok := value.Interface().([]byte); ok {
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
return fmt.Sprintf("varbinary(%d)", size)
}
return "longblob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
}
func (s mysql) currentDatabase(scope *Scope) (name string) {

View File

@ -5,6 +5,7 @@ import (
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"strings"
"time"
@ -19,17 +20,22 @@ func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
func (postgres) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
switch dataValue.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "bigserial"
}
return "bigint"
@ -41,21 +47,21 @@ func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) str
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
if _, ok := dataValue.Interface().(time.Time); ok {
return "timestamp with time zone"
}
case reflect.Map:
if value.Type() == hstoreType {
if dataValue.Type() == hstoreType {
return "hstore"
}
default:
if isByteArrayOrSlice(value) {
if isByteArrayOrSlice(dataValue) {
return "bytea"
} else if isUUID(value) {
} else if isUUID(dataValue) {
return "uuid"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String()))
}
func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {

View File

@ -3,6 +3,7 @@ package gorm
import (
"fmt"
"reflect"
"strconv"
"time"
)
@ -10,17 +11,22 @@ type sqlite3 struct {
commonDialect
}
func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
func (sqlite3) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
switch dataValue.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "integer primary key autoincrement"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "integer primary key autoincrement"
}
return "bigint"
@ -32,15 +38,15 @@ func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) stri
}
return "text"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
if _, ok := dataValue.Interface().(time.Time); ok {
return "datetime"
}
default:
if _, ok := value.Interface().([]byte); ok {
if _, ok := dataValue.Interface().([]byte); ok {
return "blob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
}
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {

View File

@ -6,7 +6,6 @@ import (
"fmt"
"go/ast"
"reflect"
"strconv"
"strings"
"sync"
"time"
@ -541,21 +540,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
}
if sqlType == "" {
var size = 255
if value, ok := field.TagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(value)
}
v, autoIncrease := field.TagSettings["AUTO_INCREMENT"]
if field.IsPrimaryKey {
autoIncrease = true
}
if v == "FALSE" {
autoIncrease = false
}
sqlType = scope.Dialect().DataTypeOf(reflectValue, size, autoIncrease)
sqlType = scope.Dialect().DataTypeOf(reflectValue, field.TagSettings)
}
if strings.TrimSpace(additionalType) == "" {

View File

@ -516,11 +516,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
for idx, fieldName := range relationship.ForeignFieldNames {
if field, ok := scope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type))
primaryKeySqlType := field.TagSettings["TYPE"]
if primaryKeySqlType == "" {
primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false)
}
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
}
}
@ -528,11 +524,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
for idx, fieldName := range relationship.AssociationForeignFieldNames {
if field, ok := toScope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type))
primaryKeySqlType := field.TagSettings["TYPE"]
if primaryKeySqlType == "" {
primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false)
}
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
}
}