Refactor Tx interface

This commit is contained in:
Jinzhu 2022-03-01 16:48:46 +08:00
parent 996b96e812
commit 4e523499d1
4 changed files with 7 additions and 27 deletions

View File

@ -600,13 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
opt = opts[0] opt = opts[0]
} }
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok { default:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else {
err = ErrInvalidTransaction err = ErrInvalidTransaction
} }

View File

@ -50,11 +50,6 @@ type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
} }
// TxConnPoolBeginner tx conn pool beginner
type TxConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
}
// TxCommitter tx committer // TxCommitter tx committer
type TxCommitter interface { type TxCommitter interface {
Commit() error Commit() error
@ -64,8 +59,7 @@ type TxCommitter interface {
// Tx sql.Tx interface // Tx sql.Tx interface
type Tx interface { type Tx interface {
ConnPool ConnPool
Commit() error TxCommitter
Rollback() error
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
} }

View File

@ -73,9 +73,6 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
if beginner, ok := db.ConnPool.(TxBeginner); ok { if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt) tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
} else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
} }
return nil, ErrInvalidTransaction return nil, ErrInvalidTransaction
} }

View File

@ -3,15 +3,12 @@ package tests_test
import ( import (
"context" "context"
"database/sql" "database/sql"
"log"
"os" "os"
"reflect" "reflect"
"testing" "testing"
"time"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger"
. "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests"
) )
@ -55,7 +52,7 @@ func (c *wrapperConnPool) Ping() error {
// return c.db.BeginTx(ctx, opts) // return c.db.BeginTx(ctx, opts)
// } // }
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) { func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
tx, err := c.db.BeginTx(ctx, opts) tx, err := c.db.BeginTx(ctx, opts)
if err != nil { if err != nil {
return nil, err return nil, err
@ -119,14 +116,7 @@ func TestConnPoolWrapper(t *testing.T) {
} }
}() }()
l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
SlowThreshold: 200 * time.Millisecond,
LogLevel: logger.Info,
IgnoreRecordNotFoundError: false,
Colorful: true,
})
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l})
if err != nil { if err != nil {
t.Fatalf("Should open db success, but got %v", err) t.Fatalf("Should open db success, but got %v", err)
} }