Refactor SQL Tag

This commit is contained in:
Jinzhu 2015-03-11 17:05:58 +08:00
parent 49454839bd
commit a549b6bd49
9 changed files with 61 additions and 83 deletions

View File

@ -36,6 +36,7 @@ type User struct {
Birthday time.Time Birthday time.Time
Age int Age int
Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag Name string `sql:"size:255"` // Default size for string is 255, you could reset it with this tag
Num int `sql:"AUTO_INCREMENT"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
DeletedAt time.Time DeletedAt time.Time

View File

@ -21,13 +21,19 @@ func (s *commonDialect) HasTop() bool {
return false return false
} }
func (s *commonDialect) SqlTag(value reflect.Value, size int) string { func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "BOOLEAN" return "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "INTEGER AUTO_INCREMENT"
}
return "INTEGER" return "INTEGER"
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "BIGINT AUTO_INCREMENT"
}
return "BIGINT" return "BIGINT"
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return "FLOAT" return "FLOAT"
@ -51,18 +57,6 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
} }
func (s *commonDialect) PrimaryKeyTag(value reflect.Value, size int) string {
suffix := " NOT NULL PRIMARY KEY"
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "INTEGER" + suffix
case reflect.Int64, reflect.Uint64:
return "BIGINT" + suffix
default:
panic("Invalid primary key type")
}
}
func (s *commonDialect) ReturningStr(tableName, key string) string { func (s *commonDialect) ReturningStr(tableName, key string) string {
return "" return ""
} }

View File

@ -9,8 +9,7 @@ type Dialect interface {
BinVar(i int) string BinVar(i int) string
SupportLastInsertId() bool SupportLastInsertId() bool
HasTop() bool HasTop() bool
SqlTag(value reflect.Value, size int) string SqlTag(value reflect.Value, size int, autoIncrease bool) string
PrimaryKeyTag(value reflect.Value, size int) string
ReturningStr(tableName, key string) string ReturningStr(tableName, key string) string
SelectFromDummyTable() string SelectFromDummyTable() string
Quote(key string) string Quote(key string) string

View File

@ -27,7 +27,6 @@ type StructField struct {
IsIgnored bool IsIgnored bool
IsScanner bool IsScanner bool
HasDefaultValue bool HasDefaultValue bool
SqlTag string
Tag reflect.StructTag Tag reflect.StructTag
Struct reflect.StructField Struct reflect.StructField
IsForeignKey bool IsForeignKey bool
@ -44,7 +43,6 @@ func (structField *StructField) clone() *StructField {
IsIgnored: structField.IsIgnored, IsIgnored: structField.IsIgnored,
IsScanner: structField.IsScanner, IsScanner: structField.IsScanner,
HasDefaultValue: structField.HasDefaultValue, HasDefaultValue: structField.HasDefaultValue,
SqlTag: structField.SqlTag,
Tag: structField.Tag, Tag: structField.Tag,
Struct: structField.Struct, Struct: structField.Struct,
IsForeignKey: structField.IsForeignKey, IsForeignKey: structField.IsForeignKey,
@ -281,10 +279,6 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
field.IsPrimaryKey = true field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
} }
if scope.db != nil {
scope.generateSqlTag(field)
}
} }
} }
modelStruct.StructFields = append(modelStruct.StructFields, field) modelStruct.StructFields = append(modelStruct.StructFields, field)
@ -301,7 +295,7 @@ func (scope *Scope) GetStructFields() (fields []*StructField) {
return scope.GetModelStruct().StructFields return scope.GetModelStruct().StructFields
} }
func (scope *Scope) generateSqlTag(field *StructField) { func (scope *Scope) generateSqlTag(field *StructField) string {
var sqlType string var sqlType string
structType := field.Struct.Type structType := field.Struct.Type
if structType.Kind() == reflect.Ptr { if structType.Kind() == reflect.Ptr {
@ -337,17 +331,18 @@ func (scope *Scope) generateSqlTag(field *StructField) {
size, _ = strconv.Atoi(value) size, _ = strconv.Atoi(value)
} }
_, autoIncrease := sqlSettings["AUTO_INCREMENT"]
if field.IsPrimaryKey { if field.IsPrimaryKey {
sqlType = scope.Dialect().PrimaryKeyTag(reflectValue, size) autoIncrease = true
} else {
sqlType = scope.Dialect().SqlTag(reflectValue, size)
} }
sqlType = scope.Dialect().SqlTag(reflectValue, size, autoIncrease)
} }
if strings.TrimSpace(additionalType) == "" { if strings.TrimSpace(additionalType) == "" {
field.SqlTag = sqlType return sqlType
} else { } else {
field.SqlTag = fmt.Sprintf("%v %v", sqlType, additionalType) return fmt.Sprintf("%v %v", sqlType, additionalType)
} }
} }

View File

@ -21,13 +21,19 @@ func (s *mssql) HasTop() bool {
return true return true
} }
func (s *mssql) SqlTag(value reflect.Value, size int) string { func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "bit" return "bit"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int IDENTITY(1,1)"
}
return "int" return "int"
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigint IDENTITY(1,1)"
}
return "bigint" return "bigint"
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return "float" return "float"
@ -51,18 +57,6 @@ func (s *mssql) SqlTag(value reflect.Value, size int) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
} }
func (s *mssql) PrimaryKeyTag(value reflect.Value, size int) string {
suffix := " IDENTITY(1,1) PRIMARY KEY"
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" + suffix
case reflect.Int64, reflect.Uint64:
return "bigint" + suffix
default:
panic("Invalid primary key type")
}
}
func (s *mssql) ReturningStr(tableName, key string) string { func (s *mssql) ReturningStr(tableName, key string) string {
return "" return ""
} }

View File

@ -21,13 +21,19 @@ func (s *mysql) HasTop() bool {
return false return false
} }
func (s *mysql) SqlTag(value reflect.Value, size int) string { func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int AUTO_INCREMENT"
}
return "int" return "int"
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigint AUTO_INCREMENT"
}
return "bigint" return "bigint"
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return "double" return "double"
@ -51,18 +57,6 @@ func (s *mysql) SqlTag(value reflect.Value, size int) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
} }
func (s *mysql) PrimaryKeyTag(value reflect.Value, size int) string {
suffix := " NOT NULL AUTO_INCREMENT PRIMARY KEY"
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "int" + suffix
case reflect.Int64, reflect.Uint64:
return "bigint" + suffix
default:
panic("Invalid primary key type")
}
}
func (s *mysql) ReturningStr(tableName, key string) string { func (s *mysql) ReturningStr(tableName, key string) string {
return "" return ""
} }

View File

@ -25,13 +25,19 @@ func (s *postgres) HasTop() bool {
return false return false
} }
func (s *postgres) SqlTag(value reflect.Value, size int) string { func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "boolean" return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "serial"
}
return "integer" return "integer"
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigserial"
}
return "bigint" return "bigint"
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return "numeric" return "numeric"
@ -56,17 +62,6 @@ func (s *postgres) SqlTag(value reflect.Value, size int) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
} }
func (s *postgres) PrimaryKeyTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "serial PRIMARY KEY"
case reflect.Int64, reflect.Uint64:
return "bigserial PRIMARY KEY"
default:
panic("Invalid primary key type")
}
}
func (s *postgres) ReturningStr(tableName, key string) string { func (s *postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key) return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key)
} }

View File

@ -447,7 +447,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
joinTable := joinTableHandler.Table(scope.db, relationship) joinTable := joinTableHandler.Table(scope.db, relationship)
if !scope.Dialect().HasTable(scope, joinTable) { if !scope.Dialect().HasTable(scope, joinTable) {
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255) primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
scope.Quote(joinTable), scope.Quote(joinTable),
strings.Join([]string{ strings.Join([]string{
@ -460,14 +460,25 @@ func (scope *Scope) createJoinTable(field *StructField) {
} }
func (scope *Scope) createTable() *Scope { func (scope *Scope) createTable() *Scope {
var sqls []string var tags []string
for _, structField := range scope.GetStructFields() { var primaryKeys []string
if structField.IsNormal { for _, field := range scope.GetStructFields() {
sqls = append(sqls, scope.Quote(structField.DBName)+" "+structField.SqlTag) if field.IsNormal {
sqlTag := scope.generateSqlTag(field)
tags = append(tags, scope.Quote(field.DBName)+" "+sqlTag)
} }
scope.createJoinTable(structField)
if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, field.DBName)
} }
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v)", scope.QuotedTableName(), strings.Join(sqls, ","))).Exec() scope.createJoinTable(field)
}
var primaryKeyStr string
if len(primaryKeys) > 0 {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
}
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec()
return scope return scope
} }
@ -530,7 +541,8 @@ func (scope *Scope) autoMigrate() *Scope {
for _, field := range scope.GetStructFields() { for _, field := range scope.GetStructFields() {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if field.IsNormal { if field.IsNormal {
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, field.SqlTag)).Exec() sqlTag := scope.generateSqlTag(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
} }
} }
scope.createJoinTable(field) scope.createJoinTable(field)

View File

@ -20,13 +20,16 @@ func (s *sqlite3) HasTop() bool {
return false return false
} }
func (s *sqlite3) SqlTag(value reflect.Value, size int) string { func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() { switch value.Kind() {
case reflect.Bool: case reflect.Bool:
return "bool" return "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
return "integer" return "integer"
case reflect.Int64, reflect.Uint64: case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "integer"
}
return "bigint" return "bigint"
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return "real" return "real"
@ -47,15 +50,6 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
} }
func (s *sqlite3) PrimaryKeyTag(value reflect.Value, size int) string {
switch value.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr, reflect.Int64, reflect.Uint64:
return "INTEGER PRIMARY KEY"
default:
panic("Invalid primary key type")
}
}
func (s *sqlite3) ReturningStr(tableName, key string) string { func (s *sqlite3) ReturningStr(tableName, key string) string {
return "" return ""
} }