Refactor SQL Tag

This commit is contained in:
Jinzhu 2015-03-11 17:05:58 +08:00
parent 49454839bd
commit a549b6bd49
9 changed files with 61 additions and 83 deletions

View File

@ -36,6 +36,7 @@ type User struct {
Birthday time.Time
Age int
Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag
Num int `sql:"AUTO_INCREMENT"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time

View File

@ -21,13 +21,19 @@ func (s *commonDialect) HasTop() bool {
return false
}
func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "INTEGER AUTO_INCREMENT"
}
return "INTEGER"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "BIGINT AUTO_INCREMENT"
}
return "BIGINT"
case reflect.Float32, reflect.Float64:
return "FLOAT"
@ -51,18 +57,6 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
}
func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string {
suffix := " NOT NULL PRIMARY KEY"
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "INTEGER" + suffix
case reflect.Int64, reflect.Uint64:
return "BIGINT" + suffix
default:
panic("Invalid primary key type")
}
}
func (s *commonDialect) ReturningStr(tableName, key string) string {
return ""
}

View File

@ -9,8 +9,7 @@ type Dialect interface {
BinVar(i int) string
SupportLastInsertId() bool
HasTop() bool
SqlTag(value reflect.Value, size int) string
PrimaryKeyTag(value reflect.Value, size int) string
SqlTag(value reflect.Value, size int, autoIncrease bool) string
ReturningStr(tableName, key string) string
SelectFromDummyTable() string
Quote(key string) string

View File

@ -27,7 +27,6 @@ type StructField struct {
IsIgnored bool
IsScanner bool
HasDefaultValue bool
SqlTag string
Tag reflect.StructTag
Struct reflect.StructField
IsForeignKey bool
@ -44,7 +43,6 @@ func (structField *StructField) clone() *StructField {
IsIgnored: structField.IsIgnored,
IsScanner: structField.IsScanner,
HasDefaultValue: structField.HasDefaultValue,
SqlTag: structField.SqlTag,
Tag: structField.Tag,
Struct: structField.Struct,
IsForeignKey: structField.IsForeignKey,
@ -281,10 +279,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}
if scope.db != nil {
scope.generateSqlTag(field)
}
}
}
modelStruct.StructFields = append(modelStruct.StructFields, field)
@ -301,7 +295,7 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
return scope.GetModelStruct().StructFields
}
func (scope *Scope) generateSqlTag(field *StructField) {
func (scope *Scope) generateSqlTag(field *StructField) string {
var sqlType string
structType := field.Struct.Type
if structType.Kind() == reflect.Ptr {
@ -337,17 +331,18 @@ func (scope *Scope) generateSqlTag(field *StructField) {
size, _ = strconv.Atoi(value)
}
_, autoIncrease := sqlSettings["AUTO_INCREMENT"]
if field.IsPrimaryKey {
sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size)
} else {
sqlType = scope.Dialect().SqlTag(reflectValue, size)
autoIncrease = true
}
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
}
if strings.TrimSpace(additionalType) == "" {
field.SqlTag = sqlType
return sqlType
} else {
field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType)
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
}

View File

@ -21,13 +21,19 @@ func (s *mssql) HasTop() bool {
return true
}
func (s *mssql) SqlTag(value reflect.Value, size int) string {
func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int IDENTITY(1,1)"
}
return "int"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigint IDENTITY(1,1)"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "float"
@ -51,18 +57,6 @@ func (s *mssql) SqlTag(value reflect.Value, size int) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
}
func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string {
suffix := " IDENTITY(1,1) PRIMARY KEY"
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" + suffix
case reflect.Int64, reflect.Uint64:
return "bigint" + suffix
default:
panic("Invalid primary key type")
}
}
func (s *mssql) ReturningStr(tableName, key string) string {
return ""
}

View File

@ -21,13 +21,19 @@ func (s *mysql) HasTop() bool {
return false
}
func (s *mysql) SqlTag(value reflect.Value, size int) string {
func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int AUTO_INCREMENT"
}
return "int"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigint AUTO_INCREMENT"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "double"
@ -51,18 +57,6 @@ func (s *mysql) SqlTag(value reflect.Value, size int) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
}
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
suffix := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" + suffix
case reflect.Int64, reflect.Uint64:
return "bigint" + suffix
default:
panic("Invalid primary key type")
}
}
func (s *mysql) ReturningStr(tableName, key string) string {
return ""
}

View File

@ -25,13 +25,19 @@ func (s *postgres) HasTop() bool {
return false
}
func (s *postgres) SqlTag(value reflect.Value, size int) string {
func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigserial"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "numeric"
@ -56,17 +62,6 @@ func (s *postgres) SqlTag(value reflect.Value, size int) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
}
func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "serial PRIMARY KEY"
case reflect.Int64, reflect.Uint64:
return "bigserial PRIMARY KEY"
default:
panic("Invalid primary key type")
}
}
func (s *postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key)
}

View File

@ -447,7 +447,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
joinTable := joinTableHandler.Table(scope.db, relationship)
if !scope.Dialect().HasTable(scope, joinTable) {
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255)
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
scope.Quote(joinTable),
strings.Join([]string{
@ -460,14 +460,25 @@ func (scope *Scope) createJoinTable(field *StructField) {
}
func (scope *Scope) createTable() *Scope {
var sqls []string
for _, structField := range scope.GetStructFields() {
if structField.IsNormal {
sqls = append(sqls, scope.Quote(structField.DBName)+" "+structField.SqlTag)
var tags []string
var primaryKeys []string
for _, field := range scope.GetStructFields() {
if field.IsNormal {
sqlTag := scope.generateSqlTag(field)
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
}
scope.createJoinTable(structField)
if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, field.DBName)
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec()
scope.createJoinTable(field)
}
var primaryKeyStr string
if len(primaryKeys) > 0 {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec()
return scope
}
@ -530,7 +541,8 @@ func (scope *Scope) autoMigrate() *Scope {
for _, field := range scope.GetStructFields() {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if field.IsNormal {
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec()
sqlTag := scope.generateSqlTag(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
}
}
scope.createJoinTable(field)

View File

@ -20,13 +20,16 @@ func (s *sqlite3) HasTop() bool {
return false
}
func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "integer"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "real"
@ -47,15 +50,6 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
}
func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
return "INTEGER PRIMARY KEY"
default:
panic("Invalid primary key type")
}
}
func (s *sqlite3) ReturningStr(tableName, key string) string {
return ""
}