Add DisableNestedTransaction support

This commit is contained in:
Jinzhu 2020-12-16 19:33:35 +08:00
parent 6848ae872f
commit 468152d45b
3 changed files with 91 additions and 19 deletions

View File

@ -498,13 +498,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction // nested transaction
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error if !db.DisableNestedTransaction {
defer func() { err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
// Make sure to rollback when panic, Block error or Commit error defer func() {
if panicked || err != nil { // Make sure to rollback when panic, Block error or Commit error
db.RollbackTo(fmt.Sprintf("sp%p", fc)) if panicked || err != nil {
} db.RollbackTo(fmt.Sprintf("sp%p", fc))
}() }
}()
}
if err == nil { if err == nil {
err = fc(db.Session(&Session{})) err = fc(db.Session(&Session{}))

31
gorm.go
View File

@ -34,6 +34,8 @@ type Config struct {
DisableAutomaticPing bool DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating // DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool DisableForeignKeyConstraintWhenMigrating bool
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
// AllowGlobalUpdate allow global update // AllowGlobalUpdate allow global update
AllowGlobalUpdate bool AllowGlobalUpdate bool
// QueryFields executes the SQL query with all fields of the table // QueryFields executes the SQL query with all fields of the table
@ -65,18 +67,19 @@ type DB struct {
// Session session config when create session with Session() method // Session session config when create session with Session() method
type Session struct { type Session struct {
DryRun bool DryRun bool
PrepareStmt bool PrepareStmt bool
NewDB bool NewDB bool
SkipHooks bool SkipHooks bool
SkipDefaultTransaction bool SkipDefaultTransaction bool
AllowGlobalUpdate bool DisableNestedTransaction bool
FullSaveAssociations bool AllowGlobalUpdate bool
QueryFields bool FullSaveAssociations bool
Context context.Context QueryFields bool
Logger logger.Interface Context context.Context
NowFunc func() time.Time Logger logger.Interface
CreateBatchSize int NowFunc func() time.Time
CreateBatchSize int
} }
// Open initialize db session based on dialector // Open initialize db session based on dialector
@ -206,6 +209,10 @@ func (db *DB) Session(config *Session) *DB {
tx.Statement.SkipHooks = true tx.Statement.SkipHooks = true
} }
if config.DisableNestedTransaction {
txConfig.DisableNestedTransaction = true
}
if !config.NewDB { if !config.NewDB {
tx.clone = 2 tx.clone = 2
} }

View File

@ -283,6 +283,69 @@ func TestNestedTransactionWithBlock(t *testing.T) {
} }
} }
func TestDisabledNestedTransaction(t *testing.T) {
var (
user = *GetUser("transaction-nested", Config{})
user1 = *GetUser("transaction-nested-1", Config{})
user2 = *GetUser("transaction-nested-2", Config{})
)
if err := DB.Session(&gorm.Session{DisableNestedTransaction: true}).Transaction(func(tx *gorm.DB) error {
tx.Create(&user)
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
if err := tx.Transaction(func(tx1 *gorm.DB) error {
tx1.Create(&user1)
if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return errors.New("rollback")
}); err == nil {
t.Fatalf("nested transaction should returns error")
}
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should not rollback record if disabled nested transaction support")
}
if err := tx.Transaction(func(tx2 *gorm.DB) error {
tx2.Create(&user2)
if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return nil
}); err != nil {
t.Fatalf("nested transaction returns error: %v", err)
}
if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return nil
}); err != nil {
t.Fatalf("no error should return, but got %v", err)
}
if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
if err := DB.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should not rollback record if disabled nested transaction support")
}
if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
}
func TestTransactionOnClosedConn(t *testing.T) { func TestTransactionOnClosedConn(t *testing.T) {
DB, err := OpenTestConnection() DB, err := OpenTestConnection()
if err != nil { if err != nil {