Add GetDBConnector interface

This commit is contained in:
Jinzhu 2021-03-19 15:54:32 +08:00
parent 220349ccf2
commit a9fe025ef5
3 changed files with 18 additions and 2 deletions

View File

@ -331,8 +331,8 @@ func (db *DB) AddError(err error) error {
func (db *DB) DB() (*sql.DB, error) { func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool connPool := db.ConnPool
if stmtDB, ok := connPool.(*PreparedStmtDB); ok { if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
connPool = stmtDB.ConnPool return dbConnector.GetDBConn()
} }
if sqldb, ok := connPool.(*sql.DB); ok { if sqldb, ok := connPool.(*sql.DB); ok {

View File

@ -57,3 +57,7 @@ type TxCommitter interface {
type Valuer interface { type Valuer interface {
GormValue(context.Context, *DB) clause.Expr GormValue(context.Context, *DB) clause.Expr
} }
type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
}

View File

@ -18,6 +18,18 @@ type PreparedStmtDB struct {
ConnPool ConnPool
} }
func (db *PreparedStmtDB) GetDB() (*sql.DB, error) {
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}
return nil, ErrInvaildDB
}
func (db *PreparedStmtDB) Close() { func (db *PreparedStmtDB) Close() {
db.Mux.Lock() db.Mux.Lock()
for _, query := range db.PreparedSQL { for _, query := range db.PreparedSQL {