diff --git a/gorm.go b/gorm.go index 46d1843d..ecdb700b 100644 --- a/gorm.go +++ b/gorm.go @@ -187,15 +187,9 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } } - preparedStmt := &PreparedStmtDB{ - ConnPool: db.ConnPool, - Stmts: make(map[string]*Stmt), - Mux: &sync.RWMutex{}, - PreparedSQL: make([]string, 0, 100), - } - db.cacheStore.Store(preparedStmtDBKey, preparedStmt) - if config.PrepareStmt { + preparedStmt := NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) db.ConnPool = preparedStmt } @@ -256,24 +250,30 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { + var preparedStmt *PreparedStmtDB + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { - preparedStmt := v.(*PreparedStmtDB) - switch t := tx.Statement.ConnPool.(type) { - case Tx: - tx.Statement.ConnPool = &PreparedStmtTX{ - Tx: t, - PreparedStmtDB: preparedStmt, - } - default: - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Mux: preparedStmt.Mux, - Stmts: preparedStmt.Stmts, - } - } - txConfig.ConnPool = tx.Statement.ConnPool - txConfig.PrepareStmt = true + preparedStmt = v.(*PreparedStmtDB) + } else { + preparedStmt = NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) } + + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } + } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true } if config.SkipHooks { diff --git a/prepare_stmt.go b/prepare_stmt.go index 4b3551c6..10fefc31 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -21,6 +21,15 @@ type PreparedStmtDB struct { ConnPool } +func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { + return &PreparedStmtDB{ + ConnPool: connPool, + Stmts: make(map[string]*Stmt), + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), + } +} + func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn()