From 41d8e2d132454a16d9fdb70b4d65da3988c0ae8d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 30 Nov 2013 14:52:01 +0800 Subject: [PATCH] Quote column name to avoid issue when it same as SQL reserved words --- dialect/dialect.go | 1 + dialect/mysql.go | 4 ++++ dialect/postgres.go | 4 ++++ dialect/sqlite3.go | 4 ++++ do.go | 43 +++++++++++++++++++++++-------------------- gorm_test.go | 3 ++- 6 files changed, 38 insertions(+), 21 deletions(-) diff --git a/dialect/dialect.go b/dialect/dialect.go index f21e1876..9418e533 100644 --- a/dialect/dialect.go +++ b/dialect/dialect.go @@ -6,6 +6,7 @@ type Dialect interface { SqlTag(column interface{}, size int) string PrimaryKeyTag(column interface{}, size int) string ReturningStr(key string) string + Quote(key string) string } func New(driver string) Dialect { diff --git a/dialect/mysql.go b/dialect/mysql.go index 8e611c3b..151c1076 100644 --- a/dialect/mysql.go +++ b/dialect/mysql.go @@ -60,3 +60,7 @@ func (s *mysql) PrimaryKeyTag(column interface{}, size int) string { func (s *mysql) ReturningStr(key string) (str string) { return } + +func (s *mysql) Quote(key string) (str string) { + return fmt.Sprintf("`%s`", key) +} diff --git a/dialect/postgres.go b/dialect/postgres.go index 13a0afa0..c31eb7e1 100644 --- a/dialect/postgres.go +++ b/dialect/postgres.go @@ -56,3 +56,7 @@ func (s *postgres) PrimaryKeyTag(column interface{}, size int) string { func (s *postgres) ReturningStr(key string) (str string) { return fmt.Sprintf("RETURNING \"%v\"", key) } + +func (s *postgres) Quote(key string) (str string) { + return fmt.Sprintf("\"%s\"", key) +} diff --git a/dialect/sqlite3.go b/dialect/sqlite3.go index cc91b122..92063786 100644 --- a/dialect/sqlite3.go +++ b/dialect/sqlite3.go @@ -48,3 +48,7 @@ func (s *sqlite3) PrimaryKeyTag(column interface{}, size int) string { func (s *sqlite3) ReturningStr(key string) (str string) { return } + +func (s *sqlite3) Quote(key string) (str string) { + return fmt.Sprintf("\"%s\"", key) +} diff --git a/do.go b/do.go index e3aeff47..8337dc22 100644 --- a/do.go +++ b/do.go @@ -46,6 +46,10 @@ func (s *Do) dialect() dialect.Dialect { return s.db.parent.dialect } +func (s *Do) quote(str string) string { + return s.dialect().Quote(str) +} + func (s *Do) err(err error) error { if err != nil { s.db.err(err) @@ -98,7 +102,7 @@ func (s *Do) prepareCreateSql() { var sqls, columns []string for key, value := range s.model.columnsAndValues("create") { - columns = append(columns, key) + columns = append(columns, s.quote(key)) sqls = append(sqls, s.addToVars(value)) } @@ -243,12 +247,12 @@ func (s *Do) updateAttrs(values interface{}, ignore_protected_attrs ...bool) *Do func (s *Do) prepareUpdateSql(include_self bool) { var sqls []string for key, value := range s.update_attrs { - sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) + sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value))) } if include_self { for key, value := range s.model.columnsAndValues("update") { - sqls = append(sqls, fmt.Sprintf("%v = %v", key, s.addToVars(value))) + sqls = append(sqls, fmt.Sprintf("%v = %v", s.quote(key), s.addToVars(value))) } } @@ -362,7 +366,7 @@ func (s *Do) related(value interface{}, foreign_keys ...string) *Do { if from_from { s.where(foreign_value).query() } else { - query := fmt.Sprintf("%v = %v", toSnake(foreign_key), s.addToVars(foreign_value)) + query := fmt.Sprintf("%v = %v", s.quote(toSnake(foreign_key)), s.addToVars(foreign_value)) s.where(query).query() } return s @@ -464,7 +468,7 @@ func (s *Do) pluck(column string, value interface{}) *Do { } func (s *Do) primaryCondiation(value interface{}) string { - return fmt.Sprintf("(%v = %v)", s.model.primaryKeyDb(), value) + return fmt.Sprintf("(%v = %v)", s.quote(s.model.primaryKeyDb()), value) } func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { @@ -481,19 +485,19 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { case sql.NullInt64: return s.primaryCondiation(s.addToVars(value.Int64)) case []int64, []int, []int32, []string: - str = fmt.Sprintf("(%v in (?))", s.model.primaryKeyDb()) + str = fmt.Sprintf("(%v in (?))", s.quote(s.model.primaryKeyDb())) clause["args"] = []interface{}{value} case map[string]interface{}: var sqls []string for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", key, s.addToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v = %v)", s.quote(key), s.addToVars(value))) } return strings.Join(sqls, " AND ") case interface{}: m := &Model{data: value, do: s} var sqls []string for _, field := range m.columnsHasValue("other") { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", field.dbName, s.addToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("(%v = %v)", s.quote(field.dbName), s.addToVars(field.Value))) } return strings.Join(sqls, " AND ") } @@ -526,19 +530,19 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { case string: if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), id) + return fmt.Sprintf("(%v <> %v)", s.quote(s.model.primaryKeyDb()), id) } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS) ").MatchString(value) { str = fmt.Sprintf(" NOT (%v) ", value) not_equal_sql = fmt.Sprintf("NOT (%v)", value) } else { - str = fmt.Sprintf("(%v NOT IN (?))", value) - not_equal_sql = fmt.Sprintf("(%v <> ?)", value) + str = fmt.Sprintf("(%v NOT IN (?))", s.quote(value)) + not_equal_sql = fmt.Sprintf("(%v <> ?)", s.quote(value)) } case int, int64, int32: - return fmt.Sprintf("(%v <> %v)", s.model.primaryKeyDb(), value) + return fmt.Sprintf("(%v <> %v)", s.quote(s.model.primaryKeyDb()), value) case []int64, []int, []int32, []string: if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v not in (?))", s.model.primaryKeyDb()) + str = fmt.Sprintf("(%v not in (?))", s.quote(s.model.primaryKeyDb())) clause["args"] = []interface{}{value} } else { return "" @@ -546,14 +550,14 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) { case map[string]interface{}: var sqls []string for key, value := range value { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", key, s.addToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", s.quote(key), s.addToVars(value))) } return strings.Join(sqls, " AND ") case interface{}: m := &Model{data: value, do: s} var sqls []string for _, field := range m.columnsHasValue("other") { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", field.dbName, s.addToVars(field.Value))) + sqls = append(sqls, fmt.Sprintf("(%v <> %v)", s.quote(field.dbName), s.addToVars(field.Value))) } return strings.Join(sqls, " AND ") } @@ -689,10 +693,9 @@ func (s *Do) createTable() *Do { var sqls []string for _, field := range s.model.fields("migration") { if len(field.sqlTag()) > 0 { - sqls = append(sqls, field.dbName+" "+field.sqlTag()) + sqls = append(sqls, s.quote(field.dbName)+" "+field.sqlTag()) } } - s.setSql(fmt.Sprintf("CREATE TABLE %v (%v)", s.table(), strings.Join(sqls, ","))) s.exec() return s @@ -705,12 +708,12 @@ func (s *Do) dropTable() *Do { } func (s *Do) modifyColumn(column string, typ string) { - s.setSql(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), column, typ)) + s.setSql(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", s.table(), s.quote(column), typ)) s.exec() } func (s *Do) dropColumn(column string) { - s.setSql(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), column)) + s.setSql(fmt.Sprintf("ALTER TABLE %v DROP COLUMN %v", s.table(), s.quote(column))) s.exec() } @@ -722,7 +725,7 @@ func (s *Do) addIndex(column string, names ...string) { index_name = fmt.Sprintf("index_%v_on_%v", s.table(), column) } - s.setSql(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), column)) + s.setSql(fmt.Sprintf("CREATE INDEX %v ON %v(%v);", index_name, s.table(), s.quote(column))) s.exec() } diff --git a/gorm_test.go b/gorm_test.go index d130a728..5a1edb2e 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -27,6 +27,7 @@ type User struct { BillingAddressId sql.NullInt64 // Embedded struct's foreign key ShippingAddress Address // Embedded struct ShippingAddressId int64 // Embedded struct's foreign key + When time.Time CreditCard CreditCard PasswordHash []byte IgnoreMe int64 `sql:"-"` @@ -133,7 +134,7 @@ func init() { t3, _ = time.Parse(shortForm, "2005-01-01 00:00:00") t4, _ = time.Parse(shortForm, "2010-01-01 00:00:00") t5, _ = time.Parse(shortForm, "2020-01-01 00:00:00") - db.Save(&User{Name: "1", Age: 18, Birthday: t1}) + db.Save(&User{Name: "1", Age: 18, Birthday: t1, When: time.Now()}) db.Save(&User{Name: "2", Age: 20, Birthday: t2}) db.Save(&User{Name: "3", Age: 22, Birthday: t3}) db.Save(&User{Name: "3", Age: 24, Birthday: t4})