Add some tests for Error handling

This commit is contained in:
Jinzhu 2013-10-28 12:12:12 +08:00
parent 0906333291
commit 13a8d98d8f
3 changed files with 57 additions and 20 deletions

36
do.go
View File

@ -41,10 +41,16 @@ func (s *Do) err(err error) {
}
}
func (s *Do) hasError() bool {
return len(s.Errors) > 0
}
func (s *Do) setModel(value interface{}) {
s.value = value
s.model = &Model{Data: value, driver: s.driver}
s.TableName = s.model.TableName()
var err error
s.TableName, err = s.model.TableName()
s.err(err)
}
func (s *Do) addToVars(value interface{}) string {
@ -53,6 +59,10 @@ func (s *Do) addToVars(value interface{}) string {
}
func (s *Do) Exec(sql ...string) {
if s.hasError() {
return
}
var err error
if len(sql) == 0 {
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
@ -179,9 +189,16 @@ func (s *Do) query(where ...interface{}) {
}
s.prepareQuerySql()
rows, err := s.db.Query(s.Sql, s.SqlVars...)
defer rows.Close()
s.err(err)
if err != nil {
return
}
defer rows.Close()
if rows.Err() != nil {
s.err(rows.Err())
}
@ -228,6 +245,10 @@ func (s *Do) count(value interface{}) {
}
func (s *Do) pluck(value interface{}) *Do {
if s.hasError() {
return s
}
dest_out := reflect.Indirect(reflect.ValueOf(value))
dest_type := dest_out.Type().Elem()
s.prepareQuerySql()
@ -372,6 +393,15 @@ func (s *Do) combinedSql() string {
}
func (s *Do) createTable() *Do {
s.Sql = s.model.CreateTable()
var sqls []string
for _, field := range s.model.Fields("null") {
sqls = append(sqls, field.DbName+" "+field.SqlType)
}
s.Sql = fmt.Sprintf(
"CREATE TABLE \"%v\" (%v)",
s.TableName,
strings.Join(sqls, ","),
)
return s
}

View File

@ -531,3 +531,17 @@ func TestRunCallbacksAndGetErrors(t *testing.T) {
t.Errorf("Should not delete record due to errors happened in callback")
}
}
func TestNoPanicInAnyCases(t *testing.T) {
var columns []string
db.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns)
type Article struct {
Name string
}
db.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&Article{})
db.Where("name = ?", "3").Find(&[]User{})
db.Where("unexisting = ?", "3").Find(&[]User{})
db.Where("unexisting = ?", "3").First(&User{})
}

View File

@ -1,10 +1,11 @@
package gorm
import (
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"time"
)
@ -101,7 +102,12 @@ func (m *Model) ColumnsAndValues(operation string) (columns []string, values []i
return
}
func (m *Model) TableName() string {
func (m *Model) TableName() (str string, err error) {
if m.Data == nil {
err = errors.New("Model haven't been set")
return
}
t := reflect.TypeOf(m.Data)
for {
c := false
@ -115,7 +121,8 @@ func (m *Model) TableName() string {
}
}
reg, _ := regexp.Compile("s*$")
return reg.ReplaceAllString(toSnake(t.Name()), "s")
str = reg.ReplaceAllString(toSnake(t.Name()), "s")
return
}
func (m *Model) callMethod(method string) error {
@ -135,20 +142,6 @@ func (model *Model) MissingColumns() (results []string) {
return
}
func (model *Model) CreateTable() (sql string) {
var sqls []string
for _, field := range model.Fields("null") {
sqls = append(sqls, field.DbName+" "+field.SqlType)
}
sql = fmt.Sprintf(
"CREATE TABLE \"%v\" (%v)",
model.TableName(),
strings.Join(sqls, ","),
)
return
}
func (model *Model) ReturningStr() (str string) {
if model.driver == "postgres" {
str = fmt.Sprintf("RETURNING \"%v\"", model.PrimaryKeyDb())