Add PrepareStmt support

This commit is contained in:
Jinzhu 2020-06-05 10:08:22 +08:00
parent 9934207c42
commit c8e7878b3e
5 changed files with 159 additions and 37 deletions

View File

@ -310,28 +310,36 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
} }
// Begin begins a transaction // Begin begins a transaction
func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
tx = db.getInstance() var (
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { tx = db.getInstance()
var opt *sql.TxOptions opt *sql.TxOptions
var err error err error
if len(opts) > 0 { )
opt = opts[0]
}
if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil { if len(opts) > 0 {
tx.AddError(err) opt = opts[0]
}
} else {
tx.AddError(ErrInvalidTransaction)
} }
return
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else {
err = ErrInvalidTransaction
}
if err != nil {
tx.AddError(err)
}
return tx
} }
// Commit commit a transaction // Commit commit a transaction
func (db *DB) Commit() *DB { func (db *DB) Commit() *DB {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
db.AddError(comminter.Commit()) db.AddError(committer.Commit())
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)
} }
@ -340,8 +348,8 @@ func (db *DB) Commit() *DB {
// Rollback rollback a transaction // Rollback rollback a transaction
func (db *DB) Rollback() *DB { func (db *DB) Rollback() *DB {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
db.AddError(comminter.Rollback()) db.AddError(committer.Rollback())
} else { } else {
db.AddError(ErrInvalidTransaction) db.AddError(ErrInvalidTransaction)
} }

49
gorm.go
View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@ -25,6 +26,9 @@ type Config struct {
// DryRun generate sql without execute // DryRun generate sql without execute
DryRun bool DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// ClauseBuilders clause builder // ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool // ConnPool db conn pool
@ -48,6 +52,7 @@ type DB struct {
// Session session config when create session with Session() method // Session session config when create session with Session() method
type Session struct { type Session struct {
DryRun bool DryRun bool
PrepareStmt bool
WithConditions bool WithConditions bool
Context context.Context Context context.Context
Logger logger.Interface Logger logger.Interface
@ -92,6 +97,22 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
err = dialector.Initialize(db) err = dialector.Initialize(db)
} }
if config.PrepareStmt {
db.ConnPool = &PreparedStmtDB{
ConnPool: db.ConnPool,
stmts: map[string]*sql.Stmt{},
}
}
if db.Statement == nil {
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
}
if err == nil { if err == nil {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping() err = pinger.Ping()
@ -131,6 +152,13 @@ func (db *DB) Session(config *Session) *DB {
tx.Statement.Context = config.Context tx.Statement.Context = config.Context
} }
if config.PrepareStmt {
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
stmts: map[string]*sql.Stmt{},
}
}
if config.WithConditions { if config.WithConditions {
tx.clone = 3 tx.clone = 3
} }
@ -256,6 +284,12 @@ func (db *DB) getInstance() *DB {
switch db.clone { switch db.clone {
case 1: // clone with new statement case 1: // clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
}
case 2: // with old statement, generate new statement for future call, used to pass to callbacks case 2: // with old statement, generate new statement for future call, used to pass to callbacks
db.clone = 1 db.clone = 1
tx.Statement = db.Statement tx.Statement = db.Statement
@ -266,21 +300,6 @@ func (db *DB) getInstance() *DB {
} }
} }
if tx.Statement == nil {
tx.Statement = &Statement{
DB: tx,
Clauses: map[string]clause.Clause{},
}
}
if db.Statement != nil {
tx.Statement.Context = db.Statement.Context
tx.Statement.ConnPool = db.Statement.ConnPool
} else {
tx.Statement.Context = context.Background()
tx.Statement.ConnPool = db.ConnPool
}
return tx return tx
} }

View File

@ -21,8 +21,8 @@ type Dialector interface {
// ConnPool db conns pool interface // ConnPool db conns pool interface
type ConnPool interface { type ConnPool interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
} }
@ -31,7 +31,11 @@ type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
} }
type TxCommiter interface { type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
}
type TxCommitter interface {
Commit() error Commit() error
Rollback() error Rollback() error
} }

92
prepare_stmt.go Normal file
View File

@ -0,0 +1,92 @@
package gorm
import (
"context"
"database/sql"
"sync"
)
type PreparedStmtDB struct {
stmts map[string]*sql.Stmt
mux sync.RWMutex
ConnPool
}
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
db.mux.RLock()
if stmt, ok := db.stmts[query]; ok {
db.mux.RUnlock()
return stmt, nil
}
db.mux.RUnlock()
db.mux.Lock()
stmt, err := db.ConnPool.PrepareContext(context.Background(), query)
if err == nil {
db.stmts[query] = stmt
}
db.mux.Unlock()
return stmt, err
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
return nil, ErrInvalidTransaction
}
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
stmt, err := db.prepare(query)
if err == nil {
return stmt.ExecContext(ctx, args...)
}
return nil, err
}
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := db.prepare(query)
if err == nil {
return stmt.QueryContext(ctx, args...)
}
return nil, err
}
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(query)
if err == nil {
return stmt.QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
type PreparedStmtTX struct {
*sql.Tx
PreparedStmtDB *PreparedStmtDB
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil {
return tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
}
return nil, err
}
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil {
return tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
}
return nil, err
}
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil {
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...)
}
return &sql.Row{}
}

View File

@ -1,7 +1,6 @@
package tests_test package tests_test
import ( import (
"database/sql"
"errors" "errors"
"testing" "testing"
@ -21,7 +20,7 @@ func TestTransaction(t *testing.T) {
t.Fatalf("Should find saved record, but got %v", err) t.Fatalf("Should find saved record, but got %v", err)
} }
if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
t.Fatalf("Should return the underlying sql.Tx") t.Fatalf("Should return the underlying sql.Tx")
} }