forked from mirror/gorm
Refactor SQL Tag
This commit is contained in:
parent
49454839bd
commit
a549b6bd49
|
@ -36,6 +36,7 @@ type User struct {
|
||||||
Birthday time.Time
|
Birthday time.Time
|
||||||
Age int
|
Age int
|
||||||
Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag
|
Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag
|
||||||
|
Num int `sql:"AUTO_INCREMENT"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
DeletedAt time.Time
|
DeletedAt time.Time
|
||||||
|
|
|
@ -21,13 +21,19 @@ func (s *commonDialect) HasTop() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
|
func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "BOOLEAN"
|
return "BOOLEAN"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if autoIncrease {
|
||||||
|
return "INTEGER AUTO_INCREMENT"
|
||||||
|
}
|
||||||
return "INTEGER"
|
return "INTEGER"
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if autoIncrease {
|
||||||
|
return "BIGINT AUTO_INCREMENT"
|
||||||
|
}
|
||||||
return "BIGINT"
|
return "BIGINT"
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "FLOAT"
|
return "FLOAT"
|
||||||
|
@ -51,18 +57,6 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string {
|
|
||||||
suffix := " NOT NULL PRIMARY KEY"
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
return "INTEGER" + suffix
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
return "BIGINT" + suffix
|
|
||||||
default:
|
|
||||||
panic("Invalid primary key type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *commonDialect) ReturningStr(tableName, key string) string {
|
func (s *commonDialect) ReturningStr(tableName, key string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,8 +9,7 @@ type Dialect interface {
|
||||||
BinVar(i int) string
|
BinVar(i int) string
|
||||||
SupportLastInsertId() bool
|
SupportLastInsertId() bool
|
||||||
HasTop() bool
|
HasTop() bool
|
||||||
SqlTag(value reflect.Value, size int) string
|
SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
||||||
PrimaryKeyTag(value reflect.Value, size int) string
|
|
||||||
ReturningStr(tableName, key string) string
|
ReturningStr(tableName, key string) string
|
||||||
SelectFromDummyTable() string
|
SelectFromDummyTable() string
|
||||||
Quote(key string) string
|
Quote(key string) string
|
||||||
|
|
|
@ -27,7 +27,6 @@ type StructField struct {
|
||||||
IsIgnored bool
|
IsIgnored bool
|
||||||
IsScanner bool
|
IsScanner bool
|
||||||
HasDefaultValue bool
|
HasDefaultValue bool
|
||||||
SqlTag string
|
|
||||||
Tag reflect.StructTag
|
Tag reflect.StructTag
|
||||||
Struct reflect.StructField
|
Struct reflect.StructField
|
||||||
IsForeignKey bool
|
IsForeignKey bool
|
||||||
|
@ -44,7 +43,6 @@ func (structField *StructField) clone() *StructField {
|
||||||
IsIgnored: structField.IsIgnored,
|
IsIgnored: structField.IsIgnored,
|
||||||
IsScanner: structField.IsScanner,
|
IsScanner: structField.IsScanner,
|
||||||
HasDefaultValue: structField.HasDefaultValue,
|
HasDefaultValue: structField.HasDefaultValue,
|
||||||
SqlTag: structField.SqlTag,
|
|
||||||
Tag: structField.Tag,
|
Tag: structField.Tag,
|
||||||
Struct: structField.Struct,
|
Struct: structField.Struct,
|
||||||
IsForeignKey: structField.IsForeignKey,
|
IsForeignKey: structField.IsForeignKey,
|
||||||
|
@ -281,10 +279,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||||
field.IsPrimaryKey = true
|
field.IsPrimaryKey = true
|
||||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scope.db != nil {
|
|
||||||
scope.generateSqlTag(field)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||||
|
@ -301,7 +295,7 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
|
||||||
return scope.GetModelStruct().StructFields
|
return scope.GetModelStruct().StructFields
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) generateSqlTag(field *StructField) {
|
func (scope *Scope) generateSqlTag(field *StructField) string {
|
||||||
var sqlType string
|
var sqlType string
|
||||||
structType := field.Struct.Type
|
structType := field.Struct.Type
|
||||||
if structType.Kind() == reflect.Ptr {
|
if structType.Kind() == reflect.Ptr {
|
||||||
|
@ -337,17 +331,18 @@ func (scope *Scope) generateSqlTag(field *StructField) {
|
||||||
size, _ = strconv.Atoi(value)
|
size, _ = strconv.Atoi(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, autoIncrease := sqlSettings["AUTO_INCREMENT"]
|
||||||
if field.IsPrimaryKey {
|
if field.IsPrimaryKey {
|
||||||
sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size)
|
autoIncrease = true
|
||||||
} else {
|
|
||||||
sqlType = scope.Dialect().SqlTag(reflectValue, size)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.TrimSpace(additionalType) == "" {
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
field.SqlTag = sqlType
|
return sqlType
|
||||||
} else {
|
} else {
|
||||||
field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType)
|
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
20
mssql.go
20
mssql.go
|
@ -21,13 +21,19 @@ func (s *mssql) HasTop() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) SqlTag(value reflect.Value, size int) string {
|
func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "bit"
|
return "bit"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if autoIncrease {
|
||||||
|
return "int IDENTITY(1,1)"
|
||||||
|
}
|
||||||
return "int"
|
return "int"
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if autoIncrease {
|
||||||
|
return "bigint IDENTITY(1,1)"
|
||||||
|
}
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "float"
|
return "float"
|
||||||
|
@ -51,18 +57,6 @@ func (s *mssql) SqlTag(value reflect.Value, size int) string {
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string {
|
|
||||||
suffix := " IDENTITY(1,1) PRIMARY KEY"
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
return "int" + suffix
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
return "bigint" + suffix
|
|
||||||
default:
|
|
||||||
panic("Invalid primary key type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mssql) ReturningStr(tableName, key string) string {
|
func (s *mssql) ReturningStr(tableName, key string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
20
mysql.go
20
mysql.go
|
@ -21,13 +21,19 @@ func (s *mysql) HasTop() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mysql) SqlTag(value reflect.Value, size int) string {
|
func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if autoIncrease {
|
||||||
|
return "int AUTO_INCREMENT"
|
||||||
|
}
|
||||||
return "int"
|
return "int"
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if autoIncrease {
|
||||||
|
return "bigint AUTO_INCREMENT"
|
||||||
|
}
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "double"
|
return "double"
|
||||||
|
@ -51,18 +57,6 @@ func (s *mysql) SqlTag(value reflect.Value, size int) string {
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
|
|
||||||
suffix := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
return "int" + suffix
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
return "bigint" + suffix
|
|
||||||
default:
|
|
||||||
panic("Invalid primary key type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mysql) ReturningStr(tableName, key string) string {
|
func (s *mysql) ReturningStr(tableName, key string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
19
postgres.go
19
postgres.go
|
@ -25,13 +25,19 @@ func (s *postgres) HasTop() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) SqlTag(value reflect.Value, size int) string {
|
func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if autoIncrease {
|
||||||
|
return "serial"
|
||||||
|
}
|
||||||
return "integer"
|
return "integer"
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if autoIncrease {
|
||||||
|
return "bigserial"
|
||||||
|
}
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "numeric"
|
return "numeric"
|
||||||
|
@ -56,17 +62,6 @@ func (s *postgres) SqlTag(value reflect.Value, size int) string {
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
|
||||||
return "serial PRIMARY KEY"
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
return "bigserial PRIMARY KEY"
|
|
||||||
default:
|
|
||||||
panic("Invalid primary key type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *postgres) ReturningStr(tableName, key string) string {
|
func (s *postgres) ReturningStr(tableName, key string) string {
|
||||||
return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key)
|
return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key)
|
||||||
}
|
}
|
||||||
|
|
|
@ -447,7 +447,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
||||||
joinTable := joinTableHandler.Table(scope.db, relationship)
|
joinTable := joinTableHandler.Table(scope.db, relationship)
|
||||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
if !scope.Dialect().HasTable(scope, joinTable) {
|
||||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255)
|
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||||
scope.Quote(joinTable),
|
scope.Quote(joinTable),
|
||||||
strings.Join([]string{
|
strings.Join([]string{
|
||||||
|
@ -460,14 +460,25 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) createTable() *Scope {
|
func (scope *Scope) createTable() *Scope {
|
||||||
var sqls []string
|
var tags []string
|
||||||
for _, structField := range scope.GetStructFields() {
|
var primaryKeys []string
|
||||||
if structField.IsNormal {
|
for _, field := range scope.GetStructFields() {
|
||||||
sqls = append(sqls, scope.Quote(structField.DBName)+" "+structField.SqlTag)
|
if field.IsNormal {
|
||||||
|
sqlTag := scope.generateSqlTag(field)
|
||||||
|
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
|
||||||
}
|
}
|
||||||
scope.createJoinTable(structField)
|
|
||||||
|
if field.IsPrimaryKey {
|
||||||
|
primaryKeys = append(primaryKeys, field.DBName)
|
||||||
}
|
}
|
||||||
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec()
|
scope.createJoinTable(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
var primaryKeyStr string
|
||||||
|
if len(primaryKeys) > 0 {
|
||||||
|
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
||||||
|
}
|
||||||
|
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec()
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -530,7 +541,8 @@ func (scope *Scope) autoMigrate() *Scope {
|
||||||
for _, field := range scope.GetStructFields() {
|
for _, field := range scope.GetStructFields() {
|
||||||
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec()
|
sqlTag := scope.generateSqlTag(field)
|
||||||
|
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scope.createJoinTable(field)
|
scope.createJoinTable(field)
|
||||||
|
|
14
sqlite3.go
14
sqlite3.go
|
@ -20,13 +20,16 @@ func (s *sqlite3) HasTop() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "bool"
|
return "bool"
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
return "integer"
|
return "integer"
|
||||||
case reflect.Int64, reflect.Uint64:
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if autoIncrease {
|
||||||
|
return "integer"
|
||||||
|
}
|
||||||
return "bigint"
|
return "bigint"
|
||||||
case reflect.Float32, reflect.Float64:
|
case reflect.Float32, reflect.Float64:
|
||||||
return "real"
|
return "real"
|
||||||
|
@ -47,15 +50,6 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
|
|
||||||
switch value.Kind() {
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
|
|
||||||
return "INTEGER PRIMARY KEY"
|
|
||||||
default:
|
|
||||||
panic("Invalid primary key type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sqlite3) ReturningStr(tableName, key string) string {
|
func (s *sqlite3) ReturningStr(tableName, key string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue