forked from mirror/gorm
Add PrepareStmt support
This commit is contained in:
parent
9934207c42
commit
c8e7878b3e
|
@ -310,28 +310,36 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin begins a transaction
|
// Begin begins a transaction
|
||||||
func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) {
|
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||||
tx = db.getInstance()
|
var (
|
||||||
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
|
tx = db.getInstance()
|
||||||
var opt *sql.TxOptions
|
opt *sql.TxOptions
|
||||||
var err error
|
err error
|
||||||
if len(opts) > 0 {
|
)
|
||||||
opt = opts[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil {
|
if len(opts) > 0 {
|
||||||
tx.AddError(err)
|
opt = opts[0]
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tx.AddError(ErrInvalidTransaction)
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
|
||||||
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
|
} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
|
||||||
|
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||||
|
} else {
|
||||||
|
err = ErrInvalidTransaction
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
tx.AddError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit commit a transaction
|
// Commit commit a transaction
|
||||||
func (db *DB) Commit() *DB {
|
func (db *DB) Commit() *DB {
|
||||||
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||||
db.AddError(comminter.Commit())
|
db.AddError(committer.Commit())
|
||||||
} else {
|
} else {
|
||||||
db.AddError(ErrInvalidTransaction)
|
db.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
|
@ -340,8 +348,8 @@ func (db *DB) Commit() *DB {
|
||||||
|
|
||||||
// Rollback rollback a transaction
|
// Rollback rollback a transaction
|
||||||
func (db *DB) Rollback() *DB {
|
func (db *DB) Rollback() *DB {
|
||||||
if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||||
db.AddError(comminter.Rollback())
|
db.AddError(committer.Rollback())
|
||||||
} else {
|
} else {
|
||||||
db.AddError(ErrInvalidTransaction)
|
db.AddError(ErrInvalidTransaction)
|
||||||
}
|
}
|
||||||
|
|
49
gorm.go
49
gorm.go
|
@ -2,6 +2,7 @@ package gorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -25,6 +26,9 @@ type Config struct {
|
||||||
// DryRun generate sql without execute
|
// DryRun generate sql without execute
|
||||||
DryRun bool
|
DryRun bool
|
||||||
|
|
||||||
|
// PrepareStmt executes the given query in cached statement
|
||||||
|
PrepareStmt bool
|
||||||
|
|
||||||
// ClauseBuilders clause builder
|
// ClauseBuilders clause builder
|
||||||
ClauseBuilders map[string]clause.ClauseBuilder
|
ClauseBuilders map[string]clause.ClauseBuilder
|
||||||
// ConnPool db conn pool
|
// ConnPool db conn pool
|
||||||
|
@ -48,6 +52,7 @@ type DB struct {
|
||||||
// Session session config when create session with Session() method
|
// Session session config when create session with Session() method
|
||||||
type Session struct {
|
type Session struct {
|
||||||
DryRun bool
|
DryRun bool
|
||||||
|
PrepareStmt bool
|
||||||
WithConditions bool
|
WithConditions bool
|
||||||
Context context.Context
|
Context context.Context
|
||||||
Logger logger.Interface
|
Logger logger.Interface
|
||||||
|
@ -92,6 +97,22 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
|
||||||
err = dialector.Initialize(db)
|
err = dialector.Initialize(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.PrepareStmt {
|
||||||
|
db.ConnPool = &PreparedStmtDB{
|
||||||
|
ConnPool: db.ConnPool,
|
||||||
|
stmts: map[string]*sql.Stmt{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if db.Statement == nil {
|
||||||
|
db.Statement = &Statement{
|
||||||
|
DB: db,
|
||||||
|
ConnPool: db.ConnPool,
|
||||||
|
Context: context.Background(),
|
||||||
|
Clauses: map[string]clause.Clause{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
|
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
|
||||||
err = pinger.Ping()
|
err = pinger.Ping()
|
||||||
|
@ -131,6 +152,13 @@ func (db *DB) Session(config *Session) *DB {
|
||||||
tx.Statement.Context = config.Context
|
tx.Statement.Context = config.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.PrepareStmt {
|
||||||
|
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||||
|
ConnPool: db.Config.ConnPool,
|
||||||
|
stmts: map[string]*sql.Stmt{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if config.WithConditions {
|
if config.WithConditions {
|
||||||
tx.clone = 3
|
tx.clone = 3
|
||||||
}
|
}
|
||||||
|
@ -256,6 +284,12 @@ func (db *DB) getInstance() *DB {
|
||||||
|
|
||||||
switch db.clone {
|
switch db.clone {
|
||||||
case 1: // clone with new statement
|
case 1: // clone with new statement
|
||||||
|
tx.Statement = &Statement{
|
||||||
|
DB: tx,
|
||||||
|
ConnPool: db.Statement.ConnPool,
|
||||||
|
Context: db.Statement.Context,
|
||||||
|
Clauses: map[string]clause.Clause{},
|
||||||
|
}
|
||||||
case 2: // with old statement, generate new statement for future call, used to pass to callbacks
|
case 2: // with old statement, generate new statement for future call, used to pass to callbacks
|
||||||
db.clone = 1
|
db.clone = 1
|
||||||
tx.Statement = db.Statement
|
tx.Statement = db.Statement
|
||||||
|
@ -266,21 +300,6 @@ func (db *DB) getInstance() *DB {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if tx.Statement == nil {
|
|
||||||
tx.Statement = &Statement{
|
|
||||||
DB: tx,
|
|
||||||
Clauses: map[string]clause.Clause{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if db.Statement != nil {
|
|
||||||
tx.Statement.Context = db.Statement.Context
|
|
||||||
tx.Statement.ConnPool = db.Statement.ConnPool
|
|
||||||
} else {
|
|
||||||
tx.Statement.Context = context.Background()
|
|
||||||
tx.Statement.ConnPool = db.ConnPool
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@ type Dialector interface {
|
||||||
|
|
||||||
// ConnPool db conns pool interface
|
// ConnPool db conns pool interface
|
||||||
type ConnPool interface {
|
type ConnPool interface {
|
||||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
|
||||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||||
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,11 @@ type TxBeginner interface {
|
||||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TxCommiter interface {
|
type ConnPoolBeginner interface {
|
||||||
|
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TxCommitter interface {
|
||||||
Commit() error
|
Commit() error
|
||||||
Rollback() error
|
Rollback() error
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PreparedStmtDB struct {
|
||||||
|
stmts map[string]*sql.Stmt
|
||||||
|
mux sync.RWMutex
|
||||||
|
ConnPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
|
||||||
|
db.mux.RLock()
|
||||||
|
if stmt, ok := db.stmts[query]; ok {
|
||||||
|
db.mux.RUnlock()
|
||||||
|
return stmt, nil
|
||||||
|
}
|
||||||
|
db.mux.RUnlock()
|
||||||
|
|
||||||
|
db.mux.Lock()
|
||||||
|
stmt, err := db.ConnPool.PrepareContext(context.Background(), query)
|
||||||
|
if err == nil {
|
||||||
|
db.stmts[query] = stmt
|
||||||
|
}
|
||||||
|
db.mux.Unlock()
|
||||||
|
|
||||||
|
return stmt, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||||
|
if beginner, ok := db.ConnPool.(TxBeginner); ok {
|
||||||
|
tx, err := beginner.BeginTx(ctx, opt)
|
||||||
|
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||||
|
}
|
||||||
|
return nil, ErrInvalidTransaction
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
stmt, err := db.prepare(query)
|
||||||
|
if err == nil {
|
||||||
|
return stmt.ExecContext(ctx, args...)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
stmt, err := db.prepare(query)
|
||||||
|
if err == nil {
|
||||||
|
return stmt.QueryContext(ctx, args...)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||||
|
stmt, err := db.prepare(query)
|
||||||
|
if err == nil {
|
||||||
|
return stmt.QueryRowContext(ctx, args...)
|
||||||
|
}
|
||||||
|
return &sql.Row{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreparedStmtTX struct {
|
||||||
|
*sql.Tx
|
||||||
|
PreparedStmtDB *PreparedStmtDB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
|
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||||
|
if err == nil {
|
||||||
|
return tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||||
|
if err == nil {
|
||||||
|
return tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||||
|
stmt, err := tx.PreparedStmtDB.prepare(query)
|
||||||
|
if err == nil {
|
||||||
|
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...)
|
||||||
|
}
|
||||||
|
return &sql.Row{}
|
||||||
|
}
|
|
@ -1,7 +1,6 @@
|
||||||
package tests_test
|
package tests_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"errors"
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -21,7 +20,7 @@ func TestTransaction(t *testing.T) {
|
||||||
t.Fatalf("Should find saved record, but got %v", err)
|
t.Fatalf("Should find saved record, but got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil {
|
if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
|
||||||
t.Fatalf("Should return the underlying sql.Tx")
|
t.Fatalf("Should return the underlying sql.Tx")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue