Merge pull request #744 from azavorotnii/ctx_cancel
Fix context cancellation racy handling
This commit is contained in:
commit
590d44c02b
96
sqlite3.go
96
sqlite3.go
|
@ -328,7 +328,7 @@ type SQLiteRows struct {
|
||||||
decltype []string
|
decltype []string
|
||||||
cls bool
|
cls bool
|
||||||
closed bool
|
closed bool
|
||||||
done chan struct{}
|
ctx context.Context // no better alternative to pass context into Next() method
|
||||||
}
|
}
|
||||||
|
|
||||||
type functionInfo struct {
|
type functionInfo struct {
|
||||||
|
@ -1847,22 +1847,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
|
||||||
decltype: nil,
|
decltype: nil,
|
||||||
cls: s.cls,
|
cls: s.cls,
|
||||||
closed: false,
|
closed: false,
|
||||||
done: make(chan struct{}),
|
ctx: ctx,
|
||||||
}
|
|
||||||
|
|
||||||
if ctxdone := ctx.Done(); ctxdone != nil {
|
|
||||||
go func(db *C.sqlite3) {
|
|
||||||
select {
|
|
||||||
case <-ctxdone:
|
|
||||||
select {
|
|
||||||
case <-rows.done:
|
|
||||||
default:
|
|
||||||
C.sqlite3_interrupt(db)
|
|
||||||
rows.Close()
|
|
||||||
}
|
|
||||||
case <-rows.done:
|
|
||||||
}
|
|
||||||
}(s.c.db)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return rows, nil
|
return rows, nil
|
||||||
|
@ -1890,29 +1875,43 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||||
return s.exec(context.Background(), list)
|
return s.exec(context.Background(), list)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// exec executes a query that doesn't return rows. Attempts to honor context timeout.
|
||||||
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
|
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
|
||||||
|
if ctx.Done() == nil {
|
||||||
|
return s.execSync(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
r driver.Result
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan result)
|
||||||
|
go func() {
|
||||||
|
r, err := s.execSync(args)
|
||||||
|
resultCh <- result{r, err}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case rv := <- resultCh:
|
||||||
|
return rv.r, rv.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
select {
|
||||||
|
case <-resultCh: // no need to interrupt
|
||||||
|
default:
|
||||||
|
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
|
||||||
|
C.sqlite3_interrupt(s.c.db)
|
||||||
|
<-resultCh // ensure goroutine completed
|
||||||
|
}
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
|
||||||
if err := s.bind(args); err != nil {
|
if err := s.bind(args); err != nil {
|
||||||
C.sqlite3_reset(s.s)
|
C.sqlite3_reset(s.s)
|
||||||
C.sqlite3_clear_bindings(s.s)
|
C.sqlite3_clear_bindings(s.s)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ctxdone := ctx.Done(); ctxdone != nil {
|
|
||||||
done := make(chan struct{})
|
|
||||||
defer close(done)
|
|
||||||
go func(db *C.sqlite3) {
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-ctxdone:
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
default:
|
|
||||||
C.sqlite3_interrupt(db)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(s.c.db)
|
|
||||||
}
|
|
||||||
|
|
||||||
var rowid, changes C.longlong
|
var rowid, changes C.longlong
|
||||||
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
|
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
|
||||||
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
||||||
|
@ -1933,9 +1932,6 @@ func (rc *SQLiteRows) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
rc.closed = true
|
rc.closed = true
|
||||||
if rc.done != nil {
|
|
||||||
close(rc.done)
|
|
||||||
}
|
|
||||||
if rc.cls {
|
if rc.cls {
|
||||||
rc.s.mu.Unlock()
|
rc.s.mu.Unlock()
|
||||||
return rc.s.Close()
|
return rc.s.Close()
|
||||||
|
@ -1979,13 +1975,39 @@ func (rc *SQLiteRows) DeclTypes() []string {
|
||||||
return rc.declTypes()
|
return rc.declTypes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next move cursor to next.
|
// Next move cursor to next. Attempts to honor context timeout from QueryContext call.
|
||||||
func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||||
rc.s.mu.Lock()
|
rc.s.mu.Lock()
|
||||||
defer rc.s.mu.Unlock()
|
defer rc.s.mu.Unlock()
|
||||||
|
|
||||||
if rc.s.closed {
|
if rc.s.closed {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rc.ctx.Done() == nil {
|
||||||
|
return rc.nextSyncLocked(dest)
|
||||||
|
}
|
||||||
|
resultCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
resultCh <- rc.nextSyncLocked(dest)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case err := <- resultCh:
|
||||||
|
return err
|
||||||
|
case <-rc.ctx.Done():
|
||||||
|
select {
|
||||||
|
case <-resultCh: // no need to interrupt
|
||||||
|
default:
|
||||||
|
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
|
||||||
|
C.sqlite3_interrupt(rc.s.c.db)
|
||||||
|
<-resultCh // ensure goroutine completed
|
||||||
|
}
|
||||||
|
return rc.ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextSyncLocked moves cursor to next; must be called with locked mutex.
|
||||||
|
func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
|
||||||
rv := C._sqlite3_step_internal(rc.s.s)
|
rv := C._sqlite3_step_internal(rc.s.s)
|
||||||
if rv == C.SQLITE_DONE {
|
if rv == C.SQLITE_DONE {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -135,6 +136,93 @@ func TestShortTimeout(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryRowContextCancel(t *testing.T) {
|
||||||
|
srcTempFilename := TempFilename(t)
|
||||||
|
defer os.Remove(srcTempFilename)
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", srcTempFilename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
initDatabase(t, db, 100)
|
||||||
|
|
||||||
|
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
|
||||||
|
var keyID string
|
||||||
|
unexpectedErrors := make(map[string]int)
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
row := db.QueryRowContext(ctx, query)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
// it is fine to get "nil" as context cancellation can be handled with delay
|
||||||
|
if err := row.Scan(&keyID); err != nil && err != context.Canceled {
|
||||||
|
if err.Error() == "sql: Rows are closed" {
|
||||||
|
// see https://github.com/golang/go/issues/24431
|
||||||
|
// fixed in 1.11.1 to properly return context error
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
unexpectedErrors[err.Error()]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for errText, count := range unexpectedErrors {
|
||||||
|
t.Error(errText, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowContextCancelParallel(t *testing.T) {
|
||||||
|
srcTempFilename := TempFilename(t)
|
||||||
|
defer os.Remove(srcTempFilename)
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", srcTempFilename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
db.SetMaxOpenConns(10)
|
||||||
|
db.SetMaxIdleConns(5)
|
||||||
|
|
||||||
|
defer db.Close()
|
||||||
|
initDatabase(t, db, 100)
|
||||||
|
|
||||||
|
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
testCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
var keyID string
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-testCtx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
row := db.QueryRowContext(ctx, query)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
_ = row.Scan(&keyID) // see TestQueryRowContextCancel
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
var keyID string
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
// note that testCtx is not cancelled during query execution
|
||||||
|
row := db.QueryRowContext(testCtx, query)
|
||||||
|
|
||||||
|
if err := row.Scan(&keyID); err != nil {
|
||||||
|
t.Fatal(i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestExecCancel(t *testing.T) {
|
func TestExecCancel(t *testing.T) {
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue