diff --git a/gorm.go b/gorm.go index 338a1473..c786b5a5 100644 --- a/gorm.go +++ b/gorm.go @@ -108,11 +108,15 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { err = config.Dialector.Initialize(db) } + preparedStmt := &PreparedStmtDB{ + ConnPool: db.ConnPool, + Stmts: map[string]*sql.Stmt{}, + PreparedSQL: make([]string, 0, 100), + } + db.cacheStore.Store("preparedStmt", preparedStmt) + if config.PrepareStmt { - db.ConnPool = &PreparedStmtDB{ - ConnPool: db.ConnPool, - Stmts: map[string]*sql.Stmt{}, - } + db.ConnPool = preparedStmt } db.Statement = &Statement{ @@ -157,9 +161,13 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Stmts: map[string]*sql.Stmt{}, + if v, ok := db.cacheStore.Load("preparedStmt"); ok { + preparedStmt := v.(*PreparedStmtDB) + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + mux: preparedStmt.mux, + Stmts: preparedStmt.Stmts, + } } } diff --git a/prepare_stmt.go b/prepare_stmt.go index 197c257c..2f4e1d57 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -7,16 +7,19 @@ import ( ) type PreparedStmtDB struct { - Stmts map[string]*sql.Stmt - mux sync.RWMutex + Stmts map[string]*sql.Stmt + PreparedSQL []string + mux sync.RWMutex ConnPool } func (db *PreparedStmtDB) Close() { db.mux.Lock() - for k, stmt := range db.Stmts { - delete(db.Stmts, k) - stmt.Close() + for _, query := range db.PreparedSQL { + if stmt, ok := db.Stmts[query]; ok { + delete(db.Stmts, query) + stmt.Close() + } } db.mux.Unlock() @@ -40,6 +43,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { stmt, err := db.ConnPool.PrepareContext(context.Background(), query) if err == nil { db.Stmts[query] = stmt + db.PreparedSQL = append(db.PreparedSQL, query) } db.mux.Unlock() diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..08cb523c --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +go.sum