Add DisableNestedTransaction support

This commit is contained in:
Jinzhu 2020-12-16 19:33:35 +08:00
parent 6848ae872f
commit 468152d45b
3 changed files with 91 additions and 19 deletions

View File

@ -498,13 +498,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc))
}
}()
if !db.DisableNestedTransaction {
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc))
}
}()
}
if err == nil {
err = fc(db.Session(&Session{}))

31
gorm.go
View File

@ -34,6 +34,8 @@ type Config struct {
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// QueryFields executes the SQL query with all fields of the table
@ -65,18 +67,19 @@ type DB struct {
// Session session config when create session with Session() method
type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
SkipHooks bool
SkipDefaultTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
QueryFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
DryRun bool
PrepareStmt bool
NewDB bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
QueryFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
}
// Open initialize db session based on dialector
@ -206,6 +209,10 @@ func (db *DB) Session(config *Session) *DB {
tx.Statement.SkipHooks = true
}
if config.DisableNestedTransaction {
txConfig.DisableNestedTransaction = true
}
if !config.NewDB {
tx.clone = 2
}

View File

@ -283,6 +283,69 @@ func TestNestedTransactionWithBlock(t *testing.T) {
}
}
func TestDisabledNestedTransaction(t *testing.T) {
var (
user = *GetUser("transaction-nested", Config{})
user1 = *GetUser("transaction-nested-1", Config{})
user2 = *GetUser("transaction-nested-2", Config{})
)
if err := DB.Session(&gorm.Session{DisableNestedTransaction: true}).Transaction(func(tx *gorm.DB) error {
tx.Create(&user)
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
if err := tx.Transaction(func(tx1 *gorm.DB) error {
tx1.Create(&user1)
if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return errors.New("rollback")
}); err == nil {
t.Fatalf("nested transaction should returns error")
}
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should not rollback record if disabled nested transaction support")
}
if err := tx.Transaction(func(tx2 *gorm.DB) error {
tx2.Create(&user2)
if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return nil
}); err != nil {
t.Fatalf("nested transaction returns error: %v", err)
}
if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return nil
}); err != nil {
t.Fatalf("no error should return, but got %v", err)
}
if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
if err := DB.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should not rollback record if disabled nested transaction support")
}
if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
}
func TestTransactionOnClosedConn(t *testing.T) {
DB, err := OpenTestConnection()
if err != nil {