From 7851faa094ef6369caccd1b9ba08c344c00ca9f5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 21 Jun 2020 18:01:50 +0800 Subject: [PATCH] Allow close prepared statements, double check before prepare --- gorm.go | 4 ++-- prepare_stmt.go | 22 +++++++++++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/gorm.go b/gorm.go index e3193f59..6027b4bb 100644 --- a/gorm.go +++ b/gorm.go @@ -102,7 +102,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { if config.PrepareStmt { db.ConnPool = &PreparedStmtDB{ ConnPool: db.ConnPool, - stmts: map[string]*sql.Stmt{}, + Stmts: map[string]*sql.Stmt{}, } } @@ -146,7 +146,7 @@ func (db *DB) Session(config *Session) *DB { if config.PrepareStmt { tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, - stmts: map[string]*sql.Stmt{}, + Stmts: map[string]*sql.Stmt{}, } } diff --git a/prepare_stmt.go b/prepare_stmt.go index bc11abbf..ba9b04b6 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -7,23 +7,39 @@ import ( ) type PreparedStmtDB struct { - stmts map[string]*sql.Stmt + Stmts map[string]*sql.Stmt mux sync.RWMutex ConnPool } +func (db *PreparedStmtDB) Close() { + db.mux.Lock() + for k, stmt := range db.Stmts { + delete(db.Stmts, k) + stmt.Close() + } + + db.mux.Unlock() +} + func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { db.mux.RLock() - if stmt, ok := db.stmts[query]; ok { + if stmt, ok := db.Stmts[query]; ok { db.mux.RUnlock() return stmt, nil } db.mux.RUnlock() db.mux.Lock() + // double check + if stmt, ok := db.Stmts[query]; ok { + db.mux.Unlock() + return stmt, nil + } + stmt, err := db.ConnPool.PrepareContext(context.Background(), query) if err == nil { - db.stmts[query] = stmt + db.Stmts[query] = stmt } db.mux.Unlock()