Generate unique savepoint names for nested transactions (#7174)

* Generate unique savepoint names

* Add a test for deeply nested wrapped transactions
This commit is contained in:
Leo Sjöberg 2024-09-14 13:58:29 +01:00 committed by GitHub
parent 0daaf1747c
commit 7f75b12bb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 74 additions and 4 deletions

View File

@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"fmt"
"hash/maphash"
"reflect"
"strings"
@ -623,14 +624,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
if !db.DisableNestedTransaction {
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
spID := new(maphash.Hash).Sum64()
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
if err != nil {
return
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc))
db.RollbackTo(fmt.Sprintf("sp%d", spID))
}
}()
}

View File

@ -29,8 +29,8 @@ require (
github.com/microsoft/go-mssqldb v1.7.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
golang.org/x/crypto v0.24.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/text v0.17.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -297,6 +297,74 @@ func TestNestedTransactionWithBlock(t *testing.T) {
}
}
func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) {
transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error {
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
return callback(ctx, tx)
})
}
var (
user = *GetUser("transaction-nested", Config{})
user1 = *GetUser("transaction-nested-1", Config{})
user2 = *GetUser("transaction-nested-2", Config{})
)
if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error {
tx.Create(&user)
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error {
tx1.Create(&user1)
if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error {
tx2.Create(&user2)
if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return errors.New("inner rollback")
}); err == nil {
t.Fatalf("nested transaction has no error")
}
return errors.New("rollback")
}); err == nil {
t.Fatalf("nested transaction should returns error")
}
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
t.Fatalf("Should not find rollbacked record")
}
if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return nil
}); err != nil {
t.Fatalf("no error should return, but got %v", err)
}
if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
t.Fatalf("Should not find rollbacked parent record")
}
if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should not find rollbacked nested record")
}
}
func TestDisabledNestedTransaction(t *testing.T) {
var (
user = *GetUser("transaction-nested", Config{})