diff --git a/finisher_api.go b/finisher_api.go index 98a877f2..03bcd20f 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -498,13 +498,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction - err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error - defer func() { - // Make sure to rollback when panic, Block error or Commit error - if panicked || err != nil { - db.RollbackTo(fmt.Sprintf("sp%p", fc)) - } - }() + if !db.DisableNestedTransaction { + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + db.RollbackTo(fmt.Sprintf("sp%p", fc)) + } + }() + } if err == nil { err = fc(db.Session(&Session{})) diff --git a/gorm.go b/gorm.go index ae1cf2c9..ae94daf4 100644 --- a/gorm.go +++ b/gorm.go @@ -34,6 +34,8 @@ type Config struct { DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating DisableForeignKeyConstraintWhenMigrating bool + // DisableNestedTransaction disable nested transaction + DisableNestedTransaction bool // AllowGlobalUpdate allow global update AllowGlobalUpdate bool // QueryFields executes the SQL query with all fields of the table @@ -65,18 +67,19 @@ type DB struct { // Session session config when create session with Session() method type Session struct { - DryRun bool - PrepareStmt bool - NewDB bool - SkipHooks bool - SkipDefaultTransaction bool - AllowGlobalUpdate bool - FullSaveAssociations bool - QueryFields bool - Context context.Context - Logger logger.Interface - NowFunc func() time.Time - CreateBatchSize int + DryRun bool + PrepareStmt bool + NewDB bool + SkipHooks bool + SkipDefaultTransaction bool + DisableNestedTransaction bool + AllowGlobalUpdate bool + FullSaveAssociations bool + QueryFields bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time + CreateBatchSize int } // Open initialize db session based on dialector @@ -206,6 +209,10 @@ func (db *DB) Session(config *Session) *DB { tx.Statement.SkipHooks = true } + if config.DisableNestedTransaction { + txConfig.DisableNestedTransaction = true + } + if !config.NewDB { tx.clone = 2 } diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 334600b8..c17fea3b 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -283,6 +283,69 @@ func TestNestedTransactionWithBlock(t *testing.T) { } } +func TestDisabledNestedTransaction(t *testing.T) { + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := DB.Session(&gorm.Session{DisableNestedTransaction: true}).Transaction(func(tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Transaction(func(tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should not rollback record if disabled nested transaction support") + } + + if err := tx.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return nil + }); err != nil { + t.Fatalf("nested transaction returns error: %v", err) + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should not rollback record if disabled nested transaction support") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + func TestTransactionOnClosedConn(t *testing.T) { DB, err := OpenTestConnection() if err != nil {