diff --git a/prepare_stmt.go b/prepare_stmt.go index 3934bb97..7591e533 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -44,6 +44,18 @@ func (db *PreparedStmtDB) Close() { } } +func (db *PreparedStmtDB) Reset() { + db.Mux.Lock() + defer db.Mux.Unlock() + for query, stmt := range db.Stmts { + delete(db.Stmts, query) + go stmt.Close() + } + + db.PreparedSQL = make([]string, 0, 100) + db.Stmts = map[string](*Stmt){} +} + func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index c7f251f2..64baa01b 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -2,8 +2,8 @@ package tests_test import ( "context" - "sync" "errors" + "sync" "testing" "time" @@ -168,3 +168,29 @@ func TestPreparedStmtInTransaction(t *testing.T) { t.Errorf("Failed, got error: %v", err) } } + +func TestPreparedStmtReset(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + user := *GetUser("prepared_stmt_reset", Config{}) + tx = tx.Create(&user) + + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + pdb.Mux.Lock() + if len(pdb.Stmts) == 0 { + pdb.Mux.Unlock() + t.Fatalf("prepared stmt can not be empty") + } + pdb.Mux.Unlock() + + pdb.Reset() + pdb.Mux.Lock() + defer pdb.Mux.Unlock() + if len(pdb.Stmts) != 0 { + t.Fatalf("prepared stmt should be empty") + } +}