mirror of https://github.com/go-gorm/gorm.git
Refactor DataTypeOf for sqlite
This commit is contained in:
parent
dc435d2225
commit
552d9bf455
42
dialect.go
42
dialect.go
|
@ -1,8 +1,11 @@
|
||||||
package gorm
|
package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Dialect interface contains behaviors that differ across SQL database
|
// Dialect interface contains behaviors that differ across SQL database
|
||||||
|
@ -12,7 +15,7 @@ type Dialect interface {
|
||||||
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
// Quote quotes field name to avoid SQL parsing exceptions by using a reserved word as a field name
|
||||||
Quote(key string) string
|
Quote(key string) string
|
||||||
// DataTypeOf return data's sql type
|
// DataTypeOf return data's sql type
|
||||||
DataTypeOf(value reflect.Value, tagSettings map[string]string) string
|
DataTypeOf(field *StructField) string
|
||||||
|
|
||||||
// HasIndex check has index or not
|
// HasIndex check has index or not
|
||||||
HasIndex(scope *Scope, tableName string, indexName string) bool
|
HasIndex(scope *Scope, tableName string, indexName string) bool
|
||||||
|
@ -48,3 +51,40 @@ func NewDialect(driver string) Dialect {
|
||||||
}
|
}
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParseFieldStructForDialect parse field struct for dialect
|
||||||
|
func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) {
|
||||||
|
// Get redirected field type
|
||||||
|
var reflectType = field.Struct.Type
|
||||||
|
for reflectType.Kind() == reflect.Ptr {
|
||||||
|
reflectType = reflectType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get redirected field value
|
||||||
|
fieldValue = reflect.Indirect(reflect.New(reflectType))
|
||||||
|
|
||||||
|
// Get scanner's real value
|
||||||
|
var getScannerValue func(reflect.Value)
|
||||||
|
getScannerValue = func(value reflect.Value) {
|
||||||
|
fieldValue = value
|
||||||
|
if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct {
|
||||||
|
getScannerValue(fieldValue.Field(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
getScannerValue(fieldValue)
|
||||||
|
|
||||||
|
// Default Size
|
||||||
|
if num, ok := field.TagSettings["SIZE"]; ok {
|
||||||
|
size, _ = strconv.Atoi(num)
|
||||||
|
} else {
|
||||||
|
size = 255
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default type from tag setting
|
||||||
|
additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
||||||
|
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||||
|
additionalType = additionalType + " DEFAULT " + value
|
||||||
|
}
|
||||||
|
|
||||||
|
return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType)
|
||||||
|
}
|
||||||
|
|
|
@ -17,8 +17,13 @@ func (commonDialect) Quote(key string) string {
|
||||||
return fmt.Sprintf(`"%s"`, key)
|
return fmt.Sprintf(`"%s"`, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (commonDialect) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
func (commonDialect) DataTypeOf(field *StructField) string {
|
||||||
var size int
|
var (
|
||||||
|
size int
|
||||||
|
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
|
||||||
|
tagSettings = field.TagSettings
|
||||||
|
)
|
||||||
|
|
||||||
if num, ok := tagSettings["SIZE"]; ok {
|
if num, ok := tagSettings["SIZE"]; ok {
|
||||||
size, _ = strconv.Atoi(num)
|
size, _ = strconv.Atoi(num)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,8 +11,13 @@ type mssql struct {
|
||||||
commonDialect
|
commonDialect
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mssql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
func (mssql) DataTypeOf(field *StructField) string {
|
||||||
var size int
|
var (
|
||||||
|
size int
|
||||||
|
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
|
||||||
|
tagSettings = field.TagSettings
|
||||||
|
)
|
||||||
|
|
||||||
if num, ok := tagSettings["SIZE"]; ok {
|
if num, ok := tagSettings["SIZE"]; ok {
|
||||||
size, _ = strconv.Atoi(num)
|
size, _ = strconv.Atoi(num)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,8 +15,13 @@ func (mysql) Quote(key string) string {
|
||||||
return fmt.Sprintf("`%s`", key)
|
return fmt.Sprintf("`%s`", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mysql) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
func (mysql) DataTypeOf(field *StructField) string {
|
||||||
var size int
|
var (
|
||||||
|
size int
|
||||||
|
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
|
||||||
|
tagSettings = field.TagSettings
|
||||||
|
)
|
||||||
|
|
||||||
if num, ok := tagSettings["SIZE"]; ok {
|
if num, ok := tagSettings["SIZE"]; ok {
|
||||||
size, _ = strconv.Atoi(num)
|
size, _ = strconv.Atoi(num)
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,8 +20,13 @@ func (postgres) BindVar(i int) string {
|
||||||
return fmt.Sprintf("$%v", i)
|
return fmt.Sprintf("$%v", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (postgres) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
func (postgres) DataTypeOf(field *StructField) string {
|
||||||
var size int
|
var (
|
||||||
|
size int
|
||||||
|
dataValue = reflect.Indirect(reflect.New(field.Struct.Type))
|
||||||
|
tagSettings = field.TagSettings
|
||||||
|
)
|
||||||
|
|
||||||
if num, ok := tagSettings["SIZE"]; ok {
|
if num, ok := tagSettings["SIZE"]; ok {
|
||||||
size, _ = strconv.Atoi(num)
|
size, _ = strconv.Atoi(num)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,7 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -11,42 +11,55 @@ type sqlite3 struct {
|
||||||
commonDialect
|
commonDialect
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sqlite3) DataTypeOf(dataValue reflect.Value, tagSettings map[string]string) string {
|
// Get Data Type for Sqlite Dialect
|
||||||
var size int
|
func (sqlite3) DataTypeOf(field *StructField) string {
|
||||||
if num, ok := tagSettings["SIZE"]; ok {
|
var (
|
||||||
size, _ = strconv.Atoi(num)
|
dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field)
|
||||||
|
)
|
||||||
|
|
||||||
|
if sqlType == "" {
|
||||||
|
switch dataValue.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
sqlType = "bool"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
sqlType = "integer primary key autoincrement"
|
||||||
|
} else {
|
||||||
|
sqlType = "integer"
|
||||||
|
}
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
sqlType = "integer primary key autoincrement"
|
||||||
|
} else {
|
||||||
|
sqlType = "bigint"
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
sqlType = "real"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
sqlType = fmt.Sprintf("varchar(%d)", size)
|
||||||
|
} else {
|
||||||
|
sqlType = "text"
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := dataValue.Interface().(time.Time); ok {
|
||||||
|
sqlType = "datetime"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := dataValue.Interface().([]byte); ok {
|
||||||
|
sqlType = "blob"
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch dataValue.Kind() {
|
if sqlType == "" {
|
||||||
case reflect.Bool:
|
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
|
||||||
return "bool"
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
|
||||||
return "integer primary key autoincrement"
|
|
||||||
}
|
|
||||||
return "integer"
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
if _, ok := tagSettings["AUTO_INCREMENT"]; ok {
|
|
||||||
return "integer primary key autoincrement"
|
|
||||||
}
|
|
||||||
return "bigint"
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
return "real"
|
|
||||||
case reflect.String:
|
|
||||||
if size > 0 && size < 65532 {
|
|
||||||
return fmt.Sprintf("varchar(%d)", size)
|
|
||||||
}
|
|
||||||
return "text"
|
|
||||||
case reflect.Struct:
|
|
||||||
if _, ok := dataValue.Interface().(time.Time); ok {
|
|
||||||
return "datetime"
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if _, ok := dataValue.Interface().([]byte); ok {
|
|
||||||
return "blob"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", dataValue.Type().Name(), dataValue.Kind().String()))
|
|
||||||
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
return sqlType
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
func (s sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||||
|
|
|
@ -3,7 +3,6 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"go/ast"
|
"go/ast"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -511,44 +510,6 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
|
||||||
return scope.GetModelStruct().StructFields
|
return scope.GetModelStruct().StructFields
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) generateSqlTag(field *StructField) string {
|
|
||||||
var sqlType string
|
|
||||||
structType := field.Struct.Type
|
|
||||||
if structType.Kind() == reflect.Ptr {
|
|
||||||
structType = structType.Elem()
|
|
||||||
}
|
|
||||||
reflectValue := reflect.Indirect(reflect.New(structType))
|
|
||||||
|
|
||||||
if value, ok := field.TagSettings["TYPE"]; ok {
|
|
||||||
sqlType = value
|
|
||||||
}
|
|
||||||
|
|
||||||
additionalType := field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"]
|
|
||||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
|
||||||
additionalType = additionalType + " DEFAULT " + value
|
|
||||||
}
|
|
||||||
|
|
||||||
if field.IsScanner {
|
|
||||||
var getScannerValue func(reflect.Value)
|
|
||||||
getScannerValue = func(value reflect.Value) {
|
|
||||||
reflectValue = value
|
|
||||||
if _, isScanner := reflect.New(reflectValue.Type()).Interface().(sql.Scanner); isScanner && reflectValue.Kind() == reflect.Struct {
|
|
||||||
getScannerValue(reflectValue.Field(0))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
getScannerValue(reflectValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
if sqlType == "" {
|
|
||||||
sqlType = scope.Dialect().DataTypeOf(reflectValue, field.TagSettings)
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.TrimSpace(additionalType) == "" {
|
|
||||||
return sqlType
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
func parseTagSetting(tags reflect.StructTag) map[string]string {
|
||||||
setting := map[string]string{}
|
setting := map[string]string{}
|
||||||
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} {
|
||||||
|
|
|
@ -511,7 +511,7 @@ func (scope *Scope) getTableOptions() string {
|
||||||
return tableOptions.(string)
|
return tableOptions.(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) createJoinTable(field *StructField) {
|
func (scope *Scope) createJoinTable(field *Field) {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||||
joinTableHandler := relationship.JoinTableHandler
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
joinTable := joinTableHandler.Table(scope.db)
|
joinTable := joinTableHandler.Table(scope.db)
|
||||||
|
@ -521,16 +521,20 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
var sqlTypes, primaryKeys []string
|
var sqlTypes, primaryKeys []string
|
||||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||||
if field, ok := scope.Fields()[fieldName]; ok {
|
if field, ok := scope.Fields()[fieldName]; ok {
|
||||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
foreignKeyStruct := field.StructField.clone()
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
|
foreignKeyStruct.IsPrimaryKey = false
|
||||||
|
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||||
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
||||||
if field, ok := toScope.Fields()[fieldName]; ok {
|
if field, ok := toScope.Fields()[fieldName]; ok {
|
||||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
foreignKeyStruct := field.StructField.clone()
|
||||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(value, field.TagSettings))
|
foreignKeyStruct.IsPrimaryKey = false
|
||||||
|
foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true"
|
||||||
|
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct))
|
||||||
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -545,9 +549,9 @@ func (scope *Scope) createTable() *Scope {
|
||||||
var tags []string
|
var tags []string
|
||||||
var primaryKeys []string
|
var primaryKeys []string
|
||||||
var primaryKeyInColumnType = false
|
var primaryKeyInColumnType = false
|
||||||
for _, field := range scope.GetStructFields() {
|
for _, field := range scope.Fields() {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
sqlTag := scope.generateSqlTag(field)
|
sqlTag := scope.Dialect().DataTypeOf(field.StructField)
|
||||||
|
|
||||||
// Check if the primary key constraint was specified as
|
// Check if the primary key constraint was specified as
|
||||||
// part of the column type. If so, we can only support
|
// part of the column type. If so, we can only support
|
||||||
|
@ -632,10 +636,10 @@ func (scope *Scope) autoMigrate() *Scope {
|
||||||
if !scope.Dialect().HasTable(scope, tableName) {
|
if !scope.Dialect().HasTable(scope, tableName) {
|
||||||
scope.createTable()
|
scope.createTable()
|
||||||
} else {
|
} else {
|
||||||
for _, field := range scope.GetStructFields() {
|
for _, field := range scope.Fields() {
|
||||||
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
sqlTag := scope.generateSqlTag(field)
|
sqlTag := scope.Dialect().DataTypeOf(field.StructField)
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue