mirror of https://github.com/go-gorm/gorm.git
feat: rm contextconnpool method
This commit is contained in:
parent
bae684b363
commit
a219acca4b
2
go.sum
2
go.sum
|
@ -1,6 +1,4 @@
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
|
|
||||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
|
||||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
|
|
7
gorm.go
7
gorm.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -373,10 +374,10 @@ func (db *DB) AddError(err error) error {
|
||||||
|
|
||||||
// DB returns `*sql.DB`
|
// DB returns `*sql.DB`
|
||||||
func (db *DB) DB() (*sql.DB, error) {
|
func (db *DB) DB() (*sql.DB, error) {
|
||||||
connPool := db.ConnPool
|
connPool := db.Statement.ConnPool
|
||||||
|
|
||||||
if connector, ok := connPool.(GetDBConnectorWithContext); ok && connector != nil {
|
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
|
||||||
return connector.GetDBConnWithContext(db)
|
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||||
|
|
|
@ -77,12 +77,6 @@ type GetDBConnector interface {
|
||||||
GetDBConn() (*sql.DB, error)
|
GetDBConn() (*sql.DB, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDBConnectorWithContext represents SQL db connector which takes into
|
|
||||||
// account the current database context
|
|
||||||
type GetDBConnectorWithContext interface {
|
|
||||||
GetDBConnWithContext(db *DB) (*sql.DB, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rows rows interface
|
// Rows rows interface
|
||||||
type Rows interface {
|
type Rows interface {
|
||||||
Columns() ([]string, error)
|
Columns() ([]string, error)
|
||||||
|
|
|
@ -30,15 +30,11 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) {
|
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||||
return sqldb, nil
|
return sqldb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil {
|
|
||||||
return connector.GetDBConnWithContext(gormdb)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||||
return dbConnector.GetDBConn()
|
return dbConnector.GetDBConn()
|
||||||
}
|
}
|
||||||
|
@ -131,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
||||||
tx, err := beginner.BeginTx(ctx, opt)
|
tx, err := beginner.BeginTx(ctx, opt)
|
||||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
beginner, ok := db.ConnPool.(ConnPoolBeginner)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrInvalidTransaction
|
||||||
|
}
|
||||||
|
|
||||||
|
connPool, err := beginner.BeginTx(ctx, opt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tx, ok := connPool.(Tx); ok {
|
||||||
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
|
||||||
|
}
|
||||||
return nil, ErrInvalidTransaction
|
return nil, ErrInvalidTransaction
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,6 +185,10 @@ type PreparedStmtTX struct {
|
||||||
PreparedStmtDB *PreparedStmtDB
|
PreparedStmtDB *PreparedStmtDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
|
||||||
|
return db.PreparedStmtDB.GetDBConn()
|
||||||
|
}
|
||||||
|
|
||||||
func (tx *PreparedStmtTX) Commit() error {
|
func (tx *PreparedStmtTX) Commit() error {
|
||||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||||
return tx.Tx.Commit()
|
return tx.Tx.Commit()
|
||||||
|
|
Loading…
Reference in New Issue