Add Errors

This commit is contained in:
Jinzhu 2015-08-13 16:42:13 +08:00
parent e1ce3b7066
commit 309740983e
5 changed files with 60 additions and 26 deletions

View File

@ -1,6 +1,9 @@
package gorm package gorm
import "errors" import (
"errors"
"strings"
)
var ( var (
RecordNotFound = errors.New("record not found") RecordNotFound = errors.New("record not found")
@ -9,3 +12,23 @@ var (
NoValidTransaction = errors.New("no valid transaction") NoValidTransaction = errors.New("no valid transaction")
CantStartTransaction = errors.New("can't start transaction") CantStartTransaction = errors.New("can't start transaction")
) )
type Errors struct {
errors []error
}
func (errs Errors) Errors() []error {
return errs.errors
}
func (errs *Errors) Add(err error) {
errs.errors = append(errs.errors, err)
}
func (errs Errors) Error() string {
var errors = []string{}
for _, e := range errs.errors {
errors = append(errors, e.Error())
}
return strings.Join(errors, "; ")
}

43
main.go
View File

@ -258,9 +258,9 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
if !result.RecordNotFound() { if !result.RecordNotFound() {
return result return result
} }
c.err(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates).db.Error) c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates).db.Error)
} else if len(c.search.assignAttrs) > 0 { } else if len(c.search.assignAttrs) > 0 {
c.err(c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates).db.Error) c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates).db.Error)
} }
return c return c
} }
@ -339,27 +339,27 @@ func (s *DB) Begin() *DB {
if db, ok := c.db.(sqlDb); ok { if db, ok := c.db.(sqlDb); ok {
tx, err := db.Begin() tx, err := db.Begin()
c.db = interface{}(tx).(sqlCommon) c.db = interface{}(tx).(sqlCommon)
c.err(err) c.AddError(err)
} else { } else {
c.err(CantStartTransaction) c.AddError(CantStartTransaction)
} }
return c return c
} }
func (s *DB) Commit() *DB { func (s *DB) Commit() *DB {
if db, ok := s.db.(sqlTx); ok { if db, ok := s.db.(sqlTx); ok {
s.err(db.Commit()) s.AddError(db.Commit())
} else { } else {
s.err(NoValidTransaction) s.AddError(NoValidTransaction)
} }
return s return s
} }
func (s *DB) Rollback() *DB { func (s *DB) Rollback() *DB {
if db, ok := s.db.(sqlTx); ok { if db, ok := s.db.(sqlTx); ok {
s.err(db.Rollback()) s.AddError(db.Rollback())
} else { } else {
s.err(NoValidTransaction) s.AddError(NoValidTransaction)
} }
return s return s
} }
@ -389,7 +389,7 @@ func (s *DB) HasTable(value interface{}) bool {
scope := s.clone().NewScope(value) scope := s.clone().NewScope(value)
tableName := scope.TableName() tableName := scope.TableName()
has := scope.Dialect().HasTable(scope, tableName) has := scope.Dialect().HasTable(scope, tableName)
s.err(scope.db.Error) s.AddError(scope.db.Error)
return has return has
} }
@ -508,3 +508,28 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
} }
} }
} }
func (s *DB) AddError(err error) error {
if err != nil {
if err != RecordNotFound {
if s.logMode == 0 {
go s.print(fileWithLineNum(), err)
} else {
s.log(err)
}
err = Errors{errors: append(s.Errors(), err)}
}
s.Error = err
}
return err
}
func (s *DB) Errors() []error {
if errs, ok := s.Error.(Errors); ok {
return errs.errors
} else {
return []error{s.Error}
}
}

View File

@ -19,20 +19,6 @@ func (s *DB) clone() *DB {
return &db return &db
} }
func (s *DB) err(err error) error {
if err != nil {
if err != RecordNotFound {
if s.logMode == 0 {
go s.print(fileWithLineNum(), err)
} else {
s.log(err)
}
}
s.Error = err
}
return err
}
func (s *DB) print(v ...interface{}) { func (s *DB) print(v ...interface{}) {
s.logger.(logger).Print(v...) s.logger.(logger).Print(v...)
} }

View File

@ -103,7 +103,7 @@ func (scope *Scope) Dialect() Dialect {
// Err write error // Err write error
func (scope *Scope) Err(err error) error { func (scope *Scope) Err(err error) error {
if err != nil { if err != nil {
scope.db.err(err) scope.db.AddError(err)
} }
return err return err
} }

View File

@ -139,7 +139,7 @@ func (s *search) getInterfaceAsSql(value interface{}) (str string) {
case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
str = fmt.Sprintf("%v", value) str = fmt.Sprintf("%v", value)
default: default:
s.db.err(InvalidSql) s.db.AddError(InvalidSql)
} }
if str == "-1" { if str == "-1" {