forked from mirror/gorm
Add tests to make sure no unexpected things happen with invalid SQL
This commit is contained in:
parent
13a8d98d8f
commit
592df904e2
2
chain.go
2
chain.go
|
@ -170,7 +170,7 @@ func (s *Chain) Or(querystring interface{}, args ...interface{}) *Chain {
|
|||
}
|
||||
|
||||
func (s *Chain) CreateTable(value interface{}) *Chain {
|
||||
s.do(value).createTable().Exec()
|
||||
s.do(value).createTable().exec()
|
||||
return s
|
||||
}
|
||||
|
||||
|
|
9
do.go
9
do.go
|
@ -58,7 +58,7 @@ func (s *Do) addToVars(value interface{}) string {
|
|||
return fmt.Sprintf("$%d", len(s.SqlVars))
|
||||
}
|
||||
|
||||
func (s *Do) Exec(sql ...string) {
|
||||
func (s *Do) exec(sql ...string) {
|
||||
if s.hasError() {
|
||||
return
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func (s *Do) update() *Do {
|
|||
s.err(s.model.callMethod("BeforeUpdate"))
|
||||
s.err(s.model.callMethod("BeforeSave"))
|
||||
if len(s.Errors) == 0 {
|
||||
s.prepareUpdateSql().Exec()
|
||||
s.prepareUpdateSql().exec()
|
||||
}
|
||||
s.err(s.model.callMethod("AfterUpdate"))
|
||||
s.err(s.model.callMethod("AfterSave"))
|
||||
|
@ -161,7 +161,7 @@ func (s *Do) prepareDeleteSql() *Do {
|
|||
func (s *Do) delete() *Do {
|
||||
s.err(s.model.callMethod("BeforeDelete"))
|
||||
if len(s.Errors) == 0 {
|
||||
s.prepareDeleteSql().Exec()
|
||||
s.prepareDeleteSql().exec()
|
||||
}
|
||||
s.err(s.model.callMethod("AfterDelete"))
|
||||
return s
|
||||
|
@ -254,6 +254,9 @@ func (s *Do) pluck(value interface{}) *Do {
|
|||
s.prepareQuerySql()
|
||||
rows, err := s.db.Query(s.Sql, s.SqlVars...)
|
||||
s.err(err)
|
||||
if err != nil {
|
||||
return s
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
|
|
26
gorm_test.go
26
gorm_test.go
|
@ -532,9 +532,14 @@ func TestRunCallbacksAndGetErrors(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNoPanicInAnyCases(t *testing.T) {
|
||||
func TestNoUnExpectedHappenWithInvalidSql(t *testing.T) {
|
||||
var columns []string
|
||||
db.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns)
|
||||
if db.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||
t.Errorf("Should got error with invalid SQL")
|
||||
}
|
||||
if db.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||
t.Errorf("Should got error with invalid SQL")
|
||||
}
|
||||
|
||||
type Article struct {
|
||||
Name string
|
||||
|
@ -542,6 +547,21 @@ func TestNoPanicInAnyCases(t *testing.T) {
|
|||
db.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&Article{})
|
||||
|
||||
db.Where("name = ?", "3").Find(&[]User{})
|
||||
var count1, count2 int64
|
||||
db.Model(&User{}).Count(&count1)
|
||||
if count1 <= 0 {
|
||||
t.Errorf("Should find some users")
|
||||
}
|
||||
|
||||
q := db.Where("name = ?", "jinzhu; delete * from users").First(&User{})
|
||||
if q.Error == nil {
|
||||
t.Errorf("Can't find user")
|
||||
}
|
||||
|
||||
db.Model(&User{}).Count(&count2)
|
||||
if count1 != count2 {
|
||||
t.Errorf("Users should not be deleted by invalid SQL")
|
||||
}
|
||||
|
||||
db.Where("unexisting = ?", "3").Find(&[]User{})
|
||||
db.Where("unexisting = ?", "3").First(&User{})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue