diff --git a/gorm.go b/gorm.go index 88212e94..9323c46d 100644 --- a/gorm.go +++ b/gorm.go @@ -331,8 +331,8 @@ func (db *DB) AddError(err error) error { func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - if stmtDB, ok := connPool.(*PreparedStmtDB); ok { - connPool = stmtDB.ConnPool + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() } if sqldb, ok := connPool.(*sql.DB); ok { diff --git a/interfaces.go b/interfaces.go index e933952b..44b2fced 100644 --- a/interfaces.go +++ b/interfaces.go @@ -57,3 +57,7 @@ type TxCommitter interface { type Valuer interface { GormValue(context.Context, *DB) clause.Expr } + +type GetDBConnector interface { + GetDBConn() (*sql.DB, error) +} diff --git a/prepare_stmt.go b/prepare_stmt.go index 78a8adb4..bc7ef180 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -18,6 +18,18 @@ type PreparedStmtDB struct { ConnPool } +func (db *PreparedStmtDB) GetDB() (*sql.DB, error) { + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + + if sqldb, ok := db.ConnPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, ErrInvaildDB +} + func (db *PreparedStmtDB) Close() { db.Mux.Lock() for _, query := range db.PreparedSQL {