Add PrepareStmt support

This commit is contained in:
Jinzhu 2020-06-05 10:08:22 +08:00
parent 9934207c42
commit c8e7878b3e
5 changed files with 159 additions and 37 deletions

View File

@ -310,28 +310,36 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
}
// Begin begins a transaction
func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
tx = db.getInstance()
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
var opt *sql.TxOptions
var err error
if len(opts) > 0 {
opt = opts[0]
}
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var (
tx = db.getInstance()
opt *sql.TxOptions
err error
)
if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil {
tx.AddError(err)
}
} else {
tx.AddError(ErrInvalidTransaction)
if len(opts) > 0 {
opt = opts[0]
}
return
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
} else {
err = ErrInvalidTransaction
}
if err != nil {
tx.AddError(err)
}
return tx
}
// Commit commit a transaction
func (db *DB) Commit() *DB {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
db.AddError(comminter.Commit())
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
db.AddError(committer.Commit())
} else {
db.AddError(ErrInvalidTransaction)
}
@ -340,8 +348,8 @@ func (db *DB) Commit() *DB {
// Rollback rollback a transaction
func (db *DB) Rollback() *DB {
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
db.AddError(comminter.Rollback())
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
db.AddError(committer.Rollback())
} else {
db.AddError(ErrInvalidTransaction)
}

49
gorm.go
View File

@ -2,6 +2,7 @@ package gorm
import (
"context"
"database/sql"
"fmt"
"sync"
"time"
@ -25,6 +26,9 @@ type Config struct {
// DryRun generate sql without execute
DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool
@ -48,6 +52,7 @@ type DB struct {
// Session session config when create session with Session() method
type Session struct {
DryRun bool
PrepareStmt bool
WithConditions bool
Context context.Context
Logger logger.Interface
@ -92,6 +97,22 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
err = dialector.Initialize(db)
}
if config.PrepareStmt {
db.ConnPool = &PreparedStmtDB{
ConnPool: db.ConnPool,
stmts: map[string]*sql.Stmt{},
}
}
if db.Statement == nil {
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
}
if err == nil {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping()
@ -131,6 +152,13 @@ func (db *DB) Session(config *Session) *DB {
tx.Statement.Context = config.Context
}
if config.PrepareStmt {
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
stmts: map[string]*sql.Stmt{},
}
}
if config.WithConditions {
tx.clone = 3
}
@ -256,6 +284,12 @@ func (db *DB) getInstance() *DB {
switch db.clone {
case 1: // clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
}
case 2: // with old statement, generate new statement for future call, used to pass to callbacks
db.clone = 1
tx.Statement = db.Statement
@ -266,21 +300,6 @@ func (db *DB) getInstance() *DB {
}
}
if tx.Statement == nil {
tx.Statement = &Statement{
DB: tx,
Clauses: map[string]clause.Clause{},
}
}
if db.Statement != nil {
tx.Statement.Context = db.Statement.Context
tx.Statement.ConnPool = db.Statement.ConnPool
} else {
tx.Statement.Context = context.Background()
tx.Statement.ConnPool = db.ConnPool
}
return tx
}

View File

@ -21,8 +21,8 @@ type Dialector interface {
// ConnPool db conns pool interface
type ConnPool interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
@ -31,7 +31,11 @@ type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
type TxCommiter interface {
type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
}
type TxCommitter interface {
Commit() error
Rollback() error
}

92
prepare_stmt.go Normal file
View File

@ -0,0 +1,92 @@
package gorm
import (
"context"
"database/sql"
"sync"
)
type PreparedStmtDB struct {
stmts map[string]*sql.Stmt
mux sync.RWMutex
ConnPool
}
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
db.mux.RLock()
if stmt, ok := db.stmts[query]; ok {
db.mux.RUnlock()
return stmt, nil
}
db.mux.RUnlock()
db.mux.Lock()
stmt, err := db.ConnPool.PrepareContext(context.Background(), query)
if err == nil {
db.stmts[query] = stmt
}
db.mux.Unlock()
return stmt, err
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
return nil, ErrInvalidTransaction
}
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
stmt, err := db.prepare(query)
if err == nil {
return stmt.ExecContext(ctx, args...)
}
return nil, err
}
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := db.prepare(query)
if err == nil {
return stmt.QueryContext(ctx, args...)
}
return nil, err
}
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(query)
if err == nil {
return stmt.QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
type PreparedStmtTX struct {
*sql.Tx
PreparedStmtDB *PreparedStmtDB
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil {
return tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
}
return nil, err
}
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil {
return tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
}
return nil, err
}
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil {
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...)
}
return &sql.Row{}
}

View File

@ -1,7 +1,6 @@
package tests_test
import (
"database/sql"
"errors"
"testing"
@ -21,7 +20,7 @@ func TestTransaction(t *testing.T) {
t.Fatalf("Should find saved record, but got %v", err)
}
if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil {
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
t.Fatalf("Should return the underlying sql.Tx")
}