feat: ajust PreparedStmtDB unlock location and BuildCondition if logic (#4681)

This commit is contained in:
heige 2021-10-08 11:16:58 +08:00 committed by GitHub
parent c13f3011f9
commit e3fc49a694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 11 deletions

View File

@ -32,14 +32,14 @@ func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
func (db *PreparedStmtDB) Close() { func (db *PreparedStmtDB) Close() {
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock()
for _, query := range db.PreparedSQL { for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok { if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query) delete(db.Stmts, query)
go stmt.Close() go stmt.Close()
} }
} }
db.Mux.Unlock()
} }
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
@ -51,9 +51,10 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Mux.RUnlock() db.Mux.RUnlock()
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock()
// double check // double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
return stmt, nil return stmt, nil
} else if ok { } else if ok {
go stmt.Close() go stmt.Close()
@ -64,7 +65,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
db.PreparedSQL = append(db.PreparedSQL, query) db.PreparedSQL = append(db.PreparedSQL, query)
} }
defer db.Mux.Unlock()
return db.Stmts[query], err return db.Stmts[query], err
} }
@ -83,9 +83,9 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
result, err = stmt.ExecContext(ctx, args...) result, err = stmt.ExecContext(ctx, args...)
if err != nil { if err != nil {
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close() go stmt.Close()
delete(db.Stmts, query) delete(db.Stmts, query)
db.Mux.Unlock()
} }
} }
return result, err return result, err
@ -97,9 +97,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
rows, err = stmt.QueryContext(ctx, args...) rows, err = stmt.QueryContext(ctx, args...)
if err != nil { if err != nil {
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close() go stmt.Close()
delete(db.Stmts, query) delete(db.Stmts, query)
db.Mux.Unlock()
} }
} }
return rows, err return rows, err
@ -138,9 +139,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if err != nil { if err != nil {
tx.PreparedStmtDB.Mux.Lock() tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close() go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query) delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.Mux.Unlock()
} }
} }
return result, err return result, err
@ -152,9 +154,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
if err != nil { if err != nil {
tx.PreparedStmtDB.Mux.Lock() tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close() go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query) delete(tx.PreparedStmtDB.Stmts, query)
tx.PreparedStmtDB.Mux.Unlock()
} }
} }
return rows, err return rows, err

View File

@ -271,13 +271,19 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if _, err := strconv.Atoi(s); err != nil { if _, err := strconv.Atoi(s); err != nil {
if s == "" && len(args) == 0 { if s == "" && len(args) == 0 {
return nil return nil
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { }
if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
// looks like a where condition // looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}} return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
} else if len(args) > 0 && strings.Contains(s, "@") { }
if len(args) > 0 && strings.Contains(s, "@") {
// looks like a named query // looks like a named query
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
} else if len(args) == 1 { }
if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
} }
} }