From 80a7e6296da71e14db241523e2bf0abd1c45116a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Oct 2013 10:17:11 +0800 Subject: [PATCH] Add method to do basic check for sql valid or not --- orm.go | 4 ++++ orm_test.go | 6 ++++++ sql.go | 9 +++++++++ 3 files changed, 19 insertions(+) diff --git a/orm.go b/orm.go index 39e6959c..440a7027 100644 --- a/orm.go +++ b/orm.go @@ -63,6 +63,7 @@ func (s *Orm) Offset(value interface{}) *Orm { } func (s *Orm) Order(value string, reorder ...bool) *Orm { + defer s.validSql(s.orderStr) if len(reorder) > 0 && reorder[0] { s.orderStr = value } else { @@ -76,12 +77,15 @@ func (s *Orm) Count() int64 { } func (s *Orm) Select(value interface{}) *Orm { + defer func() { s.validSql(s.selectStr) }() + switch value := value.(type) { case string: s.selectStr = value default: s.Error = errors.New("Can' understand the value of Select, Should be string") } + return s } diff --git a/orm_test.go b/orm_test.go index 390dd94f..808e66bf 100644 --- a/orm_test.go +++ b/orm_test.go @@ -234,6 +234,12 @@ func TestSelect(t *testing.T) { if user.Name != "3" { t.Errorf("Should got Name = 3 when searching it, %+v", user.Id) } + + query := db.Where("name = ?", "3").Select("nam;e") + if query.Error == nil { + t.Errorf("Should got error with invalid select string") + } + debug(query.Error) } func TestPluck(t *testing.T) { diff --git a/sql.go b/sql.go index 7cffbffe..10f66ace 100644 --- a/sql.go +++ b/sql.go @@ -4,9 +4,18 @@ import ( "errors" "fmt" "reflect" + "regexp" "strings" ) +func (s *Orm) validSql(str string) (result bool) { + result = regexp.MustCompile("^\\s*[\\w][\\w\\s,.]*[\\w]\\s*$").MatchString(str) + if !result { + s.Error = errors.New(fmt.Sprintf("SQL is not valid, %s", str)) + } + return +} + func (s *Orm) explain(value interface{}, operation string) *Orm { s.Model(value)