mirror of https://github.com/go-gorm/gorm.git
fix memory leaks in PrepareStatementDB (#7142)
* fix memory leaks in PrepareStatementDB * Fix CR: 1) Fix potential Segmentation Fault in Reset function 2) Setting db.Stmts to nil map when Close to avoid further using * Add Test: 1) TestPreparedStmtConcurrentReset 2) TestPreparedStmtConcurrentClose * Fix test, create new connection to keep away from other tests --------- Co-authored-by: Zehui Chen <zehui@ssc-hn.com>
This commit is contained in:
parent
4a50b36f63
commit
0dbfda5d7e
|
@ -18,7 +18,6 @@ type Stmt struct {
|
||||||
|
|
||||||
type PreparedStmtDB struct {
|
type PreparedStmtDB struct {
|
||||||
Stmts map[string]*Stmt
|
Stmts map[string]*Stmt
|
||||||
PreparedSQL []string
|
|
||||||
Mux *sync.RWMutex
|
Mux *sync.RWMutex
|
||||||
ConnPool
|
ConnPool
|
||||||
}
|
}
|
||||||
|
@ -28,7 +27,6 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
||||||
ConnPool: connPool,
|
ConnPool: connPool,
|
||||||
Stmts: make(map[string]*Stmt),
|
Stmts: make(map[string]*Stmt),
|
||||||
Mux: &sync.RWMutex{},
|
Mux: &sync.RWMutex{},
|
||||||
PreparedSQL: make([]string, 0, 100),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,12 +46,17 @@ func (db *PreparedStmtDB) Close() {
|
||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
defer db.Mux.Unlock()
|
defer db.Mux.Unlock()
|
||||||
|
|
||||||
for _, query := range db.PreparedSQL {
|
for _, stmt := range db.Stmts {
|
||||||
if stmt, ok := db.Stmts[query]; ok {
|
go func(s *Stmt) {
|
||||||
delete(db.Stmts, query)
|
// make sure the stmt must finish preparation first
|
||||||
go stmt.Close()
|
<-s.prepared
|
||||||
|
if s.Stmt != nil {
|
||||||
|
_ = s.Close()
|
||||||
}
|
}
|
||||||
|
}(stmt)
|
||||||
}
|
}
|
||||||
|
// setting db.Stmts to nil to avoid further using
|
||||||
|
db.Stmts = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sdb *PreparedStmtDB) Reset() {
|
func (sdb *PreparedStmtDB) Reset() {
|
||||||
|
@ -61,9 +64,14 @@ func (sdb *PreparedStmtDB) Reset() {
|
||||||
defer sdb.Mux.Unlock()
|
defer sdb.Mux.Unlock()
|
||||||
|
|
||||||
for _, stmt := range sdb.Stmts {
|
for _, stmt := range sdb.Stmts {
|
||||||
go stmt.Close()
|
go func(s *Stmt) {
|
||||||
|
// make sure the stmt must finish preparation first
|
||||||
|
<-s.prepared
|
||||||
|
if s.Stmt != nil {
|
||||||
|
_ = s.Close()
|
||||||
|
}
|
||||||
|
}(stmt)
|
||||||
}
|
}
|
||||||
sdb.PreparedSQL = make([]string, 0, 100)
|
|
||||||
sdb.Stmts = make(map[string]*Stmt)
|
sdb.Stmts = make(map[string]*Stmt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,7 +101,12 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
|
||||||
|
|
||||||
return *stmt, nil
|
return *stmt, nil
|
||||||
}
|
}
|
||||||
|
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
|
||||||
|
// which cause by calling Close and executing SQL concurrently
|
||||||
|
if db.Stmts == nil {
|
||||||
|
db.Mux.Unlock()
|
||||||
|
return Stmt{}, ErrInvalidDB
|
||||||
|
}
|
||||||
// cache preparing stmt first
|
// cache preparing stmt first
|
||||||
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
|
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
|
||||||
db.Stmts[query] = &cacheStmt
|
db.Stmts[query] = &cacheStmt
|
||||||
|
@ -118,7 +131,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
|
||||||
|
|
||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
cacheStmt.Stmt = stmt
|
cacheStmt.Stmt = stmt
|
||||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
|
||||||
db.Mux.Unlock()
|
db.Mux.Unlock()
|
||||||
|
|
||||||
return cacheStmt, nil
|
return cacheStmt, nil
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -167,3 +168,149 @@ func TestPreparedStmtReset(t *testing.T) {
|
||||||
t.Fatalf("prepared stmt should be empty")
|
t.Fatalf("prepared stmt should be empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isUsingClosedConnError(err error) bool {
|
||||||
|
// https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717
|
||||||
|
return err.Error() == "sql: statement is closed"
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently
|
||||||
|
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
|
||||||
|
func TestPreparedStmtConcurrentReset(t *testing.T) {
|
||||||
|
name := "prepared_stmt_concurrent_reset"
|
||||||
|
user := *GetUser(name, Config{})
|
||||||
|
createTx := DB.Session(&gorm.Session{}).Create(&user)
|
||||||
|
if createTx.Error != nil {
|
||||||
|
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new connection to keep away from other tests
|
||||||
|
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open test connection due to %s", err)
|
||||||
|
}
|
||||||
|
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
loopCount := 100
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var unexpectedError bool
|
||||||
|
writerFinish := make(chan struct{})
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id uint) {
|
||||||
|
defer wg.Done()
|
||||||
|
defer close(writerFinish)
|
||||||
|
|
||||||
|
for j := 0; j < loopCount; j++ {
|
||||||
|
var tmp User
|
||||||
|
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
|
||||||
|
if err == nil || isUsingClosedConnError(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err)
|
||||||
|
unexpectedError = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}(user.ID)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-writerFinish
|
||||||
|
pdb.Reset()
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if unexpectedError {
|
||||||
|
t.Fatalf("should is a unexpected error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
|
||||||
|
// for example: one goroutine found error and just close the database, and others are executing SQL
|
||||||
|
// this test making sure that the gorm would not get a Segmentation Fault,
|
||||||
|
// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
|
||||||
|
// and all of the goroutine must got gorm.ErrInvalidDB after database close
|
||||||
|
func TestPreparedStmtConcurrentClose(t *testing.T) {
|
||||||
|
name := "prepared_stmt_concurrent_close"
|
||||||
|
user := *GetUser(name, Config{})
|
||||||
|
createTx := DB.Session(&gorm.Session{}).Create(&user)
|
||||||
|
if createTx.Error != nil {
|
||||||
|
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new connection to keep away from other tests
|
||||||
|
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open test connection due to %s", err)
|
||||||
|
}
|
||||||
|
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
loopCount := 100
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var lastErr error
|
||||||
|
closeValid := make(chan struct{}, loopCount)
|
||||||
|
closeStartIdx := loopCount / 2 // close the database at the middle of the execution
|
||||||
|
var lastRunIndex int
|
||||||
|
var closeFinishedAt int64
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id uint) {
|
||||||
|
defer wg.Done()
|
||||||
|
defer close(closeValid)
|
||||||
|
for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ {
|
||||||
|
if lastRunIndex == closeStartIdx {
|
||||||
|
closeValid <- struct{}{}
|
||||||
|
}
|
||||||
|
var tmp User
|
||||||
|
now := time.Now().UnixNano()
|
||||||
|
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
|
||||||
|
if err == nil {
|
||||||
|
closeFinishedAt := atomic.LoadInt64(&closeFinishedAt)
|
||||||
|
if (closeFinishedAt != 0) && (now > closeFinishedAt) {
|
||||||
|
lastErr = errors.New("must got error after database closed")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}(user.ID)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for range closeValid {
|
||||||
|
for i := 0; i < loopCount; i++ {
|
||||||
|
pdb.Close() // the Close method must can be call multiple times
|
||||||
|
atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
var tmp User
|
||||||
|
err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error
|
||||||
|
if err != gorm.ErrInvalidDB {
|
||||||
|
t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// must be error
|
||||||
|
if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) {
|
||||||
|
t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr)
|
||||||
|
}
|
||||||
|
if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx {
|
||||||
|
t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex)
|
||||||
|
}
|
||||||
|
if pdb.Stmts != nil {
|
||||||
|
t.Fatalf("stmts must be nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue