forked from mirror/gorm
Add DisableNestedTransaction support
This commit is contained in:
parent
6848ae872f
commit
468152d45b
|
@ -498,13 +498,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
||||||
|
|
||||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||||
// nested transaction
|
// nested transaction
|
||||||
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
|
if !db.DisableNestedTransaction {
|
||||||
defer func() {
|
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
|
||||||
// Make sure to rollback when panic, Block error or Commit error
|
defer func() {
|
||||||
if panicked || err != nil {
|
// Make sure to rollback when panic, Block error or Commit error
|
||||||
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
if panicked || err != nil {
|
||||||
}
|
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
||||||
}()
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = fc(db.Session(&Session{}))
|
err = fc(db.Session(&Session{}))
|
||||||
|
|
31
gorm.go
31
gorm.go
|
@ -34,6 +34,8 @@ type Config struct {
|
||||||
DisableAutomaticPing bool
|
DisableAutomaticPing bool
|
||||||
// DisableForeignKeyConstraintWhenMigrating
|
// DisableForeignKeyConstraintWhenMigrating
|
||||||
DisableForeignKeyConstraintWhenMigrating bool
|
DisableForeignKeyConstraintWhenMigrating bool
|
||||||
|
// DisableNestedTransaction disable nested transaction
|
||||||
|
DisableNestedTransaction bool
|
||||||
// AllowGlobalUpdate allow global update
|
// AllowGlobalUpdate allow global update
|
||||||
AllowGlobalUpdate bool
|
AllowGlobalUpdate bool
|
||||||
// QueryFields executes the SQL query with all fields of the table
|
// QueryFields executes the SQL query with all fields of the table
|
||||||
|
@ -65,18 +67,19 @@ type DB struct {
|
||||||
|
|
||||||
// Session session config when create session with Session() method
|
// Session session config when create session with Session() method
|
||||||
type Session struct {
|
type Session struct {
|
||||||
DryRun bool
|
DryRun bool
|
||||||
PrepareStmt bool
|
PrepareStmt bool
|
||||||
NewDB bool
|
NewDB bool
|
||||||
SkipHooks bool
|
SkipHooks bool
|
||||||
SkipDefaultTransaction bool
|
SkipDefaultTransaction bool
|
||||||
AllowGlobalUpdate bool
|
DisableNestedTransaction bool
|
||||||
FullSaveAssociations bool
|
AllowGlobalUpdate bool
|
||||||
QueryFields bool
|
FullSaveAssociations bool
|
||||||
Context context.Context
|
QueryFields bool
|
||||||
Logger logger.Interface
|
Context context.Context
|
||||||
NowFunc func() time.Time
|
Logger logger.Interface
|
||||||
CreateBatchSize int
|
NowFunc func() time.Time
|
||||||
|
CreateBatchSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open initialize db session based on dialector
|
// Open initialize db session based on dialector
|
||||||
|
@ -206,6 +209,10 @@ func (db *DB) Session(config *Session) *DB {
|
||||||
tx.Statement.SkipHooks = true
|
tx.Statement.SkipHooks = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.DisableNestedTransaction {
|
||||||
|
txConfig.DisableNestedTransaction = true
|
||||||
|
}
|
||||||
|
|
||||||
if !config.NewDB {
|
if !config.NewDB {
|
||||||
tx.clone = 2
|
tx.clone = 2
|
||||||
}
|
}
|
||||||
|
|
|
@ -283,6 +283,69 @@ func TestNestedTransactionWithBlock(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDisabledNestedTransaction(t *testing.T) {
|
||||||
|
var (
|
||||||
|
user = *GetUser("transaction-nested", Config{})
|
||||||
|
user1 = *GetUser("transaction-nested-1", Config{})
|
||||||
|
user2 = *GetUser("transaction-nested-2", Config{})
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := DB.Session(&gorm.Session{DisableNestedTransaction: true}).Transaction(func(tx *gorm.DB) error {
|
||||||
|
tx.Create(&user)
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Transaction(func(tx1 *gorm.DB) error {
|
||||||
|
tx1.Create(&user1)
|
||||||
|
|
||||||
|
if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.New("rollback")
|
||||||
|
}); err == nil {
|
||||||
|
t.Fatalf("nested transaction should returns error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should not rollback record if disabled nested transaction support")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Transaction(func(tx2 *gorm.DB) error {
|
||||||
|
tx2.Create(&user2)
|
||||||
|
|
||||||
|
if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("nested transaction returns error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("no error should return, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should not rollback record if disabled nested transaction support")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTransactionOnClosedConn(t *testing.T) {
|
func TestTransactionOnClosedConn(t *testing.T) {
|
||||||
DB, err := OpenTestConnection()
|
DB, err := OpenTestConnection()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue