mirror of https://github.com/go-gorm/gorm.git
Add PrepareStmt support
This commit is contained in:
parent
9934207c42
commit
c8e7878b3e
|
@ -310,28 +310,36 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
|||
}
|
||||
|
||||
// Begin begins a transaction
|
||||
func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
|
||||
var opt *sql.TxOptions
|
||||
var err error
|
||||
if len(opts) > 0 {
|
||||
opt = opts[0]
|
||||
}
|
||||
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
var (
|
||||
tx = db.getInstance()
|
||||
opt *sql.TxOptions
|
||||
err error
|
||||
)
|
||||
|
||||
if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil {
|
||||
tx.AddError(err)
|
||||
}
|
||||
} else {
|
||||
tx.AddError(ErrInvalidTransaction)
|
||||
if len(opts) > 0 {
|
||||
opt = opts[0]
|
||||
}
|
||||
return
|
||||
|
||||
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
} else {
|
||||
err = ErrInvalidTransaction
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
tx.AddError(err)
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// Commit commit a transaction
|
||||
func (db *DB) Commit() *DB {
|
||||
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
|
||||
db.AddError(comminter.Commit())
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
db.AddError(committer.Commit())
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
|
@ -340,8 +348,8 @@ func (db *DB) Commit() *DB {
|
|||
|
||||
// Rollback rollback a transaction
|
||||
func (db *DB) Rollback() *DB {
|
||||
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
|
||||
db.AddError(comminter.Rollback())
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
db.AddError(committer.Rollback())
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
|
|
49
gorm.go
49
gorm.go
|
@ -2,6 +2,7 @@ package gorm
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -25,6 +26,9 @@ type Config struct {
|
|||
// DryRun generate sql without execute
|
||||
DryRun bool
|
||||
|
||||
// PrepareStmt executes the given query in cached statement
|
||||
PrepareStmt bool
|
||||
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
// ConnPool db conn pool
|
||||
|
@ -48,6 +52,7 @@ type DB struct {
|
|||
// Session session config when create session with Session() method
|
||||
type Session struct {
|
||||
DryRun bool
|
||||
PrepareStmt bool
|
||||
WithConditions bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
|
@ -92,6 +97,22 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
|||
err = dialector.Initialize(db)
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
db.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.ConnPool,
|
||||
stmts: map[string]*sql.Stmt{},
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement == nil {
|
||||
db.Statement = &Statement{
|
||||
DB: db,
|
||||
ConnPool: db.ConnPool,
|
||||
Context: context.Background(),
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
|
||||
err = pinger.Ping()
|
||||
|
@ -131,6 +152,13 @@ func (db *DB) Session(config *Session) *DB {
|
|||
tx.Statement.Context = config.Context
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
stmts: map[string]*sql.Stmt{},
|
||||
}
|
||||
}
|
||||
|
||||
if config.WithConditions {
|
||||
tx.clone = 3
|
||||
}
|
||||
|
@ -256,6 +284,12 @@ func (db *DB) getInstance() *DB {
|
|||
|
||||
switch db.clone {
|
||||
case 1: // clone with new statement
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
ConnPool: db.Statement.ConnPool,
|
||||
Context: db.Statement.Context,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
case 2: // with old statement, generate new statement for future call, used to pass to callbacks
|
||||
db.clone = 1
|
||||
tx.Statement = db.Statement
|
||||
|
@ -266,21 +300,6 @@ func (db *DB) getInstance() *DB {
|
|||
}
|
||||
}
|
||||
|
||||
if tx.Statement == nil {
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement != nil {
|
||||
tx.Statement.Context = db.Statement.Context
|
||||
tx.Statement.ConnPool = db.Statement.ConnPool
|
||||
} else {
|
||||
tx.Statement.Context = context.Background()
|
||||
tx.Statement.ConnPool = db.ConnPool
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
|
|
|
@ -21,8 +21,8 @@ type Dialector interface {
|
|||
|
||||
// ConnPool db conns pool interface
|
||||
type ConnPool interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
@ -31,7 +31,11 @@ type TxBeginner interface {
|
|||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type TxCommiter interface {
|
||||
type ConnPoolBeginner interface {
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
||||
}
|
||||
|
||||
type TxCommitter interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type PreparedStmtDB struct {
|
||||
stmts map[string]*sql.Stmt
|
||||
mux sync.RWMutex
|
||||
ConnPool
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
||||
db.mux.RLock()
|
||||
if stmt, ok := db.stmts[query]; ok {
|
||||
db.mux.RUnlock()
|
||||
return stmt, nil
|
||||
}
|
||||
db.mux.RUnlock()
|
||||
|
||||
db.mux.Lock()
|
||||
stmt, err := db.ConnPool.PrepareContext(context.Background(), query)
|
||||
if err == nil {
|
||||
db.stmts[query] = stmt
|
||||
}
|
||||
db.mux.Unlock()
|
||||
|
||||
return stmt, err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||
if beginner, ok := db.ConnPool.(TxBeginner); ok {
|
||||
tx, err := beginner.BeginTx(ctx, opt)
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||
}
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := db.prepare(query)
|
||||
if err == nil {
|
||||
return stmt.ExecContext(ctx, args...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
stmt, err := db.prepare(query)
|
||||
if err == nil {
|
||||
return stmt.QueryContext(ctx, args...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
stmt, err := db.prepare(query)
|
||||
if err == nil {
|
||||
return stmt.QueryRowContext(ctx, args...)
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
type PreparedStmtTX struct {
|
||||
*sql.Tx
|
||||
PreparedStmtDB *PreparedStmtDB
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||
if err == nil {
|
||||
return tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||
if err == nil {
|
||||
return tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||
if err == nil {
|
||||
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...)
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
|
@ -1,7 +1,6 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
|
@ -21,7 +20,7 @@ func TestTransaction(t *testing.T) {
|
|||
t.Fatalf("Should find saved record, but got %v", err)
|
||||
}
|
||||
|
||||
if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil {
|
||||
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
|
||||
t.Fatalf("Should return the underlying sql.Tx")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue