diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 430a341d..14d31a62 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -7,7 +7,7 @@ import ( func BeginTransaction(db *gorm.DB) { if tx := db.Begin(); tx.Error == nil { db.Statement.ConnPool = tx.Statement.ConnPool - tx.InstanceSet("gorm:started_transaction", true) + db.InstanceSet("gorm:started_transaction", true) } else { tx.Error = nil } diff --git a/finisher_api.go b/finisher_api.go index 73e42508..43aff843 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -351,7 +351,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - err = fc(tx.Session(&Session{})) + err = fc(tx) if err == nil { err = tx.Commit().Error @@ -364,7 +364,8 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er // Begin begins a transaction func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( - tx = db.getInstance() + // clone statement + tx = db.Session(&Session{WithConditions: true, Context: db.Statement.Context}) opt *sql.TxOptions err error ) diff --git a/statement.go b/statement.go index 2a092966..e3c882ee 100644 --- a/statement.go +++ b/statement.go @@ -351,5 +351,10 @@ func (stmt *Statement) clone() *Statement { newStmt.Joins[k] = j } + stmt.Settings.Range(func(k, v interface{}) bool { + newStmt.Settings.Store(k, v) + return true + }) + return newStmt } diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 592f1321..d1bf8645 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -20,6 +20,16 @@ func TestTransaction(t *testing.T) { t.Fatalf("Should find saved record, but got %v", err) } + user1 := *GetUser("transaction1-1", Config{}) + + if err := tx.Save(&user1).Error; err != nil { + t.Fatalf("No error should raise, but got %v", err) + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record, but got %v", err) + } + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { t.Fatalf("Should return the underlying sql.Tx") }