forked from mirror/gorm
Refactor DataTypeOf for sqlite
This commit is contained in:
parent
dc435d2225
commit
552d9bf455
42
dialect.go
42
dialect.go
|
@ -1,8 +1,11 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Dialect interface contains behaviors that differ across SQL database
|
||||
|
@ -12,7 +15,7 @@ type Dialect interface {
|
|||
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||
Quote(key string) string
|
||||
// DataTypeOf return data's sql type
|
||||
DataTypeOf(value reflect.Value, tagSettings map[string]string) string
|
||||
DataTypeOf(field *StructField) string
|
||||
|
||||
// HasIndex check has index or not
|
||||
HasIndex(scope *Scope, tableName string, indexName string) bool
|
||||
|
@ -48,3 +51,40 @@ func NewDialect(driver string) Dialect {
|
|||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// ParseFieldStructForDialect parse field struct for dialect
|
||||
func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
||||
// Get redirected field type
|
||||
var reflectType = field.Struct.Type
|
||||
for reflectType.Kind() == reflect.Ptr {
|
||||
reflectType = reflectType.Elem()
|
||||
}
|
||||
|
||||
// Get redirected field value
|
||||
fieldValue = reflect.Indirect(reflect.New(reflectType))
|
||||
|
||||
// Get scanner's real value
|
||||
var getScannerValue func(reflect.Value)
|
||||
getScannerValue = func(value reflect.Value) {
|
||||
fieldValue = value
|
||||
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
|
||||
getScannerValue(fieldValue.Field(0))
|
||||
}
|
||||
}
|
||||
getScannerValue(fieldValue)
|
||||
|
||||
// Default Size
|
||||
if num, ok := field.TagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
} else {
|
||||
size = 255
|
||||
}
|
||||
|
||||
// Default type from tag setting
|
||||
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
additionalType = additionalType + " DEFAULT " + value
|
||||
}
|
||||
|
||||
return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType)
|
||||
}
|
||||
|
|
|
@ -17,8 +17,13 @@ func (commonDialect) Quote(key string) string {
|
|||
return fmt.Sprintf(`"%s"`, key)
|
||||
}
|
||||
|
||||
func (commonDialect) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
||||
var size int
|
||||
func (commonDialect) DataTypeOf(field *StructField) string {
|
||||
var (
|
||||
size int
|
||||
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
tagSettings = field.TagSettings
|
||||
)
|
||||
|
||||
if num, ok := tagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
}
|
||||
|
|
|
@ -11,8 +11,13 @@ type mssql struct {
|
|||
commonDialect
|
||||
}
|
||||
|
||||
func (mssql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
||||
var size int
|
||||
func (mssql) DataTypeOf(field *StructField) string {
|
||||
var (
|
||||
size int
|
||||
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
tagSettings = field.TagSettings
|
||||
)
|
||||
|
||||
if num, ok := tagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
}
|
||||
|
|
|
@ -15,8 +15,13 @@ func (mysql) Quote(key string) string {
|
|||
return fmt.Sprintf("`%s`", key)
|
||||
}
|
||||
|
||||
func (mysql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
||||
var size int
|
||||
func (mysql) DataTypeOf(field *StructField) string {
|
||||
var (
|
||||
size int
|
||||
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
tagSettings = field.TagSettings
|
||||
)
|
||||
|
||||
if num, ok := tagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
}
|
||||
|
|
|
@ -20,8 +20,13 @@ func (postgres) BindVar(i int) string {
|
|||
return fmt.Sprintf("$%v", i)
|
||||
}
|
||||
|
||||
func (postgres) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
||||
var size int
|
||||
func (postgres) DataTypeOf(field *StructField) string {
|
||||
var (
|
||||
size int
|
||||
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
tagSettings = field.TagSettings
|
||||
)
|
||||
|
||||
if num, ok := tagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package gorm
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -11,44 +11,57 @@ type sqlite3 struct {
|
|||
commonDialect
|
||||
}
|
||||
|
||||
func (sqlite3) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
||||
var size int
|
||||
if num, ok := tagSettings["SIZE"]; ok {
|
||||
size, _ = strconv.Atoi(num)
|
||||
}
|
||||
// Get Data Type for Sqlite Dialect
|
||||
func (sqlite3) DataTypeOf(field *StructField) string {
|
||||
var (
|
||||
dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||
)
|
||||
|
||||
if sqlType == "" {
|
||||
switch dataValue.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bool"
|
||||
sqlType = "bool"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "integer primary key autoincrement"
|
||||
if field.IsPrimaryKey {
|
||||
sqlType = "integer primary key autoincrement"
|
||||
} else {
|
||||
sqlType = "integer"
|
||||
}
|
||||
return "integer"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
||||
return "integer primary key autoincrement"
|
||||
if field.IsPrimaryKey {
|
||||
sqlType = "integer primary key autoincrement"
|
||||
} else {
|
||||
sqlType = "bigint"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "real"
|
||||
sqlType = "real"
|
||||
case reflect.String:
|
||||
if size > 0 && size < 65532 {
|
||||
return fmt.Sprintf("varchar(%d)", size)
|
||||
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||
} else {
|
||||
sqlType = "text"
|
||||
}
|
||||
return "text"
|
||||
case reflect.Struct:
|
||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||
return "datetime"
|
||||
sqlType = "datetime"
|
||||
}
|
||||
default:
|
||||
if _, ok := dataValue.Interface().([]byte); ok {
|
||||
return "blob"
|
||||
sqlType = "blob"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||
var count int
|
||||
s.RawScanInt(scope, &count, fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName)
|
||||
|
|
|
@ -3,7 +3,6 @@ package gorm
|
|||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -511,44 +510,6 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
|
|||
return scope.GetModelStruct().StructFields
|
||||
}
|
||||
|
||||
func (scope *Scope) generateSqlTag(field *StructField) string {
|
||||
var sqlType string
|
||||
structType := field.Struct.Type
|
||||
if structType.Kind() == reflect.Ptr {
|
||||
structType = structType.Elem()
|
||||
}
|
||||
reflectValue := reflect.Indirect(reflect.New(structType))
|
||||
|
||||
if value, ok := field.TagSettings["TYPE"]; ok {
|
||||
sqlType = value
|
||||
}
|
||||
|
||||
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
additionalType = additionalType + " DEFAULT " + value
|
||||
}
|
||||
|
||||
if field.IsScanner {
|
||||
var getScannerValue func(reflect.Value)
|
||||
getScannerValue = func(value reflect.Value) {
|
||||
reflectValue = value
|
||||
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
|
||||
getScannerValue(reflectValue.Field(0))
|
||||
}
|
||||
}
|
||||
getScannerValue(reflectValue)
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
sqlType = scope.Dialect().DataTypeOf(reflectValue, field.TagSettings)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
return sqlType
|
||||
}
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
|
||||
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
||||
setting := map[string]string{}
|
||||
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
||||
|
|
|
@ -511,7 +511,7 @@ func (scope *Scope) getTableOptions() string {
|
|||
return tableOptions.(string)
|
||||
}
|
||||
|
||||
func (scope *Scope) createJoinTable(field *StructField) {
|
||||
func (scope *Scope) createJoinTable(field *Field) {
|
||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
joinTable := joinTableHandler.Table(scope.db)
|
||||
|
@ -521,16 +521,20 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||
var sqlTypes, primaryKeys []string
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
if field, ok := scope.Fields()[fieldName]; ok {
|
||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
|
||||
foreignKeyStruct := field.StructField.clone()
|
||||
foreignKeyStruct.IsPrimaryKey = false
|
||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||
}
|
||||
}
|
||||
|
||||
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := toScope.Fields()[fieldName]; ok {
|
||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
|
||||
foreignKeyStruct := field.StructField.clone()
|
||||
foreignKeyStruct.IsPrimaryKey = false
|
||||
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||
}
|
||||
}
|
||||
|
@ -545,9 +549,9 @@ func (scope *Scope) createTable() *Scope {
|
|||
var tags []string
|
||||
var primaryKeys []string
|
||||
var primaryKeyInColumnType = false
|
||||
for _, field := range scope.GetStructFields() {
|
||||
for _, field := range scope.Fields() {
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.generateSqlTag(field)
|
||||
sqlTag := scope.Dialect().DataTypeOf(field.StructField)
|
||||
|
||||
// Check if the primary key constraint was specified as
|
||||
// part of the column type. If so, we can only support
|
||||
|
@ -632,10 +636,10 @@ func (scope *Scope) autoMigrate() *Scope {
|
|||
if !scope.Dialect().HasTable(scope, tableName) {
|
||||
scope.createTable()
|
||||
} else {
|
||||
for _, field := range scope.GetStructFields() {
|
||||
for _, field := range scope.Fields() {
|
||||
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.generateSqlTag(field)
|
||||
sqlTag := scope.Dialect().DataTypeOf(field.StructField)
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue