diff --git a/errors.go b/errors.go index ce3a25c0..832fa9b0 100644 --- a/errors.go +++ b/errors.go @@ -18,40 +18,38 @@ var ( ErrUnaddressable = errors.New("using unaddressable value") ) -type errorsInterface interface { - GetErrors() []error -} - // Errors contains all happened errors -type Errors struct { - errors []error -} +type Errors []error -// GetErrors get all happened errors +// GetErrors gets all happened errors func (errs Errors) GetErrors() []error { - return errs.errors + return errs } -// Add add an error -func (errs *Errors) Add(err error) { - if errors, ok := err.(errorsInterface); ok { - for _, err := range errors.GetErrors() { - errs.Add(err) - } - } else { - for _, e := range errs.errors { - if err == e { - return +// Add adds an error +func (errs Errors) Add(newErrors ...error) Errors { + for _, err := range newErrors { + if errors, ok := err.(Errors); ok { + errs = errs.Add(errors...) + } else { + ok = true + for _, e := range errs { + if err == e { + ok = false + } + } + if ok { + errs = append(errs, err) } } - errs.errors = append(errs.errors, err) } + return errs } // Error format happened errors func (errs Errors) Error() string { var errors = []string{} - for _, e := range errs.errors { + for _, e := range errs { errors = append(errors, e.Error()) } return strings.Join(errors, "; ") diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 00000000..9a428dec --- /dev/null +++ b/errors_test.go @@ -0,0 +1,20 @@ +package gorm_test + +import ( + "errors" + "testing" + + "github.com/jinzhu/gorm" +) + +func TestErrorsCanBeUsedOutsideGorm(t *testing.T) { + errs := []error{errors.New("First"), errors.New("Second")} + + gErrs := gorm.Errors(errs) + gErrs = gErrs.Add(errors.New("Third")) + gErrs = gErrs.Add(gErrs) + + if gErrs.Error() != "First; Second; Third" { + t.Fatalf("Gave wrong error, got %s", gErrs.Error()) + } +} diff --git a/main.go b/main.go index e4af5873..192dbd7c 100644 --- a/main.go +++ b/main.go @@ -655,9 +655,9 @@ func (s *DB) AddError(err error) error { s.log(err) } - errors := Errors{errors: s.GetErrors()} + errors := Errors(s.GetErrors()) errors.Add(err) - if len(errors.GetErrors()) > 1 { + if len(errors) > 1 { err = errors } } @@ -669,8 +669,8 @@ func (s *DB) AddError(err error) error { // GetErrors get happened errors from the db func (s *DB) GetErrors() (errors []error) { - if errs, ok := s.Error.(errorsInterface); ok { - return errs.GetErrors() + if errs, ok := s.Error.(Errors); ok { + return errs } else if s.Error != nil { return []error{s.Error} }