mirror of https://github.com/go-gorm/gorm.git
Refactor Tx interface
This commit is contained in:
parent
996b96e812
commit
4e523499d1
|
@ -600,13 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||||
opt = opts[0]
|
opt = opts[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
|
switch beginner := tx.Statement.ConnPool.(type) {
|
||||||
|
case TxBeginner:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
|
case ConnPoolBeginner:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
} else if beginner, ok := tx.Statement.ConnPool.(TxConnPoolBeginner); ok {
|
default:
|
||||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
|
||||||
} else {
|
|
||||||
err = ErrInvalidTransaction
|
err = ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -50,11 +50,6 @@ type ConnPoolBeginner interface {
|
||||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TxConnPoolBeginner tx conn pool beginner
|
|
||||||
type TxConnPoolBeginner interface {
|
|
||||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TxCommitter tx committer
|
// TxCommitter tx committer
|
||||||
type TxCommitter interface {
|
type TxCommitter interface {
|
||||||
Commit() error
|
Commit() error
|
||||||
|
@ -64,8 +59,7 @@ type TxCommitter interface {
|
||||||
// Tx sql.Tx interface
|
// Tx sql.Tx interface
|
||||||
type Tx interface {
|
type Tx interface {
|
||||||
ConnPool
|
ConnPool
|
||||||
Commit() error
|
TxCommitter
|
||||||
Rollback() error
|
|
||||||
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
|
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -73,9 +73,6 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
||||||
if beginner, ok := db.ConnPool.(TxBeginner); ok {
|
if beginner, ok := db.ConnPool.(TxBeginner); ok {
|
||||||
tx, err := beginner.BeginTx(ctx, opt)
|
tx, err := beginner.BeginTx(ctx, opt)
|
||||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||||
} else if beginner, ok := db.ConnPool.(TxConnPoolBeginner); ok {
|
|
||||||
tx, err := beginner.BeginTx(ctx, opt)
|
|
||||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
|
||||||
}
|
}
|
||||||
return nil, ErrInvalidTransaction
|
return nil, ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,15 +3,12 @@ package tests_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"log"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"gorm.io/driver/mysql"
|
"gorm.io/driver/mysql"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -55,7 +52,7 @@ func (c *wrapperConnPool) Ping() error {
|
||||||
// return c.db.BeginTx(ctx, opts)
|
// return c.db.BeginTx(ctx, opts)
|
||||||
// }
|
// }
|
||||||
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
|
// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
|
||||||
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.Tx, error) {
|
func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
|
||||||
tx, err := c.db.BeginTx(ctx, opts)
|
tx, err := c.db.BeginTx(ctx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -119,14 +116,7 @@ func TestConnPoolWrapper(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
l := logger.New(log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{
|
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
|
||||||
SlowThreshold: 200 * time.Millisecond,
|
|
||||||
LogLevel: logger.Info,
|
|
||||||
IgnoreRecordNotFoundError: false,
|
|
||||||
Colorful: true,
|
|
||||||
})
|
|
||||||
|
|
||||||
db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}), &gorm.Config{Logger: l})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Should open db success, but got %v", err)
|
t.Fatalf("Should open db success, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue