mirror of https://github.com/go-gorm/gorm.git
Quote column name to avoid issue when it same as SQL reserved words
This commit is contained in:
parent
1959d99646
commit
41d8e2d132
|
@ -6,6 +6,7 @@ type Dialect interface {
|
|||
SqlTag(column interface{}, size int) string
|
||||
PrimaryKeyTag(column interface{}, size int) string
|
||||
ReturningStr(key string) string
|
||||
Quote(key string) string
|
||||
}
|
||||
|
||||
func New(driver string) Dialect {
|
||||
|
|
|
@ -60,3 +60,7 @@ func (s *mysql) PrimaryKeyTag(column interface{}, size int) string {
|
|||
func (s *mysql) ReturningStr(key string) (str string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *mysql) Quote(key string) (str string) {
|
||||
return fmt.Sprintf("`%s`", key)
|
||||
}
|
||||
|
|
|
@ -56,3 +56,7 @@ func (s *postgres) PrimaryKeyTag(column interface{}, size int) string {
|
|||
func (s *postgres) ReturningStr(key string) (str string) {
|
||||
return fmt.Sprintf("RETURNING \"%v\"", key)
|
||||
}
|
||||
|
||||
func (s *postgres) Quote(key string) (str string) {
|
||||
return fmt.Sprintf("\"%s\"", key)
|
||||
}
|
||||
|
|
|
@ -48,3 +48,7 @@ func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string {
|
|||
func (s *sqlite3) ReturningStr(key string) (str string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *sqlite3) Quote(key string) (str string) {
|
||||
return fmt.Sprintf("\"%s\"", key)
|
||||
}
|
||||
|
|
43
do.go
43
do.go
|
@ -46,6 +46,10 @@ func (s *Do) dialect() dialect.Dialect {
|
|||
return s.db.parent.dialect
|
||||
}
|
||||
|
||||
func (s *Do) quote(str string) string {
|
||||
return s.dialect().Quote(str)
|
||||
}
|
||||
|
||||
func (s *Do) err(err error) error {
|
||||
if err != nil {
|
||||
s.db.err(err)
|
||||
|
@ -98,7 +102,7 @@ func (s *Do) prepareCreateSql() {
|
|||
var sqls, columns []string
|
||||
|
||||
for key, value := range s.model.columnsAndValues("create") {
|
||||
columns = append(columns, key)
|
||||
columns = append(columns, s.quote(key))
|
||||
sqls = append(sqls, s.addToVars(value))
|
||||
}
|
||||
|
||||
|
@ -243,12 +247,12 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do
|
|||
func (s *Do) prepareUpdateSql(include_self bool) {
|
||||
var sqls []string
|
||||
for key, value := range s.update_attrs {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
|
||||
}
|
||||
|
||||
if include_self {
|
||||
for key, value := range s.model.columnsAndValues("update") {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -362,7 +366,7 @@ func (s *Do) related(value interface{}, foreign_keys ...string) *Do {
|
|||
if from_from {
|
||||
s.where(foreign_value).query()
|
||||
} else {
|
||||
query := fmt.Sprintf("%v = %v", toSnake(foreign_key), s.addToVars(foreign_value))
|
||||
query := fmt.Sprintf("%v = %v", s.quote(toSnake(foreign_key)), s.addToVars(foreign_value))
|
||||
s.where(query).query()
|
||||
}
|
||||
return s
|
||||
|
@ -464,7 +468,7 @@ func (s *Do) pluck(column string, value interface{}) *Do {
|
|||
}
|
||||
|
||||
func (s *Do) primaryCondiation(value interface{}) string {
|
||||
return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value)
|
||||
return fmt.Sprintf("(%v = %v)", s.quote(s.model.primaryKeyDb()), value)
|
||||
}
|
||||
|
||||
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||
|
@ -481,19 +485,19 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||
case sql.NullInt64:
|
||||
return s.primaryCondiation(s.addToVars(value.Int64))
|
||||
case []int64, []int, []int32, []string:
|
||||
str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb())
|
||||
str = fmt.Sprintf("(%v in (?))", s.quote(s.model.primaryKeyDb()))
|
||||
clause["args"] = []interface{}{value}
|
||||
case map[string]interface{}:
|
||||
var sqls []string
|
||||
for key, value := range value {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value)))
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", s.quote(key), s.addToVars(value)))
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
m := &Model{data: value, do: s}
|
||||
var sqls []string
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.dbName, s.addToVars(field.Value)))
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", s.quote(field.dbName), s.addToVars(field.Value)))
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
}
|
||||
|
@ -526,19 +530,19 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
|||
case string:
|
||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||
id, _ := strconv.Atoi(value)
|
||||
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id)
|
||||
return fmt.Sprintf("(%v <> %v)", s.quote(s.model.primaryKeyDb()), id)
|
||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
|
||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||
not_equal_sql = fmt.Sprintf("NOT (%v)", value)
|
||||
} else {
|
||||
str = fmt.Sprintf("(%v NOT IN (?))", value)
|
||||
not_equal_sql = fmt.Sprintf("(%v <> ?)", value)
|
||||
str = fmt.Sprintf("(%v NOT IN (?))", s.quote(value))
|
||||
not_equal_sql = fmt.Sprintf("(%v <> ?)", s.quote(value))
|
||||
}
|
||||
case int, int64, int32:
|
||||
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), value)
|
||||
return fmt.Sprintf("(%v <> %v)", s.quote(s.model.primaryKeyDb()), value)
|
||||
case []int64, []int, []int32, []string:
|
||||
if reflect.ValueOf(value).Len() > 0 {
|
||||
str = fmt.Sprintf("(%v not in (?))", s.model.primaryKeyDb())
|
||||
str = fmt.Sprintf("(%v not in (?))", s.quote(s.model.primaryKeyDb()))
|
||||
clause["args"] = []interface{}{value}
|
||||
} else {
|
||||
return ""
|
||||
|
@ -546,14 +550,14 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
|||
case map[string]interface{}:
|
||||
var sqls []string
|
||||
for key, value := range value {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value)))
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", s.quote(key), s.addToVars(value)))
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
case interface{}:
|
||||
m := &Model{data: value, do: s}
|
||||
var sqls []string
|
||||
for _, field := range m.columnsHasValue("other") {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.dbName, s.addToVars(field.Value)))
|
||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", s.quote(field.dbName), s.addToVars(field.Value)))
|
||||
}
|
||||
return strings.Join(sqls, " AND ")
|
||||
}
|
||||
|
@ -689,10 +693,9 @@ func (s *Do) createTable() *Do {
|
|||
var sqls []string
|
||||
for _, field := range s.model.fields("migration") {
|
||||
if len(field.sqlTag()) > 0 {
|
||||
sqls = append(sqls, field.dbName+" "+field.sqlTag())
|
||||
sqls = append(sqls, s.quote(field.dbName)+" "+field.sqlTag())
|
||||
}
|
||||
}
|
||||
|
||||
s.setSql(fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ",")))
|
||||
s.exec()
|
||||
return s
|
||||
|
@ -705,12 +708,12 @@ func (s *Do) dropTable() *Do {
|
|||
}
|
||||
|
||||
func (s *Do) modifyColumn(column string, typ string) {
|
||||
s.setSql(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ))
|
||||
s.setSql(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), s.quote(column), typ))
|
||||
s.exec()
|
||||
}
|
||||
|
||||
func (s *Do) dropColumn(column string) {
|
||||
s.setSql(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column))
|
||||
s.setSql(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), s.quote(column)))
|
||||
s.exec()
|
||||
}
|
||||
|
||||
|
@ -722,7 +725,7 @@ func (s *Do) addIndex(column string, names ...string) {
|
|||
index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column)
|
||||
}
|
||||
|
||||
s.setSql(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column))
|
||||
s.setSql(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), s.quote(column)))
|
||||
s.exec()
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ type User struct {
|
|||
BillingAddressId sql.NullInt64 // Embedded struct's foreign key
|
||||
ShippingAddress Address // Embedded struct
|
||||
ShippingAddressId int64 // Embedded struct's foreign key
|
||||
When time.Time
|
||||
CreditCard CreditCard
|
||||
PasswordHash []byte
|
||||
IgnoreMe int64 `sql:"-"`
|
||||
|
@ -133,7 +134,7 @@ func init() {
|
|||
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
|
||||
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
|
||||
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
|
||||
db.Save(&User{Name: "1", Age: 18, Birthday: t1})
|
||||
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()})
|
||||
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
|
||||
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
|
||||
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
|
||||
|
|
Loading…
Reference in New Issue