mirror of https://github.com/mattn/go-sqlite3.git
sqlite3: handle trailing comments and multiple SQL statements in Queries
This commit fixes *SQLiteConn.Query to properly handle trailing comments after a SQL query statement. Previously, trailing comments could lead to an infinite loop. It also changes Query to error if the provided SQL statement contains multiple queries ("SELECT 1; SELECT 2;") - previously only the last query was executed ("SELECT 1; SELECT 2;" would yield only 2). This may be a breaking change as previously: Query consumed all of its args - despite only using the last query (Query now only uses the args required to satisfy the first query and errors if there is a mismatch); Query used only the last query and there may be code using this library that depends on this behavior. Personally, I believe the behavior introduced by this commit is correct and any code relying on the prior undocumented behavior incorrect, but it could still be a break.
This commit is contained in:
parent
c34a16e589
commit
f5e855b246
92
sqlite3.go
92
sqlite3.go
|
@ -30,7 +30,6 @@ package sqlite3
|
||||||
#endif
|
#endif
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <ctype.h>
|
|
||||||
|
|
||||||
#ifdef __CYGWIN__
|
#ifdef __CYGWIN__
|
||||||
# include <errno.h>
|
# include <errno.h>
|
||||||
|
@ -91,16 +90,6 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
|
||||||
return rv;
|
return rv;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char *
|
|
||||||
_trim_leading_spaces(const char *str) {
|
|
||||||
if (str) {
|
|
||||||
while (isspace(*str)) {
|
|
||||||
str++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return str;
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
|
#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
|
||||||
extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
|
extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
|
||||||
extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
|
extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
|
||||||
|
@ -121,11 +110,7 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
|
||||||
static int
|
static int
|
||||||
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
|
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
|
||||||
{
|
{
|
||||||
int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
|
return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
|
||||||
if (pzTail) {
|
|
||||||
*pzTail = _trim_leading_spaces(*pzTail);
|
|
||||||
}
|
|
||||||
return rv;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
@ -148,12 +133,9 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
|
||||||
static int
|
static int
|
||||||
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
|
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
|
||||||
{
|
{
|
||||||
int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
|
return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
|
||||||
if (pzTail) {
|
|
||||||
*pzTail = _trim_leading_spaces(*pzTail);
|
|
||||||
}
|
|
||||||
return rv;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
|
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
|
||||||
|
@ -951,46 +933,44 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
|
||||||
op := pquery // original pointer
|
op := pquery // original pointer
|
||||||
defer C.free(unsafe.Pointer(op))
|
defer C.free(unsafe.Pointer(op))
|
||||||
|
|
||||||
var stmtArgs []driver.NamedValue
|
|
||||||
var tail *C.char
|
var tail *C.char
|
||||||
s := new(SQLiteStmt) // escapes to the heap so reuse it
|
s := &SQLiteStmt{c: c, cls: true}
|
||||||
start := 0
|
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
|
||||||
for {
|
if rv != C.SQLITE_OK {
|
||||||
*s = SQLiteStmt{c: c, cls: true} // reset
|
return nil, c.lastError()
|
||||||
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
|
}
|
||||||
|
if s.s == nil {
|
||||||
|
return &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}, nil
|
||||||
|
}
|
||||||
|
na := s.NumInput()
|
||||||
|
if n := len(args); n != na {
|
||||||
|
s.finalize()
|
||||||
|
if n < na {
|
||||||
|
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args))
|
||||||
|
}
|
||||||
|
rows, err := s.query(ctx, args)
|
||||||
|
if err != nil && err != driver.ErrSkip {
|
||||||
|
s.finalize()
|
||||||
|
return rows, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Consume the rest of the query
|
||||||
|
for pquery = tail; pquery != nil && *pquery != 0; pquery = tail {
|
||||||
|
var stmt *C.sqlite3_stmt
|
||||||
|
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &stmt, &tail)
|
||||||
if rv != C.SQLITE_OK {
|
if rv != C.SQLITE_OK {
|
||||||
|
rows.Close()
|
||||||
return nil, c.lastError()
|
return nil, c.lastError()
|
||||||
}
|
}
|
||||||
|
if stmt != nil {
|
||||||
na := s.NumInput()
|
rows.Close()
|
||||||
if len(args)-start < na {
|
return nil, errors.New("cannot insert multiple commands into a prepared statement") // WARN
|
||||||
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
|
|
||||||
}
|
}
|
||||||
// consume the number of arguments used in the current
|
|
||||||
// statement and append all named arguments not contained
|
|
||||||
// therein
|
|
||||||
stmtArgs = append(stmtArgs[:0], args[start:start+na]...)
|
|
||||||
for i := range args {
|
|
||||||
if (i < start || i >= na) && args[i].Name != "" {
|
|
||||||
stmtArgs = append(stmtArgs, args[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := range stmtArgs {
|
|
||||||
stmtArgs[i].Ordinal = i + 1
|
|
||||||
}
|
|
||||||
rows, err := s.query(ctx, stmtArgs)
|
|
||||||
if err != nil && err != driver.ErrSkip {
|
|
||||||
s.finalize()
|
|
||||||
return rows, err
|
|
||||||
}
|
|
||||||
start += na
|
|
||||||
if tail == nil || *tail == '\000' {
|
|
||||||
return rows, nil
|
|
||||||
}
|
|
||||||
rows.Close()
|
|
||||||
s.finalize()
|
|
||||||
pquery = tail
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin transaction.
|
// Begin transaction.
|
||||||
|
@ -2044,7 +2024,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||||
return s.query(context.Background(), list)
|
return s.query(context.Background(), list)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (*SQLiteRows, error) {
|
||||||
if err := s.bind(args); err != nil {
|
if err := s.bind(args); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
182
sqlite3_test.go
182
sqlite3_test.go
|
@ -18,6 +18,7 @@ import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
@ -1080,67 +1081,158 @@ func TestExecer(t *testing.T) {
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
_, err = db.Exec(`
|
_, err = db.Exec(`
|
||||||
create table foo (id integer); -- one comment
|
CREATE TABLE foo (id INTEGER); -- one comment
|
||||||
insert into foo(id) values(?);
|
INSERT INTO foo(id) VALUES(?);
|
||||||
insert into foo(id) values(?);
|
INSERT INTO foo(id) VALUES(?);
|
||||||
insert into foo(id) values(?); -- another comment
|
INSERT INTO foo(id) VALUES(?); -- another comment
|
||||||
`, 1, 2, 3)
|
`, 1, 2, 3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Failed to call db.Exec:", err)
|
t.Error("Failed to call db.Exec:", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryer(t *testing.T) {
|
func testQuery(t *testing.T, seed bool, test func(t *testing.T, db *sql.DB)) {
|
||||||
tempFilename := TempFilename(t)
|
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3"))
|
||||||
defer os.Remove(tempFilename)
|
|
||||||
db, err := sql.Open("sqlite3", tempFilename)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to open database:", err)
|
t.Fatal("Failed to open database:", err)
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
_, err = db.Exec(`
|
if seed {
|
||||||
create table foo (id integer);
|
if _, err := db.Exec(`create table foo (id integer);`); err != nil {
|
||||||
`)
|
t.Fatal(err)
|
||||||
if err != nil {
|
}
|
||||||
t.Error("Failed to call db.Query:", err)
|
_, err := db.Exec(`
|
||||||
|
INSERT INTO foo(id) VALUES(?);
|
||||||
|
INSERT INTO foo(id) VALUES(?);
|
||||||
|
INSERT INTO foo(id) VALUES(?);
|
||||||
|
`, 3, 2, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec(`
|
// Capture panic so tests can continue
|
||||||
insert into foo(id) values(?);
|
defer func() {
|
||||||
insert into foo(id) values(?);
|
if e := recover(); e != nil {
|
||||||
insert into foo(id) values(?);
|
buf := make([]byte, 32*1024)
|
||||||
`, 3, 2, 1)
|
n := runtime.Stack(buf, false)
|
||||||
if err != nil {
|
t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n])
|
||||||
t.Error("Failed to call db.Exec:", err)
|
}
|
||||||
}
|
}()
|
||||||
rows, err := db.Query(`
|
test(t, db)
|
||||||
select id from foo order by id;
|
}
|
||||||
`)
|
|
||||||
if err != nil {
|
func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} {
|
||||||
t.Error("Failed to call db.Query:", err)
|
var values []interface{}
|
||||||
}
|
testQuery(t, true, func(t *testing.T, db *sql.DB) {
|
||||||
defer rows.Close()
|
rows, err := db.Query(query, args...)
|
||||||
n := 0
|
|
||||||
for rows.Next() {
|
|
||||||
var id int
|
|
||||||
err = rows.Scan(&id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("Failed to db.Query:", err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if id != n+1 {
|
if rows == nil {
|
||||||
t.Error("Failed to db.Query: not matched results")
|
t.Fatal("nil rows")
|
||||||
}
|
}
|
||||||
n = n + 1
|
for i := 0; rows.Next(); i++ {
|
||||||
|
if i > 1_000 {
|
||||||
|
t.Fatal("To many iterations of rows.Next():", i)
|
||||||
|
}
|
||||||
|
var v interface{}
|
||||||
|
if err := rows.Scan(&v); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
values = append(values, v)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery(t *testing.T) {
|
||||||
|
queries := []struct {
|
||||||
|
query string
|
||||||
|
args []interface{}
|
||||||
|
}{
|
||||||
|
{"SELECT id FROM foo ORDER BY id;", nil},
|
||||||
|
{"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}},
|
||||||
|
{"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}},
|
||||||
|
|
||||||
|
// Comments
|
||||||
|
{"SELECT id FROM foo ORDER BY id; -- comment", nil},
|
||||||
|
{"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil},
|
||||||
|
{
|
||||||
|
`-- FOO
|
||||||
|
SELECT id FROM foo ORDER BY id; -- BAR
|
||||||
|
/* BAZ */`,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
if err := rows.Err(); err != nil {
|
want := []interface{}{
|
||||||
t.Errorf("Post-scan failed: %v\n", err)
|
int64(1),
|
||||||
|
int64(2),
|
||||||
|
int64(3),
|
||||||
}
|
}
|
||||||
if n != 3 {
|
for _, q := range queries {
|
||||||
t.Errorf("Expected 3 rows but retrieved %v", n)
|
t.Run("", func(t *testing.T) {
|
||||||
|
got := testQueryValues(t, q.query, q.args...)
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryNoSQL(t *testing.T) {
|
||||||
|
got := testQueryValues(t, "")
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testQueryError(t *testing.T, query string, args ...interface{}) {
|
||||||
|
testQuery(t, true, func(t *testing.T, db *sql.DB) {
|
||||||
|
rows, err := db.Query(query, args...)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected an error got:", err)
|
||||||
|
}
|
||||||
|
if rows != nil {
|
||||||
|
t.Error("Returned rows should be nil on error!")
|
||||||
|
// Attempt to iterate over rows to make sure they don't panic.
|
||||||
|
for i := 0; rows.Next(); i++ {
|
||||||
|
if i > 1_000 {
|
||||||
|
t.Fatal("To many iterations of rows.Next():", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
rows.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryNotEnoughArgs(t *testing.T) {
|
||||||
|
testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryTooManyArgs(t *testing.T) {
|
||||||
|
// TODO: test error message / kind
|
||||||
|
testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryMultipleStatements(t *testing.T) {
|
||||||
|
testQueryError(t, "SELECT 1; SELECT 2;")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryInvalidTable(t *testing.T) {
|
||||||
|
testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;")
|
||||||
|
}
|
||||||
|
|
||||||
func TestStress(t *testing.T) {
|
func TestStress(t *testing.T) {
|
||||||
tempFilename := TempFilename(t)
|
tempFilename := TempFilename(t)
|
||||||
defer os.Remove(tempFilename)
|
defer os.Remove(tempFilename)
|
||||||
|
@ -2112,7 +2204,6 @@ var benchmarks = []testing.InternalBenchmark{
|
||||||
{Name: "BenchmarkRows", F: benchmarkRows},
|
{Name: "BenchmarkRows", F: benchmarkRows},
|
||||||
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
|
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
|
||||||
{Name: "BenchmarkExecStep", F: benchmarkExecStep},
|
{Name: "BenchmarkExecStep", F: benchmarkExecStep},
|
||||||
{Name: "BenchmarkQueryStep", F: benchmarkQueryStep},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *TestDB) mustExec(sql string, args ...any) sql.Result {
|
func (db *TestDB) mustExec(sql string, args ...any) sql.Result {
|
||||||
|
@ -2580,12 +2671,3 @@ func benchmarkExecStep(b *testing.B) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkQueryStep(b *testing.B) {
|
|
||||||
var i int
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
if err := db.QueryRow(largeSelectStmt).Scan(&i); err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in New Issue