mirror of https://github.com/go-gorm/gorm.git
Refactor SQL Tag
This commit is contained in:
parent
49454839bd
commit
a549b6bd49
|
@ -36,6 +36,7 @@ type User struct {
|
|||
Birthday time.Time
|
||||
Age int
|
||||
Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag
|
||||
Num int `sql:"AUTO_INCREMENT"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt time.Time
|
||||
|
|
|
@ -21,13 +21,19 @@ func (s *commonDialect) HasTop() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
|
||||
func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "BOOLEAN"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "INTEGER AUTO_INCREMENT"
|
||||
}
|
||||
return "INTEGER"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "BIGINT AUTO_INCREMENT"
|
||||
}
|
||||
return "BIGINT"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "FLOAT"
|
||||
|
@ -51,18 +57,6 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
|
|||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
suffix := " NOT NULL PRIMARY KEY"
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "INTEGER" + suffix
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "BIGINT" + suffix
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *commonDialect) ReturningStr(tableName, key string) string {
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -9,8 +9,7 @@ type Dialect interface {
|
|||
BinVar(i int) string
|
||||
SupportLastInsertId() bool
|
||||
HasTop() bool
|
||||
SqlTag(value reflect.Value, size int) string
|
||||
PrimaryKeyTag(value reflect.Value, size int) string
|
||||
SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
||||
ReturningStr(tableName, key string) string
|
||||
SelectFromDummyTable() string
|
||||
Quote(key string) string
|
||||
|
|
|
@ -27,7 +27,6 @@ type StructField struct {
|
|||
IsIgnored bool
|
||||
IsScanner bool
|
||||
HasDefaultValue bool
|
||||
SqlTag string
|
||||
Tag reflect.StructTag
|
||||
Struct reflect.StructField
|
||||
IsForeignKey bool
|
||||
|
@ -44,7 +43,6 @@ func (structField *StructField) clone() *StructField {
|
|||
IsIgnored: structField.IsIgnored,
|
||||
IsScanner: structField.IsScanner,
|
||||
HasDefaultValue: structField.HasDefaultValue,
|
||||
SqlTag: structField.SqlTag,
|
||||
Tag: structField.Tag,
|
||||
Struct: structField.Struct,
|
||||
IsForeignKey: structField.IsForeignKey,
|
||||
|
@ -281,10 +279,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||
field.IsPrimaryKey = true
|
||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||
}
|
||||
|
||||
if scope.db != nil {
|
||||
scope.generateSqlTag(field)
|
||||
}
|
||||
}
|
||||
}
|
||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||
|
@ -301,7 +295,7 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
|
|||
return scope.GetModelStruct().StructFields
|
||||
}
|
||||
|
||||
func (scope *Scope) generateSqlTag(field *StructField) {
|
||||
func (scope *Scope) generateSqlTag(field *StructField) string {
|
||||
var sqlType string
|
||||
structType := field.Struct.Type
|
||||
if structType.Kind() == reflect.Ptr {
|
||||
|
@ -337,17 +331,18 @@ func (scope *Scope) generateSqlTag(field *StructField) {
|
|||
size, _ = strconv.Atoi(value)
|
||||
}
|
||||
|
||||
_, autoIncrease := sqlSettings["AUTO_INCREMENT"]
|
||||
if field.IsPrimaryKey {
|
||||
sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size)
|
||||
} else {
|
||||
sqlType = scope.Dialect().SqlTag(reflectValue, size)
|
||||
autoIncrease = true
|
||||
}
|
||||
|
||||
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(additionalType) == "" {
|
||||
field.SqlTag = sqlType
|
||||
return sqlType
|
||||
} else {
|
||||
field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
return fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
20
mssql.go
20
mssql.go
|
@ -21,13 +21,19 @@ func (s *mssql) HasTop() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (s *mssql) SqlTag(value reflect.Value, size int) string {
|
||||
func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bit"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "int IDENTITY(1,1)"
|
||||
}
|
||||
return "int"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "bigint IDENTITY(1,1)"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "float"
|
||||
|
@ -51,18 +57,6 @@ func (s *mssql) SqlTag(value reflect.Value, size int) string {
|
|||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
suffix := " IDENTITY(1,1) PRIMARY KEY"
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "int" + suffix
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint" + suffix
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *mssql) ReturningStr(tableName, key string) string {
|
||||
return ""
|
||||
}
|
||||
|
|
20
mysql.go
20
mysql.go
|
@ -21,13 +21,19 @@ func (s *mysql) HasTop() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (s *mysql) SqlTag(value reflect.Value, size int) string {
|
||||
func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "int AUTO_INCREMENT"
|
||||
}
|
||||
return "int"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "bigint AUTO_INCREMENT"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "double"
|
||||
|
@ -51,18 +57,6 @@ func (s *mysql) SqlTag(value reflect.Value, size int) string {
|
|||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
suffix := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "int" + suffix
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigint" + suffix
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *mysql) ReturningStr(tableName, key string) string {
|
||||
return ""
|
||||
}
|
||||
|
|
19
postgres.go
19
postgres.go
|
@ -25,13 +25,19 @@ func (s *postgres) HasTop() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (s *postgres) SqlTag(value reflect.Value, size int) string {
|
||||
func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "serial"
|
||||
}
|
||||
return "integer"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "bigserial"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "numeric"
|
||||
|
@ -56,17 +62,6 @@ func (s *postgres) SqlTag(value reflect.Value, size int) string {
|
|||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "serial PRIMARY KEY"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
return "bigserial PRIMARY KEY"
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *postgres) ReturningStr(tableName, key string) string {
|
||||
return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key)
|
||||
}
|
||||
|
|
|
@ -447,7 +447,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
||||
joinTable := joinTableHandler.Table(scope.db, relationship)
|
||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255)
|
||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||
scope.Quote(joinTable),
|
||||
strings.Join([]string{
|
||||
|
@ -460,14 +460,25 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||
}
|
||||
|
||||
func (scope *Scope) createTable() *Scope {
|
||||
var sqls []string
|
||||
for _, structField := range scope.GetStructFields() {
|
||||
if structField.IsNormal {
|
||||
sqls = append(sqls, scope.Quote(structField.DBName)+" "+structField.SqlTag)
|
||||
var tags []string
|
||||
var primaryKeys []string
|
||||
for _, field := range scope.GetStructFields() {
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.generateSqlTag(field)
|
||||
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
|
||||
}
|
||||
scope.createJoinTable(structField)
|
||||
|
||||
if field.IsPrimaryKey {
|
||||
primaryKeys = append(primaryKeys, field.DBName)
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
}
|
||||
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec()
|
||||
|
||||
var primaryKeyStr string
|
||||
if len(primaryKeys) > 0 {
|
||||
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
||||
}
|
||||
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec()
|
||||
return scope
|
||||
}
|
||||
|
||||
|
@ -530,7 +541,8 @@ func (scope *Scope) autoMigrate() *Scope {
|
|||
for _, field := range scope.GetStructFields() {
|
||||
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
||||
if field.IsNormal {
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec()
|
||||
sqlTag := scope.generateSqlTag(field)
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
|
||||
}
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
|
|
14
sqlite3.go
14
sqlite3.go
|
@ -20,13 +20,16 @@ func (s *sqlite3) HasTop() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
||||
func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "bool"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
return "integer"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "integer"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "real"
|
||||
|
@ -47,15 +50,6 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
|
|||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
||||
}
|
||||
|
||||
func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
|
||||
return "INTEGER PRIMARY KEY"
|
||||
default:
|
||||
panic("Invalid primary key type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *sqlite3) ReturningStr(tableName, key string) string {
|
||||
return ""
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue