From 552d9bf4550d378e6b62f85881191173fd0a0f61 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 13 Feb 2016 23:51:36 +0800 Subject: [PATCH] Refactor DataTypeOf for sqlite --- dialect.go | 42 ++++++++++++++++++++++- dialect_common.go | 9 +++-- dialect_mssql.go | 9 +++-- dialect_mysql.go | 9 +++-- dialect_postgres.go | 9 +++-- dialect_sqlite3.go | 81 ++++++++++++++++++++++++++------------------- model_struct.go | 39 ---------------------- scope_private.go | 22 +++++++----- 8 files changed, 129 insertions(+), 91 deletions(-) diff --git a/dialect.go b/dialect.go index dd3c032e..61220a42 100644 --- a/dialect.go +++ b/dialect.go @@ -1,8 +1,11 @@ package gorm import ( + "database/sql" "fmt" "reflect" + "strconv" + "strings" ) // Dialect interface contains behaviors that differ across SQL database @@ -12,7 +15,7 @@ type Dialect interface { // Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name Quote(key string) string // DataTypeOf return data's sql type - DataTypeOf(value reflect.Value, tagSettings map[string]string) string + DataTypeOf(field *StructField) string // HasIndex check has index or not HasIndex(scope *Scope, tableName string, indexName string) bool @@ -48,3 +51,40 @@ func NewDialect(driver string) Dialect { } return d } + +// ParseFieldStructForDialect parse field struct for dialect +func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { + // Get redirected field type + var reflectType = field.Struct.Type + for reflectType.Kind() == reflect.Ptr { + reflectType = reflectType.Elem() + } + + // Get redirected field value + fieldValue = reflect.Indirect(reflect.New(reflectType)) + + // Get scanner's real value + var getScannerValue func(reflect.Value) + getScannerValue = func(value reflect.Value) { + fieldValue = value + if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { + getScannerValue(fieldValue.Field(0)) + } + } + getScannerValue(fieldValue) + + // Default Size + if num, ok := field.TagSettings["SIZE"]; ok { + size, _ = strconv.Atoi(num) + } else { + size = 255 + } + + // Default type from tag setting + additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] + if value, ok := field.TagSettings["DEFAULT"]; ok { + additionalType = additionalType + " DEFAULT " + value + } + + return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType) +} diff --git a/dialect_common.go b/dialect_common.go index fc717e17..f95f3903 100644 --- a/dialect_common.go +++ b/dialect_common.go @@ -17,8 +17,13 @@ func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } -func (commonDialect) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string { - var size int +func (commonDialect) DataTypeOf(field *StructField) string { + var ( + size int + dataValue = reflect.Indirect(reflect.New(field.Struct.Type)) + tagSettings = field.TagSettings + ) + if num, ok := tagSettings["SIZE"]; ok { size, _ = strconv.Atoi(num) } diff --git a/dialect_mssql.go b/dialect_mssql.go index d130badb..aa3f7f5d 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -11,8 +11,13 @@ type mssql struct { commonDialect } -func (mssql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string { - var size int +func (mssql) DataTypeOf(field *StructField) string { + var ( + size int + dataValue = reflect.Indirect(reflect.New(field.Struct.Type)) + tagSettings = field.TagSettings + ) + if num, ok := tagSettings["SIZE"]; ok { size, _ = strconv.Atoi(num) } diff --git a/dialect_mysql.go b/dialect_mysql.go index acc1f2b7..15849abc 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -15,8 +15,13 @@ func (mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } -func (mysql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string { - var size int +func (mysql) DataTypeOf(field *StructField) string { + var ( + size int + dataValue = reflect.Indirect(reflect.New(field.Struct.Type)) + tagSettings = field.TagSettings + ) + if num, ok := tagSettings["SIZE"]; ok { size, _ = strconv.Atoi(num) } diff --git a/dialect_postgres.go b/dialect_postgres.go index 5215ab96..e49df3d2 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -20,8 +20,13 @@ func (postgres) BindVar(i int) string { return fmt.Sprintf("$%v", i) } -func (postgres) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string { - var size int +func (postgres) DataTypeOf(field *StructField) string { + var ( + size int + dataValue = reflect.Indirect(reflect.New(field.Struct.Type)) + tagSettings = field.TagSettings + ) + if num, ok := tagSettings["SIZE"]; ok { size, _ = strconv.Atoi(num) } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index c838bcc1..0bf2aa8c 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -3,7 +3,7 @@ package gorm import ( "fmt" "reflect" - "strconv" + "strings" "time" ) @@ -11,42 +11,55 @@ type sqlite3 struct { commonDialect } -func (sqlite3) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string { - var size int - if num, ok := tagSettings["SIZE"]; ok { - size, _ = strconv.Atoi(num) +// Get Data Type for Sqlite Dialect +func (sqlite3) DataTypeOf(field *StructField) string { + var ( + dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) + ) + + if sqlType == "" { + switch dataValue.Kind() { + case reflect.Bool: + sqlType = "bool" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if field.IsPrimaryKey { + sqlType = "integer primary key autoincrement" + } else { + sqlType = "integer" + } + case reflect.Int64, reflect.Uint64: + if field.IsPrimaryKey { + sqlType = "integer primary key autoincrement" + } else { + sqlType = "bigint" + } + case reflect.Float32, reflect.Float64: + sqlType = "real" + case reflect.String: + if size > 0 && size < 65532 { + sqlType = fmt.Sprintf("varchar(%d)", size) + } else { + sqlType = "text" + } + case reflect.Struct: + if _, ok := dataValue.Interface().(time.Time); ok { + sqlType = "datetime" + } + default: + if _, ok := dataValue.Interface().([]byte); ok { + sqlType = "blob" + } + } } - switch dataValue.Kind() { - case reflect.Bool: - return "bool" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := tagSettings["AUTO_INCREMENT"]; ok { - return "integer primary key autoincrement" - } - return "integer" - case reflect.Int64, reflect.Uint64: - if _, ok := tagSettings["AUTO_INCREMENT"]; ok { - return "integer primary key autoincrement" - } - return "bigint" - case reflect.Float32, reflect.Float64: - return "real" - case reflect.String: - if size > 0 && size < 65532 { - return fmt.Sprintf("varchar(%d)", size) - } - return "text" - case reflect.Struct: - if _, ok := dataValue.Interface().(time.Time); ok { - return "datetime" - } - default: - if _, ok := dataValue.Interface().([]byte); ok { - return "blob" - } + if sqlType == "" { + panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) } - panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String())) + + if strings.TrimSpace(additionalType) == "" { + return sqlType + } + return fmt.Sprintf("%v %v", sqlType, additionalType) } func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { diff --git a/model_struct.go b/model_struct.go index a551d578..a17d2257 100644 --- a/model_struct.go +++ b/model_struct.go @@ -3,7 +3,6 @@ package gorm import ( "database/sql" "errors" - "fmt" "go/ast" "reflect" "strings" @@ -511,44 +510,6 @@ func (scope *Scope) GetStructFields() (fields []*StructField) { return scope.GetModelStruct().StructFields } -func (scope *Scope) generateSqlTag(field *StructField) string { - var sqlType string - structType := field.Struct.Type - if structType.Kind() == reflect.Ptr { - structType = structType.Elem() - } - reflectValue := reflect.Indirect(reflect.New(structType)) - - if value, ok := field.TagSettings["TYPE"]; ok { - sqlType = value - } - - additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { - additionalType = additionalType + " DEFAULT " + value - } - - if field.IsScanner { - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - reflectValue = value - if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct { - getScannerValue(reflectValue.Field(0)) - } - } - getScannerValue(reflectValue) - } - - if sqlType == "" { - sqlType = scope.Dialect().DataTypeOf(reflectValue, field.TagSettings) - } - - if strings.TrimSpace(additionalType) == "" { - return sqlType - } - return fmt.Sprintf("%v %v", sqlType, additionalType) -} - func parseTagSetting(tags reflect.StructTag) map[string]string { setting := map[string]string{} for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { diff --git a/scope_private.go b/scope_private.go index 4fd48833..d8dd9b93 100644 --- a/scope_private.go +++ b/scope_private.go @@ -511,7 +511,7 @@ func (scope *Scope) getTableOptions() string { return tableOptions.(string) } -func (scope *Scope) createJoinTable(field *StructField) { +func (scope *Scope) createJoinTable(field *Field) { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { joinTableHandler := relationship.JoinTableHandler joinTable := joinTableHandler.Table(scope.db) @@ -521,16 +521,20 @@ func (scope *Scope) createJoinTable(field *StructField) { var sqlTypes, primaryKeys []string for idx, fieldName := range relationship.ForeignFieldNames { if field, ok := scope.Fields()[fieldName]; ok { - value := reflect.Indirect(reflect.New(field.Struct.Type)) - sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings)) + foreignKeyStruct := field.StructField.clone() + foreignKeyStruct.IsPrimaryKey = false + foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" + sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) } } for idx, fieldName := range relationship.AssociationForeignFieldNames { if field, ok := toScope.Fields()[fieldName]; ok { - value := reflect.Indirect(reflect.New(field.Struct.Type)) - sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings)) + foreignKeyStruct := field.StructField.clone() + foreignKeyStruct.IsPrimaryKey = false + foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" + sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) } } @@ -545,9 +549,9 @@ func (scope *Scope) createTable() *Scope { var tags []string var primaryKeys []string var primaryKeyInColumnType = false - for _, field := range scope.GetStructFields() { + for _, field := range scope.Fields() { if field.IsNormal { - sqlTag := scope.generateSqlTag(field) + sqlTag := scope.Dialect().DataTypeOf(field.StructField) // Check if the primary key constraint was specified as // part of the column type. If so, we can only support @@ -632,10 +636,10 @@ func (scope *Scope) autoMigrate() *Scope { if !scope.Dialect().HasTable(scope, tableName) { scope.createTable() } else { - for _, field := range scope.GetStructFields() { + for _, field := range scope.Fields() { if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { if field.IsNormal { - sqlTag := scope.generateSqlTag(field) + sqlTag := scope.Dialect().DataTypeOf(field.StructField) scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() } }