Add tests to make sure no unexpected things happen with invalid SQL

This commit is contained in:
Jinzhu 2013-10-28 12:43:02 +08:00
parent 13a8d98d8f
commit 592df904e2
3 changed files with 30 additions and 7 deletions

View File

@ -170,7 +170,7 @@ func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain {
} }
func (s *Chain) CreateTable(value interface{}) *Chain { func (s *Chain) CreateTable(value interface{}) *Chain {
s.do(value).createTable().Exec() s.do(value).createTable().exec()
return s return s
} }

9
do.go
View File

@ -58,7 +58,7 @@ func (s *Do) addToVars(value interface{}) string {
return fmt.Sprintf("$%d", len(s.SqlVars)) return fmt.Sprintf("$%d", len(s.SqlVars))
} }
func (s *Do) Exec(sql ...string) { func (s *Do) exec(sql ...string) {
if s.hasError() { if s.hasError() {
return return
} }
@ -146,7 +146,7 @@ func (s *Do) update() *Do {
s.err(s.model.callMethod("BeforeUpdate")) s.err(s.model.callMethod("BeforeUpdate"))
s.err(s.model.callMethod("BeforeSave")) s.err(s.model.callMethod("BeforeSave"))
if len(s.Errors) == 0 { if len(s.Errors) == 0 {
s.prepareUpdateSql().Exec() s.prepareUpdateSql().exec()
} }
s.err(s.model.callMethod("AfterUpdate")) s.err(s.model.callMethod("AfterUpdate"))
s.err(s.model.callMethod("AfterSave")) s.err(s.model.callMethod("AfterSave"))
@ -161,7 +161,7 @@ func (s *Do) prepareDeleteSql() *Do {
func (s *Do) delete() *Do { func (s *Do) delete() *Do {
s.err(s.model.callMethod("BeforeDelete")) s.err(s.model.callMethod("BeforeDelete"))
if len(s.Errors) == 0 { if len(s.Errors) == 0 {
s.prepareDeleteSql().Exec() s.prepareDeleteSql().exec()
} }
s.err(s.model.callMethod("AfterDelete")) s.err(s.model.callMethod("AfterDelete"))
return s return s
@ -254,6 +254,9 @@ func (s *Do) pluck(value interface{}) *Do {
s.prepareQuerySql() s.prepareQuerySql()
rows, err := s.db.Query(s.Sql, s.SqlVars...) rows, err := s.db.Query(s.Sql, s.SqlVars...)
s.err(err) s.err(err)
if err != nil {
return s
}
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {

View File

@ -532,9 +532,14 @@ func TestRunCallbacksAndGetErrors(t *testing.T) {
} }
} }
func TestNoPanicInAnyCases(t *testing.T) { func TestNoUnExpectedHappenWithInvalidSql(t *testing.T) {
var columns []string var columns []string
db.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns) if db.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
t.Errorf("Should got error with invalid SQL")
}
if db.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
t.Errorf("Should got error with invalid SQL")
}
type Article struct { type Article struct {
Name string Name string
@ -542,6 +547,21 @@ func TestNoPanicInAnyCases(t *testing.T) {
db.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&Article{}) db.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&Article{})
db.Where("name = ?", "3").Find(&[]User{}) db.Where("name = ?", "3").Find(&[]User{})
var count1, count2 int64
db.Model(&User{}).Count(&count1)
if count1 <= 0 {
t.Errorf("Should find some users")
}
q := db.Where("name = ?", "jinzhu; delete * from users").First(&User{})
if q.Error == nil {
t.Errorf("Can't find user")
}
db.Model(&User{}).Count(&count2)
if count1 != count2 {
t.Errorf("Users should not be deleted by invalid SQL")
}
db.Where("unexisting = ?", "3").Find(&[]User{}) db.Where("unexisting = ?", "3").Find(&[]User{})
db.Where("unexisting = ?", "3").First(&User{})
} }