diff --git a/sqlite3.go b/sqlite3.go index a931735..c9edd40 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1171,9 +1171,13 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result defer close(done) go func(db *C.sqlite3) { select { - case <-ctx.Done(): - C.sqlite3_interrupt(db) case <-done: + case <-ctx.Done(): + select { + case <-done: + default: + C.sqlite3_interrupt(db) + } } }(s.c.db) diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index 2662fcf..44fc4df 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -134,3 +134,24 @@ func TestShortTimeout(t *testing.T) { t.Fatal(ctx.Err()) } } + +func TestExecCancel(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, err = db.Exec("create table foo (id integer primary key)"); err != nil { + t.Fatal(err) + } + + for n := 0; n < 100; n++ { + ctx, cancel := context.WithCancel(context.Background()) + _, err = db.ExecContext(ctx, "insert into foo (id) values (?)", n) + cancel() + if err != nil { + t.Fatal(err) + } + } +}