Refactor DataTypeOf for sqlite

This commit is contained in:
Jinzhu 2016-02-13 23:51:36 +08:00
parent dc435d2225
commit 552d9bf455
8 changed files with 129 additions and 91 deletions

View File

@ -1,8 +1,11 @@
package gorm
import (
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"
)
// Dialect interface contains behaviors that differ across SQL database
@ -12,7 +15,7 @@ type Dialect interface {
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
Quote(key string) string
// DataTypeOf return data's sql type
DataTypeOf(value reflect.Value, tagSettings map[string]string) string
DataTypeOf(field *StructField) string
// HasIndex check has index or not
HasIndex(scope *Scope, tableName string, indexName string) bool
@ -48,3 +51,40 @@ func NewDialect(driver string) Dialect {
}
return d
}
// ParseFieldStructForDialect parse field struct for dialect
func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
// Get redirected field type
var reflectType = field.Struct.Type
for reflectType.Kind() == reflect.Ptr {
reflectType = reflectType.Elem()
}
// Get redirected field value
fieldValue = reflect.Indirect(reflect.New(reflectType))
// Get scanner's real value
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
fieldValue = value
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
getScannerValue(fieldValue.Field(0))
}
}
getScannerValue(fieldValue)
// Default Size
if num, ok := field.TagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
} else {
size = 255
}
// Default type from tag setting
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = additionalType + " DEFAULT " + value
}
return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType)
}

View File

@ -17,8 +17,13 @@ func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (commonDialect) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
func (commonDialect) DataTypeOf(field *StructField) string {
var (
size int
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
tagSettings = field.TagSettings
)
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}

View File

@ -11,8 +11,13 @@ type mssql struct {
commonDialect
}
func (mssql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
func (mssql) DataTypeOf(field *StructField) string {
var (
size int
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
tagSettings = field.TagSettings
)
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}

View File

@ -15,8 +15,13 @@ func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
func (mysql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
func (mysql) DataTypeOf(field *StructField) string {
var (
size int
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
tagSettings = field.TagSettings
)
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}

View File

@ -20,8 +20,13 @@ func (postgres) BindVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (postgres) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
func (postgres) DataTypeOf(field *StructField) string {
var (
size int
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
tagSettings = field.TagSettings
)
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}

View File

@ -3,7 +3,7 @@ package gorm
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
@ -11,44 +11,57 @@ type sqlite3 struct {
commonDialect
}
func (sqlite3) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
var size int
if num, ok := tagSettings["SIZE"]; ok {
size, _ = strconv.Atoi(num)
}
// Get Data Type for Sqlite Dialect
func (sqlite3) DataTypeOf(field *StructField) string {
var (
dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
return "bool"
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "integer primary key autoincrement"
if field.IsPrimaryKey {
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
return "integer primary key autoincrement"
if field.IsPrimaryKey {
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
sqlType = "real"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
return "text"
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
return "datetime"
sqlType = "datetime"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
return "blob"
sqlType = "blob"
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)

View File

@ -3,7 +3,6 @@ package gorm
import (
"database/sql"
"errors"
"fmt"
"go/ast"
"reflect"
"strings"
@ -511,44 +510,6 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
return scope.GetModelStruct().StructFields
}
func (scope *Scope) generateSqlTag(field *StructField) string {
var sqlType string
structType := field.Struct.Type
if structType.Kind() == reflect.Ptr {
structType = structType.Elem()
}
reflectValue := reflect.Indirect(reflect.New(structType))
if value, ok := field.TagSettings["TYPE"]; ok {
sqlType = value
}
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = additionalType + " DEFAULT " + value
}
if field.IsScanner {
var getScannerValue func(reflect.Value)
getScannerValue = func(value reflect.Value) {
reflectValue = value
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
getScannerValue(reflectValue.Field(0))
}
}
getScannerValue(reflectValue)
}
if sqlType == "" {
sqlType = scope.Dialect().DataTypeOf(reflectValue, field.TagSettings)
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func parseTagSetting(tags reflect.StructTag) map[string]string {
setting := map[string]string{}
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {

View File

@ -511,7 +511,7 @@ func (scope *Scope) getTableOptions() string {
return tableOptions.(string)
}
func (scope *Scope) createJoinTable(field *StructField) {
func (scope *Scope) createJoinTable(field *Field) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
@ -521,16 +521,20 @@ func (scope *Scope) createJoinTable(field *StructField) {
var sqlTypes, primaryKeys []string
for idx, fieldName := range relationship.ForeignFieldNames {
if field, ok := scope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type))
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
foreignKeyStruct := field.StructField.clone()
foreignKeyStruct.IsPrimaryKey = false
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
}
}
for idx, fieldName := range relationship.AssociationForeignFieldNames {
if field, ok := toScope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type))
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
foreignKeyStruct := field.StructField.clone()
foreignKeyStruct.IsPrimaryKey = false
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
}
}
@ -545,9 +549,9 @@ func (scope *Scope) createTable() *Scope {
var tags []string
var primaryKeys []string
var primaryKeyInColumnType = false
for _, field := range scope.GetStructFields() {
for _, field := range scope.Fields() {
if field.IsNormal {
sqlTag := scope.generateSqlTag(field)
sqlTag := scope.Dialect().DataTypeOf(field.StructField)
// Check if the primary key constraint was specified as
// part of the column type. If so, we can only support
@ -632,10 +636,10 @@ func (scope *Scope) autoMigrate() *Scope {
if !scope.Dialect().HasTable(scope, tableName) {
scope.createTable()
} else {
for _, field := range scope.GetStructFields() {
for _, field := range scope.Fields() {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.generateSqlTag(field)
sqlTag := scope.Dialect().DataTypeOf(field.StructField)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
}
}