Close cached prepared stmt when got error

This commit is contained in:
Jinzhu 2020-07-03 10:26:18 +08:00
parent 8100ac7663
commit f93345afa8
1 changed files with 36 additions and 42 deletions

View File

@ -54,41 +54,38 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
return nil, ErrInvalidTransaction return nil, ErrInvalidTransaction
} }
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := db.prepare(query) stmt, err := db.prepare(query)
if err == nil { if err == nil {
return stmt.ExecContext(ctx, args...) result, err = stmt.ExecContext(ctx, args...)
} else { if err != nil {
db.mux.Lock() db.mux.Lock()
stmt.Close() stmt.Close()
delete(db.Stmts, query) delete(db.Stmts, query)
db.mux.Unlock() db.mux.Unlock()
}
} }
return nil, err return result, err
} }
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := db.prepare(query) stmt, err := db.prepare(query)
if err == nil { if err == nil {
return stmt.QueryContext(ctx, args...) rows, err = stmt.QueryContext(ctx, args...)
} else { if err != nil {
db.mux.Lock() db.mux.Lock()
stmt.Close() stmt.Close()
delete(db.Stmts, query) delete(db.Stmts, query)
db.mux.Unlock() db.mux.Unlock()
}
} }
return nil, err return rows, err
} }
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(query) stmt, err := db.prepare(query)
if err == nil { if err == nil {
return stmt.QueryRowContext(ctx, args...) return stmt.QueryRowContext(ctx, args...)
} else {
db.mux.Lock()
stmt.Close()
delete(db.Stmts, query)
db.mux.Unlock()
} }
return &sql.Row{} return &sql.Row{}
} }
@ -98,41 +95,38 @@ type PreparedStmtTX struct {
PreparedStmtDB *PreparedStmtDB PreparedStmtDB *PreparedStmtDB
} }
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(query) stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil { if err == nil {
return tx.Tx.Stmt(stmt).ExecContext(ctx, args...) result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...)
} else { if err != nil {
tx.PreparedStmtDB.mux.Lock() tx.PreparedStmtDB.mux.Lock()
stmt.Close() stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query) delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock() tx.PreparedStmtDB.mux.Unlock()
}
} }
return nil, err return result, err
} }
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(query) stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil { if err == nil {
return tx.Tx.Stmt(stmt).QueryContext(ctx, args...) rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...)
} else { if err != nil {
tx.PreparedStmtDB.mux.Lock() tx.PreparedStmtDB.mux.Lock()
stmt.Close() stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query) delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock() tx.PreparedStmtDB.mux.Unlock()
}
} }
return nil, err return rows, err
} }
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(query) stmt, err := tx.PreparedStmtDB.prepare(query)
if err == nil { if err == nil {
return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...)
} else {
tx.PreparedStmtDB.mux.Lock()
stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.mux.Unlock()
} }
return &sql.Row{} return &sql.Row{}
} }