feat: rm contextconnpool method

This commit is contained in:
qqxhb 2023-08-18 18:55:58 +08:00
parent bae684b363
commit a219acca4b
4 changed files with 22 additions and 16 deletions

2
go.sum
View File

@ -1,6 +1,4 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"reflect"
"sort" "sort"
"sync" "sync"
"time" "time"
@ -373,10 +374,10 @@ func (db *DB) AddError(err error) error {
// DB returns `*sql.DB` // DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) { func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool connPool := db.Statement.ConnPool
if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil { if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return connector.GetDBConnWithContext(db) return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
} }
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {

View File

@ -77,12 +77,6 @@ type GetDBConnector interface {
GetDBConn() (*sql.DB, error) GetDBConn() (*sql.DB, error)
} }
// GetDBConnectorWithContext represents SQL db connector which takes into
// account the current database context
type GetDBConnectorWithContext interface {
GetDBConnWithContext(db *DB) (*sql.DB, error)
}
// Rows rows interface // Rows rows interface
type Rows interface { type Rows interface {
Columns() ([]string, error) Columns() ([]string, error)

View File

@ -30,15 +30,11 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
} }
} }
func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) { func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok { if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil return sqldb, nil
} }
if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil {
return connector.GetDBConnWithContext(gormdb)
}
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn() return dbConnector.GetDBConn()
} }
@ -131,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
tx, err := beginner.BeginTx(ctx, opt) tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
} }
beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}
connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
return nil, ErrInvalidTransaction return nil, ErrInvalidTransaction
} }
@ -176,6 +185,10 @@ type PreparedStmtTX struct {
PreparedStmtDB *PreparedStmtDB PreparedStmtDB *PreparedStmtDB
} }
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}
func (tx *PreparedStmtTX) Commit() error { func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit() return tx.Tx.Commit()