diff --git a/finisher_api.go b/finisher_api.go index dd0eb83a..355d89bd 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -534,9 +534,7 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { defer conn.Close() tx.Statement.ConnPool = conn - err = fc(tx) - - return + return fc(tx) } // Transaction start a transaction as a block, return error will rollback, otherwise to commit. @@ -547,6 +545,10 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er // nested transaction if !db.DisableNestedTransaction { err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error + if err != nil { + return + } + defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { @@ -555,11 +557,12 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er }() } - if err == nil { - err = fc(db.Session(&Session{})) - } + err = fc(db.Session(&Session{})) } else { tx := db.Begin(opts...) + if tx.Error != nil { + return tx.Error + } defer func() { // Make sure to rollback when panic, Block error or Commit error @@ -568,12 +571,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - if err = tx.Error; err == nil { - err = fc(tx) - } - - if err == nil { - err = tx.Commit().Error + if err = fc(tx); err == nil { + panicked = false + return tx.Commit().Error } } diff --git a/tests/connection_test.go b/tests/connection_test.go index 9b5dcd05..92b13dd6 100644 --- a/tests/connection_test.go +++ b/tests/connection_test.go @@ -2,13 +2,13 @@ package tests_test import ( "fmt" + "testing" + "gorm.io/driver/mysql" "gorm.io/gorm" - "testing" ) func TestWithSingleConnection(t *testing.T) { - var expectedName = "test" var actualName string @@ -35,7 +35,6 @@ func TestWithSingleConnection(t *testing.T) { if actualName != expectedName { t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName) } - } func getSetSQL(driverName string) (string, string) {