mirror of https://github.com/mattn/go-sqlite3.git
sqlite3: handle trailing comments and multiple SQL statements in Queries
This commit fixes *SQLiteConn.Query to properly handle trailing comments after a SQL query statement. Previously, trailing comments could lead to an infinite loop. It also changes Query to error if the provided SQL statement contains multiple queries ("SELECT 1; SELECT 2;") - previously only the last query was executed ("SELECT 1; SELECT 2;" would yield only 2). This may be a breaking change as previously: Query consumed all of its args - despite only using the last query (Query now only uses the args required to satisfy the first query and errors if there is a mismatch); Query used only the last query and there may be code using this library that depends on this behavior. Personally, I believe the behavior introduced by this commit is correct and any code relying on the prior undocumented behavior incorrect, but it could still be a break.
This commit is contained in:
parent
c34a16e589
commit
f5e855b246
92
sqlite3.go
92
sqlite3.go
|
@ -30,7 +30,6 @@ package sqlite3
|
|||
#endif
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <ctype.h>
|
||||
|
||||
#ifdef __CYGWIN__
|
||||
# include <errno.h>
|
||||
|
@ -91,16 +90,6 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
|
|||
return rv;
|
||||
}
|
||||
|
||||
static const char *
|
||||
_trim_leading_spaces(const char *str) {
|
||||
if (str) {
|
||||
while (isspace(*str)) {
|
||||
str++;
|
||||
}
|
||||
}
|
||||
return str;
|
||||
}
|
||||
|
||||
#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
|
||||
extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
|
||||
extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
|
||||
|
@ -121,11 +110,7 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
|
|||
static int
|
||||
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
|
||||
{
|
||||
int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
|
||||
if (pzTail) {
|
||||
*pzTail = _trim_leading_spaces(*pzTail);
|
||||
}
|
||||
return rv;
|
||||
return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
|
||||
}
|
||||
|
||||
#else
|
||||
|
@ -148,12 +133,9 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
|
|||
static int
|
||||
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
|
||||
{
|
||||
int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
|
||||
if (pzTail) {
|
||||
*pzTail = _trim_leading_spaces(*pzTail);
|
||||
}
|
||||
return rv;
|
||||
return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
|
||||
|
@ -951,46 +933,44 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
|
|||
op := pquery // original pointer
|
||||
defer C.free(unsafe.Pointer(op))
|
||||
|
||||
var stmtArgs []driver.NamedValue
|
||||
var tail *C.char
|
||||
s := new(SQLiteStmt) // escapes to the heap so reuse it
|
||||
start := 0
|
||||
for {
|
||||
*s = SQLiteStmt{c: c, cls: true} // reset
|
||||
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
|
||||
s := &SQLiteStmt{c: c, cls: true}
|
||||
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
|
||||
if rv != C.SQLITE_OK {
|
||||
return nil, c.lastError()
|
||||
}
|
||||
if s.s == nil {
|
||||
return &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}, nil
|
||||
}
|
||||
na := s.NumInput()
|
||||
if n := len(args); n != na {
|
||||
s.finalize()
|
||||
if n < na {
|
||||
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
|
||||
}
|
||||
return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args))
|
||||
}
|
||||
rows, err := s.query(ctx, args)
|
||||
if err != nil && err != driver.ErrSkip {
|
||||
s.finalize()
|
||||
return rows, err
|
||||
}
|
||||
|
||||
// Consume the rest of the query
|
||||
for pquery = tail; pquery != nil && *pquery != 0; pquery = tail {
|
||||
var stmt *C.sqlite3_stmt
|
||||
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &stmt, &tail)
|
||||
if rv != C.SQLITE_OK {
|
||||
rows.Close()
|
||||
return nil, c.lastError()
|
||||
}
|
||||
|
||||
na := s.NumInput()
|
||||
if len(args)-start < na {
|
||||
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
|
||||
if stmt != nil {
|
||||
rows.Close()
|
||||
return nil, errors.New("cannot insert multiple commands into a prepared statement") // WARN
|
||||
}
|
||||
// consume the number of arguments used in the current
|
||||
// statement and append all named arguments not contained
|
||||
// therein
|
||||
stmtArgs = append(stmtArgs[:0], args[start:start+na]...)
|
||||
for i := range args {
|
||||
if (i < start || i >= na) && args[i].Name != "" {
|
||||
stmtArgs = append(stmtArgs, args[i])
|
||||
}
|
||||
}
|
||||
for i := range stmtArgs {
|
||||
stmtArgs[i].Ordinal = i + 1
|
||||
}
|
||||
rows, err := s.query(ctx, stmtArgs)
|
||||
if err != nil && err != driver.ErrSkip {
|
||||
s.finalize()
|
||||
return rows, err
|
||||
}
|
||||
start += na
|
||||
if tail == nil || *tail == '\000' {
|
||||
return rows, nil
|
||||
}
|
||||
rows.Close()
|
||||
s.finalize()
|
||||
pquery = tail
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Begin transaction.
|
||||
|
@ -2044,7 +2024,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
|
|||
return s.query(context.Background(), list)
|
||||
}
|
||||
|
||||
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (*SQLiteRows, error) {
|
||||
if err := s.bind(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
182
sqlite3_test.go
182
sqlite3_test.go
|
@ -18,6 +18,7 @@ import (
|
|||
"math/rand"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
|
@ -1080,67 +1081,158 @@ func TestExecer(t *testing.T) {
|
|||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`
|
||||
create table foo (id integer); -- one comment
|
||||
insert into foo(id) values(?);
|
||||
insert into foo(id) values(?);
|
||||
insert into foo(id) values(?); -- another comment
|
||||
CREATE TABLE foo (id INTEGER); -- one comment
|
||||
INSERT INTO foo(id) VALUES(?);
|
||||
INSERT INTO foo(id) VALUES(?);
|
||||
INSERT INTO foo(id) VALUES(?); -- another comment
|
||||
`, 1, 2, 3)
|
||||
if err != nil {
|
||||
t.Error("Failed to call db.Exec:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryer(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
db, err := sql.Open("sqlite3", tempFilename)
|
||||
func testQuery(t *testing.T, seed bool, test func(t *testing.T, db *sql.DB)) {
|
||||
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3"))
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`
|
||||
create table foo (id integer);
|
||||
`)
|
||||
if err != nil {
|
||||
t.Error("Failed to call db.Query:", err)
|
||||
if seed {
|
||||
if _, err := db.Exec(`create table foo (id integer);`); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO foo(id) VALUES(?);
|
||||
INSERT INTO foo(id) VALUES(?);
|
||||
INSERT INTO foo(id) VALUES(?);
|
||||
`, 3, 2, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
insert into foo(id) values(?);
|
||||
insert into foo(id) values(?);
|
||||
insert into foo(id) values(?);
|
||||
`, 3, 2, 1)
|
||||
if err != nil {
|
||||
t.Error("Failed to call db.Exec:", err)
|
||||
}
|
||||
rows, err := db.Query(`
|
||||
select id from foo order by id;
|
||||
`)
|
||||
if err != nil {
|
||||
t.Error("Failed to call db.Query:", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
n := 0
|
||||
for rows.Next() {
|
||||
var id int
|
||||
err = rows.Scan(&id)
|
||||
// Capture panic so tests can continue
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
buf := make([]byte, 32*1024)
|
||||
n := runtime.Stack(buf, false)
|
||||
t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n])
|
||||
}
|
||||
}()
|
||||
test(t, db)
|
||||
}
|
||||
|
||||
func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} {
|
||||
var values []interface{}
|
||||
testQuery(t, true, func(t *testing.T, db *sql.DB) {
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
t.Error("Failed to db.Query:", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
if id != n+1 {
|
||||
t.Error("Failed to db.Query: not matched results")
|
||||
if rows == nil {
|
||||
t.Fatal("nil rows")
|
||||
}
|
||||
n = n + 1
|
||||
for i := 0; rows.Next(); i++ {
|
||||
if i > 1_000 {
|
||||
t.Fatal("To many iterations of rows.Next():", i)
|
||||
}
|
||||
var v interface{}
|
||||
if err := rows.Scan(&v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
values = append(values, v)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
return values
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
queries := []struct {
|
||||
query string
|
||||
args []interface{}
|
||||
}{
|
||||
{"SELECT id FROM foo ORDER BY id;", nil},
|
||||
{"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}},
|
||||
{"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}},
|
||||
|
||||
// Comments
|
||||
{"SELECT id FROM foo ORDER BY id; -- comment", nil},
|
||||
{"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil},
|
||||
{
|
||||
`-- FOO
|
||||
SELECT id FROM foo ORDER BY id; -- BAR
|
||||
/* BAZ */`,
|
||||
nil,
|
||||
},
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
t.Errorf("Post-scan failed: %v\n", err)
|
||||
want := []interface{}{
|
||||
int64(1),
|
||||
int64(2),
|
||||
int64(3),
|
||||
}
|
||||
if n != 3 {
|
||||
t.Errorf("Expected 3 rows but retrieved %v", n)
|
||||
for _, q := range queries {
|
||||
t.Run("", func(t *testing.T) {
|
||||
got := testQueryValues(t, q.query, q.args...)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryNoSQL(t *testing.T) {
|
||||
got := testQueryValues(t, "")
|
||||
if got != nil {
|
||||
t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func testQueryError(t *testing.T, query string, args ...interface{}) {
|
||||
testQuery(t, true, func(t *testing.T, db *sql.DB) {
|
||||
rows, err := db.Query(query, args...)
|
||||
if err == nil {
|
||||
t.Error("Expected an error got:", err)
|
||||
}
|
||||
if rows != nil {
|
||||
t.Error("Returned rows should be nil on error!")
|
||||
// Attempt to iterate over rows to make sure they don't panic.
|
||||
for i := 0; rows.Next(); i++ {
|
||||
if i > 1_000 {
|
||||
t.Fatal("To many iterations of rows.Next():", i)
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueryNotEnoughArgs(t *testing.T) {
|
||||
testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1)
|
||||
}
|
||||
|
||||
func TestQueryTooManyArgs(t *testing.T) {
|
||||
// TODO: test error message / kind
|
||||
testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2)
|
||||
}
|
||||
|
||||
func TestQueryMultipleStatements(t *testing.T) {
|
||||
testQueryError(t, "SELECT 1; SELECT 2;")
|
||||
}
|
||||
|
||||
func TestQueryInvalidTable(t *testing.T) {
|
||||
testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;")
|
||||
}
|
||||
|
||||
func TestStress(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
|
@ -2112,7 +2204,6 @@ var benchmarks = []testing.InternalBenchmark{
|
|||
{Name: "BenchmarkRows", F: benchmarkRows},
|
||||
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
|
||||
{Name: "BenchmarkExecStep", F: benchmarkExecStep},
|
||||
{Name: "BenchmarkQueryStep", F: benchmarkQueryStep},
|
||||
}
|
||||
|
||||
func (db *TestDB) mustExec(sql string, args ...any) sql.Result {
|
||||
|
@ -2580,12 +2671,3 @@ func benchmarkExecStep(b *testing.B) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkQueryStep(b *testing.B) {
|
||||
var i int
|
||||
for n := 0; n < b.N; n++ {
|
||||
if err := db.QueryRow(largeSelectStmt).Scan(&i); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue