diff --git a/main.go b/main.go index 48d22c85..24fd8382 100644 --- a/main.go +++ b/main.go @@ -528,12 +528,12 @@ func (s *DB) Debug() *DB { // Transaction start a transaction as a block, // return error will rollback, otherwise to commit. func (s *DB) Transaction(fc func(tx *DB) error) (err error) { + panicked := true tx := s.Begin() defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("%s", r) + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { tx.Rollback() - return } }() @@ -543,10 +543,7 @@ func (s *DB) Transaction(fc func(tx *DB) error) (err error) { err = tx.Commit().Error } - // Makesure rollback when Block error or Commit error - if err != nil { - tx.Rollback() - } + panicked = false return } diff --git a/main_test.go b/main_test.go index 134672b7..98ea4694 100644 --- a/main_test.go +++ b/main_test.go @@ -470,6 +470,15 @@ func TestTransaction(t *testing.T) { } } +func assertPanic(t *testing.T, f func()) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + f() +} + func TestTransactionWithBlock(t *testing.T) { // rollback err := DB.Transaction(func(tx *gorm.DB) error { @@ -511,17 +520,19 @@ func TestTransactionWithBlock(t *testing.T) { } // panic will rollback - DB.Transaction(func(tx *gorm.DB) error { - u3 := User{Name: "transcation-3"} - if err := tx.Save(&u3).Error; err != nil { - t.Errorf("No error should raise") - } + assertPanic(t, func() { + DB.Transaction(func(tx *gorm.DB) error { + u3 := User{Name: "transcation-3"} + if err := tx.Save(&u3).Error; err != nil { + t.Errorf("No error should raise") + } - if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil { - t.Errorf("Should find saved record") - } + if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil { + t.Errorf("Should find saved record") + } - panic("force panic") + panic("force panic") + }) }) if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {