mirror of https://github.com/go-gorm/gorm.git
feat: rm GetDBConnWithContext method (#6535)
* feat: rm contextconnpool method * feat: nil
This commit is contained in:
parent
bae684b363
commit
fef42941ba
2
go.sum
2
go.sum
|
@ -1,6 +1,4 @@
|
|||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
|
|
9
gorm.go
9
gorm.go
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -374,9 +375,11 @@ func (db *DB) AddError(err error) error {
|
|||
// DB returns `*sql.DB`
|
||||
func (db *DB) DB() (*sql.DB, error) {
|
||||
connPool := db.ConnPool
|
||||
|
||||
if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil {
|
||||
return connector.GetDBConnWithContext(db)
|
||||
if db.Statement != nil && db.Statement.ConnPool != nil {
|
||||
connPool = db.Statement.ConnPool
|
||||
}
|
||||
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
|
||||
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
|
||||
}
|
||||
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
|
|
|
@ -77,12 +77,6 @@ type GetDBConnector interface {
|
|||
GetDBConn() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// GetDBConnectorWithContext represents SQL db connector which takes into
|
||||
// account the current database context
|
||||
type GetDBConnectorWithContext interface {
|
||||
GetDBConnWithContext(db *DB) (*sql.DB, error)
|
||||
}
|
||||
|
||||
// Rows rows interface
|
||||
type Rows interface {
|
||||
Columns() ([]string, error)
|
||||
|
|
|
@ -30,15 +30,11 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
|||
}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) {
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil {
|
||||
return connector.GetDBConnWithContext(gormdb)
|
||||
}
|
||||
|
||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
@ -131,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
|||
tx, err := beginner.BeginTx(ctx, opt)
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||
}
|
||||
|
||||
beginner, ok := db.ConnPool.(ConnPoolBeginner)
|
||||
if !ok {
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
connPool, err := beginner.BeginTx(ctx, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tx, ok := connPool.(Tx); ok {
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
|
||||
}
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
|
@ -176,6 +185,10 @@ type PreparedStmtTX struct {
|
|||
PreparedStmtDB *PreparedStmtDB
|
||||
}
|
||||
|
||||
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
|
||||
return db.PreparedStmtDB.GetDBConn()
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Commit() error {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
return tx.Tx.Commit()
|
||||
|
|
Loading…
Reference in New Issue