forked from mirror/go-sqlcipher
Merge pull request #744 from azavorotnii/ctx_cancel
Fix context cancellation racy handling
This commit is contained in:
commit
590d44c02b
96
sqlite3.go
96
sqlite3.go
|
@ -328,7 +328,7 @@ type SQLiteRows struct {
|
|||
decltype []string
|
||||
cls bool
|
||||
closed bool
|
||||
done chan struct{}
|
||||
ctx context.Context // no better alternative to pass context into Next() method
|
||||
}
|
||||
|
||||
type functionInfo struct {
|
||||
|
@ -1847,22 +1847,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
|
|||
decltype: nil,
|
||||
cls: s.cls,
|
||||
closed: false,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if ctxdone := ctx.Done(); ctxdone != nil {
|
||||
go func(db *C.sqlite3) {
|
||||
select {
|
||||
case <-ctxdone:
|
||||
select {
|
||||
case <-rows.done:
|
||||
default:
|
||||
C.sqlite3_interrupt(db)
|
||||
rows.Close()
|
||||
}
|
||||
case <-rows.done:
|
||||
}
|
||||
}(s.c.db)
|
||||
ctx: ctx,
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
|
@ -1890,29 +1875,43 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
|
|||
return s.exec(context.Background(), list)
|
||||
}
|
||||
|
||||
// exec executes a query that doesn't return rows. Attempts to honor context timeout.
|
||||
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
|
||||
if ctx.Done() == nil {
|
||||
return s.execSync(args)
|
||||
}
|
||||
|
||||
type result struct {
|
||||
r driver.Result
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result)
|
||||
go func() {
|
||||
r, err := s.execSync(args)
|
||||
resultCh <- result{r, err}
|
||||
}()
|
||||
select {
|
||||
case rv := <- resultCh:
|
||||
return rv.r, rv.err
|
||||
case <-ctx.Done():
|
||||
select {
|
||||
case <-resultCh: // no need to interrupt
|
||||
default:
|
||||
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
|
||||
C.sqlite3_interrupt(s.c.db)
|
||||
<-resultCh // ensure goroutine completed
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SQLiteStmt) execSync(args []namedValue) (driver.Result, error) {
|
||||
if err := s.bind(args); err != nil {
|
||||
C.sqlite3_reset(s.s)
|
||||
C.sqlite3_clear_bindings(s.s)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ctxdone := ctx.Done(); ctxdone != nil {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func(db *C.sqlite3) {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctxdone:
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
C.sqlite3_interrupt(db)
|
||||
}
|
||||
}
|
||||
}(s.c.db)
|
||||
}
|
||||
|
||||
var rowid, changes C.longlong
|
||||
rv := C._sqlite3_step_row_internal(s.s, &rowid, &changes)
|
||||
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
||||
|
@ -1933,9 +1932,6 @@ func (rc *SQLiteRows) Close() error {
|
|||
return nil
|
||||
}
|
||||
rc.closed = true
|
||||
if rc.done != nil {
|
||||
close(rc.done)
|
||||
}
|
||||
if rc.cls {
|
||||
rc.s.mu.Unlock()
|
||||
return rc.s.Close()
|
||||
|
@ -1979,13 +1975,39 @@ func (rc *SQLiteRows) DeclTypes() []string {
|
|||
return rc.declTypes()
|
||||
}
|
||||
|
||||
// Next move cursor to next.
|
||||
// Next move cursor to next. Attempts to honor context timeout from QueryContext call.
|
||||
func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||
rc.s.mu.Lock()
|
||||
defer rc.s.mu.Unlock()
|
||||
|
||||
if rc.s.closed {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
if rc.ctx.Done() == nil {
|
||||
return rc.nextSyncLocked(dest)
|
||||
}
|
||||
resultCh := make(chan error)
|
||||
go func() {
|
||||
resultCh <- rc.nextSyncLocked(dest)
|
||||
}()
|
||||
select {
|
||||
case err := <- resultCh:
|
||||
return err
|
||||
case <-rc.ctx.Done():
|
||||
select {
|
||||
case <-resultCh: // no need to interrupt
|
||||
default:
|
||||
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
|
||||
C.sqlite3_interrupt(rc.s.c.db)
|
||||
<-resultCh // ensure goroutine completed
|
||||
}
|
||||
return rc.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// nextSyncLocked moves cursor to next; must be called with locked mutex.
|
||||
func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error {
|
||||
rv := C._sqlite3_step_internal(rc.s.s)
|
||||
if rv == C.SQLITE_DONE {
|
||||
return io.EOF
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -135,6 +136,93 @@ func TestShortTimeout(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestQueryRowContextCancel(t *testing.T) {
|
||||
srcTempFilename := TempFilename(t)
|
||||
defer os.Remove(srcTempFilename)
|
||||
|
||||
db, err := sql.Open("sqlite3", srcTempFilename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
initDatabase(t, db, 100)
|
||||
|
||||
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
|
||||
var keyID string
|
||||
unexpectedErrors := make(map[string]int)
|
||||
for i := 0; i < 10000; i++ {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
row := db.QueryRowContext(ctx, query)
|
||||
|
||||
cancel()
|
||||
// it is fine to get "nil" as context cancellation can be handled with delay
|
||||
if err := row.Scan(&keyID); err != nil && err != context.Canceled {
|
||||
if err.Error() == "sql: Rows are closed" {
|
||||
// see https://github.com/golang/go/issues/24431
|
||||
// fixed in 1.11.1 to properly return context error
|
||||
continue
|
||||
}
|
||||
unexpectedErrors[err.Error()]++
|
||||
}
|
||||
}
|
||||
for errText, count := range unexpectedErrors {
|
||||
t.Error(errText, count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryRowContextCancelParallel(t *testing.T) {
|
||||
srcTempFilename := TempFilename(t)
|
||||
defer os.Remove(srcTempFilename)
|
||||
|
||||
db, err := sql.Open("sqlite3", srcTempFilename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetMaxIdleConns(5)
|
||||
|
||||
defer db.Close()
|
||||
initDatabase(t, db, 100)
|
||||
|
||||
const query = `SELECT key_id FROM test_table ORDER BY key2 ASC`
|
||||
wg := sync.WaitGroup{}
|
||||
defer wg.Wait()
|
||||
|
||||
testCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
var keyID string
|
||||
for {
|
||||
select {
|
||||
case <-testCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
row := db.QueryRowContext(ctx, query)
|
||||
|
||||
cancel()
|
||||
_ = row.Scan(&keyID) // see TestQueryRowContextCancel
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var keyID string
|
||||
for i := 0; i < 10000; i++ {
|
||||
// note that testCtx is not cancelled during query execution
|
||||
row := db.QueryRowContext(testCtx, query)
|
||||
|
||||
if err := row.Scan(&keyID); err != nil {
|
||||
t.Fatal(i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecCancel(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue