sqlite3: reduce C to Go string conversions in SQLiteConn.{query,exec}

This commit fixes an issue where SQLiteConn.{exec,query} would use an
exponential amount of memory when processing a prepared statement that
consisted of multiple SQL statements ("stmt 1; stmt 2; ...").

Previously both exec/query used SQLiteConn.prepare() which converts the
"tail" pointer set by sqlite3_prepare_v2() back into a Go string, which
then has to be converted back to a C string for the next call to
prepare() (assuming there are multiple SQL statements in the provided
query).

This commit fixes this by changing both exec and query to use the
returned "tail" pointer for the next call to exec/query.

It also changes prepare() to use pointer arithmetic to calculate the
offset of the remaining "tail" portion of the query as a substring of
the original query. This saves a call to C.GoString() and an allocation.

Benchmarks:
```
goos: darwin
goarch: arm64
pkg: github.com/mattn/go-sqlite3
                            │ base.10.txt  │             new.10.txt              │
                            │    sec/op    │   sec/op     vs base                │
Suite/BenchmarkExec-10         1.351µ ± 1%   1.247µ ± 1%   -7.74% (p=0.000 n=10)
Suite/BenchmarkQuery-10        3.830µ ± 1%   3.558µ ± 1%   -7.11% (p=0.000 n=10)
Suite/BenchmarkParams-10       4.221µ ± 0%   4.228µ ± 1%        ~ (p=1.000 n=10)
Suite/BenchmarkStmt-10         2.906µ ± 1%   2.864µ ± 1%   -1.45% (p=0.001 n=10)
Suite/BenchmarkRows-10         149.1µ ± 4%   148.2µ ± 1%   -0.61% (p=0.023 n=10)
Suite/BenchmarkStmtRows-10     147.3µ ± 1%   145.6µ ± 0%   -1.16% (p=0.000 n=10)
Suite/BenchmarkExecStep-10    1898.9µ ± 3%   889.0µ ± 1%  -53.18% (p=0.000 n=10)
Suite/BenchmarkQueryStep-10   1848.0µ ± 1%   894.6µ ± 1%  -51.59% (p=0.000 n=10)
geomean                        38.56µ        31.30µ       -18.84%

                            │  base.10.txt   │               new.10.txt               │
                            │      B/op      │     B/op      vs base                  │
Suite/BenchmarkExec-10            184.0 ± 0%     176.0 ± 0%   -4.35% (p=0.000 n=10)
Suite/BenchmarkQuery-10           864.0 ± 0%     856.0 ± 0%   -0.93% (p=0.000 n=10)
Suite/BenchmarkParams-10        1.289Ki ± 0%   1.281Ki ± 0%   -0.61% (p=0.000 n=10)
Suite/BenchmarkStmt-10          1.078Ki ± 0%   1.078Ki ± 0%        ~ (p=1.000 n=10) ¹
Suite/BenchmarkRows-10          34.45Ki ± 0%   34.45Ki ± 0%   -0.02% (p=0.000 n=10)
Suite/BenchmarkStmtRows-10      34.40Ki ± 0%   34.40Ki ± 0%        ~ (p=1.000 n=10) ¹
Suite/BenchmarkExecStep-10    5334.61Ki ± 0%   70.41Ki ± 0%  -98.68% (p=0.000 n=10)
Suite/BenchmarkQueryStep-10    5397.4Ki ± 0%   133.2Ki ± 0%  -97.53% (p=0.000 n=10)
geomean                         17.06Ki        6.208Ki       -63.62%
¹ all samples are equal

                            │ base.10.txt  │              new.10.txt               │
                            │  allocs/op   │  allocs/op   vs base                  │
Suite/BenchmarkExec-10          13.00 ± 0%    12.00 ± 0%   -7.69% (p=0.000 n=10)
Suite/BenchmarkQuery-10         46.00 ± 0%    45.00 ± 0%   -2.17% (p=0.000 n=10)
Suite/BenchmarkParams-10        54.00 ± 0%    53.00 ± 0%   -1.85% (p=0.000 n=10)
Suite/BenchmarkStmt-10          49.00 ± 0%    49.00 ± 0%        ~ (p=1.000 n=10) ¹
Suite/BenchmarkRows-10         2.042k ± 0%   2.041k ± 0%   -0.05% (p=0.000 n=10)
Suite/BenchmarkStmtRows-10     2.038k ± 0%   2.038k ± 0%        ~ (p=1.000 n=10) ¹
Suite/BenchmarkExecStep-10    13.000k ± 0%   8.004k ± 0%  -38.43% (p=0.000 n=10)
Suite/BenchmarkQueryStep-10   11.013k ± 0%   6.017k ± 0%  -45.36% (p=0.000 n=10)
geomean                         418.6         359.8       -14.04%
¹ all samples are equal

```
This commit is contained in:
Charlie Vieth 2023-02-02 17:17:49 -05:00
parent 3c0390b77c
commit c34a16e589
No known key found for this signature in database
GPG Key ID: F6DBDE178E5DE3F0
2 changed files with 94 additions and 29 deletions

View File

@ -30,6 +30,7 @@ package sqlite3
#endif #endif
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <ctype.h>
#ifdef __CYGWIN__ #ifdef __CYGWIN__
# include <errno.h> # include <errno.h>
@ -90,6 +91,16 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
return rv; return rv;
} }
static const char *
_trim_leading_spaces(const char *str) {
if (str) {
while (isspace(*str)) {
str++;
}
}
return str;
}
#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY #ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
extern int _sqlite3_step_blocking(sqlite3_stmt *stmt); extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes); extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
@ -110,7 +121,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
static int static int
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail) _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
{ {
return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail); int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
if (pzTail) {
*pzTail = _trim_leading_spaces(*pzTail);
}
return rv;
} }
#else #else
@ -133,7 +148,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
static int static int
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail) _sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
{ {
return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail); int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
if (pzTail) {
*pzTail = _trim_leading_spaces(*pzTail);
}
return rv;
} }
#endif #endif
@ -858,25 +877,34 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
} }
func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
pquery := C.CString(query)
op := pquery // original pointer
defer C.free(unsafe.Pointer(op))
var stmtArgs []driver.NamedValue
var tail *C.char
s := new(SQLiteStmt) // escapes to the heap so reuse it
defer s.finalize()
start := 0 start := 0
for { for {
s, err := c.prepare(ctx, query) *s = SQLiteStmt{c: c} // reset
if err != nil { rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
return nil, err if rv != C.SQLITE_OK {
return nil, c.lastError()
} }
var res driver.Result var res driver.Result
if s.(*SQLiteStmt).s != nil { if s.s != nil {
stmtArgs := make([]driver.NamedValue, 0, len(args))
na := s.NumInput() na := s.NumInput()
if len(args)-start < na { if len(args)-start < na {
s.Close() s.finalize()
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
} }
// consume the number of arguments used in the current // consume the number of arguments used in the current
// statement and append all named arguments not // statement and append all named arguments not
// contained therein // contained therein
if len(args[start:start+na]) > 0 { if len(args[start:start+na]) > 0 {
stmtArgs = append(stmtArgs, args[start:start+na]...) stmtArgs = append(stmtArgs[:0], args[start:start+na]...)
for i := range args { for i := range args {
if (i < start || i >= na) && args[i].Name != "" { if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i]) stmtArgs = append(stmtArgs, args[i])
@ -886,23 +914,23 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named
stmtArgs[i].Ordinal = i + 1 stmtArgs[i].Ordinal = i + 1
} }
} }
res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs) var err error
res, err = s.exec(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.finalize()
return nil, err return nil, err
} }
start += na start += na
} }
tail := s.(*SQLiteStmt).t s.finalize()
s.Close() if tail == nil || *tail == '\000' {
if tail == "" {
if res == nil { if res == nil {
// https://github.com/mattn/go-sqlite3/issues/963 // https://github.com/mattn/go-sqlite3/issues/963
res = &SQLiteResult{0, 0} res = &SQLiteResult{0, 0}
} }
return res, nil return res, nil
} }
query = tail pquery = tail
} }
} }
@ -919,14 +947,21 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
} }
func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
pquery := C.CString(query)
op := pquery // original pointer
defer C.free(unsafe.Pointer(op))
var stmtArgs []driver.NamedValue
var tail *C.char
s := new(SQLiteStmt) // escapes to the heap so reuse it
start := 0 start := 0
for { for {
stmtArgs := make([]driver.NamedValue, 0, len(args)) *s = SQLiteStmt{c: c, cls: true} // reset
s, err := c.prepare(ctx, query) rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
if err != nil { if rv != C.SQLITE_OK {
return nil, err return nil, c.lastError()
} }
s.(*SQLiteStmt).cls = true
na := s.NumInput() na := s.NumInput()
if len(args)-start < na { if len(args)-start < na {
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start) return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
@ -934,7 +969,7 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
// consume the number of arguments used in the current // consume the number of arguments used in the current
// statement and append all named arguments not contained // statement and append all named arguments not contained
// therein // therein
stmtArgs = append(stmtArgs, args[start:start+na]...) stmtArgs = append(stmtArgs[:0], args[start:start+na]...)
for i := range args { for i := range args {
if (i < start || i >= na) && args[i].Name != "" { if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i]) stmtArgs = append(stmtArgs, args[i])
@ -943,19 +978,18 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
for i := range stmtArgs { for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1 stmtArgs[i].Ordinal = i + 1
} }
rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs) rows, err := s.query(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.finalize()
return rows, err return rows, err
} }
start += na start += na
tail := s.(*SQLiteStmt).t if tail == nil || *tail == '\000' {
if tail == "" {
return rows, nil return rows, nil
} }
rows.Close() rows.Close()
s.Close() s.finalize()
query = tail pquery = tail
} }
} }
@ -1818,8 +1852,11 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
return nil, c.lastError() return nil, c.lastError()
} }
var t string var t string
if tail != nil && *tail != '\000' { if tail != nil && *tail != 0 {
t = strings.TrimSpace(C.GoString(tail)) n := int(uintptr(unsafe.Pointer(tail))) - int(uintptr(unsafe.Pointer(pquery)))
if 0 <= n && n < len(query) {
t = strings.TrimSpace(query[n:])
}
} }
ss := &SQLiteStmt{c: c, s: s, t: t} ss := &SQLiteStmt{c: c, s: s, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close) runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
@ -1913,6 +1950,13 @@ func (s *SQLiteStmt) Close() error {
return nil return nil
} }
func (s *SQLiteStmt) finalize() {
if s.s != nil {
C.sqlite3_finalize(s.s)
s.s = nil
}
}
// NumInput return a number of parameters. // NumInput return a number of parameters.
func (s *SQLiteStmt) NumInput() int { func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s)) return int(C.sqlite3_bind_parameter_count(s.s))

View File

@ -2111,6 +2111,8 @@ var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkStmt", F: benchmarkStmt}, {Name: "BenchmarkStmt", F: benchmarkStmt},
{Name: "BenchmarkRows", F: benchmarkRows}, {Name: "BenchmarkRows", F: benchmarkRows},
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, {Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
{Name: "BenchmarkExecStep", F: benchmarkExecStep},
{Name: "BenchmarkQueryStep", F: benchmarkQueryStep},
} }
func (db *TestDB) mustExec(sql string, args ...any) sql.Result { func (db *TestDB) mustExec(sql string, args ...any) sql.Result {
@ -2568,3 +2570,22 @@ func benchmarkStmtRows(b *testing.B) {
} }
} }
} }
var largeSelectStmt = strings.Repeat("select 1;\n", 1_000)
func benchmarkExecStep(b *testing.B) {
for n := 0; n < b.N; n++ {
if _, err := db.Exec(largeSelectStmt); err != nil {
b.Fatal(err)
}
}
}
func benchmarkQueryStep(b *testing.B) {
var i int
for n := 0; n < b.N; n++ {
if err := db.QueryRow(largeSelectStmt).Scan(&i); err != nil {
b.Fatal(err)
}
}
}