diff --git a/README.md b/README.md index e53b3625..89077c43 100644 --- a/README.md +++ b/README.md @@ -712,7 +712,11 @@ for rows.Next() { ## Run Raw SQl ```go +// Raw sql db.Exec("drop table users;") + +// Raw sql with arguments +db.Exec("update orders set shipped_at=? where id in (?)", time.Now, []int64{11,22,33}) ``` ## Error Handling diff --git a/do.go b/do.go index 61e08406..c414aa32 100644 --- a/do.go +++ b/do.go @@ -67,12 +67,14 @@ func (s *Do) trace(t time.Time) { } } -func (s *Do) exec(sqls ...string) *Do { +func (s *Do) raw(query string, values ...interface{}) *Do { + s.sql = s.buildWhereCondition(map[string]interface{}{"query": query, "args": values}) + return s +} + +func (s *Do) exec() *Do { defer s.trace(time.Now()) if !s.db.hasError() { - if len(sqls) > 0 { - s.sql = sqls[0] - } _, err := s.db.db.Exec(s.sql, s.sqlVars...) s.err(err) } @@ -447,7 +449,7 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) { id, _ := strconv.Atoi(value) return s.primaryCondiation(s.addToVars(id)) } else { - str = "(" + value + ")" + str = value } case int, int64, int32: return s.primaryCondiation(s.addToVars(value)) diff --git a/gorm_test.go b/gorm_test.go index 6dc6231a..b32e6ed6 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -1406,7 +1406,7 @@ func TestGroup(t *testing.T) { } func TestHaving(t *testing.T) { - rows, err := db.Debug().Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() + rows, err := db.Select("name, count(*) as total").Table("users").Group("name").Having("name IN (?)", []string{"2", "3"}).Rows() if err == nil { defer rows.Close() @@ -1427,6 +1427,13 @@ func TestHaving(t *testing.T) { } } +func TestExecRawSql(t *testing.T) { + db.Exec("update users set name=? where name in (?)", "jinzhu", []string{"1", "2", "3"}) + if db.Where("name in (?)", []string{"1", "2", "3"}).First(&User{}).Error != RecordNotFound { + t.Error("Raw sql should be able to parse argument") + } +} + func BenchmarkGorm(b *testing.B) { b.N = 2000 for x := 0; x < b.N; x++ { diff --git a/main.go b/main.go index f1f0b985..84ac2d76 100644 --- a/main.go +++ b/main.go @@ -161,8 +161,8 @@ func (s *DB) Delete(value interface{}) *DB { return s.clone().do(value).begin().delete().commit_or_rollback().db } -func (s *DB) Exec(sql string) *DB { - return s.clone().do(nil).exec(sql).db +func (s *DB) Exec(sql string, values ...interface{}) *DB { + return s.clone().do(nil).raw(sql, values...).exec().db } func (s *DB) Model(value interface{}) *DB { diff --git a/search.go b/search.go index f3ff69ad..b627a1e8 100644 --- a/search.go +++ b/search.go @@ -2,7 +2,6 @@ package gorm import ( "regexp" - "strconv" )