fix: prepare deadlock (#5568)

* fix: prepare deadlock

* chore[ci skip]: code style

* chore[ci skip]: test remove unnecessary params

* fix: prepare deadlock

* fix: double check prepare

* test: more goroutines

* chore[ci skip]: improve code comments

Co-authored-by: Jinzhu <wosmvp@gmail.com>
This commit is contained in:
Cr 2022-09-30 18:13:36 +08:00 committed by GitHub
parent a3cc6c6088
commit 0b7113b618
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 107 additions and 12 deletions

View File

@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
preparedStmt := &PreparedStmtDB{ preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool, ConnPool: db.ConnPool,
Stmts: map[string]Stmt{}, Stmts: map[string](*Stmt){},
Mux: &sync.RWMutex{}, Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100), PreparedSQL: make([]string, 0, 100),
} }

View File

@ -9,10 +9,12 @@ import (
type Stmt struct { type Stmt struct {
*sql.Stmt *sql.Stmt
Transaction bool Transaction bool
prepared chan struct{}
prepareErr error
} }
type PreparedStmtDB struct { type PreparedStmtDB struct {
Stmts map[string]Stmt Stmts map[string]*Stmt
PreparedSQL []string PreparedSQL []string
Mux *sync.RWMutex Mux *sync.RWMutex
ConnPool ConnPool
@ -46,27 +48,57 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Mux.RLock() db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock() db.Mux.RUnlock()
return stmt, nil // wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
} }
db.Mux.RUnlock() db.Mux.RUnlock()
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock()
// double check // double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
return stmt, nil db.Mux.Unlock()
} else if ok { // wait for other goroutines prepared
go stmt.Close() <-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
} }
return *stmt, nil
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeued because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query) stmt, err := conn.PrepareContext(ctx, query)
if err == nil { if err != nil {
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} cacheStmt.prepareErr = err
db.PreparedSQL = append(db.PreparedSQL, query) db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
} }
return db.Stmts[query], err db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()
return cacheStmt, nil
} }
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {

View File

@ -2,6 +2,7 @@ package tests_test
import ( import (
"context" "context"
"sync"
"errors" "errors"
"testing" "testing"
"time" "time"
@ -90,6 +91,68 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
tx2.Commit() tx2.Commit()
} }
func TestPreparedStmtDeadlock(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)
tx = tx.Session(&gorm.Session{PrepareStmt: true})
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
user := User{Name: "jinzhu"}
tx.Create(&user)
var result User
tx.First(&result)
wg.Done()
}()
}
wg.Wait()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 2)
for _, stmt := range conn.Stmts {
if stmt == nil {
t.Fatalf("stmt cannot bee nil")
}
}
AssertEqual(t, sqlDB.Stats().InUse, 0)
}
func TestPreparedStmtError(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)
tx = tx.Session(&gorm.Session{PrepareStmt: true})
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
// err prepare
tag := Tag{Locale: "zh"}
tx.Table("users").Find(&tag)
wg.Done()
}()
}
wg.Wait()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 0)
AssertEqual(t, sqlDB.Stats().InUse, 0)
}
func TestPreparedStmtInTransaction(t *testing.T) { func TestPreparedStmtInTransaction(t *testing.T) {
user := User{Name: "jinzhu"} user := User{Name: "jinzhu"}