From 7f75b12bb2c3c44b0b894d5cbf3128e000834add Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leo=20Sj=C3=B6berg?= Date: Sat, 14 Sep 2024 13:58:29 +0100 Subject: [PATCH] Generate unique savepoint names for nested transactions (#7174) * Generate unique savepoint names * Add a test for deeply nested wrapped transactions --- finisher_api.go | 6 ++-- tests/go.mod | 4 +-- tests/transaction_test.go | 68 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 4 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index f97571ed..6802945c 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "hash/maphash" "reflect" "strings" @@ -623,14 +624,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction if !db.DisableNestedTransaction { - err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + spID := new(maphash.Hash).Sum64() + err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error if err != nil { return } defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { - db.RollbackTo(fmt.Sprintf("sp%p", fc)) + db.RollbackTo(fmt.Sprintf("sp%d", spID)) } }() } diff --git a/tests/go.mod b/tests/go.mod index 350d1794..8eeab51d 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -29,8 +29,8 @@ require ( github.com/microsoft/go-mssqldb v1.7.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/crypto v0.24.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/crypto v0.26.0 // indirect + golang.org/x/text v0.17.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/tests/transaction_test.go b/tests/transaction_test.go index d2cbc9a9..9f0f067c 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -297,6 +297,74 @@ func TestNestedTransactionWithBlock(t *testing.T) { } } +func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) { + transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error { + return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return callback(ctx, tx) + }) + } + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("inner rollback") + }); err == nil { + t.Fatalf("nested transaction has no error") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked parent record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should not find rollbacked nested record") + } +} + func TestDisabledNestedTransaction(t *testing.T) { var ( user = *GetUser("transaction-nested", Config{})