mirror of https://github.com/go-gorm/gorm.git
Generate unique savepoint names for nested transactions (#7174)
* Generate unique savepoint names * Add a test for deeply nested wrapped transactions
This commit is contained in:
parent
0daaf1747c
commit
7f75b12bb2
|
@ -4,6 +4,7 @@ import (
|
|||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
|
@ -623,14 +624,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
|||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
// nested transaction
|
||||
if !db.DisableNestedTransaction {
|
||||
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
|
||||
spID := new(maphash.Hash).Sum64()
|
||||
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
||||
db.RollbackTo(fmt.Sprintf("sp%d", spID))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -29,8 +29,8 @@ require (
|
|||
github.com/microsoft/go-mssqldb v1.7.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.12.0 // indirect
|
||||
golang.org/x/crypto v0.24.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/crypto v0.26.0 // indirect
|
||||
golang.org/x/text v0.17.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
|
|
|
@ -297,6 +297,74 @@ func TestNestedTransactionWithBlock(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) {
|
||||
transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error {
|
||||
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
return callback(ctx, tx)
|
||||
})
|
||||
}
|
||||
var (
|
||||
user = *GetUser("transaction-nested", Config{})
|
||||
user1 = *GetUser("transaction-nested-1", Config{})
|
||||
user2 = *GetUser("transaction-nested-2", Config{})
|
||||
)
|
||||
|
||||
if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error {
|
||||
tx.Create(&user)
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
|
||||
if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error {
|
||||
tx1.Create(&user1)
|
||||
|
||||
if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
|
||||
if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error {
|
||||
tx2.Create(&user2)
|
||||
|
||||
if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
|
||||
return errors.New("inner rollback")
|
||||
}); err == nil {
|
||||
t.Fatalf("nested transaction has no error")
|
||||
}
|
||||
|
||||
return errors.New("rollback")
|
||||
}); err == nil {
|
||||
t.Fatalf("nested transaction should returns error")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
|
||||
t.Fatalf("Should not find rollbacked record")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatalf("no error should return, but got %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||
t.Fatalf("Should find saved record")
|
||||
}
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
|
||||
t.Fatalf("Should not find rollbacked parent record")
|
||||
}
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||
t.Fatalf("Should not find rollbacked nested record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisabledNestedTransaction(t *testing.T) {
|
||||
var (
|
||||
user = *GetUser("transaction-nested", Config{})
|
||||
|
|
Loading…
Reference in New Issue