diff --git a/finisher_api.go b/finisher_api.go index b97f2301..e493b406 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -310,28 +310,36 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } // Begin begins a transaction -func (db *DB) Begin(opts ...*sql.TxOptions) (tx *DB) { - tx = db.getInstance() - if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { - var opt *sql.TxOptions - var err error - if len(opts) > 0 { - opt = opts[0] - } +func (db *DB) Begin(opts ...*sql.TxOptions) *DB { + var ( + tx = db.getInstance() + opt *sql.TxOptions + err error + ) - if tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt); err != nil { - tx.AddError(err) - } - } else { - tx.AddError(ErrInvalidTransaction) + if len(opts) > 0 { + opt = opts[0] } - return + + if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { + tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) + } else { + err = ErrInvalidTransaction + } + + if err != nil { + tx.AddError(err) + } + + return tx } // Commit commit a transaction func (db *DB) Commit() *DB { - if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { - db.AddError(comminter.Commit()) + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + db.AddError(committer.Commit()) } else { db.AddError(ErrInvalidTransaction) } @@ -340,8 +348,8 @@ func (db *DB) Commit() *DB { // Rollback rollback a transaction func (db *DB) Rollback() *DB { - if comminter, ok := db.Statement.ConnPool.(TxCommiter); ok && comminter != nil { - db.AddError(comminter.Rollback()) + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + db.AddError(committer.Rollback()) } else { db.AddError(ErrInvalidTransaction) } diff --git a/gorm.go b/gorm.go index 8a801d68..e6a28635 100644 --- a/gorm.go +++ b/gorm.go @@ -2,6 +2,7 @@ package gorm import ( "context" + "database/sql" "fmt" "sync" "time" @@ -25,6 +26,9 @@ type Config struct { // DryRun generate sql without execute DryRun bool + // PrepareStmt executes the given query in cached statement + PrepareStmt bool + // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder // ConnPool db conn pool @@ -48,6 +52,7 @@ type DB struct { // Session session config when create session with Session() method type Session struct { DryRun bool + PrepareStmt bool WithConditions bool Context context.Context Logger logger.Interface @@ -92,6 +97,22 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { err = dialector.Initialize(db) } + if config.PrepareStmt { + db.ConnPool = &PreparedStmtDB{ + ConnPool: db.ConnPool, + stmts: map[string]*sql.Stmt{}, + } + } + + if db.Statement == nil { + db.Statement = &Statement{ + DB: db, + ConnPool: db.ConnPool, + Context: context.Background(), + Clauses: map[string]clause.Clause{}, + } + } + if err == nil { if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { err = pinger.Ping() @@ -131,6 +152,13 @@ func (db *DB) Session(config *Session) *DB { tx.Statement.Context = config.Context } + if config.PrepareStmt { + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + stmts: map[string]*sql.Stmt{}, + } + } + if config.WithConditions { tx.clone = 3 } @@ -256,6 +284,12 @@ func (db *DB) getInstance() *DB { switch db.clone { case 1: // clone with new statement + tx.Statement = &Statement{ + DB: tx, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + } case 2: // with old statement, generate new statement for future call, used to pass to callbacks db.clone = 1 tx.Statement = db.Statement @@ -266,21 +300,6 @@ func (db *DB) getInstance() *DB { } } - if tx.Statement == nil { - tx.Statement = &Statement{ - DB: tx, - Clauses: map[string]clause.Clause{}, - } - } - - if db.Statement != nil { - tx.Statement.Context = db.Statement.Context - tx.Statement.ConnPool = db.Statement.ConnPool - } else { - tx.Statement.Context = context.Background() - tx.Statement.ConnPool = db.ConnPool - } - return tx } diff --git a/interfaces.go b/interfaces.go index 6d9c6212..4be54565 100644 --- a/interfaces.go +++ b/interfaces.go @@ -21,8 +21,8 @@ type Dialector interface { // ConnPool db conns pool interface type ConnPool interface { - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } @@ -31,7 +31,11 @@ type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } -type TxCommiter interface { +type ConnPoolBeginner interface { + BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) +} + +type TxCommitter interface { Commit() error Rollback() error } diff --git a/prepare_stmt.go b/prepare_stmt.go new file mode 100644 index 00000000..bc11abbf --- /dev/null +++ b/prepare_stmt.go @@ -0,0 +1,92 @@ +package gorm + +import ( + "context" + "database/sql" + "sync" +) + +type PreparedStmtDB struct { + stmts map[string]*sql.Stmt + mux sync.RWMutex + ConnPool +} + +func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { + db.mux.RLock() + if stmt, ok := db.stmts[query]; ok { + db.mux.RUnlock() + return stmt, nil + } + db.mux.RUnlock() + + db.mux.Lock() + stmt, err := db.ConnPool.PrepareContext(context.Background(), query) + if err == nil { + db.stmts[query] = stmt + } + db.mux.Unlock() + + return stmt, err +} + +func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { + if beginner, ok := db.ConnPool.(TxBeginner); ok { + tx, err := beginner.BeginTx(ctx, opt) + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err + } + return nil, ErrInvalidTransaction +} + +func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + stmt, err := db.prepare(query) + if err == nil { + return stmt.ExecContext(ctx, args...) + } + return nil, err +} + +func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + stmt, err := db.prepare(query) + if err == nil { + return stmt.QueryContext(ctx, args...) + } + return nil, err +} + +func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := db.prepare(query) + if err == nil { + return stmt.QueryRowContext(ctx, args...) + } + return &sql.Row{} +} + +type PreparedStmtTX struct { + *sql.Tx + PreparedStmtDB *PreparedStmtDB +} + +func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + stmt, err := tx.PreparedStmtDB.prepare(query) + if err == nil { + return tx.Tx.Stmt(stmt).ExecContext(ctx, args...) + } + return nil, err +} + +func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + stmt, err := tx.PreparedStmtDB.prepare(query) + if err == nil { + return tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + } + return nil, err +} + +func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + stmt, err := tx.PreparedStmtDB.prepare(query) + if err == nil { + return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) + } + return &sql.Row{} +} diff --git a/tests/transaction_test.go b/tests/transaction_test.go index b810e3bb..0c04e2ed 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -1,7 +1,6 @@ package tests_test import ( - "database/sql" "errors" "testing" @@ -21,7 +20,7 @@ func TestTransaction(t *testing.T) { t.Fatalf("Should find saved record, but got %v", err) } - if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { + if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { t.Fatalf("Should return the underlying sql.Tx") }