diff --git a/do.go b/do.go index 8b0141a7..b47a793a 100644 --- a/do.go +++ b/do.go @@ -41,10 +41,16 @@ func (s *Do) err(err error) { } } +func (s *Do) hasError() bool { + return len(s.Errors) > 0 +} + func (s *Do) setModel(value interface{}) { s.value = value s.model = &Model{Data: value, driver: s.driver} - s.TableName = s.model.TableName() + var err error + s.TableName, err = s.model.TableName() + s.err(err) } func (s *Do) addToVars(value interface{}) string { @@ -53,6 +59,10 @@ func (s *Do) addToVars(value interface{}) string { } func (s *Do) Exec(sql ...string) { + if s.hasError() { + return + } + var err error if len(sql) == 0 { s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...) @@ -179,9 +189,16 @@ func (s *Do) query(where ...interface{}) { } s.prepareQuerySql() + rows, err := s.db.Query(s.Sql, s.SqlVars...) - defer rows.Close() s.err(err) + + if err != nil { + return + } + + defer rows.Close() + if rows.Err() != nil { s.err(rows.Err()) } @@ -228,6 +245,10 @@ func (s *Do) count(value interface{}) { } func (s *Do) pluck(value interface{}) *Do { + if s.hasError() { + return s + } + dest_out := reflect.Indirect(reflect.ValueOf(value)) dest_type := dest_out.Type().Elem() s.prepareQuerySql() @@ -372,6 +393,15 @@ func (s *Do) combinedSql() string { } func (s *Do) createTable() *Do { - s.Sql = s.model.CreateTable() + var sqls []string + for _, field := range s.model.Fields("null") { + sqls = append(sqls, field.DbName+" "+field.SqlType) + } + + s.Sql = fmt.Sprintf( + "CREATE TABLE \"%v\" (%v)", + s.TableName, + strings.Join(sqls, ","), + ) return s } diff --git a/orm_test.go b/gorm_test.go similarity index 97% rename from orm_test.go rename to gorm_test.go index 9399edba..77356dc8 100644 --- a/orm_test.go +++ b/gorm_test.go @@ -531,3 +531,17 @@ func TestRunCallbacksAndGetErrors(t *testing.T) { t.Errorf("Should not delete record due to errors happened in callback") } } + +func TestNoPanicInAnyCases(t *testing.T) { + var columns []string + db.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns) + + type Article struct { + Name string + } + db.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&Article{}) + + db.Where("name = ?", "3").Find(&[]User{}) + db.Where("unexisting = ?", "3").Find(&[]User{}) + db.Where("unexisting = ?", "3").First(&User{}) +} diff --git a/model.go b/model.go index 9be449cf..7cf25d25 100644 --- a/model.go +++ b/model.go @@ -1,10 +1,11 @@ package gorm import ( + "errors" "fmt" "reflect" "regexp" - "strings" + "time" ) @@ -101,7 +102,12 @@ func (m *Model) ColumnsAndValues(operation string) (columns []string, values []i return } -func (m *Model) TableName() string { +func (m *Model) TableName() (str string, err error) { + if m.Data == nil { + err = errors.New("Model haven't been set") + return + } + t := reflect.TypeOf(m.Data) for { c := false @@ -115,7 +121,8 @@ func (m *Model) TableName() string { } } reg, _ := regexp.Compile("s*$") - return reg.ReplaceAllString(toSnake(t.Name()), "s") + str = reg.ReplaceAllString(toSnake(t.Name()), "s") + return } func (m *Model) callMethod(method string) error { @@ -135,20 +142,6 @@ func (model *Model) MissingColumns() (results []string) { return } -func (model *Model) CreateTable() (sql string) { - var sqls []string - for _, field := range model.Fields("null") { - sqls = append(sqls, field.DbName+" "+field.SqlType) - } - - sql = fmt.Sprintf( - "CREATE TABLE \"%v\" (%v)", - model.TableName(), - strings.Join(sqls, ","), - ) - return -} - func (model *Model) ReturningStr() (str string) { if model.driver == "postgres" { str = fmt.Sprintf("RETURNING \"%v\"", model.PrimaryKeyDb())