mirror of https://github.com/go-gorm/gorm.git
Add some tests for Error handling
This commit is contained in:
parent
0906333291
commit
13a8d98d8f
36
do.go
36
do.go
|
@ -41,10 +41,16 @@ func (s *Do) err(err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Do) hasError() bool {
|
||||
return len(s.Errors) > 0
|
||||
}
|
||||
|
||||
func (s *Do) setModel(value interface{}) {
|
||||
s.value = value
|
||||
s.model = &Model{Data: value, driver: s.driver}
|
||||
s.TableName = s.model.TableName()
|
||||
var err error
|
||||
s.TableName, err = s.model.TableName()
|
||||
s.err(err)
|
||||
}
|
||||
|
||||
func (s *Do) addToVars(value interface{}) string {
|
||||
|
@ -53,6 +59,10 @@ func (s *Do) addToVars(value interface{}) string {
|
|||
}
|
||||
|
||||
func (s *Do) Exec(sql ...string) {
|
||||
if s.hasError() {
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if len(sql) == 0 {
|
||||
s.SqlResult, err = s.db.Exec(s.Sql, s.SqlVars...)
|
||||
|
@ -179,9 +189,16 @@ func (s *Do) query(where ...interface{}) {
|
|||
}
|
||||
|
||||
s.prepareQuerySql()
|
||||
|
||||
rows, err := s.db.Query(s.Sql, s.SqlVars...)
|
||||
defer rows.Close()
|
||||
s.err(err)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Err() != nil {
|
||||
s.err(rows.Err())
|
||||
}
|
||||
|
@ -228,6 +245,10 @@ func (s *Do) count(value interface{}) {
|
|||
}
|
||||
|
||||
func (s *Do) pluck(value interface{}) *Do {
|
||||
if s.hasError() {
|
||||
return s
|
||||
}
|
||||
|
||||
dest_out := reflect.Indirect(reflect.ValueOf(value))
|
||||
dest_type := dest_out.Type().Elem()
|
||||
s.prepareQuerySql()
|
||||
|
@ -372,6 +393,15 @@ func (s *Do) combinedSql() string {
|
|||
}
|
||||
|
||||
func (s *Do) createTable() *Do {
|
||||
s.Sql = s.model.CreateTable()
|
||||
var sqls []string
|
||||
for _, field := range s.model.Fields("null") {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||
}
|
||||
|
||||
s.Sql = fmt.Sprintf(
|
||||
"CREATE TABLE \"%v\" (%v)",
|
||||
s.TableName,
|
||||
strings.Join(sqls, ","),
|
||||
)
|
||||
return s
|
||||
}
|
||||
|
|
|
@ -531,3 +531,17 @@ func TestRunCallbacksAndGetErrors(t *testing.T) {
|
|||
t.Errorf("Should not delete record due to errors happened in callback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoPanicInAnyCases(t *testing.T) {
|
||||
var columns []string
|
||||
db.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns)
|
||||
|
||||
type Article struct {
|
||||
Name string
|
||||
}
|
||||
db.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&Article{})
|
||||
|
||||
db.Where("name = ?", "3").Find(&[]User{})
|
||||
db.Where("unexisting = ?", "3").Find(&[]User{})
|
||||
db.Where("unexisting = ?", "3").First(&User{})
|
||||
}
|
27
model.go
27
model.go
|
@ -1,10 +1,11 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -101,7 +102,12 @@ func (m *Model) ColumnsAndValues(operation string) (columns []string, values []i
|
|||
return
|
||||
}
|
||||
|
||||
func (m *Model) TableName() string {
|
||||
func (m *Model) TableName() (str string, err error) {
|
||||
if m.Data == nil {
|
||||
err = errors.New("Model haven't been set")
|
||||
return
|
||||
}
|
||||
|
||||
t := reflect.TypeOf(m.Data)
|
||||
for {
|
||||
c := false
|
||||
|
@ -115,7 +121,8 @@ func (m *Model) TableName() string {
|
|||
}
|
||||
}
|
||||
reg, _ := regexp.Compile("s*$")
|
||||
return reg.ReplaceAllString(toSnake(t.Name()), "s")
|
||||
str = reg.ReplaceAllString(toSnake(t.Name()), "s")
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Model) callMethod(method string) error {
|
||||
|
@ -135,20 +142,6 @@ func (model *Model) MissingColumns() (results []string) {
|
|||
return
|
||||
}
|
||||
|
||||
func (model *Model) CreateTable() (sql string) {
|
||||
var sqls []string
|
||||
for _, field := range model.Fields("null") {
|
||||
sqls = append(sqls, field.DbName+" "+field.SqlType)
|
||||
}
|
||||
|
||||
sql = fmt.Sprintf(
|
||||
"CREATE TABLE \"%v\" (%v)",
|
||||
model.TableName(),
|
||||
strings.Join(sqls, ","),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (model *Model) ReturningStr() (str string) {
|
||||
if model.driver == "postgres" {
|
||||
str = fmt.Sprintf("RETURNING \"%v\"", model.PrimaryKeyDb())
|
||||
|
|
Loading…
Reference in New Issue