From 592df904e268e301b910b14bcbbd3cdacac61b82 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 28 Oct 2013 12:43:02 +0800 Subject: [PATCH] Add tests to make sure no unexpected things happen with invalid SQL --- chain.go | 2 +- do.go | 9 ++++++--- gorm_test.go | 26 +++++++++++++++++++++++--- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/chain.go b/chain.go index 17c6cc92..ca863746 100644 --- a/chain.go +++ b/chain.go @@ -170,7 +170,7 @@ func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain { } func (s *Chain) CreateTable(value interface{}) *Chain { - s.do(value).createTable().Exec() + s.do(value).createTable().exec() return s } diff --git a/do.go b/do.go index b47a793a..f531699c 100644 --- a/do.go +++ b/do.go @@ -58,7 +58,7 @@ func (s *Do) addToVars(value interface{}) string { return fmt.Sprintf("$%d", len(s.SqlVars)) } -func (s *Do) Exec(sql ...string) { +func (s *Do) exec(sql ...string) { if s.hasError() { return } @@ -146,7 +146,7 @@ func (s *Do) update() *Do { s.err(s.model.callMethod("BeforeUpdate")) s.err(s.model.callMethod("BeforeSave")) if len(s.Errors) == 0 { - s.prepareUpdateSql().Exec() + s.prepareUpdateSql().exec() } s.err(s.model.callMethod("AfterUpdate")) s.err(s.model.callMethod("AfterSave")) @@ -161,7 +161,7 @@ func (s *Do) prepareDeleteSql() *Do { func (s *Do) delete() *Do { s.err(s.model.callMethod("BeforeDelete")) if len(s.Errors) == 0 { - s.prepareDeleteSql().Exec() + s.prepareDeleteSql().exec() } s.err(s.model.callMethod("AfterDelete")) return s @@ -254,6 +254,9 @@ func (s *Do) pluck(value interface{}) *Do { s.prepareQuerySql() rows, err := s.db.Query(s.Sql, s.SqlVars...) s.err(err) + if err != nil { + return s + } defer rows.Close() for rows.Next() { diff --git a/gorm_test.go b/gorm_test.go index 77356dc8..ddc1b81a 100644 --- a/gorm_test.go +++ b/gorm_test.go @@ -532,9 +532,14 @@ func TestRunCallbacksAndGetErrors(t *testing.T) { } } -func TestNoPanicInAnyCases(t *testing.T) { +func TestNoUnExpectedHappenWithInvalidSql(t *testing.T) { var columns []string - db.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns) + if db.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } + if db.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { + t.Errorf("Should got error with invalid SQL") + } type Article struct { Name string @@ -542,6 +547,21 @@ func TestNoPanicInAnyCases(t *testing.T) { db.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&Article{}) db.Where("name = ?", "3").Find(&[]User{}) + var count1, count2 int64 + db.Model(&User{}).Count(&count1) + if count1 <= 0 { + t.Errorf("Should find some users") + } + + q := db.Where("name = ?", "jinzhu; delete * from users").First(&User{}) + if q.Error == nil { + t.Errorf("Can't find user") + } + + db.Model(&User{}).Count(&count2) + if count1 != count2 { + t.Errorf("Users should not be deleted by invalid SQL") + } + db.Where("unexisting = ?", "3").Find(&[]User{}) - db.Where("unexisting = ?", "3").First(&User{}) }