diff --git a/gorm.go b/gorm.go index fec4310b..ed01ccfe 100644 --- a/gorm.go +++ b/gorm.go @@ -176,6 +176,8 @@ func (db *DB) Session(config *Session) *DB { Mux: preparedStmt.Mux, Stmts: preparedStmt.Stmts, } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true } } diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index af610165..6b10b6dc 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -12,6 +12,10 @@ import ( func TestPreparedStmt(t *testing.T) { tx := DB.Session(&gorm.Session{PrepareStmt: true}) + if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() txCtx := tx.WithContext(ctx)