mirror of https://github.com/go-gorm/gorm.git
Add SavePoint/RollbackTo/NestedTransaction
This commit is contained in:
parent
2c1b04a2cf
commit
7dc255acfe
|
@ -25,4 +25,6 @@ var (
|
||||||
ErrorPrimaryKeyRequired = errors.New("primary key required")
|
ErrorPrimaryKeyRequired = errors.New("primary key required")
|
||||||
// ErrorModelValueRequired model value required
|
// ErrorModelValueRequired model value required
|
||||||
ErrorModelValueRequired = errors.New("model value required")
|
ErrorModelValueRequired = errors.New("model value required")
|
||||||
|
// ErrUnsupportedDriver unsupported driver
|
||||||
|
ErrUnsupportedDriver = errors.New("unsupported driver")
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,6 +3,7 @@ package gorm
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -343,18 +344,33 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
|
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
|
||||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
||||||
panicked := true
|
panicked := true
|
||||||
tx := db.Begin(opts...)
|
|
||||||
defer func() {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||||
// Make sure to rollback when panic, Block error or Commit error
|
// nested transaction
|
||||||
if panicked || err != nil {
|
db.SavePoint(fmt.Sprintf("sp%p", fc))
|
||||||
tx.Rollback()
|
defer func() {
|
||||||
|
// Make sure to rollback when panic, Block error or Commit error
|
||||||
|
if panicked || err != nil {
|
||||||
|
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = fc(db.Session(&Session{WithConditions: true}))
|
||||||
|
} else {
|
||||||
|
tx := db.Begin(opts...)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
// Make sure to rollback when panic, Block error or Commit error
|
||||||
|
if panicked || err != nil {
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = fc(tx)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
err = tx.Commit().Error
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
err = fc(tx)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
err = tx.Commit().Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
panicked = false
|
panicked = false
|
||||||
|
@ -409,6 +425,24 @@ func (db *DB) Rollback() *DB {
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *DB) SavePoint(name string) *DB {
|
||||||
|
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||||
|
savePointer.SavePoint(db, name)
|
||||||
|
} else {
|
||||||
|
db.AddError(ErrUnsupportedDriver)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) RollbackTo(name string) *DB {
|
||||||
|
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||||
|
savePointer.RollbackTo(db, name)
|
||||||
|
} else {
|
||||||
|
db.AddError(ErrUnsupportedDriver)
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
// Exec execute raw sql
|
// Exec execute raw sql
|
||||||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
|
|
|
@ -27,6 +27,11 @@ type ConnPool interface {
|
||||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SavePointerDialectorInterface interface {
|
||||||
|
SavePoint(tx *DB, name string) error
|
||||||
|
RollbackTo(tx *DB, name string) error
|
||||||
|
}
|
||||||
|
|
||||||
type TxBeginner interface {
|
type TxBeginner interface {
|
||||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||||
}
|
}
|
||||||
|
|
10
tests/go.mod
10
tests/go.mod
|
@ -6,11 +6,11 @@ require (
|
||||||
github.com/google/uuid v1.1.1
|
github.com/google/uuid v1.1.1
|
||||||
github.com/jinzhu/now v1.1.1
|
github.com/jinzhu/now v1.1.1
|
||||||
github.com/lib/pq v1.6.0
|
github.com/lib/pq v1.6.0
|
||||||
gorm.io/driver/mysql v0.2.0
|
gorm.io/driver/mysql v0.2.1
|
||||||
gorm.io/driver/postgres v0.2.0
|
gorm.io/driver/postgres v0.2.1
|
||||||
gorm.io/driver/sqlite v1.0.2
|
gorm.io/driver/sqlite v1.0.4
|
||||||
gorm.io/driver/sqlserver v0.2.0
|
gorm.io/driver/sqlserver v0.2.1
|
||||||
gorm.io/gorm v0.0.0-00010101000000-000000000000
|
gorm.io/gorm v0.2.7
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
|
|
@ -142,3 +142,123 @@ func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) {
|
||||||
t.Fatalf("Rollback after commit should raise error")
|
t.Fatalf("Rollback after commit should raise error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTransactionWithSavePoint(t *testing.T) {
|
||||||
|
tx := DB.Begin()
|
||||||
|
|
||||||
|
user := *GetUser("transaction-save-point", Config{})
|
||||||
|
tx.Create(&user)
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.SavePoint("save_point1").Error; err != nil {
|
||||||
|
t.Fatalf("Failed to save point, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user1 := *GetUser("transaction-save-point-1", Config{})
|
||||||
|
tx.Create(&user1)
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.RollbackTo("save_point1").Error; err != nil {
|
||||||
|
t.Fatalf("Failed to save point, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
|
||||||
|
t.Fatalf("Should not find rollbacked record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.SavePoint("save_point2").Error; err != nil {
|
||||||
|
t.Fatalf("Failed to save point, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
user2 := *GetUser("transaction-save-point-2", Config{})
|
||||||
|
tx.Create(&user2)
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
t.Fatalf("Failed to commit, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
|
||||||
|
t.Fatalf("Should not find rollbacked record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNestedTransactionWithBlock(t *testing.T) {
|
||||||
|
var (
|
||||||
|
user = *GetUser("transaction-nested", Config{})
|
||||||
|
user1 = *GetUser("transaction-nested-1", Config{})
|
||||||
|
user2 = *GetUser("transaction-nested-2", Config{})
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
tx.Create(&user)
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Transaction(func(tx1 *gorm.DB) error {
|
||||||
|
tx1.Create(&user1)
|
||||||
|
|
||||||
|
if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.New("rollback")
|
||||||
|
}); err == nil {
|
||||||
|
t.Fatalf("nested transaction should returns error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
|
||||||
|
t.Fatalf("Should not find rollbacked record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Transaction(func(tx2 *gorm.DB) error {
|
||||||
|
tx2.Create(&user2)
|
||||||
|
|
||||||
|
if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("nested transaction returns error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("no error should return, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
|
||||||
|
t.Fatalf("Should not find rollbacked record")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
|
||||||
|
t.Fatalf("Should find saved record")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -124,9 +124,3 @@ build:
|
||||||
name: test mssql
|
name: test mssql
|
||||||
code: |
|
code: |
|
||||||
GORM_DIALECT=mssql GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" ./tests/tests_all.sh
|
GORM_DIALECT=mssql GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" ./tests/tests_all.sh
|
||||||
|
|
||||||
- script:
|
|
||||||
name: codecov
|
|
||||||
code: |
|
|
||||||
go test -race -coverprofile=coverage.txt -covermode=atomic ./...
|
|
||||||
bash <(curl -s https://codecov.io/bash)
|
|
||||||
|
|
Loading…
Reference in New Issue