forked from mirror/go-sqlcipher
Merge branch 'master' into master
This commit is contained in:
commit
132eeedb4a
|
@ -65,7 +65,7 @@ FAQ
|
|||
|
||||
* Want to get time.Time with current locale
|
||||
|
||||
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
|
||||
Use `_loc=auto` in SQLite3 filename schema like `file:foo.db?_loc=auto`.
|
||||
|
||||
* Can I use this in multiple routines concurrently?
|
||||
|
||||
|
|
|
@ -14,6 +14,12 @@ func main() {
|
|||
&sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||
sqlite3conn = append(sqlite3conn, conn)
|
||||
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
|
||||
switch op {
|
||||
case sqlite3.SQLITE_INSERT:
|
||||
log.Println("Notified of insert on db", db, "table", table, "rowid", rowid)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
|
18
callback.go
18
callback.go
|
@ -59,6 +59,24 @@ func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.ch
|
|||
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
|
||||
}
|
||||
|
||||
//export commitHookTrampoline
|
||||
func commitHookTrampoline(handle uintptr) int {
|
||||
callback := lookupHandle(handle).(func() int)
|
||||
return callback()
|
||||
}
|
||||
|
||||
//export rollbackHookTrampoline
|
||||
func rollbackHookTrampoline(handle uintptr) {
|
||||
callback := lookupHandle(handle).(func())
|
||||
callback()
|
||||
}
|
||||
|
||||
//export updateHookTrampoline
|
||||
func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
|
||||
callback := lookupHandle(handle).(func(int, string, string, int64))
|
||||
callback(op, C.GoString(db), C.GoString(table), rowid)
|
||||
}
|
||||
|
||||
// Use handles to avoid passing Go pointers to C.
|
||||
|
||||
type handleVal struct {
|
||||
|
|
152
sqlite3.go
152
sqlite3.go
|
@ -7,7 +7,7 @@ package sqlite3
|
|||
|
||||
/*
|
||||
#cgo CFLAGS: -std=gnu99
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE=1
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
|
||||
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
|
||||
#cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC
|
||||
|
@ -102,6 +102,9 @@ int _sqlite3_create_function(
|
|||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||
|
||||
int compareTrampoline(void*, int, char*, int, char*);
|
||||
int commitHookTrampoline(void*);
|
||||
void rollbackHookTrampoline(void*);
|
||||
void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
|
@ -115,6 +118,7 @@ import (
|
|||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
|
@ -151,6 +155,12 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) {
|
|||
return libVersion, libVersionNumber, sourceID
|
||||
}
|
||||
|
||||
const (
|
||||
SQLITE_DELETE = C.SQLITE_DELETE
|
||||
SQLITE_INSERT = C.SQLITE_INSERT
|
||||
SQLITE_UPDATE = C.SQLITE_UPDATE
|
||||
)
|
||||
|
||||
// SQLiteDriver implement sql.Driver.
|
||||
type SQLiteDriver struct {
|
||||
Extensions []string
|
||||
|
@ -159,6 +169,7 @@ type SQLiteDriver struct {
|
|||
|
||||
// SQLiteConn implement sql.Conn.
|
||||
type SQLiteConn struct {
|
||||
mu sync.Mutex
|
||||
db *C.sqlite3
|
||||
loc *time.Location
|
||||
txlock string
|
||||
|
@ -173,6 +184,7 @@ type SQLiteTx struct {
|
|||
|
||||
// SQLiteStmt implement sql.Stmt.
|
||||
type SQLiteStmt struct {
|
||||
mu sync.Mutex
|
||||
c *SQLiteConn
|
||||
s *C.sqlite3_stmt
|
||||
t string
|
||||
|
@ -193,6 +205,7 @@ type SQLiteRows struct {
|
|||
cols []string
|
||||
decltype []string
|
||||
cls bool
|
||||
closed bool
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
|
@ -338,6 +351,51 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int
|
|||
return nil
|
||||
}
|
||||
|
||||
// RegisterCommitHook sets the commit hook for a connection.
|
||||
//
|
||||
// If the callback returns non-zero the transaction will become a rollback.
|
||||
//
|
||||
// If there is an existing commit hook for this connection, it will be
|
||||
// removed. If callback is nil the existing hook (if any) will be removed
|
||||
// without creating a new one.
|
||||
func (c *SQLiteConn) RegisterCommitHook(callback func() int) {
|
||||
if callback == nil {
|
||||
C.sqlite3_commit_hook(c.db, nil, nil)
|
||||
} else {
|
||||
C.sqlite3_commit_hook(c.db, (*[0]byte)(unsafe.Pointer(C.commitHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRollbackHook sets the rollback hook for a connection.
|
||||
//
|
||||
// If there is an existing rollback hook for this connection, it will be
|
||||
// removed. If callback is nil the existing hook (if any) will be removed
|
||||
// without creating a new one.
|
||||
func (c *SQLiteConn) RegisterRollbackHook(callback func()) {
|
||||
if callback == nil {
|
||||
C.sqlite3_rollback_hook(c.db, nil, nil)
|
||||
} else {
|
||||
C.sqlite3_rollback_hook(c.db, (*[0]byte)(unsafe.Pointer(C.rollbackHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterUpdateHook sets the update hook for a connection.
|
||||
//
|
||||
// The parameters to the callback are the operation (one of the constants
|
||||
// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the
|
||||
// table name, and the rowid.
|
||||
//
|
||||
// If there is an existing update hook for this connection, it will be
|
||||
// removed. If callback is nil the existing hook (if any) will be removed
|
||||
// without creating a new one.
|
||||
func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) {
|
||||
if callback == nil {
|
||||
C.sqlite3_update_hook(c.db, nil, nil)
|
||||
} else {
|
||||
C.sqlite3_update_hook(c.db, (*[0]byte)(unsafe.Pointer(C.updateHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterFunc makes a Go function available as a SQLite function.
|
||||
//
|
||||
// The Go function can have arguments of the following types: any
|
||||
|
@ -568,6 +626,8 @@ func errorString(err Error) string {
|
|||
// "deferred", "exclusive".
|
||||
// _foreign_keys=X
|
||||
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
|
||||
// _recursive_triggers=X
|
||||
// Enable or disable recursive triggers. X can be 1 or 0.
|
||||
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||
if C.sqlite3_threadsafe() == 0 {
|
||||
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
||||
|
@ -577,6 +637,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
txlock := "BEGIN"
|
||||
busyTimeout := 5000
|
||||
foreignKeys := -1
|
||||
recursiveTriggers := -1
|
||||
pos := strings.IndexRune(dsn, '?')
|
||||
if pos >= 1 {
|
||||
params, err := url.ParseQuery(dsn[pos+1:])
|
||||
|
@ -631,6 +692,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// _recursive_triggers
|
||||
if val := params.Get("_recursive_triggers"); val != "" {
|
||||
switch val {
|
||||
case "1":
|
||||
recursiveTriggers = 1
|
||||
case "0":
|
||||
recursiveTriggers = 0
|
||||
default:
|
||||
return nil, fmt.Errorf("Invalid _recursive_triggers: %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(dsn, "file:") {
|
||||
dsn = dsn[:pos]
|
||||
}
|
||||
|
@ -677,6 +750,17 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if recursiveTriggers == 0 {
|
||||
if err := exec("PRAGMA recursive_triggers = OFF;"); err != nil {
|
||||
C.sqlite3_close_v2(db)
|
||||
return nil, err
|
||||
}
|
||||
} else if recursiveTriggers == 1 {
|
||||
if err := exec("PRAGMA recursive_triggers = ON;"); err != nil {
|
||||
C.sqlite3_close_v2(db)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
|
||||
|
||||
|
@ -704,11 +788,22 @@ func (c *SQLiteConn) Close() error {
|
|||
return c.lastError()
|
||||
}
|
||||
deleteHandles(c)
|
||||
c.mu.Lock()
|
||||
c.db = nil
|
||||
c.mu.Unlock()
|
||||
runtime.SetFinalizer(c, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) dbConnOpen() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.db != nil
|
||||
}
|
||||
|
||||
// Prepare the query string. Return a new statement.
|
||||
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
|
||||
return c.prepare(context.Background(), query)
|
||||
|
@ -734,14 +829,17 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
|
|||
|
||||
// Close the statement.
|
||||
func (s *SQLiteStmt) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
if s.c == nil || s.c.db == nil {
|
||||
if !s.c.dbConnOpen() {
|
||||
return errors.New("sqlite statement with already closed database connection")
|
||||
}
|
||||
rv := C.sqlite3_finalize(s.s)
|
||||
s.s = nil
|
||||
if rv != C.SQLITE_OK {
|
||||
return s.c.lastError()
|
||||
}
|
||||
|
@ -759,6 +857,8 @@ type bindArg struct {
|
|||
v driver.Value
|
||||
}
|
||||
|
||||
var placeHolder = []byte{0}
|
||||
|
||||
func (s *SQLiteStmt) bind(args []namedValue) error {
|
||||
rv := C.sqlite3_reset(s.s)
|
||||
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
||||
|
@ -780,8 +880,7 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
|
|||
rv = C.sqlite3_bind_null(s.s, n)
|
||||
case string:
|
||||
if len(v) == 0 {
|
||||
b := []byte{0}
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
|
||||
} else {
|
||||
b := []byte(v)
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||
|
@ -797,11 +896,11 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
|
|||
case float64:
|
||||
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
rv = C._sqlite3_bind_blob(s.s, n, nil, 0)
|
||||
} else {
|
||||
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v)))
|
||||
ln := len(v)
|
||||
if ln == 0 {
|
||||
v = placeHolder
|
||||
}
|
||||
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
|
||||
case time.Time:
|
||||
b := []byte(v.Format(SQLiteTimestampFormats[0]))
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||
|
@ -836,6 +935,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
|
|||
cols: nil,
|
||||
decltype: nil,
|
||||
cls: s.cls,
|
||||
closed: false,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
|
@ -908,25 +1008,33 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
|
|||
|
||||
// Close the rows.
|
||||
func (rc *SQLiteRows) Close() error {
|
||||
if rc.s.closed {
|
||||
rc.s.mu.Lock()
|
||||
if rc.s.closed || rc.closed {
|
||||
rc.s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
rc.closed = true
|
||||
if rc.done != nil {
|
||||
close(rc.done)
|
||||
}
|
||||
if rc.cls {
|
||||
rc.s.mu.Unlock()
|
||||
return rc.s.Close()
|
||||
}
|
||||
rv := C.sqlite3_reset(rc.s.s)
|
||||
if rv != C.SQLITE_OK {
|
||||
rc.s.mu.Unlock()
|
||||
return rc.s.c.lastError()
|
||||
}
|
||||
rc.s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Columns return column names.
|
||||
func (rc *SQLiteRows) Columns() []string {
|
||||
if rc.nc != len(rc.cols) {
|
||||
rc.s.mu.Lock()
|
||||
defer rc.s.mu.Unlock()
|
||||
if rc.s.s != nil && rc.nc != len(rc.cols) {
|
||||
rc.cols = make([]string, rc.nc)
|
||||
for i := 0; i < rc.nc; i++ {
|
||||
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
|
||||
|
@ -935,9 +1043,8 @@ func (rc *SQLiteRows) Columns() []string {
|
|||
return rc.cols
|
||||
}
|
||||
|
||||
// DeclTypes return column types.
|
||||
func (rc *SQLiteRows) DeclTypes() []string {
|
||||
if rc.decltype == nil {
|
||||
func (rc *SQLiteRows) declTypes() []string {
|
||||
if rc.s.s != nil && rc.decltype == nil {
|
||||
rc.decltype = make([]string, rc.nc)
|
||||
for i := 0; i < rc.nc; i++ {
|
||||
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
|
||||
|
@ -946,8 +1053,20 @@ func (rc *SQLiteRows) DeclTypes() []string {
|
|||
return rc.decltype
|
||||
}
|
||||
|
||||
// DeclTypes return column types.
|
||||
func (rc *SQLiteRows) DeclTypes() []string {
|
||||
rc.s.mu.Lock()
|
||||
defer rc.s.mu.Unlock()
|
||||
return rc.declTypes()
|
||||
}
|
||||
|
||||
// Next move cursor to next.
|
||||
func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||
if rc.s.closed {
|
||||
return io.EOF
|
||||
}
|
||||
rc.s.mu.Lock()
|
||||
defer rc.s.mu.Unlock()
|
||||
rv := C.sqlite3_step(rc.s.s)
|
||||
if rv == C.SQLITE_DONE {
|
||||
return io.EOF
|
||||
|
@ -960,7 +1079,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
rc.DeclTypes()
|
||||
rc.declTypes()
|
||||
|
||||
for i := range dest {
|
||||
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
|
||||
|
@ -973,10 +1092,11 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
|||
// large to be a reasonable timestamp in seconds.
|
||||
if val > 1e12 || val < -1e12 {
|
||||
val *= int64(time.Millisecond) // convert ms to nsec
|
||||
t = time.Unix(0, val)
|
||||
} else {
|
||||
val *= int64(time.Second) // convert sec to nsec
|
||||
t = time.Unix(val, 0)
|
||||
}
|
||||
t = time.Unix(0, val).UTC()
|
||||
t = t.UTC()
|
||||
if rc.s.c.loc != nil {
|
||||
t = t.In(rc.s.c.loc)
|
||||
}
|
||||
|
|
|
@ -8,9 +8,13 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNamedParams(t *testing.T) {
|
||||
|
@ -48,3 +52,91 @@ func TestNamedParams(t *testing.T) {
|
|||
t.Error("Failed to db.QueryRow: not matched results")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
testTableStatements = []string{
|
||||
`DROP TABLE IF EXISTS test_table`,
|
||||
`
|
||||
CREATE TABLE IF NOT EXISTS test_table (
|
||||
key1 VARCHAR(64) PRIMARY KEY,
|
||||
key_id VARCHAR(64) NOT NULL,
|
||||
key2 VARCHAR(64) NOT NULL,
|
||||
key3 VARCHAR(64) NOT NULL,
|
||||
key4 VARCHAR(64) NOT NULL,
|
||||
key5 VARCHAR(64) NOT NULL,
|
||||
key6 VARCHAR(64) NOT NULL,
|
||||
data BLOB NOT NULL
|
||||
);`,
|
||||
}
|
||||
letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
)
|
||||
|
||||
func randStringBytes(n int) string {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func initDatabase(t *testing.T, db *sql.DB, rowCount int64) {
|
||||
t.Logf("Executing db initializing statements")
|
||||
for _, query := range testTableStatements {
|
||||
_, err := db.Exec(query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
for i := int64(0); i < rowCount; i++ {
|
||||
query := `INSERT INTO test_table
|
||||
(key1, key_id, key2, key3, key4, key5, key6, data)
|
||||
VALUES
|
||||
(?, ?, ?, ?, ?, ?, ?, ?);`
|
||||
args := []interface{}{
|
||||
randStringBytes(50),
|
||||
fmt.Sprint(i),
|
||||
randStringBytes(50),
|
||||
randStringBytes(50),
|
||||
randStringBytes(50),
|
||||
randStringBytes(50),
|
||||
randStringBytes(50),
|
||||
randStringBytes(50),
|
||||
randStringBytes(2048),
|
||||
}
|
||||
_, err := db.Exec(query, args...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShortTimeout(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", "file::memory:?mode=memory&cache=shared")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
initDatabase(t, db, 10000)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Microsecond)
|
||||
defer cancel()
|
||||
query := `SELECT key1, key_id, key2, key3, key4, key5, key6, data
|
||||
FROM test_table
|
||||
ORDER BY key2 ASC`
|
||||
rows, err := db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var key1, keyid, key2, key3, key4, key5, key6 string
|
||||
var data []byte
|
||||
err = rows.Scan(&key1, &keyid, &key2, &key3, &key4, &key5, &key6, &data)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if context.DeadlineExceeded != ctx.Err() {
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,5 +10,6 @@ package sqlite3
|
|||
#cgo CFLAGS: -DUSE_LIBSQLITE3
|
||||
#cgo linux LDFLAGS: -lsqlite3
|
||||
#cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3
|
||||
#cgo solaris LDFLAGS: -lsqlite3
|
||||
*/
|
||||
import "C"
|
||||
|
|
|
@ -9,5 +9,6 @@ package sqlite3
|
|||
/*
|
||||
#cgo CFLAGS: -I.
|
||||
#cgo linux LDFLAGS: -ldl
|
||||
#cgo solaris LDFLAGS: -lc
|
||||
*/
|
||||
import "C"
|
||||
|
|
582
sqlite3_test.go
582
sqlite3_test.go
|
@ -6,21 +6,22 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mattn/go-sqlite3/sqlite3_test"
|
||||
)
|
||||
|
||||
func TempFilename(t *testing.T) string {
|
||||
|
@ -136,6 +137,35 @@ func TestForeignKeys(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRecursiveTriggers(t *testing.T) {
|
||||
cases := map[string]bool{
|
||||
"?_recursive_triggers=1": true,
|
||||
"?_recursive_triggers=0": false,
|
||||
}
|
||||
for option, want := range cases {
|
||||
fname := TempFilename(t)
|
||||
uri := "file:" + fname + option
|
||||
db, err := sql.Open("sqlite3", uri)
|
||||
if err != nil {
|
||||
os.Remove(fname)
|
||||
t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
|
||||
continue
|
||||
}
|
||||
var enabled bool
|
||||
err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled)
|
||||
db.Close()
|
||||
os.Remove(fname)
|
||||
if err != nil {
|
||||
t.Errorf("query recursive_triggers for %s: %v", uri, err)
|
||||
continue
|
||||
}
|
||||
if enabled != want {
|
||||
t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
|
@ -403,6 +433,7 @@ func TestTimestamp(t *testing.T) {
|
|||
}{
|
||||
{"nonsense", time.Time{}},
|
||||
{"0000-00-00 00:00:00", time.Time{}},
|
||||
{time.Time{}.Unix(), time.Time{}},
|
||||
{timestamp1, timestamp1},
|
||||
{timestamp2.Unix(), timestamp2.Truncate(time.Second)},
|
||||
{timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)},
|
||||
|
@ -840,18 +871,6 @@ func TestTimezoneConversion(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSuite(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE)
|
||||
}
|
||||
|
||||
// TODO: Execer & Queryer currently disabled
|
||||
// https://github.com/mattn/go-sqlite3/issues/82
|
||||
func TestExecer(t *testing.T) {
|
||||
|
@ -1385,6 +1404,122 @@ func TestPinger(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestUpdateAndTransactionHooks(t *testing.T) {
|
||||
var events []string
|
||||
var commitHookReturn = 0
|
||||
|
||||
sql.Register("sqlite3_UpdateHook", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
conn.RegisterCommitHook(func() int {
|
||||
events = append(events, "commit")
|
||||
return commitHookReturn
|
||||
})
|
||||
conn.RegisterRollbackHook(func() {
|
||||
events = append(events, "rollback")
|
||||
})
|
||||
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
|
||||
events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid))
|
||||
})
|
||||
return nil
|
||||
},
|
||||
})
|
||||
db, err := sql.Open("sqlite3_UpdateHook", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
statements := []string{
|
||||
"create table foo (id integer primary key)",
|
||||
"insert into foo values (9)",
|
||||
"update foo set id = 99 where id = 9",
|
||||
"delete from foo where id = 99",
|
||||
}
|
||||
for _, statement := range statements {
|
||||
_, err = db.Exec(statement)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to prepare test data [%v]: %v", statement, err)
|
||||
}
|
||||
}
|
||||
|
||||
commitHookReturn = 1
|
||||
_, err = db.Exec("insert into foo values (5)")
|
||||
if err == nil {
|
||||
t.Error("Commit hook failed to rollback transaction")
|
||||
}
|
||||
|
||||
var expected = []string{
|
||||
"commit",
|
||||
fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT),
|
||||
"commit",
|
||||
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE),
|
||||
"commit",
|
||||
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE),
|
||||
"commit",
|
||||
fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT),
|
||||
"commit",
|
||||
"rollback",
|
||||
}
|
||||
if !reflect.DeepEqual(events, expected) {
|
||||
t.Errorf("Expected notifications %v but got %v", expected, events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilAndEmptyBytes(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
actualNil := []byte("use this to use an actual nil not a reference to nil")
|
||||
emptyBytes := []byte{}
|
||||
for tsti, tst := range []struct {
|
||||
name string
|
||||
columnType string
|
||||
insertBytes []byte
|
||||
expectedBytes []byte
|
||||
}{
|
||||
{"actual nil blob", "blob", actualNil, nil},
|
||||
{"referenced nil blob", "blob", nil, nil},
|
||||
{"empty blob", "blob", emptyBytes, emptyBytes},
|
||||
{"actual nil text", "text", actualNil, nil},
|
||||
{"referenced nil text", "text", nil, nil},
|
||||
{"empty text", "text", emptyBytes, emptyBytes},
|
||||
} {
|
||||
if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
if bytes.Equal(tst.insertBytes, actualNil) {
|
||||
if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
} else {
|
||||
if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
}
|
||||
rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti))
|
||||
if err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
if !rows.Next() {
|
||||
t.Fatal(tst.name, "no rows")
|
||||
}
|
||||
var scanBytes []byte
|
||||
if err = rows.Scan(&scanBytes); err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
if tst.expectedBytes == nil && scanBytes != nil {
|
||||
t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
|
||||
} else if !bytes.Equal(scanBytes, tst.expectedBytes) {
|
||||
t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var customFunctionOnce sync.Once
|
||||
|
||||
func BenchmarkCustomFunctions(b *testing.B) {
|
||||
|
@ -1419,3 +1554,422 @@ func BenchmarkCustomFunctions(b *testing.B) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuite(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
db = &TestDB{t, d, SQLITE, sync.Once{}}
|
||||
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
|
||||
|
||||
if !testing.Short() {
|
||||
for _, b := range benchmarks {
|
||||
fmt.Printf("%-20s", b.Name)
|
||||
r := testing.Benchmark(b.F)
|
||||
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
|
||||
}
|
||||
}
|
||||
db.tearDown()
|
||||
}
|
||||
|
||||
// Dialect is a type of dialect of databases.
|
||||
type Dialect int
|
||||
|
||||
// Dialects for databases.
|
||||
const (
|
||||
SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
|
||||
POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
|
||||
MYSQL // MYSQL mean MySQL dialect
|
||||
)
|
||||
|
||||
// DB provide context for the tests
|
||||
type TestDB struct {
|
||||
*testing.T
|
||||
*sql.DB
|
||||
dialect Dialect
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
var db *TestDB
|
||||
|
||||
// the following tables will be created and dropped during the test
|
||||
var testTables = []string{"foo", "bar", "t", "bench"}
|
||||
|
||||
var tests = []testing.InternalTest{
|
||||
{Name: "TestResult", F: testResult},
|
||||
{Name: "TestBlobs", F: testBlobs},
|
||||
{Name: "TestManyQueryRow", F: testManyQueryRow},
|
||||
{Name: "TestTxQuery", F: testTxQuery},
|
||||
{Name: "TestPreparedStmt", F: testPreparedStmt},
|
||||
}
|
||||
|
||||
var benchmarks = []testing.InternalBenchmark{
|
||||
{Name: "BenchmarkExec", F: benchmarkExec},
|
||||
{Name: "BenchmarkQuery", F: benchmarkQuery},
|
||||
{Name: "BenchmarkParams", F: benchmarkParams},
|
||||
{Name: "BenchmarkStmt", F: benchmarkStmt},
|
||||
{Name: "BenchmarkRows", F: benchmarkRows},
|
||||
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
|
||||
}
|
||||
|
||||
func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result {
|
||||
res, err := db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
db.Fatalf("Error running %q: %v", sql, err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (db *TestDB) tearDown() {
|
||||
for _, tbl := range testTables {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
db.mustExec("drop table if exists " + tbl)
|
||||
case MYSQL, POSTGRESQL:
|
||||
db.mustExec("drop table if exists " + tbl)
|
||||
default:
|
||||
db.Fatal("unknown dialect")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// q replaces ? parameters if needed
|
||||
func (db *TestDB) q(sql string) string {
|
||||
switch db.dialect {
|
||||
case POSTGRESQL: // repace with $1, $2, ..
|
||||
qrx := regexp.MustCompile(`\?`)
|
||||
n := 0
|
||||
return qrx.ReplaceAllStringFunc(sql, func(string) string {
|
||||
n++
|
||||
return "$" + strconv.Itoa(n)
|
||||
})
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (db *TestDB) blobType(size int) string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return fmt.Sprintf("blob[%d]", size)
|
||||
case POSTGRESQL:
|
||||
return "bytea"
|
||||
case MYSQL:
|
||||
return fmt.Sprintf("VARBINARY(%d)", size)
|
||||
}
|
||||
panic("unknown dialect")
|
||||
}
|
||||
|
||||
func (db *TestDB) serialPK() string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return "integer primary key autoincrement"
|
||||
case POSTGRESQL:
|
||||
return "serial primary key"
|
||||
case MYSQL:
|
||||
return "integer primary key auto_increment"
|
||||
}
|
||||
panic("unknown dialect")
|
||||
}
|
||||
|
||||
func (db *TestDB) now() string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return "datetime('now')"
|
||||
case POSTGRESQL:
|
||||
return "now()"
|
||||
case MYSQL:
|
||||
return "now()"
|
||||
}
|
||||
panic("unknown dialect")
|
||||
}
|
||||
|
||||
func makeBench() {
|
||||
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
for i := 0; i < 100; i++ {
|
||||
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testResult is test for result
|
||||
func testResult(t *testing.T) {
|
||||
db.tearDown()
|
||||
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
|
||||
|
||||
for i := 1; i < 3; i++ {
|
||||
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
|
||||
n, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("got %v, want %v", n, 1)
|
||||
}
|
||||
n, err = r.LastInsertId()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != int64(i) {
|
||||
t.Errorf("got %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec("error!"); err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
// testBlobs is test for blobs
|
||||
func testBlobs(t *testing.T) {
|
||||
db.tearDown()
|
||||
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
|
||||
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
|
||||
|
||||
want := fmt.Sprintf("%x", blob)
|
||||
|
||||
b := make([]byte, 16)
|
||||
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
|
||||
got := fmt.Sprintf("%x", b)
|
||||
if err != nil {
|
||||
t.Errorf("[]byte scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for []byte, got %q; want %q", got, want)
|
||||
}
|
||||
|
||||
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
|
||||
want = string(blob)
|
||||
if err != nil {
|
||||
t.Errorf("string scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for string, got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// testManyQueryRow is test for many query row
|
||||
func testManyQueryRow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Log("skipping in short mode")
|
||||
return
|
||||
}
|
||||
db.tearDown()
|
||||
db.mustExec("create table foo (id integer primary key, name varchar(50))")
|
||||
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||
var name string
|
||||
for i := 0; i < 10000; i++ {
|
||||
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
|
||||
if err != nil || name != "bob" {
|
||||
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testTxQuery is test for transactional query
|
||||
func testTxQuery(t *testing.T) {
|
||||
db.tearDown()
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Fatal("expected one rows")
|
||||
}
|
||||
|
||||
var name string
|
||||
err = r.Scan(&name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// testPreparedStmt is test for prepared statement
|
||||
func testPreparedStmt(t *testing.T) {
|
||||
db.tearDown()
|
||||
db.mustExec("CREATE TABLE t (count INT)")
|
||||
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 1: %v", err)
|
||||
}
|
||||
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 2: %v", err)
|
||||
}
|
||||
|
||||
for n := 1; n <= 3; n++ {
|
||||
if _, err := ins.Exec(n); err != nil {
|
||||
t.Fatalf("insert(%d) = %v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
const nRuns = 10
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nRuns; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
count := 0
|
||||
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
|
||||
t.Errorf("Query: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := ins.Exec(rand.Intn(100)); err != nil {
|
||||
t.Errorf("Insert: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Benchmarks need to use panic() since b.Error errors are lost when
|
||||
// running via testing.Benchmark() I would like to run these via go
|
||||
// test -bench but calling Benchmark() from a benchmark test
|
||||
// currently hangs go.
|
||||
|
||||
// benchmarkExec is benchmark for exec
|
||||
func benchmarkExec(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := db.Exec("select 1"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkQuery is benchmark for query
|
||||
func benchmarkQuery(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkParams is benchmark for params
|
||||
func benchmarkParams(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkStmt is benchmark for statement
|
||||
func benchmarkStmt(b *testing.B) {
|
||||
st, err := db.Prepare("select ?, ?, ?, ?")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkRows is benchmark for rows
|
||||
func benchmarkRows(b *testing.B) {
|
||||
db.once.Do(makeBench)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
var t time.Time
|
||||
r, err := db.Query("select * from bench")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for r.Next() {
|
||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err = r.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkStmtRows is benchmark for statement rows
|
||||
func benchmarkStmtRows(b *testing.B) {
|
||||
db.once.Do(makeBench)
|
||||
|
||||
st, err := db.Prepare("select * from bench")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
var t time.Time
|
||||
r, err := st.Query()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for r.Next() {
|
||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err = r.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,423 +0,0 @@
|
|||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Dialect is a type of dialect of databases.
|
||||
type Dialect int
|
||||
|
||||
// Dialects for databases.
|
||||
const (
|
||||
SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
|
||||
POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
|
||||
MYSQL // MYSQL mean MySQL dialect
|
||||
)
|
||||
|
||||
// DB provide context for the tests
|
||||
type DB struct {
|
||||
*testing.T
|
||||
*sql.DB
|
||||
dialect Dialect
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
var db *DB
|
||||
|
||||
// the following tables will be created and dropped during the test
|
||||
var testTables = []string{"foo", "bar", "t", "bench"}
|
||||
|
||||
var tests = []testing.InternalTest{
|
||||
{Name: "TestBlobs", F: TestBlobs},
|
||||
{Name: "TestManyQueryRow", F: TestManyQueryRow},
|
||||
{Name: "TestTxQuery", F: TestTxQuery},
|
||||
{Name: "TestPreparedStmt", F: TestPreparedStmt},
|
||||
}
|
||||
|
||||
var benchmarks = []testing.InternalBenchmark{
|
||||
{Name: "BenchmarkExec", F: BenchmarkExec},
|
||||
{Name: "BenchmarkQuery", F: BenchmarkQuery},
|
||||
{Name: "BenchmarkParams", F: BenchmarkParams},
|
||||
{Name: "BenchmarkStmt", F: BenchmarkStmt},
|
||||
{Name: "BenchmarkRows", F: BenchmarkRows},
|
||||
{Name: "BenchmarkStmtRows", F: BenchmarkStmtRows},
|
||||
}
|
||||
|
||||
// RunTests runs the SQL test suite
|
||||
func RunTests(t *testing.T, d *sql.DB, dialect Dialect) {
|
||||
db = &DB{t, d, dialect, sync.Once{}}
|
||||
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
|
||||
|
||||
if !testing.Short() {
|
||||
for _, b := range benchmarks {
|
||||
fmt.Printf("%-20s", b.Name)
|
||||
r := testing.Benchmark(b.F)
|
||||
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
|
||||
}
|
||||
}
|
||||
db.tearDown()
|
||||
}
|
||||
|
||||
func (db *DB) mustExec(sql string, args ...interface{}) sql.Result {
|
||||
res, err := db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
db.Fatalf("Error running %q: %v", sql, err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (db *DB) tearDown() {
|
||||
for _, tbl := range testTables {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
db.mustExec("drop table if exists " + tbl)
|
||||
case MYSQL, POSTGRESQL:
|
||||
db.mustExec("drop table if exists " + tbl)
|
||||
default:
|
||||
db.Fatal("unknown dialect")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// q replaces ? parameters if needed
|
||||
func (db *DB) q(sql string) string {
|
||||
switch db.dialect {
|
||||
case POSTGRESQL: // repace with $1, $2, ..
|
||||
qrx := regexp.MustCompile(`\?`)
|
||||
n := 0
|
||||
return qrx.ReplaceAllStringFunc(sql, func(string) string {
|
||||
n++
|
||||
return "$" + strconv.Itoa(n)
|
||||
})
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (db *DB) blobType(size int) string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return fmt.Sprintf("blob[%d]", size)
|
||||
case POSTGRESQL:
|
||||
return "bytea"
|
||||
case MYSQL:
|
||||
return fmt.Sprintf("VARBINARY(%d)", size)
|
||||
}
|
||||
panic("unknown dialect")
|
||||
}
|
||||
|
||||
func (db *DB) serialPK() string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return "integer primary key autoincrement"
|
||||
case POSTGRESQL:
|
||||
return "serial primary key"
|
||||
case MYSQL:
|
||||
return "integer primary key auto_increment"
|
||||
}
|
||||
panic("unknown dialect")
|
||||
}
|
||||
|
||||
func (db *DB) now() string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return "datetime('now')"
|
||||
case POSTGRESQL:
|
||||
return "now()"
|
||||
case MYSQL:
|
||||
return "now()"
|
||||
}
|
||||
panic("unknown dialect")
|
||||
}
|
||||
|
||||
func makeBench() {
|
||||
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
for i := 0; i < 100; i++ {
|
||||
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestResult is test for result
|
||||
func TestResult(t *testing.T) {
|
||||
db.tearDown()
|
||||
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
|
||||
|
||||
for i := 1; i < 3; i++ {
|
||||
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
|
||||
n, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("got %v, want %v", n, 1)
|
||||
}
|
||||
n, err = r.LastInsertId()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != int64(i) {
|
||||
t.Errorf("got %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec("error!"); err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBlobs is test for blobs
|
||||
func TestBlobs(t *testing.T) {
|
||||
db.tearDown()
|
||||
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
|
||||
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
|
||||
|
||||
want := fmt.Sprintf("%x", blob)
|
||||
|
||||
b := make([]byte, 16)
|
||||
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
|
||||
got := fmt.Sprintf("%x", b)
|
||||
if err != nil {
|
||||
t.Errorf("[]byte scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for []byte, got %q; want %q", got, want)
|
||||
}
|
||||
|
||||
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
|
||||
want = string(blob)
|
||||
if err != nil {
|
||||
t.Errorf("string scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for string, got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestManyQueryRow is test for many query row
|
||||
func TestManyQueryRow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Log("skipping in short mode")
|
||||
return
|
||||
}
|
||||
db.tearDown()
|
||||
db.mustExec("create table foo (id integer primary key, name varchar(50))")
|
||||
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||
var name string
|
||||
for i := 0; i < 10000; i++ {
|
||||
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
|
||||
if err != nil || name != "bob" {
|
||||
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestTxQuery is test for transactional query
|
||||
func TestTxQuery(t *testing.T) {
|
||||
db.tearDown()
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Fatal("expected one rows")
|
||||
}
|
||||
|
||||
var name string
|
||||
err = r.Scan(&name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPreparedStmt is test for prepared statement
|
||||
func TestPreparedStmt(t *testing.T) {
|
||||
db.tearDown()
|
||||
db.mustExec("CREATE TABLE t (count INT)")
|
||||
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 1: %v", err)
|
||||
}
|
||||
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 2: %v", err)
|
||||
}
|
||||
|
||||
for n := 1; n <= 3; n++ {
|
||||
if _, err := ins.Exec(n); err != nil {
|
||||
t.Fatalf("insert(%d) = %v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
const nRuns = 10
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nRuns; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
count := 0
|
||||
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
|
||||
t.Errorf("Query: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := ins.Exec(rand.Intn(100)); err != nil {
|
||||
t.Errorf("Insert: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Benchmarks need to use panic() since b.Error errors are lost when
|
||||
// running via testing.Benchmark() I would like to run these via go
|
||||
// test -bench but calling Benchmark() from a benchmark test
|
||||
// currently hangs go.
|
||||
|
||||
// BenchmarkExec is benchmark for exec
|
||||
func BenchmarkExec(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := db.Exec("select 1"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkQuery is benchmark for query
|
||||
func BenchmarkQuery(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkParams is benchmark for params
|
||||
func BenchmarkParams(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStmt is benchmark for statement
|
||||
func BenchmarkStmt(b *testing.B) {
|
||||
st, err := db.Prepare("select ?, ?, ?, ?")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRows is benchmark for rows
|
||||
func BenchmarkRows(b *testing.B) {
|
||||
db.once.Do(makeBench)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
var t time.Time
|
||||
r, err := db.Query("select * from bench")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for r.Next() {
|
||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err = r.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkStmtRows is benchmark for statement rows
|
||||
func BenchmarkStmtRows(b *testing.B) {
|
||||
db.once.Do(makeBench)
|
||||
|
||||
st, err := db.Prepare("select * from bench")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
var t time.Time
|
||||
r, err := st.Query()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for r.Next() {
|
||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err = r.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue