diff --git a/main.go b/main.go index e39a868a..48d22c85 100644 --- a/main.go +++ b/main.go @@ -525,6 +525,31 @@ func (s *DB) Debug() *DB { return s.clone().LogMode(true) } +// Transaction start a transaction as a block, +// return error will rollback, otherwise to commit. +func (s *DB) Transaction(fc func(tx *DB) error) (err error) { + tx := s.Begin() + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("%s", r) + tx.Rollback() + return + } + }() + + err = fc(tx) + + if err == nil { + err = tx.Commit().Error + } + + // Makesure rollback when Block error or Commit error + if err != nil { + tx.Rollback() + } + return +} + // Begin begins a transaction func (s *DB) Begin() *DB { return s.BeginTx(context.Background(), &sql.TxOptions{}) diff --git a/main_test.go b/main_test.go index 68bf7419..134672b7 100644 --- a/main_test.go +++ b/main_test.go @@ -8,6 +8,7 @@ import ( "context" "database/sql" "database/sql/driver" + "errors" "fmt" "os" "path/filepath" @@ -469,6 +470,65 @@ func TestTransaction(t *testing.T) { } } +func TestTransactionWithBlock(t *testing.T) { + // rollback + err := DB.Transaction(func(tx *gorm.DB) error { + u := User{Name: "transcation"} + if err := tx.Save(&u).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { + t.Errorf("Should find saved record") + } + + return errors.New("the error message") + }) + + if err.Error() != "the error message" { + t.Errorf("Transaction return error will equal the block returns error") + } + + if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after rollback") + } + + // commit + DB.Transaction(func(tx *gorm.DB) error { + u2 := User{Name: "transcation-2"} + if err := tx.Save(&u2).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should find saved record") + } + return nil + }) + + if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { + t.Errorf("Should be able to find committed record") + } + + // panic will rollback + DB.Transaction(func(tx *gorm.DB) error { + u3 := User{Name: "transcation-3"} + if err := tx.Save(&u3).Error; err != nil { + t.Errorf("No error should raise") + } + + if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil { + t.Errorf("Should find saved record") + } + + panic("force panic") + }) + + if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { + t.Errorf("Should not find record after panic rollback") + } +} + func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { tx := DB.Begin() u := User{Name: "transcation"}