diff --git a/finisher_api.go b/finisher_api.go index 5d49ddf9..4b428a59 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -600,13 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { opt = opts[0] } - if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { + switch beginner := tx.Statement.ConnPool.(type) { + case TxBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { + case ConnPoolBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { - tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) - } else { + default: err = ErrInvalidTransaction } diff --git a/interfaces.go b/interfaces.go index ed7112f2..84dc94bb 100644 --- a/interfaces.go +++ b/interfaces.go @@ -50,11 +50,6 @@ type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } -// TxConnPoolBeginner tx conn pool beginner -type TxConnPoolBeginner interface { - BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) -} - // TxCommitter tx committer type TxCommitter interface { Commit() error @@ -64,8 +59,7 @@ type TxCommitter interface { // Tx sql.Tx interface type Tx interface { ConnPool - Commit() error - Rollback() error + TxCommitter StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt } diff --git a/prepare_stmt.go b/prepare_stmt.go index 94282fad..b062b0d6 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -73,9 +73,6 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn if beginner, ok := db.ConnPool.(TxBeginner); ok { tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err - } else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok { - tx, err := beginner.BeginTx(ctx, opt) - return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } return nil, ErrInvalidTransaction } diff --git a/tests/connpool_test.go b/tests/connpool_test.go index 3713ad7c..fbae2294 100644 --- a/tests/connpool_test.go +++ b/tests/connpool_test.go @@ -3,15 +3,12 @@ package tests_test import ( "context" "database/sql" - "log" "os" "reflect" "testing" - "time" "gorm.io/driver/mysql" "gorm.io/gorm" - "gorm.io/gorm/logger" . "gorm.io/gorm/utils/tests" ) @@ -55,7 +52,7 @@ func (c *wrapperConnPool) Ping() error { // return c.db.BeginTx(ctx, opts) // } // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. -func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { +func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { tx, err := c.db.BeginTx(ctx, opts) if err != nil { return nil, err @@ -119,14 +116,7 @@ func TestConnPoolWrapper(t *testing.T) { } }() - l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ - SlowThreshold: 200 * time.Millisecond, - LogLevel: logger.Info, - IgnoreRecordNotFoundError: false, - Colorful: true, - }) - - db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l}) + db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn})) if err != nil { t.Fatalf("Should open db success, but got %v", err) }