gorm/prepare_stmt.go

261 lines
6.4 KiB
Go
Raw Normal View History

2020-06-05 05:08:22 +03:00
package gorm
import (
"context"
"database/sql"
2024-03-20 14:47:20 +03:00
"database/sql/driver"
"errors"
"reflect"
2020-06-05 05:08:22 +03:00
"sync"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
2020-06-05 05:08:22 +03:00
type PreparedStmtDB struct {
Stmts map[string]*Stmt
2020-07-28 09:26:09 +03:00
PreparedSQL []string
Mux *sync.RWMutex
2020-06-05 05:08:22 +03:00
ConnPool
}
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
}
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
2021-03-19 10:54:32 +03:00
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
2021-04-19 16:03:39 +03:00
return nil, ErrInvalidDB
2021-03-19 10:54:32 +03:00
}
func (db *PreparedStmtDB) Close() {
2020-08-03 16:48:36 +03:00
db.Mux.Lock()
defer db.Mux.Unlock()
2020-07-28 09:26:09 +03:00
for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
go stmt.Close()
2020-07-28 09:26:09 +03:00
}
}
}
func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()
for _, stmt := range sdb.Stmts {
go stmt.Close()
}
sdb.PreparedSQL = make([]string, 0, 100)
sdb.Stmts = make(map[string]*Stmt)
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
2020-08-03 16:48:36 +03:00
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
2020-08-03 16:48:36 +03:00
db.Mux.RUnlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
2020-06-05 05:08:22 +03:00
}
2020-08-03 16:48:36 +03:00
db.Mux.RUnlock()
2020-06-05 05:08:22 +03:00
2020-08-03 16:48:36 +03:00
db.Mux.Lock()
// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
2020-06-05 05:08:22 +03:00
}
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()
return cacheStmt, nil
2020-06-05 05:08:22 +03:00
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}
connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
2020-06-05 05:08:22 +03:00
return nil, ErrInvalidTransaction
}
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
2020-06-05 05:08:22 +03:00
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
2024-03-20 14:47:20 +03:00
if errors.Is(err, driver.ErrBadConn) {
2020-08-03 16:48:36 +03:00
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
}
2020-06-05 05:08:22 +03:00
}
return result, err
2020-06-05 05:08:22 +03:00
}
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
2020-06-05 05:08:22 +03:00
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
2024-03-20 14:47:20 +03:00
if errors.Is(err, driver.ErrBadConn) {
2020-08-03 16:48:36 +03:00
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
}
2020-06-05 05:08:22 +03:00
}
return rows, err
2020-06-05 05:08:22 +03:00
}
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
2020-06-05 05:08:22 +03:00
if err == nil {
return stmt.QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
func (db *PreparedStmtDB) Ping() error {
conn, err := db.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}
2020-06-05 05:08:22 +03:00
type PreparedStmtTX struct {
Tx
2020-06-05 05:08:22 +03:00
PreparedStmtDB *PreparedStmtDB
}
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Rollback()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
2020-06-05 05:08:22 +03:00
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
2024-03-20 14:47:20 +03:00
if errors.Is(err, driver.ErrBadConn) {
2020-08-03 16:48:36 +03:00
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
}
2020-06-05 05:08:22 +03:00
}
return result, err
2020-06-05 05:08:22 +03:00
}
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
2020-06-05 05:08:22 +03:00
if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
2024-03-20 14:47:20 +03:00
if errors.Is(err, driver.ErrBadConn) {
2020-08-03 16:48:36 +03:00
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
}
2020-06-05 05:08:22 +03:00
}
return rows, err
2020-06-05 05:08:22 +03:00
}
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
2020-06-05 05:08:22 +03:00
if err == nil {
return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
2020-06-05 05:08:22 +03:00
}
return &sql.Row{}
}
func (tx *PreparedStmtTX) Ping() error {
conn, err := tx.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}