gorm/prepare_stmt.go

151 lines
3.5 KiB
Go
Raw Normal View History

2020-06-05 05:08:22 +03:00
package gorm
import (
"context"
"database/sql"
"sync"
)
type PreparedStmtDB struct {
2020-07-28 09:26:09 +03:00
Stmts map[string]*sql.Stmt
PreparedSQL []string
Mux *sync.RWMutex
2020-06-05 05:08:22 +03:00
ConnPool
}
func (db *PreparedStmtDB) Close() {
2020-08-03 16:48:36 +03:00
db.Mux.Lock()
2020-07-28 09:26:09 +03:00
for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
stmt.Close()
}
}
2020-08-03 16:48:36 +03:00
db.Mux.Unlock()
}
func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) {
2020-08-03 16:48:36 +03:00
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok {
2020-08-03 16:48:36 +03:00
db.Mux.RUnlock()
2020-06-05 05:08:22 +03:00
return stmt, nil
}
2020-08-03 16:48:36 +03:00
db.Mux.RUnlock()
2020-06-05 05:08:22 +03:00
2020-08-03 16:48:36 +03:00
db.Mux.Lock()
// double check
if stmt, ok := db.Stmts[query]; ok {
2020-08-03 16:48:36 +03:00
db.Mux.Unlock()
return stmt, nil
}
stmt, err := db.ConnPool.PrepareContext(ctx, query)
2020-06-05 05:08:22 +03:00
if err == nil {
db.Stmts[query] = stmt
2020-07-28 09:26:09 +03:00
db.PreparedSQL = append(db.PreparedSQL, query)
2020-06-05 05:08:22 +03:00
}
2020-08-03 16:48:36 +03:00
db.Mux.Unlock()
2020-06-05 05:08:22 +03:00
return stmt, err
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
return nil, ErrInvalidTransaction
}
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := db.prepare(ctx, query)
2020-06-05 05:08:22 +03:00
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if err != nil {
2020-08-03 16:48:36 +03:00
db.Mux.Lock()
stmt.Close()
delete(db.Stmts, query)
2020-08-03 16:48:36 +03:00
db.Mux.Unlock()
}
2020-06-05 05:08:22 +03:00
}
return result, err
2020-06-05 05:08:22 +03:00
}
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := db.prepare(ctx, query)
2020-06-05 05:08:22 +03:00
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
2020-08-03 16:48:36 +03:00
db.Mux.Lock()
stmt.Close()
delete(db.Stmts, query)
2020-08-03 16:48:36 +03:00
db.Mux.Unlock()
}
2020-06-05 05:08:22 +03:00
}
return rows, err
2020-06-05 05:08:22 +03:00
}
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(ctx, query)
2020-06-05 05:08:22 +03:00
if err == nil {
return stmt.QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
type PreparedStmtTX struct {
*sql.Tx
PreparedStmtDB *PreparedStmtDB
}
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil {
return tx.Tx.Commit()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil {
return tx.Tx.Rollback()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
2020-06-05 05:08:22 +03:00
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...)
if err != nil {
2020-08-03 16:48:36 +03:00
tx.PreparedStmtDB.Mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
2020-08-03 16:48:36 +03:00
tx.PreparedStmtDB.Mux.Unlock()
}
2020-06-05 05:08:22 +03:00
}
return result, err
2020-06-05 05:08:22 +03:00
}
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
2020-06-05 05:08:22 +03:00
if err == nil {
rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
if err != nil {
2020-08-03 16:48:36 +03:00
tx.PreparedStmtDB.Mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
2020-08-03 16:48:36 +03:00
tx.PreparedStmtDB.Mux.Unlock()
}
2020-06-05 05:08:22 +03:00
}
return rows, err
2020-06-05 05:08:22 +03:00
}
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(ctx, query)
2020-06-05 05:08:22 +03:00
if err == nil {
return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...)
2020-06-05 05:08:22 +03:00
}
return &sql.Row{}
}