diff --git a/prepare_stmt.go b/prepare_stmt.go index 5faea995..88bec4e9 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -32,14 +32,14 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { func (db *PreparedStmtDB) Close() { db.Mux.Lock() + defer db.Mux.Unlock() + for _, query := range db.PreparedSQL { if stmt, ok := db.Stmts[query]; ok { delete(db.Stmts, query) go stmt.Close() } } - - db.Mux.Unlock() } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { @@ -51,9 +51,10 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Mux.RUnlock() db.Mux.Lock() + defer db.Mux.Unlock() + // double check if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { - db.Mux.Unlock() return stmt, nil } else if ok { go stmt.Close() @@ -64,7 +65,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} db.PreparedSQL = append(db.PreparedSQL, query) } - defer db.Mux.Unlock() return db.Stmts[query], err } @@ -83,9 +83,9 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. result, err = stmt.ExecContext(ctx, args...) if err != nil { db.Mux.Lock() + defer db.Mux.Unlock() go stmt.Close() delete(db.Stmts, query) - db.Mux.Unlock() } } return result, err @@ -97,9 +97,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . rows, err = stmt.QueryContext(ctx, args...) if err != nil { db.Mux.Lock() + defer db.Mux.Unlock() + go stmt.Close() delete(db.Stmts, query) - db.Mux.Unlock() } } return rows, err @@ -138,9 +139,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() + defer tx.PreparedStmtDB.Mux.Unlock() + go stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.Mux.Unlock() } } return result, err @@ -152,9 +154,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() + defer tx.PreparedStmtDB.Mux.Unlock() + go stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.Mux.Unlock() } } return rows, err diff --git a/statement.go b/statement.go index 347f88ff..3b76f653 100644 --- a/statement.go +++ b/statement.go @@ -271,13 +271,19 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if _, err := strconv.Atoi(s); err != nil { if s == "" && len(args) == 0 { return nil - } else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { + } + + if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { // looks like a where condition return []clause.Expression{clause.Expr{SQL: s, Vars: args}} - } else if len(args) > 0 && strings.Contains(s, "@") { + } + + if len(args) > 0 && strings.Contains(s, "@") { // looks like a named query return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} - } else if len(args) == 1 { + } + + if len(args) == 1 { return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} } }