forked from mirror/go-sqlcipher
Fix context cancellation racy handling
[why] Context cancellation goroutine is not in sync with Next() method lifetime. It leads to sql.ErrNoRows instead of context.Canceled often (easy to reproduce). It leads to interruption of next query executed on same connection (harder to reproduce). [how] Do query in goroutine, wait when interruption done. [testing] Add unit test that reproduces error cases.
This commit is contained in:
parent
d3c690956b
commit
7e1a61dbcd
96
sqlite3.go
96
sqlite3.go
|
@ -328,7 +328,7 @@ type SQLiteRows struct {
|
||||||
decltype []string
|
decltype []string
|
||||||
cls bool
|
cls bool
|
||||||
closed bool
|
closed bool
|
||||||
done chan struct{}
|
ctx context.Context // no better alternative to pass context into Next() method
|
||||||
}
|
}
|
||||||
|
|
||||||
type functionInfo struct {
|
type functionInfo struct {
|
||||||
|
@ -1846,22 +1846,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
|
||||||
decltype: nil,
|
decltype: nil,
|
||||||
cls: s.cls,
|
cls: s.cls,
|
||||||
closed: false,
|
closed: false,
|
||||||
done: make(chan struct{}),
|
ctx: ctx,
|
||||||
}
|
|
||||||
|
|
||||||
if ctxdone := ctx.Done(); ctxdone != nil {
|
|
||||||
go func(db *C.sqlite3) {
|
|
||||||
select {
|
|
||||||
case <-ctxdone:
|
|
||||||
select {
|
|
||||||
case <-rows.done:
|
|
||||||
default:
|
|
||||||
C.sqlite3_interrupt(db)
|
|
||||||
rows.Close()
|
|
||||||
}
|
|
||||||
case <-rows.done:
|
|
||||||
}
|
|
||||||
}(s.c.db)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return rows, nil
|
return rows, nil
|
||||||
|
@ -1889,29 +1874,43 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||||
return s.exec(context.Background(), list)
|
return s.exec(context.Background(), list)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// exec executes a query that doesn't return rows. Attempts to honor context timeout.
|
||||||
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
|
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
|
||||||
|
if ctx.Done() == nil {
|
||||||
|
return s.execSync(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
r driver.Result
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan result)
|
||||||
|
go func() {
|
||||||
|
r, err := s.execSync(args)
|
||||||
|
resultCh <- result{r, err}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case rv := <- resultCh:
|
||||||
|
return rv.r, rv.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
select {
|
||||||
|
case <-resultCh: // no need to interrupt
|
||||||
|
default:
|
||||||
|
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
|
||||||
|
C.sqlite3_interrupt(s.c.db)
|
||||||
|
<-resultCh // ensure goroutine completed
|
||||||
|
}
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
|
||||||
if err := s.bind(args); err != nil {
|
if err := s.bind(args); err != nil {
|
||||||
C.sqlite3_reset(s.s)
|
C.sqlite3_reset(s.s)
|
||||||
C.sqlite3_clear_bindings(s.s)
|
C.sqlite3_clear_bindings(s.s)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ctxdone := ctx.Done(); ctxdone != nil {
|
|
||||||
done := make(chan struct{})
|
|
||||||
defer close(done)
|
|
||||||
go func(db *C.sqlite3) {
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-ctxdone:
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
default:
|
|
||||||
C.sqlite3_interrupt(db)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(s.c.db)
|
|
||||||
}
|
|
||||||
|
|
||||||
var rowid, changes C.longlong
|
var rowid, changes C.longlong
|
||||||
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
|
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
|
||||||
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
||||||
|
@ -1932,9 +1931,6 @@ func (rc *SQLiteRows) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
rc.closed = true
|
rc.closed = true
|
||||||
if rc.done != nil {
|
|
||||||
close(rc.done)
|
|
||||||
}
|
|
||||||
if rc.cls {
|
if rc.cls {
|
||||||
rc.s.mu.Unlock()
|
rc.s.mu.Unlock()
|
||||||
return rc.s.Close()
|
return rc.s.Close()
|
||||||
|
@ -1978,13 +1974,39 @@ func (rc *SQLiteRows) DeclTypes() []string {
|
||||||
return rc.declTypes()
|
return rc.declTypes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next move cursor to next.
|
// Next move cursor to next. Attempts to honor context timeout from QueryContext call.
|
||||||
func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||||
rc.s.mu.Lock()
|
rc.s.mu.Lock()
|
||||||
defer rc.s.mu.Unlock()
|
defer rc.s.mu.Unlock()
|
||||||
|
|
||||||
if rc.s.closed {
|
if rc.s.closed {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rc.ctx.Done() == nil {
|
||||||
|
return rc.nextSyncLocked(dest)
|
||||||
|
}
|
||||||
|
resultCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
resultCh <- rc.nextSyncLocked(dest)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case err := <- resultCh:
|
||||||
|
return err
|
||||||
|
case <-rc.ctx.Done():
|
||||||
|
select {
|
||||||
|
case <-resultCh: // no need to interrupt
|
||||||
|
default:
|
||||||
|
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
|
||||||
|
C.sqlite3_interrupt(rc.s.c.db)
|
||||||
|
<-resultCh // ensure goroutine completed
|
||||||
|
}
|
||||||
|
return rc.ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nextSyncLocked moves cursor to next; must be called with locked mutex.
|
||||||
|
func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
|
||||||
rv := C._sqlite3_step_internal(rc.s.s)
|
rv := C._sqlite3_step_internal(rc.s.s)
|
||||||
if rv == C.SQLITE_DONE {
|
if rv == C.SQLITE_DONE {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -135,6 +136,93 @@ func TestShortTimeout(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryRowContextCancel(t *testing.T) {
|
||||||
|
srcTempFilename := TempFilename(t)
|
||||||
|
defer os.Remove(srcTempFilename)
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", srcTempFilename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
initDatabase(t, db, 100)
|
||||||
|
|
||||||
|
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
|
||||||
|
var keyID string
|
||||||
|
unexpectedErrors := make(map[string]int)
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
row := db.QueryRowContext(ctx, query)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
// it is fine to get "nil" as context cancellation can be handled with delay
|
||||||
|
if err := row.Scan(&keyID); err != nil && err != context.Canceled {
|
||||||
|
if err.Error() == "sql: Rows are closed" {
|
||||||
|
// see https://github.com/golang/go/issues/24431
|
||||||
|
// fixed in 1.11.1 to properly return context error
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
unexpectedErrors[err.Error()]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for errText, count := range unexpectedErrors {
|
||||||
|
t.Error(errText, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryRowContextCancelParallel(t *testing.T) {
|
||||||
|
srcTempFilename := TempFilename(t)
|
||||||
|
defer os.Remove(srcTempFilename)
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3", srcTempFilename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
db.SetMaxOpenConns(10)
|
||||||
|
db.SetMaxIdleConns(5)
|
||||||
|
|
||||||
|
defer db.Close()
|
||||||
|
initDatabase(t, db, 100)
|
||||||
|
|
||||||
|
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
testCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
var keyID string
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-testCtx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
row := db.QueryRowContext(ctx, query)
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
_ = row.Scan(&keyID) // see TestQueryRowContextCancel
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
var keyID string
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
// note that testCtx is not cancelled during query execution
|
||||||
|
row := db.QueryRowContext(testCtx, query)
|
||||||
|
|
||||||
|
if err := row.Scan(&keyID); err != nil {
|
||||||
|
t.Fatal(i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestExecCancel(t *testing.T) {
|
func TestExecCancel(t *testing.T) {
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue