From f93345afa8e17725660d370f52608c3b0014bdc0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 3 Jul 2020 10:26:18 +0800 Subject: [PATCH] Close cached prepared stmt when got error --- prepare_stmt.go | 78 +++++++++++++++++++++++-------------------------- 1 file changed, 36 insertions(+), 42 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index e017bb23..197c257c 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -54,41 +54,38 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn return nil, ErrInvalidTransaction } -func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { stmt, err := db.prepare(query) if err == nil { - return stmt.ExecContext(ctx, args...) - } else { - db.mux.Lock() - stmt.Close() - delete(db.Stmts, query) - db.mux.Unlock() + result, err = stmt.ExecContext(ctx, args...) + if err != nil { + db.mux.Lock() + stmt.Close() + delete(db.Stmts, query) + db.mux.Unlock() + } } - return nil, err + return result, err } -func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := db.prepare(query) if err == nil { - return stmt.QueryContext(ctx, args...) - } else { - db.mux.Lock() - stmt.Close() - delete(db.Stmts, query) - db.mux.Unlock() + rows, err = stmt.QueryContext(ctx, args...) + if err != nil { + db.mux.Lock() + stmt.Close() + delete(db.Stmts, query) + db.mux.Unlock() + } } - return nil, err + return rows, err } func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { stmt, err := db.prepare(query) if err == nil { return stmt.QueryRowContext(ctx, args...) - } else { - db.mux.Lock() - stmt.Close() - delete(db.Stmts, query) - db.mux.Unlock() } return &sql.Row{} } @@ -98,41 +95,38 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } -func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { stmt, err := tx.PreparedStmtDB.prepare(query) if err == nil { - return tx.Tx.Stmt(stmt).ExecContext(ctx, args...) - } else { - tx.PreparedStmtDB.mux.Lock() - stmt.Close() - delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() + result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.mux.Lock() + stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.mux.Unlock() + } } - return nil, err + return result, err } -func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { +func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(query) if err == nil { - return tx.Tx.Stmt(stmt).QueryContext(ctx, args...) - } else { - tx.PreparedStmtDB.mux.Lock() - stmt.Close() - delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() + rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + if err != nil { + tx.PreparedStmtDB.mux.Lock() + stmt.Close() + delete(tx.PreparedStmtDB.Stmts, query) + tx.PreparedStmtDB.mux.Unlock() + } } - return nil, err + return rows, err } func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { stmt, err := tx.PreparedStmtDB.prepare(query) if err == nil { return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) - } else { - tx.PreparedStmtDB.mux.Lock() - stmt.Close() - delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() } return &sql.Row{} }