Quote column name to avoid issue when it same as SQL reserved words

This commit is contained in:
Jinzhu 2013-11-30 14:52:01 +08:00
parent 1959d99646
commit 41d8e2d132
6 changed files with 38 additions and 21 deletions

View File

@ -6,6 +6,7 @@ type Dialect interface {
SqlTag(column interface{}, size int) string SqlTag(column interface{}, size int) string
PrimaryKeyTag(column interface{}, size int) string PrimaryKeyTag(column interface{}, size int) string
ReturningStr(key string) string ReturningStr(key string) string
Quote(key string) string
} }
func New(driver string) Dialect { func New(driver string) Dialect {

View File

@ -60,3 +60,7 @@ func (s *mysql) PrimaryKeyTag(column interface{}, size int) string {
func (s *mysql) ReturningStr(key string) (str string) { func (s *mysql) ReturningStr(key string) (str string) {
return return
} }
func (s *mysql) Quote(key string) (str string) {
return fmt.Sprintf("`%s`", key)
}

View File

@ -56,3 +56,7 @@ func (s *postgres) PrimaryKeyTag(column interface{}, size int) string {
func (s *postgres) ReturningStr(key string) (str string) { func (s *postgres) ReturningStr(key string) (str string) {
return fmt.Sprintf("RETURNING \"%v\"", key) return fmt.Sprintf("RETURNING \"%v\"", key)
} }
func (s *postgres) Quote(key string) (str string) {
return fmt.Sprintf("\"%s\"", key)
}

View File

@ -48,3 +48,7 @@ func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string {
func (s *sqlite3) ReturningStr(key string) (str string) { func (s *sqlite3) ReturningStr(key string) (str string) {
return return
} }
func (s *sqlite3) Quote(key string) (str string) {
return fmt.Sprintf("\"%s\"", key)
}

43
do.go
View File

@ -46,6 +46,10 @@ func (s *Do) dialect() dialect.Dialect {
return s.db.parent.dialect return s.db.parent.dialect
} }
func (s *Do) quote(str string) string {
return s.dialect().Quote(str)
}
func (s *Do) err(err error) error { func (s *Do) err(err error) error {
if err != nil { if err != nil {
s.db.err(err) s.db.err(err)
@ -98,7 +102,7 @@ func (s *Do) prepareCreateSql() {
var sqls, columns []string var sqls, columns []string
for key, value := range s.model.columnsAndValues("create") { for key, value := range s.model.columnsAndValues("create") {
columns = append(columns, key) columns = append(columns, s.quote(key))
sqls = append(sqls, s.addToVars(value)) sqls = append(sqls, s.addToVars(value))
} }
@ -243,12 +247,12 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do
func (s *Do) prepareUpdateSql(include_self bool) { func (s *Do) prepareUpdateSql(include_self bool) {
var sqls []string var sqls []string
for key, value := range s.update_attrs { for key, value := range s.update_attrs {
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
} }
if include_self { if include_self {
for key, value := range s.model.columnsAndValues("update") { for key, value := range s.model.columnsAndValues("update") {
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
} }
} }
@ -362,7 +366,7 @@ func (s *Do) related(value interface{}, foreign_keys ...string) *Do {
if from_from { if from_from {
s.where(foreign_value).query() s.where(foreign_value).query()
} else { } else {
query := fmt.Sprintf("%v = %v", toSnake(foreign_key), s.addToVars(foreign_value)) query := fmt.Sprintf("%v = %v", s.quote(toSnake(foreign_key)), s.addToVars(foreign_value))
s.where(query).query() s.where(query).query()
} }
return s return s
@ -464,7 +468,7 @@ func (s *Do) pluck(column string, value interface{}) *Do {
} }
func (s *Do) primaryCondiation(value interface{}) string { func (s *Do) primaryCondiation(value interface{}) string {
return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value) return fmt.Sprintf("(%v = %v)", s.quote(s.model.primaryKeyDb()), value)
} }
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
@ -481,19 +485,19 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
case sql.NullInt64: case sql.NullInt64:
return s.primaryCondiation(s.addToVars(value.Int64)) return s.primaryCondiation(s.addToVars(value.Int64))
case []int64, []int, []int32, []string: case []int64, []int, []int32, []string:
str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb()) str = fmt.Sprintf("(%v in (?))", s.quote(s.model.primaryKeyDb()))
clause["args"] = []interface{}{value} clause["args"] = []interface{}{value}
case map[string]interface{}: case map[string]interface{}:
var sqls []string var sqls []string
for key, value := range value { for key, value := range value {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value))) sqls = append(sqls, fmt.Sprintf("(%v = %v)", s.quote(key), s.addToVars(value)))
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
case interface{}: case interface{}:
m := &Model{data: value, do: s} m := &Model{data: value, do: s}
var sqls []string var sqls []string
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.dbName, s.addToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("(%v = %v)", s.quote(field.dbName), s.addToVars(field.Value)))
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
} }
@ -526,19 +530,19 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
case string: case string:
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
id, _ := strconv.Atoi(value) id, _ := strconv.Atoi(value)
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id) return fmt.Sprintf("(%v <> %v)", s.quote(s.model.primaryKeyDb()), id)
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
str = fmt.Sprintf(" NOT (%v) ", value) str = fmt.Sprintf(" NOT (%v) ", value)
not_equal_sql = fmt.Sprintf("NOT (%v)", value) not_equal_sql = fmt.Sprintf("NOT (%v)", value)
} else { } else {
str = fmt.Sprintf("(%v NOT IN (?))", value) str = fmt.Sprintf("(%v NOT IN (?))", s.quote(value))
not_equal_sql = fmt.Sprintf("(%v <> ?)", value) not_equal_sql = fmt.Sprintf("(%v <> ?)", s.quote(value))
} }
case int, int64, int32: case int, int64, int32:
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), value) return fmt.Sprintf("(%v <> %v)", s.quote(s.model.primaryKeyDb()), value)
case []int64, []int, []int32, []string: case []int64, []int, []int32, []string:
if reflect.ValueOf(value).Len() > 0 { if reflect.ValueOf(value).Len() > 0 {
str = fmt.Sprintf("(%v not in (?))", s.model.primaryKeyDb()) str = fmt.Sprintf("(%v not in (?))", s.quote(s.model.primaryKeyDb()))
clause["args"] = []interface{}{value} clause["args"] = []interface{}{value}
} else { } else {
return "" return ""
@ -546,14 +550,14 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
case map[string]interface{}: case map[string]interface{}:
var sqls []string var sqls []string
for key, value := range value { for key, value := range value {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value))) sqls = append(sqls, fmt.Sprintf("(%v <> %v)", s.quote(key), s.addToVars(value)))
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
case interface{}: case interface{}:
m := &Model{data: value, do: s} m := &Model{data: value, do: s}
var sqls []string var sqls []string
for _, field := range m.columnsHasValue("other") { for _, field := range m.columnsHasValue("other") {
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.dbName, s.addToVars(field.Value))) sqls = append(sqls, fmt.Sprintf("(%v <> %v)", s.quote(field.dbName), s.addToVars(field.Value)))
} }
return strings.Join(sqls, " AND ") return strings.Join(sqls, " AND ")
} }
@ -689,10 +693,9 @@ func (s *Do) createTable() *Do {
var sqls []string var sqls []string
for _, field := range s.model.fields("migration") { for _, field := range s.model.fields("migration") {
if len(field.sqlTag()) > 0 { if len(field.sqlTag()) > 0 {
sqls = append(sqls, field.dbName+" "+field.sqlTag()) sqls = append(sqls, s.quote(field.dbName)+" "+field.sqlTag())
} }
} }
s.setSql(fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))) s.setSql(fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ",")))
s.exec() s.exec()
return s return s
@ -705,12 +708,12 @@ func (s *Do) dropTable() *Do {
} }
func (s *Do) modifyColumn(column string, typ string) { func (s *Do) modifyColumn(column string, typ string) {
s.setSql(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ)) s.setSql(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), s.quote(column), typ))
s.exec() s.exec()
} }
func (s *Do) dropColumn(column string) { func (s *Do) dropColumn(column string) {
s.setSql(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column)) s.setSql(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), s.quote(column)))
s.exec() s.exec()
} }
@ -722,7 +725,7 @@ func (s *Do) addIndex(column string, names ...string) {
index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column) index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column)
} }
s.setSql(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column)) s.setSql(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), s.quote(column)))
s.exec() s.exec()
} }

View File

@ -27,6 +27,7 @@ type User struct {
BillingAddressId sql.NullInt64 // Embedded struct's foreign key BillingAddressId sql.NullInt64 // Embedded struct's foreign key
ShippingAddress Address // Embedded struct ShippingAddress Address // Embedded struct
ShippingAddressId int64 // Embedded struct's foreign key ShippingAddressId int64 // Embedded struct's foreign key
When time.Time
CreditCard CreditCard CreditCard CreditCard
PasswordHash []byte PasswordHash []byte
IgnoreMe int64 `sql:"-"` IgnoreMe int64 `sql:"-"`
@ -133,7 +134,7 @@ func init() {
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00") t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00") t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00") t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
db.Save(&User{Name: "1", Age: 18, Birthday: t1}) db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()})
db.Save(&User{Name: "2", Age: 20, Birthday: t2}) db.Save(&User{Name: "2", Age: 20, Birthday: t2})
db.Save(&User{Name: "3", Age: 22, Birthday: t3}) db.Save(&User{Name: "3", Age: 22, Birthday: t3})
db.Save(&User{Name: "3", Age: 24, Birthday: t4}) db.Save(&User{Name: "3", Age: 24, Birthday: t4})