From 2dfd76d22bd75122110dc23225b31b685e6a769f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 19 Jan 2016 20:58:38 +0800 Subject: [PATCH] Refactor DataTypeOf --- dialect.go | 2 +- dialect_common.go | 20 +++++++++++++------- dialect_mssql.go | 20 +++++++++++++------- dialect_mysql.go | 24 +++++++++++++++--------- dialect_postgres.go | 24 +++++++++++++++--------- dialect_sqlite3.go | 20 +++++++++++++------- model_struct.go | 17 +---------------- scope_private.go | 12 ++---------- 8 files changed, 73 insertions(+), 66 deletions(-) diff --git a/dialect.go b/dialect.go index 72b6b2aa..dd3c032e 100644 --- a/dialect.go +++ b/dialect.go @@ -12,7 +12,7 @@ type Dialect interface { // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name Quote(key string) string // DataTypeOf return data's sql type - DataTypeOf(value reflect.Value, size int, autoIncrease bool) string + DataTypeOf(value reflect.Value, tagSettings map[string]string) string // HasIndex check has index or not HasIndex(scope *Scope, tableName string, indexName string) bool diff --git a/dialect_common.go b/dialect_common.go index efc8d642..fc717e17 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -3,6 +3,7 @@ package gorm import ( "fmt" "reflect" + "strconv" "time" ) @@ -16,17 +17,22 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } -func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { +func (commonDialect) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string { + var size int + if num, ok := tagSettings["SIZE"]; ok { + size, _ = strconv.Atoi(num) + } + + switch dataValue.Kind() { case reflect.Bool: return "BOOLEAN" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "INTEGER AUTO_INCREMENT" } return "INTEGER" case reflect.Int64, reflect.Uint64: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "BIGINT AUTO_INCREMENT" } return "BIGINT" @@ -38,18 +44,18 @@ func (commonDialect) DataTypeOf(value reflect.Value, size int, autoIncrease bool } return "VARCHAR(65532)" case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { + if _, ok := dataValue.Interface().(time.Time); ok { return "TIMESTAMP" } default: - if _, ok := value.Interface().([]byte); ok { + if _, ok := dataValue.Interface().([]byte); ok { if size > 0 && size < 65532 { return fmt.Sprintf("BINARY(%d)", size) } return "BINARY(65532)" } } - panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) + panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String())) } func (c commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { diff --git a/dialect_mssql.go b/dialect_mssql.go index c3e21c97..d130badb 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -3,6 +3,7 @@ package gorm import ( "fmt" "reflect" + "strconv" "time" ) @@ -10,17 +11,22 @@ type mssql struct { commonDialect } -func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { +func (mssql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string { + var size int + if num, ok := tagSettings["SIZE"]; ok { + size, _ = strconv.Atoi(num) + } + + switch dataValue.Kind() { case reflect.Bool: return "bit" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "int IDENTITY(1,1)" } return "int" case reflect.Int64, reflect.Uint64: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "bigint IDENTITY(1,1)" } return "bigint" @@ -32,18 +38,18 @@ func (mssql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string } return "text" case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { + if _, ok := dataValue.Interface().(time.Time); ok { return "datetime2" } default: - if _, ok := value.Interface().([]byte); ok { + if _, ok := dataValue.Interface().([]byte); ok { if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } return "text" } } - panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) + panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", dataValue.Type().Name(), dataValue.Kind().String())) } func (s mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { diff --git a/dialect_mysql.go b/dialect_mysql.go index e334c7a4..acc1f2b7 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -3,6 +3,7 @@ package gorm import ( "fmt" "reflect" + "strconv" "time" ) @@ -14,27 +15,32 @@ func (mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } -func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { +func (mysql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string { + var size int + if num, ok := tagSettings["SIZE"]; ok { + size, _ = strconv.Atoi(num) + } + + switch dataValue.Kind() { case reflect.Bool: return "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "int AUTO_INCREMENT" } return "int" case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "int unsigned AUTO_INCREMENT" } return "int unsigned" case reflect.Int64: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "bigint AUTO_INCREMENT" } return "bigint" case reflect.Uint64: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "bigint unsigned AUTO_INCREMENT" } return "bigint unsigned" @@ -46,18 +52,18 @@ func (mysql) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string } return "longtext" case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { + if _, ok := dataValue.Interface().(time.Time); ok { return "timestamp NULL" } default: - if _, ok := value.Interface().([]byte); ok { + if _, ok := dataValue.Interface().([]byte); ok { if size > 0 && size < 65532 { return fmt.Sprintf("varbinary(%d)", size) } return "longblob" } } - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) + panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) } func (s mysql) currentDatabase(scope *Scope) (name string) { diff --git a/dialect_postgres.go b/dialect_postgres.go index c4742aec..5215ab96 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "fmt" "reflect" + "strconv" "strings" "time" @@ -19,17 +20,22 @@ func (postgres) BindVar(i int) string { return fmt.Sprintf("$%v", i) } -func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { +func (postgres) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string { + var size int + if num, ok := tagSettings["SIZE"]; ok { + size, _ = strconv.Atoi(num) + } + + switch dataValue.Kind() { case reflect.Bool: return "boolean" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "serial" } return "integer" case reflect.Int64, reflect.Uint64: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "bigserial" } return "bigint" @@ -41,21 +47,21 @@ func (postgres) DataTypeOf(value reflect.Value, size int, autoIncrease bool) str } return "text" case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { + if _, ok := dataValue.Interface().(time.Time); ok { return "timestamp with time zone" } case reflect.Map: - if value.Type() == hstoreType { + if dataValue.Type() == hstoreType { return "hstore" } default: - if isByteArrayOrSlice(value) { + if isByteArrayOrSlice(dataValue) { return "bytea" - } else if isUUID(value) { + } else if isUUID(dataValue) { return "uuid" } } - panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) + panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", dataValue.Type().Name(), dataValue.Kind().String())) } func (s postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index e1e35bf7..c838bcc1 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -3,6 +3,7 @@ package gorm import ( "fmt" "reflect" + "strconv" "time" ) @@ -10,17 +11,22 @@ type sqlite3 struct { commonDialect } -func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) string { - switch value.Kind() { +func (sqlite3) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string { + var size int + if num, ok := tagSettings["SIZE"]; ok { + size, _ = strconv.Atoi(num) + } + + switch dataValue.Kind() { case reflect.Bool: return "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "integer primary key autoincrement" } return "integer" case reflect.Int64, reflect.Uint64: - if autoIncrease { + if _, ok := tagSettings["AUTO_INCREMENT"]; ok { return "integer primary key autoincrement" } return "bigint" @@ -32,15 +38,15 @@ func (sqlite3) DataTypeOf(value reflect.Value, size int, autoIncrease bool) stri } return "text" case reflect.Struct: - if _, ok := value.Interface().(time.Time); ok { + if _, ok := dataValue.Interface().(time.Time); ok { return "datetime" } default: - if _, ok := value.Interface().([]byte); ok { + if _, ok := dataValue.Interface().([]byte); ok { return "blob" } } - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) + panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) } func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { diff --git a/model_struct.go b/model_struct.go index c81dcd88..a551d578 100644 --- a/model_struct.go +++ b/model_struct.go @@ -6,7 +6,6 @@ import ( "fmt" "go/ast" "reflect" - "strconv" "strings" "sync" "time" @@ -541,21 +540,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string { } if sqlType == "" { - var size = 255 - - if value, ok := field.TagSettings["SIZE"]; ok { - size, _ = strconv.Atoi(value) - } - - v, autoIncrease := field.TagSettings["AUTO_INCREMENT"] - if field.IsPrimaryKey { - autoIncrease = true - } - if v == "FALSE" { - autoIncrease = false - } - - sqlType = scope.Dialect().DataTypeOf(reflectValue, size, autoIncrease) + sqlType = scope.Dialect().DataTypeOf(reflectValue, field.TagSettings) } if strings.TrimSpace(additionalType) == "" { diff --git a/scope_private.go b/scope_private.go index d5d384af..138bd6fd 100644 --- a/scope_private.go +++ b/scope_private.go @@ -516,11 +516,7 @@ func (scope *Scope) createJoinTable(field *StructField) { for idx, fieldName := range relationship.ForeignFieldNames { if field, ok := scope.Fields()[fieldName]; ok { value := reflect.Indirect(reflect.New(field.Struct.Type)) - primaryKeySqlType := field.TagSettings["TYPE"] - if primaryKeySqlType == "" { - primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false) - } - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType) + sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings)) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) } } @@ -528,11 +524,7 @@ func (scope *Scope) createJoinTable(field *StructField) { for idx, fieldName := range relationship.AssociationForeignFieldNames { if field, ok := toScope.Fields()[fieldName]; ok { value := reflect.Indirect(reflect.New(field.Struct.Type)) - primaryKeySqlType := field.TagSettings["TYPE"] - if primaryKeySqlType == "" { - primaryKeySqlType = scope.Dialect().DataTypeOf(value, 255, false) - } - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType) + sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings)) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) } }