sqlite3: handle trailing comments and multiple SQL statements in Queries

This commit fixes *SQLiteConn.Query to properly handle trailing comments
after a SQL query statement. Previously, trailing comments could lead to
an infinite loop.

It also changes Query to error if the provided SQL statement contains
multiple queries ("SELECT 1; SELECT 2;") - previously only the last
query was executed ("SELECT 1; SELECT 2;" would yield only 2).

This may be a breaking change as previously: Query consumed all of its
args - despite only using the last query (Query now only uses the args
required to satisfy the first query and errors if there is a mismatch);
Query used only the last query and there may be code using this library
that depends on this behavior.

Personally, I believe the behavior introduced by this commit is correct
and any code relying on the prior undocumented behavior incorrect, but
it could still be a break.
This commit is contained in:
Charlie Vieth 2023-04-20 00:16:13 -04:00
parent c34a16e589
commit f5e855b246
No known key found for this signature in database
GPG Key ID: F6DBDE178E5DE3F0
2 changed files with 168 additions and 106 deletions

View File

@ -30,7 +30,6 @@ package sqlite3
#endif #endif
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <ctype.h>
#ifdef __CYGWIN__ #ifdef __CYGWIN__
# include <errno.h> # include <errno.h>
@ -91,16 +90,6 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
return rv; return rv;
} }
static const char *
_trim_leading_spaces(const char *str) {
if (str) {
while (isspace(*str)) {
str++;
}
}
return str;
}
#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY #ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
extern int _sqlite3_step_blocking(sqlite3_stmt *stmt); extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes); extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
@ -121,11 +110,7 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
static int static int
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail) _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
{ {
int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail); return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
if (pzTail) {
*pzTail = _trim_leading_spaces(*pzTail);
}
return rv;
} }
#else #else
@ -148,12 +133,9 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
static int static int
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail) _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
{ {
int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
if (pzTail) {
*pzTail = _trim_leading_spaces(*pzTail);
}
return rv;
} }
#endif #endif
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
@ -951,46 +933,44 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
op := pquery // original pointer op := pquery // original pointer
defer C.free(unsafe.Pointer(op)) defer C.free(unsafe.Pointer(op))
var stmtArgs []driver.NamedValue
var tail *C.char var tail *C.char
s := new(SQLiteStmt) // escapes to the heap so reuse it s := &SQLiteStmt{c: c, cls: true}
start := 0 rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
for { if rv != C.SQLITE_OK {
*s = SQLiteStmt{c: c, cls: true} // reset return nil, c.lastError()
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail) }
if s.s == nil {
return &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}, nil
}
na := s.NumInput()
if n := len(args); n != na {
s.finalize()
if n < na {
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
}
return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args))
}
rows, err := s.query(ctx, args)
if err != nil && err != driver.ErrSkip {
s.finalize()
return rows, err
}
// Consume the rest of the query
for pquery = tail; pquery != nil && *pquery != 0; pquery = tail {
var stmt *C.sqlite3_stmt
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &stmt, &tail)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
rows.Close()
return nil, c.lastError() return nil, c.lastError()
} }
if stmt != nil {
na := s.NumInput() rows.Close()
if len(args)-start < na { return nil, errors.New("cannot insert multiple commands into a prepared statement") // WARN
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
} }
// consume the number of arguments used in the current
// statement and append all named arguments not contained
// therein
stmtArgs = append(stmtArgs[:0], args[start:start+na]...)
for i := range args {
if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i])
}
}
for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1
}
rows, err := s.query(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip {
s.finalize()
return rows, err
}
start += na
if tail == nil || *tail == '\000' {
return rows, nil
}
rows.Close()
s.finalize()
pquery = tail
} }
return rows, nil
} }
// Begin transaction. // Begin transaction.
@ -2044,7 +2024,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
return s.query(context.Background(), list) return s.query(context.Background(), list)
} }
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (*SQLiteRows, error) {
if err := s.bind(args); err != nil { if err := s.bind(args); err != nil {
return nil, err return nil, err
} }

View File

@ -18,6 +18,7 @@ import (
"math/rand" "math/rand"
"net/url" "net/url"
"os" "os"
"path/filepath"
"reflect" "reflect"
"regexp" "regexp"
"runtime" "runtime"
@ -1080,67 +1081,158 @@ func TestExecer(t *testing.T) {
defer db.Close() defer db.Close()
_, err = db.Exec(` _, err = db.Exec(`
create table foo (id integer); -- one comment CREATE TABLE foo (id INTEGER); -- one comment
insert into foo(id) values(?); INSERT INTO foo(id) VALUES(?);
insert into foo(id) values(?); INSERT INTO foo(id) VALUES(?);
insert into foo(id) values(?); -- another comment INSERT INTO foo(id) VALUES(?); -- another comment
`, 1, 2, 3) `, 1, 2, 3)
if err != nil { if err != nil {
t.Error("Failed to call db.Exec:", err) t.Error("Failed to call db.Exec:", err)
} }
} }
func TestQueryer(t *testing.T) { func testQuery(t *testing.T, seed bool, test func(t *testing.T, db *sql.DB)) {
tempFilename := TempFilename(t) db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3"))
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer db.Close() defer db.Close()
_, err = db.Exec(` if seed {
create table foo (id integer); if _, err := db.Exec(`create table foo (id integer);`); err != nil {
`) t.Fatal(err)
if err != nil { }
t.Error("Failed to call db.Query:", err) _, err := db.Exec(`
INSERT INTO foo(id) VALUES(?);
INSERT INTO foo(id) VALUES(?);
INSERT INTO foo(id) VALUES(?);
`, 3, 2, 1)
if err != nil {
t.Fatal(err)
}
} }
_, err = db.Exec(` // Capture panic so tests can continue
insert into foo(id) values(?); defer func() {
insert into foo(id) values(?); if e := recover(); e != nil {
insert into foo(id) values(?); buf := make([]byte, 32*1024)
`, 3, 2, 1) n := runtime.Stack(buf, false)
if err != nil { t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n])
t.Error("Failed to call db.Exec:", err) }
} }()
rows, err := db.Query(` test(t, db)
select id from foo order by id; }
`)
if err != nil { func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} {
t.Error("Failed to call db.Query:", err) var values []interface{}
} testQuery(t, true, func(t *testing.T, db *sql.DB) {
defer rows.Close() rows, err := db.Query(query, args...)
n := 0
for rows.Next() {
var id int
err = rows.Scan(&id)
if err != nil { if err != nil {
t.Error("Failed to db.Query:", err) t.Fatal(err)
} }
if id != n+1 { if rows == nil {
t.Error("Failed to db.Query: not matched results") t.Fatal("nil rows")
} }
n = n + 1 for i := 0; rows.Next(); i++ {
if i > 1_000 {
t.Fatal("To many iterations of rows.Next():", i)
}
var v interface{}
if err := rows.Scan(&v); err != nil {
t.Fatal(err)
}
values = append(values, v)
}
if err := rows.Err(); err != nil {
t.Fatal(err)
}
if err := rows.Close(); err != nil {
t.Fatal(err)
}
})
return values
}
func TestQuery(t *testing.T) {
queries := []struct {
query string
args []interface{}
}{
{"SELECT id FROM foo ORDER BY id;", nil},
{"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}},
{"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}},
// Comments
{"SELECT id FROM foo ORDER BY id; -- comment", nil},
{"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil},
{
`-- FOO
SELECT id FROM foo ORDER BY id; -- BAR
/* BAZ */`,
nil,
},
} }
if err := rows.Err(); err != nil { want := []interface{}{
t.Errorf("Post-scan failed: %v\n", err) int64(1),
int64(2),
int64(3),
} }
if n != 3 { for _, q := range queries {
t.Errorf("Expected 3 rows but retrieved %v", n) t.Run("", func(t *testing.T) {
got := testQueryValues(t, q.query, q.args...)
if !reflect.DeepEqual(got, want) {
t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want)
}
})
} }
} }
func TestQueryNoSQL(t *testing.T) {
got := testQueryValues(t, "")
if got != nil {
t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil)
}
}
func testQueryError(t *testing.T, query string, args ...interface{}) {
testQuery(t, true, func(t *testing.T, db *sql.DB) {
rows, err := db.Query(query, args...)
if err == nil {
t.Error("Expected an error got:", err)
}
if rows != nil {
t.Error("Returned rows should be nil on error!")
// Attempt to iterate over rows to make sure they don't panic.
for i := 0; rows.Next(); i++ {
if i > 1_000 {
t.Fatal("To many iterations of rows.Next():", i)
}
}
if err := rows.Err(); err != nil {
t.Error(err)
}
rows.Close()
}
})
}
func TestQueryNotEnoughArgs(t *testing.T) {
testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1)
}
func TestQueryTooManyArgs(t *testing.T) {
// TODO: test error message / kind
testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2)
}
func TestQueryMultipleStatements(t *testing.T) {
testQueryError(t, "SELECT 1; SELECT 2;")
}
func TestQueryInvalidTable(t *testing.T) {
testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;")
}
func TestStress(t *testing.T) { func TestStress(t *testing.T) {
tempFilename := TempFilename(t) tempFilename := TempFilename(t)
defer os.Remove(tempFilename) defer os.Remove(tempFilename)
@ -2112,7 +2204,6 @@ var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkRows", F: benchmarkRows}, {Name: "BenchmarkRows", F: benchmarkRows},
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, {Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
{Name: "BenchmarkExecStep", F: benchmarkExecStep}, {Name: "BenchmarkExecStep", F: benchmarkExecStep},
{Name: "BenchmarkQueryStep", F: benchmarkQueryStep},
} }
func (db *TestDB) mustExec(sql string, args ...any) sql.Result { func (db *TestDB) mustExec(sql string, args ...any) sql.Result {
@ -2580,12 +2671,3 @@ func benchmarkExecStep(b *testing.B) {
} }
} }
} }
func benchmarkQueryStep(b *testing.B) {
var i int
for n := 0; n < b.N; n++ {
if err := db.QueryRow(largeSelectStmt).Scan(&i); err != nil {
b.Fatal(err)
}
}
}