mirror of https://github.com/go-gorm/gorm.git
Add GetDBConnector interface
This commit is contained in:
parent
220349ccf2
commit
a9fe025ef5
4
gorm.go
4
gorm.go
|
@ -331,8 +331,8 @@ func (db *DB) AddError(err error) error {
|
|||
func (db *DB) DB() (*sql.DB, error) {
|
||||
connPool := db.ConnPool
|
||||
|
||||
if stmtDB, ok := connPool.(*PreparedStmtDB); ok {
|
||||
connPool = stmtDB.ConnPool
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
||||
if sqldb, ok := connPool.(*sql.DB); ok {
|
||||
|
|
|
@ -57,3 +57,7 @@ type TxCommitter interface {
|
|||
type Valuer interface {
|
||||
GormValue(context.Context, *DB) clause.Expr
|
||||
}
|
||||
|
||||
type GetDBConnector interface {
|
||||
GetDBConn() (*sql.DB, error)
|
||||
}
|
||||
|
|
|
@ -18,6 +18,18 @@ type PreparedStmtDB struct {
|
|||
ConnPool
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) GetDB() (*sql.DB, error) {
|
||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvaildDB
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Close() {
|
||||
db.Mux.Lock()
|
||||
for _, query := range db.PreparedSQL {
|
||||
|
|
Loading…
Reference in New Issue