Refactor Prepared Statement

This commit is contained in:
Jinzhu 2020-07-28 14:26:09 +08:00
parent f4cfa9411b
commit c7667e9299
3 changed files with 25 additions and 12 deletions

14
gorm.go
View File

@ -108,11 +108,15 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
err = config.Dialector.Initialize(db) err = config.Dialector.Initialize(db)
} }
if config.PrepareStmt { preparedStmt := &PreparedStmtDB{
db.ConnPool = &PreparedStmtDB{
ConnPool: db.ConnPool, ConnPool: db.ConnPool,
Stmts: map[string]*sql.Stmt{}, Stmts: map[string]*sql.Stmt{},
PreparedSQL: make([]string, 0, 100),
} }
db.cacheStore.Store("preparedStmt", preparedStmt)
if config.PrepareStmt {
db.ConnPool = preparedStmt
} }
db.Statement = &Statement{ db.Statement = &Statement{
@ -157,9 +161,13 @@ func (db *DB) Session(config *Session) *DB {
} }
if config.PrepareStmt { if config.PrepareStmt {
if v, ok := db.cacheStore.Load("preparedStmt"); ok {
preparedStmt := v.(*PreparedStmtDB)
tx.Statement.ConnPool = &PreparedStmtDB{ tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool, ConnPool: db.Config.ConnPool,
Stmts: map[string]*sql.Stmt{}, mux: preparedStmt.mux,
Stmts: preparedStmt.Stmts,
}
} }
} }

View File

@ -8,16 +8,19 @@ import (
type PreparedStmtDB struct { type PreparedStmtDB struct {
Stmts map[string]*sql.Stmt Stmts map[string]*sql.Stmt
PreparedSQL []string
mux sync.RWMutex mux sync.RWMutex
ConnPool ConnPool
} }
func (db *PreparedStmtDB) Close() { func (db *PreparedStmtDB) Close() {
db.mux.Lock() db.mux.Lock()
for k, stmt := range db.Stmts { for _, query := range db.PreparedSQL {
delete(db.Stmts, k) if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
stmt.Close() stmt.Close()
} }
}
db.mux.Unlock() db.mux.Unlock()
} }
@ -40,6 +43,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) {
stmt, err := db.ConnPool.PrepareContext(context.Background(), query) stmt, err := db.ConnPool.PrepareContext(context.Background(), query)
if err == nil { if err == nil {
db.Stmts[query] = stmt db.Stmts[query] = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
} }
db.mux.Unlock() db.mux.Unlock()

1
tests/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
go.sum