diff --git a/prepare_stmt.go b/prepare_stmt.go index 4d533885..094bb477 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -17,18 +17,16 @@ type Stmt struct { } type PreparedStmtDB struct { - Stmts map[string]*Stmt - PreparedSQL []string - Mux *sync.RWMutex + Stmts map[string]*Stmt + Mux *sync.RWMutex ConnPool } func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { return &PreparedStmtDB{ - ConnPool: connPool, - Stmts: make(map[string]*Stmt), - Mux: &sync.RWMutex{}, - PreparedSQL: make([]string, 0, 100), + ConnPool: connPool, + Stmts: make(map[string]*Stmt), + Mux: &sync.RWMutex{}, } } @@ -48,12 +46,17 @@ func (db *PreparedStmtDB) Close() { db.Mux.Lock() defer db.Mux.Unlock() - for _, query := range db.PreparedSQL { - if stmt, ok := db.Stmts[query]; ok { - delete(db.Stmts, query) - go stmt.Close() - } + for _, stmt := range db.Stmts { + go func(s *Stmt) { + // make sure the stmt must finish preparation first + <-s.prepared + if s.Stmt != nil { + _ = s.Close() + } + }(stmt) } + // setting db.Stmts to nil to avoid further using + db.Stmts = nil } func (sdb *PreparedStmtDB) Reset() { @@ -61,9 +64,14 @@ func (sdb *PreparedStmtDB) Reset() { defer sdb.Mux.Unlock() for _, stmt := range sdb.Stmts { - go stmt.Close() + go func(s *Stmt) { + // make sure the stmt must finish preparation first + <-s.prepared + if s.Stmt != nil { + _ = s.Close() + } + }(stmt) } - sdb.PreparedSQL = make([]string, 0, 100) sdb.Stmts = make(map[string]*Stmt) } @@ -93,7 +101,12 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact return *stmt, nil } - + // check db.Stmts first to avoid Segmentation Fault(setting value to nil map) + // which cause by calling Close and executing SQL concurrently + if db.Stmts == nil { + db.Mux.Unlock() + return Stmt{}, ErrInvalidDB + } // cache preparing stmt first cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})} db.Stmts[query] = &cacheStmt @@ -118,7 +131,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Mux.Lock() cacheStmt.Stmt = stmt - db.PreparedSQL = append(db.PreparedSQL, query) db.Mux.Unlock() return cacheStmt, nil diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index b86bc3d6..20a4f730 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" "testing" "time" @@ -167,3 +168,149 @@ func TestPreparedStmtReset(t *testing.T) { t.Fatalf("prepared stmt should be empty") } } + +func isUsingClosedConnError(err error) bool { + // https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717 + return err.Error() == "sql: statement is closed" +} + +// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently +// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt +func TestPreparedStmtConcurrentReset(t *testing.T) { + name := "prepared_stmt_concurrent_reset" + user := *GetUser(name, Config{}) + createTx := DB.Session(&gorm.Session{}).Create(&user) + if createTx.Error != nil { + t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) + } + + // create a new connection to keep away from other tests + tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) + if err != nil { + t.Fatalf("failed to open test connection due to %s", err) + } + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + loopCount := 100 + var wg sync.WaitGroup + var unexpectedError bool + writerFinish := make(chan struct{}) + + wg.Add(1) + go func(id uint) { + defer wg.Done() + defer close(writerFinish) + + for j := 0; j < loopCount; j++ { + var tmp User + err := tx.Session(&gorm.Session{}).First(&tmp, id).Error + if err == nil || isUsingClosedConnError(err) { + continue + } + t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err) + unexpectedError = true + break + } + }(user.ID) + + wg.Add(1) + go func() { + defer wg.Done() + <-writerFinish + pdb.Reset() + }() + + wg.Wait() + + if unexpectedError { + t.Fatalf("should is a unexpected error") + } +} + +// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently +// for example: one goroutine found error and just close the database, and others are executing SQL +// this test making sure that the gorm would not get a Segmentation Fault, +// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB +// and all of the goroutine must got gorm.ErrInvalidDB after database close +func TestPreparedStmtConcurrentClose(t *testing.T) { + name := "prepared_stmt_concurrent_close" + user := *GetUser(name, Config{}) + createTx := DB.Session(&gorm.Session{}).Create(&user) + if createTx.Error != nil { + t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) + } + + // create a new connection to keep away from other tests + tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) + if err != nil { + t.Fatalf("failed to open test connection due to %s", err) + } + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) + if !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + loopCount := 100 + var wg sync.WaitGroup + var lastErr error + closeValid := make(chan struct{}, loopCount) + closeStartIdx := loopCount / 2 // close the database at the middle of the execution + var lastRunIndex int + var closeFinishedAt int64 + + wg.Add(1) + go func(id uint) { + defer wg.Done() + defer close(closeValid) + for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ { + if lastRunIndex == closeStartIdx { + closeValid <- struct{}{} + } + var tmp User + now := time.Now().UnixNano() + err := tx.Session(&gorm.Session{}).First(&tmp, id).Error + if err == nil { + closeFinishedAt := atomic.LoadInt64(&closeFinishedAt) + if (closeFinishedAt != 0) && (now > closeFinishedAt) { + lastErr = errors.New("must got error after database closed") + break + } + continue + } + lastErr = err + break + } + }(user.ID) + + wg.Add(1) + go func() { + defer wg.Done() + for range closeValid { + for i := 0; i < loopCount; i++ { + pdb.Close() // the Close method must can be call multiple times + atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano()) + } + } + }() + + wg.Wait() + var tmp User + err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error + if err != gorm.ErrInvalidDB { + t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err) + } + + // must be error + if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) { + t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr) + } + if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx { + t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex) + } + if pdb.Stmts != nil { + t.Fatalf("stmts must be nil") + } +}