forked from mirror/gorm
Quote column name to avoid issue when it same as SQL reserved words
This commit is contained in:
parent
1959d99646
commit
41d8e2d132
|
@ -6,6 +6,7 @@ type Dialect interface {
|
||||||
SqlTag(column interface{}, size int) string
|
SqlTag(column interface{}, size int) string
|
||||||
PrimaryKeyTag(column interface{}, size int) string
|
PrimaryKeyTag(column interface{}, size int) string
|
||||||
ReturningStr(key string) string
|
ReturningStr(key string) string
|
||||||
|
Quote(key string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(driver string) Dialect {
|
func New(driver string) Dialect {
|
||||||
|
|
|
@ -60,3 +60,7 @@ func (s *mysql) PrimaryKeyTag(column interface{}, size int) string {
|
||||||
func (s *mysql) ReturningStr(key string) (str string) {
|
func (s *mysql) ReturningStr(key string) (str string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *mysql) Quote(key string) (str string) {
|
||||||
|
return fmt.Sprintf("`%s`", key)
|
||||||
|
}
|
||||||
|
|
|
@ -56,3 +56,7 @@ func (s *postgres) PrimaryKeyTag(column interface{}, size int) string {
|
||||||
func (s *postgres) ReturningStr(key string) (str string) {
|
func (s *postgres) ReturningStr(key string) (str string) {
|
||||||
return fmt.Sprintf("RETURNING \"%v\"", key)
|
return fmt.Sprintf("RETURNING \"%v\"", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *postgres) Quote(key string) (str string) {
|
||||||
|
return fmt.Sprintf("\"%s\"", key)
|
||||||
|
}
|
||||||
|
|
|
@ -48,3 +48,7 @@ func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string {
|
||||||
func (s *sqlite3) ReturningStr(key string) (str string) {
|
func (s *sqlite3) ReturningStr(key string) (str string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *sqlite3) Quote(key string) (str string) {
|
||||||
|
return fmt.Sprintf("\"%s\"", key)
|
||||||
|
}
|
||||||
|
|
43
do.go
43
do.go
|
@ -46,6 +46,10 @@ func (s *Do) dialect() dialect.Dialect {
|
||||||
return s.db.parent.dialect
|
return s.db.parent.dialect
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Do) quote(str string) string {
|
||||||
|
return s.dialect().Quote(str)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Do) err(err error) error {
|
func (s *Do) err(err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.db.err(err)
|
s.db.err(err)
|
||||||
|
@ -98,7 +102,7 @@ func (s *Do) prepareCreateSql() {
|
||||||
var sqls, columns []string
|
var sqls, columns []string
|
||||||
|
|
||||||
for key, value := range s.model.columnsAndValues("create") {
|
for key, value := range s.model.columnsAndValues("create") {
|
||||||
columns = append(columns, key)
|
columns = append(columns, s.quote(key))
|
||||||
sqls = append(sqls, s.addToVars(value))
|
sqls = append(sqls, s.addToVars(value))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -243,12 +247,12 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do
|
||||||
func (s *Do) prepareUpdateSql(include_self bool) {
|
func (s *Do) prepareUpdateSql(include_self bool) {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for key, value := range s.update_attrs {
|
for key, value := range s.update_attrs {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
|
||||||
}
|
}
|
||||||
|
|
||||||
if include_self {
|
if include_self {
|
||||||
for key, value := range s.model.columnsAndValues("update") {
|
for key, value := range s.model.columnsAndValues("update") {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -362,7 +366,7 @@ func (s *Do) related(value interface{}, foreign_keys ...string) *Do {
|
||||||
if from_from {
|
if from_from {
|
||||||
s.where(foreign_value).query()
|
s.where(foreign_value).query()
|
||||||
} else {
|
} else {
|
||||||
query := fmt.Sprintf("%v = %v", toSnake(foreign_key), s.addToVars(foreign_value))
|
query := fmt.Sprintf("%v = %v", s.quote(toSnake(foreign_key)), s.addToVars(foreign_value))
|
||||||
s.where(query).query()
|
s.where(query).query()
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
|
@ -464,7 +468,7 @@ func (s *Do) pluck(column string, value interface{}) *Do {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) primaryCondiation(value interface{}) string {
|
func (s *Do) primaryCondiation(value interface{}) string {
|
||||||
return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value)
|
return fmt.Sprintf("(%v = %v)", s.quote(s.model.primaryKeyDb()), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
|
@ -481,19 +485,19 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
||||||
case sql.NullInt64:
|
case sql.NullInt64:
|
||||||
return s.primaryCondiation(s.addToVars(value.Int64))
|
return s.primaryCondiation(s.addToVars(value.Int64))
|
||||||
case []int64, []int, []int32, []string:
|
case []int64, []int, []int32, []string:
|
||||||
str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb())
|
str = fmt.Sprintf("(%v in (?))", s.quote(s.model.primaryKeyDb()))
|
||||||
clause["args"] = []interface{}{value}
|
clause["args"] = []interface{}{value}
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for key, value := range value {
|
for key, value := range value {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("(%v = %v)", s.quote(key), s.addToVars(value)))
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
case interface{}:
|
case interface{}:
|
||||||
m := &Model{data: value, do: s}
|
m := &Model{data: value, do: s}
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range m.columnsHasValue("other") {
|
for _, field := range m.columnsHasValue("other") {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.dbName, s.addToVars(field.Value)))
|
sqls = append(sqls, fmt.Sprintf("(%v = %v)", s.quote(field.dbName), s.addToVars(field.Value)))
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
}
|
}
|
||||||
|
@ -526,19 +530,19 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
case string:
|
case string:
|
||||||
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) {
|
||||||
id, _ := strconv.Atoi(value)
|
id, _ := strconv.Atoi(value)
|
||||||
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id)
|
return fmt.Sprintf("(%v <> %v)", s.quote(s.model.primaryKeyDb()), id)
|
||||||
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
|
} else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) {
|
||||||
str = fmt.Sprintf(" NOT (%v) ", value)
|
str = fmt.Sprintf(" NOT (%v) ", value)
|
||||||
not_equal_sql = fmt.Sprintf("NOT (%v)", value)
|
not_equal_sql = fmt.Sprintf("NOT (%v)", value)
|
||||||
} else {
|
} else {
|
||||||
str = fmt.Sprintf("(%v NOT IN (?))", value)
|
str = fmt.Sprintf("(%v NOT IN (?))", s.quote(value))
|
||||||
not_equal_sql = fmt.Sprintf("(%v <> ?)", value)
|
not_equal_sql = fmt.Sprintf("(%v <> ?)", s.quote(value))
|
||||||
}
|
}
|
||||||
case int, int64, int32:
|
case int, int64, int32:
|
||||||
return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), value)
|
return fmt.Sprintf("(%v <> %v)", s.quote(s.model.primaryKeyDb()), value)
|
||||||
case []int64, []int, []int32, []string:
|
case []int64, []int, []int32, []string:
|
||||||
if reflect.ValueOf(value).Len() > 0 {
|
if reflect.ValueOf(value).Len() > 0 {
|
||||||
str = fmt.Sprintf("(%v not in (?))", s.model.primaryKeyDb())
|
str = fmt.Sprintf("(%v not in (?))", s.quote(s.model.primaryKeyDb()))
|
||||||
clause["args"] = []interface{}{value}
|
clause["args"] = []interface{}{value}
|
||||||
} else {
|
} else {
|
||||||
return ""
|
return ""
|
||||||
|
@ -546,14 +550,14 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for key, value := range value {
|
for key, value := range value {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value)))
|
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", s.quote(key), s.addToVars(value)))
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
case interface{}:
|
case interface{}:
|
||||||
m := &Model{data: value, do: s}
|
m := &Model{data: value, do: s}
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range m.columnsHasValue("other") {
|
for _, field := range m.columnsHasValue("other") {
|
||||||
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.dbName, s.addToVars(field.Value)))
|
sqls = append(sqls, fmt.Sprintf("(%v <> %v)", s.quote(field.dbName), s.addToVars(field.Value)))
|
||||||
}
|
}
|
||||||
return strings.Join(sqls, " AND ")
|
return strings.Join(sqls, " AND ")
|
||||||
}
|
}
|
||||||
|
@ -689,10 +693,9 @@ func (s *Do) createTable() *Do {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
for _, field := range s.model.fields("migration") {
|
for _, field := range s.model.fields("migration") {
|
||||||
if len(field.sqlTag()) > 0 {
|
if len(field.sqlTag()) > 0 {
|
||||||
sqls = append(sqls, field.dbName+" "+field.sqlTag())
|
sqls = append(sqls, s.quote(field.dbName)+" "+field.sqlTag())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setSql(fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ",")))
|
s.setSql(fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ",")))
|
||||||
s.exec()
|
s.exec()
|
||||||
return s
|
return s
|
||||||
|
@ -705,12 +708,12 @@ func (s *Do) dropTable() *Do {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) modifyColumn(column string, typ string) {
|
func (s *Do) modifyColumn(column string, typ string) {
|
||||||
s.setSql(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ))
|
s.setSql(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), s.quote(column), typ))
|
||||||
s.exec()
|
s.exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Do) dropColumn(column string) {
|
func (s *Do) dropColumn(column string) {
|
||||||
s.setSql(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column))
|
s.setSql(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), s.quote(column)))
|
||||||
s.exec()
|
s.exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -722,7 +725,7 @@ func (s *Do) addIndex(column string, names ...string) {
|
||||||
index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column)
|
index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setSql(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column))
|
s.setSql(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), s.quote(column)))
|
||||||
s.exec()
|
s.exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ type User struct {
|
||||||
BillingAddressId sql.NullInt64 // Embedded struct's foreign key
|
BillingAddressId sql.NullInt64 // Embedded struct's foreign key
|
||||||
ShippingAddress Address // Embedded struct
|
ShippingAddress Address // Embedded struct
|
||||||
ShippingAddressId int64 // Embedded struct's foreign key
|
ShippingAddressId int64 // Embedded struct's foreign key
|
||||||
|
When time.Time
|
||||||
CreditCard CreditCard
|
CreditCard CreditCard
|
||||||
PasswordHash []byte
|
PasswordHash []byte
|
||||||
IgnoreMe int64 `sql:"-"`
|
IgnoreMe int64 `sql:"-"`
|
||||||
|
@ -133,7 +134,7 @@ func init() {
|
||||||
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
|
t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00")
|
||||||
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
|
t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00")
|
||||||
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
|
t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00")
|
||||||
db.Save(&User{Name: "1", Age: 18, Birthday: t1})
|
db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()})
|
||||||
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
|
db.Save(&User{Name: "2", Age: 20, Birthday: t2})
|
||||||
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
|
db.Save(&User{Name: "3", Age: 22, Birthday: t3})
|
||||||
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
|
db.Save(&User{Name: "3", Age: 24, Birthday: t4})
|
||||||
|
|
Loading…
Reference in New Issue