forked from mirror/go-sqlcipher
Merge branch 'master' into master
This commit is contained in:
commit
132eeedb4a
|
@ -65,7 +65,7 @@ FAQ
|
||||||
|
|
||||||
* Want to get time.Time with current locale
|
* Want to get time.Time with current locale
|
||||||
|
|
||||||
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
|
Use `_loc=auto` in SQLite3 filename schema like `file:foo.db?_loc=auto`.
|
||||||
|
|
||||||
* Can I use this in multiple routines concurrently?
|
* Can I use this in multiple routines concurrently?
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,12 @@ func main() {
|
||||||
&sqlite3.SQLiteDriver{
|
&sqlite3.SQLiteDriver{
|
||||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||||
sqlite3conn = append(sqlite3conn, conn)
|
sqlite3conn = append(sqlite3conn, conn)
|
||||||
|
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
|
||||||
|
switch op {
|
||||||
|
case sqlite3.SQLITE_INSERT:
|
||||||
|
log.Println("Notified of insert on db", db, "table", table, "rowid", rowid)
|
||||||
|
}
|
||||||
|
})
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
18
callback.go
18
callback.go
|
@ -59,6 +59,24 @@ func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.ch
|
||||||
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
|
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//export commitHookTrampoline
|
||||||
|
func commitHookTrampoline(handle uintptr) int {
|
||||||
|
callback := lookupHandle(handle).(func() int)
|
||||||
|
return callback()
|
||||||
|
}
|
||||||
|
|
||||||
|
//export rollbackHookTrampoline
|
||||||
|
func rollbackHookTrampoline(handle uintptr) {
|
||||||
|
callback := lookupHandle(handle).(func())
|
||||||
|
callback()
|
||||||
|
}
|
||||||
|
|
||||||
|
//export updateHookTrampoline
|
||||||
|
func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
|
||||||
|
callback := lookupHandle(handle).(func(int, string, string, int64))
|
||||||
|
callback(op, C.GoString(db), C.GoString(table), rowid)
|
||||||
|
}
|
||||||
|
|
||||||
// Use handles to avoid passing Go pointers to C.
|
// Use handles to avoid passing Go pointers to C.
|
||||||
|
|
||||||
type handleVal struct {
|
type handleVal struct {
|
||||||
|
|
152
sqlite3.go
152
sqlite3.go
|
@ -7,7 +7,7 @@ package sqlite3
|
||||||
|
|
||||||
/*
|
/*
|
||||||
#cgo CFLAGS: -std=gnu99
|
#cgo CFLAGS: -std=gnu99
|
||||||
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
|
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE=1
|
||||||
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
|
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
|
||||||
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
|
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
|
||||||
#cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC
|
#cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC
|
||||||
|
@ -102,6 +102,9 @@ int _sqlite3_create_function(
|
||||||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||||
|
|
||||||
int compareTrampoline(void*, int, char*, int, char*);
|
int compareTrampoline(void*, int, char*, int, char*);
|
||||||
|
int commitHookTrampoline(void*);
|
||||||
|
void rollbackHookTrampoline(void*);
|
||||||
|
void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
|
@ -115,6 +118,7 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
@ -151,6 +155,12 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) {
|
||||||
return libVersion, libVersionNumber, sourceID
|
return libVersion, libVersionNumber, sourceID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
SQLITE_DELETE = C.SQLITE_DELETE
|
||||||
|
SQLITE_INSERT = C.SQLITE_INSERT
|
||||||
|
SQLITE_UPDATE = C.SQLITE_UPDATE
|
||||||
|
)
|
||||||
|
|
||||||
// SQLiteDriver implement sql.Driver.
|
// SQLiteDriver implement sql.Driver.
|
||||||
type SQLiteDriver struct {
|
type SQLiteDriver struct {
|
||||||
Extensions []string
|
Extensions []string
|
||||||
|
@ -159,6 +169,7 @@ type SQLiteDriver struct {
|
||||||
|
|
||||||
// SQLiteConn implement sql.Conn.
|
// SQLiteConn implement sql.Conn.
|
||||||
type SQLiteConn struct {
|
type SQLiteConn struct {
|
||||||
|
mu sync.Mutex
|
||||||
db *C.sqlite3
|
db *C.sqlite3
|
||||||
loc *time.Location
|
loc *time.Location
|
||||||
txlock string
|
txlock string
|
||||||
|
@ -173,6 +184,7 @@ type SQLiteTx struct {
|
||||||
|
|
||||||
// SQLiteStmt implement sql.Stmt.
|
// SQLiteStmt implement sql.Stmt.
|
||||||
type SQLiteStmt struct {
|
type SQLiteStmt struct {
|
||||||
|
mu sync.Mutex
|
||||||
c *SQLiteConn
|
c *SQLiteConn
|
||||||
s *C.sqlite3_stmt
|
s *C.sqlite3_stmt
|
||||||
t string
|
t string
|
||||||
|
@ -193,6 +205,7 @@ type SQLiteRows struct {
|
||||||
cols []string
|
cols []string
|
||||||
decltype []string
|
decltype []string
|
||||||
cls bool
|
cls bool
|
||||||
|
closed bool
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -338,6 +351,51 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterCommitHook sets the commit hook for a connection.
|
||||||
|
//
|
||||||
|
// If the callback returns non-zero the transaction will become a rollback.
|
||||||
|
//
|
||||||
|
// If there is an existing commit hook for this connection, it will be
|
||||||
|
// removed. If callback is nil the existing hook (if any) will be removed
|
||||||
|
// without creating a new one.
|
||||||
|
func (c *SQLiteConn) RegisterCommitHook(callback func() int) {
|
||||||
|
if callback == nil {
|
||||||
|
C.sqlite3_commit_hook(c.db, nil, nil)
|
||||||
|
} else {
|
||||||
|
C.sqlite3_commit_hook(c.db, (*[0]byte)(unsafe.Pointer(C.commitHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRollbackHook sets the rollback hook for a connection.
|
||||||
|
//
|
||||||
|
// If there is an existing rollback hook for this connection, it will be
|
||||||
|
// removed. If callback is nil the existing hook (if any) will be removed
|
||||||
|
// without creating a new one.
|
||||||
|
func (c *SQLiteConn) RegisterRollbackHook(callback func()) {
|
||||||
|
if callback == nil {
|
||||||
|
C.sqlite3_rollback_hook(c.db, nil, nil)
|
||||||
|
} else {
|
||||||
|
C.sqlite3_rollback_hook(c.db, (*[0]byte)(unsafe.Pointer(C.rollbackHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterUpdateHook sets the update hook for a connection.
|
||||||
|
//
|
||||||
|
// The parameters to the callback are the operation (one of the constants
|
||||||
|
// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the
|
||||||
|
// table name, and the rowid.
|
||||||
|
//
|
||||||
|
// If there is an existing update hook for this connection, it will be
|
||||||
|
// removed. If callback is nil the existing hook (if any) will be removed
|
||||||
|
// without creating a new one.
|
||||||
|
func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) {
|
||||||
|
if callback == nil {
|
||||||
|
C.sqlite3_update_hook(c.db, nil, nil)
|
||||||
|
} else {
|
||||||
|
C.sqlite3_update_hook(c.db, (*[0]byte)(unsafe.Pointer(C.updateHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterFunc makes a Go function available as a SQLite function.
|
// RegisterFunc makes a Go function available as a SQLite function.
|
||||||
//
|
//
|
||||||
// The Go function can have arguments of the following types: any
|
// The Go function can have arguments of the following types: any
|
||||||
|
@ -568,6 +626,8 @@ func errorString(err Error) string {
|
||||||
// "deferred", "exclusive".
|
// "deferred", "exclusive".
|
||||||
// _foreign_keys=X
|
// _foreign_keys=X
|
||||||
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
|
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
|
||||||
|
// _recursive_triggers=X
|
||||||
|
// Enable or disable recursive triggers. X can be 1 or 0.
|
||||||
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
if C.sqlite3_threadsafe() == 0 {
|
if C.sqlite3_threadsafe() == 0 {
|
||||||
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
||||||
|
@ -577,6 +637,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
txlock := "BEGIN"
|
txlock := "BEGIN"
|
||||||
busyTimeout := 5000
|
busyTimeout := 5000
|
||||||
foreignKeys := -1
|
foreignKeys := -1
|
||||||
|
recursiveTriggers := -1
|
||||||
pos := strings.IndexRune(dsn, '?')
|
pos := strings.IndexRune(dsn, '?')
|
||||||
if pos >= 1 {
|
if pos >= 1 {
|
||||||
params, err := url.ParseQuery(dsn[pos+1:])
|
params, err := url.ParseQuery(dsn[pos+1:])
|
||||||
|
@ -631,6 +692,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// _recursive_triggers
|
||||||
|
if val := params.Get("_recursive_triggers"); val != "" {
|
||||||
|
switch val {
|
||||||
|
case "1":
|
||||||
|
recursiveTriggers = 1
|
||||||
|
case "0":
|
||||||
|
recursiveTriggers = 0
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Invalid _recursive_triggers: %v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(dsn, "file:") {
|
if !strings.HasPrefix(dsn, "file:") {
|
||||||
dsn = dsn[:pos]
|
dsn = dsn[:pos]
|
||||||
}
|
}
|
||||||
|
@ -677,6 +750,17 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if recursiveTriggers == 0 {
|
||||||
|
if err := exec("PRAGMA recursive_triggers = OFF;"); err != nil {
|
||||||
|
C.sqlite3_close_v2(db)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else if recursiveTriggers == 1 {
|
||||||
|
if err := exec("PRAGMA recursive_triggers = ON;"); err != nil {
|
||||||
|
C.sqlite3_close_v2(db)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
|
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
|
||||||
|
|
||||||
|
@ -704,11 +788,22 @@ func (c *SQLiteConn) Close() error {
|
||||||
return c.lastError()
|
return c.lastError()
|
||||||
}
|
}
|
||||||
deleteHandles(c)
|
deleteHandles(c)
|
||||||
|
c.mu.Lock()
|
||||||
c.db = nil
|
c.db = nil
|
||||||
|
c.mu.Unlock()
|
||||||
runtime.SetFinalizer(c, nil)
|
runtime.SetFinalizer(c, nil)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *SQLiteConn) dbConnOpen() bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.db != nil
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare the query string. Return a new statement.
|
// Prepare the query string. Return a new statement.
|
||||||
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
|
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
|
||||||
return c.prepare(context.Background(), query)
|
return c.prepare(context.Background(), query)
|
||||||
|
@ -734,14 +829,17 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
|
||||||
|
|
||||||
// Close the statement.
|
// Close the statement.
|
||||||
func (s *SQLiteStmt) Close() error {
|
func (s *SQLiteStmt) Close() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
if s.closed {
|
if s.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
s.closed = true
|
s.closed = true
|
||||||
if s.c == nil || s.c.db == nil {
|
if !s.c.dbConnOpen() {
|
||||||
return errors.New("sqlite statement with already closed database connection")
|
return errors.New("sqlite statement with already closed database connection")
|
||||||
}
|
}
|
||||||
rv := C.sqlite3_finalize(s.s)
|
rv := C.sqlite3_finalize(s.s)
|
||||||
|
s.s = nil
|
||||||
if rv != C.SQLITE_OK {
|
if rv != C.SQLITE_OK {
|
||||||
return s.c.lastError()
|
return s.c.lastError()
|
||||||
}
|
}
|
||||||
|
@ -759,6 +857,8 @@ type bindArg struct {
|
||||||
v driver.Value
|
v driver.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var placeHolder = []byte{0}
|
||||||
|
|
||||||
func (s *SQLiteStmt) bind(args []namedValue) error {
|
func (s *SQLiteStmt) bind(args []namedValue) error {
|
||||||
rv := C.sqlite3_reset(s.s)
|
rv := C.sqlite3_reset(s.s)
|
||||||
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
||||||
|
@ -780,8 +880,7 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
|
||||||
rv = C.sqlite3_bind_null(s.s, n)
|
rv = C.sqlite3_bind_null(s.s, n)
|
||||||
case string:
|
case string:
|
||||||
if len(v) == 0 {
|
if len(v) == 0 {
|
||||||
b := []byte{0}
|
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
|
||||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
|
|
||||||
} else {
|
} else {
|
||||||
b := []byte(v)
|
b := []byte(v)
|
||||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||||
|
@ -797,11 +896,11 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
|
||||||
case float64:
|
case float64:
|
||||||
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
|
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
|
||||||
case []byte:
|
case []byte:
|
||||||
if len(v) == 0 {
|
ln := len(v)
|
||||||
rv = C._sqlite3_bind_blob(s.s, n, nil, 0)
|
if ln == 0 {
|
||||||
} else {
|
v = placeHolder
|
||||||
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v)))
|
|
||||||
}
|
}
|
||||||
|
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
|
||||||
case time.Time:
|
case time.Time:
|
||||||
b := []byte(v.Format(SQLiteTimestampFormats[0]))
|
b := []byte(v.Format(SQLiteTimestampFormats[0]))
|
||||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||||
|
@ -836,6 +935,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
|
||||||
cols: nil,
|
cols: nil,
|
||||||
decltype: nil,
|
decltype: nil,
|
||||||
cls: s.cls,
|
cls: s.cls,
|
||||||
|
closed: false,
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -908,25 +1008,33 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
|
||||||
|
|
||||||
// Close the rows.
|
// Close the rows.
|
||||||
func (rc *SQLiteRows) Close() error {
|
func (rc *SQLiteRows) Close() error {
|
||||||
if rc.s.closed {
|
rc.s.mu.Lock()
|
||||||
|
if rc.s.closed || rc.closed {
|
||||||
|
rc.s.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
rc.closed = true
|
||||||
if rc.done != nil {
|
if rc.done != nil {
|
||||||
close(rc.done)
|
close(rc.done)
|
||||||
}
|
}
|
||||||
if rc.cls {
|
if rc.cls {
|
||||||
|
rc.s.mu.Unlock()
|
||||||
return rc.s.Close()
|
return rc.s.Close()
|
||||||
}
|
}
|
||||||
rv := C.sqlite3_reset(rc.s.s)
|
rv := C.sqlite3_reset(rc.s.s)
|
||||||
if rv != C.SQLITE_OK {
|
if rv != C.SQLITE_OK {
|
||||||
|
rc.s.mu.Unlock()
|
||||||
return rc.s.c.lastError()
|
return rc.s.c.lastError()
|
||||||
}
|
}
|
||||||
|
rc.s.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Columns return column names.
|
// Columns return column names.
|
||||||
func (rc *SQLiteRows) Columns() []string {
|
func (rc *SQLiteRows) Columns() []string {
|
||||||
if rc.nc != len(rc.cols) {
|
rc.s.mu.Lock()
|
||||||
|
defer rc.s.mu.Unlock()
|
||||||
|
if rc.s.s != nil && rc.nc != len(rc.cols) {
|
||||||
rc.cols = make([]string, rc.nc)
|
rc.cols = make([]string, rc.nc)
|
||||||
for i := 0; i < rc.nc; i++ {
|
for i := 0; i < rc.nc; i++ {
|
||||||
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
|
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
|
||||||
|
@ -935,9 +1043,8 @@ func (rc *SQLiteRows) Columns() []string {
|
||||||
return rc.cols
|
return rc.cols
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeclTypes return column types.
|
func (rc *SQLiteRows) declTypes() []string {
|
||||||
func (rc *SQLiteRows) DeclTypes() []string {
|
if rc.s.s != nil && rc.decltype == nil {
|
||||||
if rc.decltype == nil {
|
|
||||||
rc.decltype = make([]string, rc.nc)
|
rc.decltype = make([]string, rc.nc)
|
||||||
for i := 0; i < rc.nc; i++ {
|
for i := 0; i < rc.nc; i++ {
|
||||||
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
|
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
|
||||||
|
@ -946,8 +1053,20 @@ func (rc *SQLiteRows) DeclTypes() []string {
|
||||||
return rc.decltype
|
return rc.decltype
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeclTypes return column types.
|
||||||
|
func (rc *SQLiteRows) DeclTypes() []string {
|
||||||
|
rc.s.mu.Lock()
|
||||||
|
defer rc.s.mu.Unlock()
|
||||||
|
return rc.declTypes()
|
||||||
|
}
|
||||||
|
|
||||||
// Next move cursor to next.
|
// Next move cursor to next.
|
||||||
func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||||
|
if rc.s.closed {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
rc.s.mu.Lock()
|
||||||
|
defer rc.s.mu.Unlock()
|
||||||
rv := C.sqlite3_step(rc.s.s)
|
rv := C.sqlite3_step(rc.s.s)
|
||||||
if rv == C.SQLITE_DONE {
|
if rv == C.SQLITE_DONE {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
|
@ -960,7 +1079,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rc.DeclTypes()
|
rc.declTypes()
|
||||||
|
|
||||||
for i := range dest {
|
for i := range dest {
|
||||||
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
|
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
|
||||||
|
@ -973,10 +1092,11 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||||
// large to be a reasonable timestamp in seconds.
|
// large to be a reasonable timestamp in seconds.
|
||||||
if val > 1e12 || val < -1e12 {
|
if val > 1e12 || val < -1e12 {
|
||||||
val *= int64(time.Millisecond) // convert ms to nsec
|
val *= int64(time.Millisecond) // convert ms to nsec
|
||||||
|
t = time.Unix(0, val)
|
||||||
} else {
|
} else {
|
||||||
val *= int64(time.Second) // convert sec to nsec
|
t = time.Unix(val, 0)
|
||||||
}
|
}
|
||||||
t = time.Unix(0, val).UTC()
|
t = t.UTC()
|
||||||
if rc.s.c.loc != nil {
|
if rc.s.c.loc != nil {
|
||||||
t = t.In(rc.s.c.loc)
|
t = t.In(rc.s.c.loc)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,9 +8,13 @@
|
||||||
package sqlite3
|
package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNamedParams(t *testing.T) {
|
func TestNamedParams(t *testing.T) {
|
||||||
|
@ -48,3 +52,91 @@ func TestNamedParams(t *testing.T) {
|
||||||
t.Error("Failed to db.QueryRow: not matched results")
|
t.Error("Failed to db.QueryRow: not matched results")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
testTableStatements = []string{
|
||||||
|
`DROP TABLE IF EXISTS test_table`,
|
||||||
|
`
|
||||||
|
CREATE TABLE IF NOT EXISTS test_table (
|
||||||
|
key1 VARCHAR(64) PRIMARY KEY,
|
||||||
|
key_id VARCHAR(64) NOT NULL,
|
||||||
|
key2 VARCHAR(64) NOT NULL,
|
||||||
|
key3 VARCHAR(64) NOT NULL,
|
||||||
|
key4 VARCHAR(64) NOT NULL,
|
||||||
|
key5 VARCHAR(64) NOT NULL,
|
||||||
|
key6 VARCHAR(64) NOT NULL,
|
||||||
|
data BLOB NOT NULL
|
||||||
|
);`,
|
||||||
|
}
|
||||||
|
letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
)
|
||||||
|
|
||||||
|
func randStringBytes(n int) string {
|
||||||
|
b := make([]byte, n)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func initDatabase(t *testing.T, db *sql.DB, rowCount int64) {
|
||||||
|
t.Logf("Executing db initializing statements")
|
||||||
|
for _, query := range testTableStatements {
|
||||||
|
_, err := db.Exec(query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for i := int64(0); i < rowCount; i++ {
|
||||||
|
query := `INSERT INTO test_table
|
||||||
|
(key1, key_id, key2, key3, key4, key5, key6, data)
|
||||||
|
VALUES
|
||||||
|
(?, ?, ?, ?, ?, ?, ?, ?);`
|
||||||
|
args := []interface{}{
|
||||||
|
randStringBytes(50),
|
||||||
|
fmt.Sprint(i),
|
||||||
|
randStringBytes(50),
|
||||||
|
randStringBytes(50),
|
||||||
|
randStringBytes(50),
|
||||||
|
randStringBytes(50),
|
||||||
|
randStringBytes(50),
|
||||||
|
randStringBytes(50),
|
||||||
|
randStringBytes(2048),
|
||||||
|
}
|
||||||
|
_, err := db.Exec(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShortTimeout(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", "file::memory:?mode=memory&cache=shared")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
initDatabase(t, db, 10000)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Microsecond)
|
||||||
|
defer cancel()
|
||||||
|
query := `SELECT key1, key_id, key2, key3, key4, key5, key6, data
|
||||||
|
FROM test_table
|
||||||
|
ORDER BY key2 ASC`
|
||||||
|
rows, err := db.QueryContext(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var key1, keyid, key2, key3, key4, key5, key6 string
|
||||||
|
var data []byte
|
||||||
|
err = rows.Scan(&key1, &keyid, &key2, &key3, &key4, &key5, &key6, &data)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if context.DeadlineExceeded != ctx.Err() {
|
||||||
|
t.Fatal(ctx.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -10,5 +10,6 @@ package sqlite3
|
||||||
#cgo CFLAGS: -DUSE_LIBSQLITE3
|
#cgo CFLAGS: -DUSE_LIBSQLITE3
|
||||||
#cgo linux LDFLAGS: -lsqlite3
|
#cgo linux LDFLAGS: -lsqlite3
|
||||||
#cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3
|
#cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3
|
||||||
|
#cgo solaris LDFLAGS: -lsqlite3
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
|
@ -9,5 +9,6 @@ package sqlite3
|
||||||
/*
|
/*
|
||||||
#cgo CFLAGS: -I.
|
#cgo CFLAGS: -I.
|
||||||
#cgo linux LDFLAGS: -ldl
|
#cgo linux LDFLAGS: -ldl
|
||||||
|
#cgo solaris LDFLAGS: -lc
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
582
sqlite3_test.go
582
sqlite3_test.go
|
@ -6,21 +6,22 @@
|
||||||
package sqlite3
|
package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/mattn/go-sqlite3/sqlite3_test"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TempFilename(t *testing.T) string {
|
func TempFilename(t *testing.T) string {
|
||||||
|
@ -136,6 +137,35 @@ func TestForeignKeys(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRecursiveTriggers(t *testing.T) {
|
||||||
|
cases := map[string]bool{
|
||||||
|
"?_recursive_triggers=1": true,
|
||||||
|
"?_recursive_triggers=0": false,
|
||||||
|
}
|
||||||
|
for option, want := range cases {
|
||||||
|
fname := TempFilename(t)
|
||||||
|
uri := "file:" + fname + option
|
||||||
|
db, err := sql.Open("sqlite3", uri)
|
||||||
|
if err != nil {
|
||||||
|
os.Remove(fname)
|
||||||
|
t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var enabled bool
|
||||||
|
err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled)
|
||||||
|
db.Close()
|
||||||
|
os.Remove(fname)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("query recursive_triggers for %s: %v", uri, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if enabled != want {
|
||||||
|
t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClose(t *testing.T) {
|
func TestClose(t *testing.T) {
|
||||||
tempFilename := TempFilename(t)
|
tempFilename := TempFilename(t)
|
||||||
defer os.Remove(tempFilename)
|
defer os.Remove(tempFilename)
|
||||||
|
@ -403,6 +433,7 @@ func TestTimestamp(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"nonsense", time.Time{}},
|
{"nonsense", time.Time{}},
|
||||||
{"0000-00-00 00:00:00", time.Time{}},
|
{"0000-00-00 00:00:00", time.Time{}},
|
||||||
|
{time.Time{}.Unix(), time.Time{}},
|
||||||
{timestamp1, timestamp1},
|
{timestamp1, timestamp1},
|
||||||
{timestamp2.Unix(), timestamp2.Truncate(time.Second)},
|
{timestamp2.Unix(), timestamp2.Truncate(time.Second)},
|
||||||
{timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)},
|
{timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)},
|
||||||
|
@ -840,18 +871,6 @@ func TestTimezoneConversion(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSuite(t *testing.T) {
|
|
||||||
tempFilename := TempFilename(t)
|
|
||||||
defer os.Remove(tempFilename)
|
|
||||||
db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Execer & Queryer currently disabled
|
// TODO: Execer & Queryer currently disabled
|
||||||
// https://github.com/mattn/go-sqlite3/issues/82
|
// https://github.com/mattn/go-sqlite3/issues/82
|
||||||
func TestExecer(t *testing.T) {
|
func TestExecer(t *testing.T) {
|
||||||
|
@ -1385,6 +1404,122 @@ func TestPinger(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateAndTransactionHooks(t *testing.T) {
|
||||||
|
var events []string
|
||||||
|
var commitHookReturn = 0
|
||||||
|
|
||||||
|
sql.Register("sqlite3_UpdateHook", &SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
conn.RegisterCommitHook(func() int {
|
||||||
|
events = append(events, "commit")
|
||||||
|
return commitHookReturn
|
||||||
|
})
|
||||||
|
conn.RegisterRollbackHook(func() {
|
||||||
|
events = append(events, "rollback")
|
||||||
|
})
|
||||||
|
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
|
||||||
|
events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid))
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
db, err := sql.Open("sqlite3_UpdateHook", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
statements := []string{
|
||||||
|
"create table foo (id integer primary key)",
|
||||||
|
"insert into foo values (9)",
|
||||||
|
"update foo set id = 99 where id = 9",
|
||||||
|
"delete from foo where id = 99",
|
||||||
|
}
|
||||||
|
for _, statement := range statements {
|
||||||
|
_, err = db.Exec(statement)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to prepare test data [%v]: %v", statement, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
commitHookReturn = 1
|
||||||
|
_, err = db.Exec("insert into foo values (5)")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Commit hook failed to rollback transaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
var expected = []string{
|
||||||
|
"commit",
|
||||||
|
fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT),
|
||||||
|
"commit",
|
||||||
|
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE),
|
||||||
|
"commit",
|
||||||
|
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE),
|
||||||
|
"commit",
|
||||||
|
fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT),
|
||||||
|
"commit",
|
||||||
|
"rollback",
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(events, expected) {
|
||||||
|
t.Errorf("Expected notifications %v but got %v", expected, events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNilAndEmptyBytes(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
actualNil := []byte("use this to use an actual nil not a reference to nil")
|
||||||
|
emptyBytes := []byte{}
|
||||||
|
for tsti, tst := range []struct {
|
||||||
|
name string
|
||||||
|
columnType string
|
||||||
|
insertBytes []byte
|
||||||
|
expectedBytes []byte
|
||||||
|
}{
|
||||||
|
{"actual nil blob", "blob", actualNil, nil},
|
||||||
|
{"referenced nil blob", "blob", nil, nil},
|
||||||
|
{"empty blob", "blob", emptyBytes, emptyBytes},
|
||||||
|
{"actual nil text", "text", actualNil, nil},
|
||||||
|
{"referenced nil text", "text", nil, nil},
|
||||||
|
{"empty text", "text", emptyBytes, emptyBytes},
|
||||||
|
} {
|
||||||
|
if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil {
|
||||||
|
t.Fatal(tst.name, err)
|
||||||
|
}
|
||||||
|
if bytes.Equal(tst.insertBytes, actualNil) {
|
||||||
|
if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil {
|
||||||
|
t.Fatal(tst.name, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil {
|
||||||
|
t.Fatal(tst.name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(tst.name, err)
|
||||||
|
}
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Fatal(tst.name, "no rows")
|
||||||
|
}
|
||||||
|
var scanBytes []byte
|
||||||
|
if err = rows.Scan(&scanBytes); err != nil {
|
||||||
|
t.Fatal(tst.name, err)
|
||||||
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
t.Fatal(tst.name, err)
|
||||||
|
}
|
||||||
|
if tst.expectedBytes == nil && scanBytes != nil {
|
||||||
|
t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
|
||||||
|
} else if !bytes.Equal(scanBytes, tst.expectedBytes) {
|
||||||
|
t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var customFunctionOnce sync.Once
|
var customFunctionOnce sync.Once
|
||||||
|
|
||||||
func BenchmarkCustomFunctions(b *testing.B) {
|
func BenchmarkCustomFunctions(b *testing.B) {
|
||||||
|
@ -1419,3 +1554,422 @@ func BenchmarkCustomFunctions(b *testing.B) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSuite(t *testing.T) {
|
||||||
|
tempFilename := TempFilename(t)
|
||||||
|
defer os.Remove(tempFilename)
|
||||||
|
d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer d.Close()
|
||||||
|
|
||||||
|
db = &TestDB{t, d, SQLITE, sync.Once{}}
|
||||||
|
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
|
||||||
|
|
||||||
|
if !testing.Short() {
|
||||||
|
for _, b := range benchmarks {
|
||||||
|
fmt.Printf("%-20s", b.Name)
|
||||||
|
r := testing.Benchmark(b.F)
|
||||||
|
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
db.tearDown()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dialect is a type of dialect of databases.
|
||||||
|
type Dialect int
|
||||||
|
|
||||||
|
// Dialects for databases.
|
||||||
|
const (
|
||||||
|
SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
|
||||||
|
POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
|
||||||
|
MYSQL // MYSQL mean MySQL dialect
|
||||||
|
)
|
||||||
|
|
||||||
|
// DB provide context for the tests
|
||||||
|
type TestDB struct {
|
||||||
|
*testing.T
|
||||||
|
*sql.DB
|
||||||
|
dialect Dialect
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
var db *TestDB
|
||||||
|
|
||||||
|
// the following tables will be created and dropped during the test
|
||||||
|
var testTables = []string{"foo", "bar", "t", "bench"}
|
||||||
|
|
||||||
|
var tests = []testing.InternalTest{
|
||||||
|
{Name: "TestResult", F: testResult},
|
||||||
|
{Name: "TestBlobs", F: testBlobs},
|
||||||
|
{Name: "TestManyQueryRow", F: testManyQueryRow},
|
||||||
|
{Name: "TestTxQuery", F: testTxQuery},
|
||||||
|
{Name: "TestPreparedStmt", F: testPreparedStmt},
|
||||||
|
}
|
||||||
|
|
||||||
|
var benchmarks = []testing.InternalBenchmark{
|
||||||
|
{Name: "BenchmarkExec", F: benchmarkExec},
|
||||||
|
{Name: "BenchmarkQuery", F: benchmarkQuery},
|
||||||
|
{Name: "BenchmarkParams", F: benchmarkParams},
|
||||||
|
{Name: "BenchmarkStmt", F: benchmarkStmt},
|
||||||
|
{Name: "BenchmarkRows", F: benchmarkRows},
|
||||||
|
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result {
|
||||||
|
res, err := db.Exec(sql, args...)
|
||||||
|
if err != nil {
|
||||||
|
db.Fatalf("Error running %q: %v", sql, err)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *TestDB) tearDown() {
|
||||||
|
for _, tbl := range testTables {
|
||||||
|
switch db.dialect {
|
||||||
|
case SQLITE:
|
||||||
|
db.mustExec("drop table if exists " + tbl)
|
||||||
|
case MYSQL, POSTGRESQL:
|
||||||
|
db.mustExec("drop table if exists " + tbl)
|
||||||
|
default:
|
||||||
|
db.Fatal("unknown dialect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// q replaces ? parameters if needed
|
||||||
|
func (db *TestDB) q(sql string) string {
|
||||||
|
switch db.dialect {
|
||||||
|
case POSTGRESQL: // repace with $1, $2, ..
|
||||||
|
qrx := regexp.MustCompile(`\?`)
|
||||||
|
n := 0
|
||||||
|
return qrx.ReplaceAllStringFunc(sql, func(string) string {
|
||||||
|
n++
|
||||||
|
return "$" + strconv.Itoa(n)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *TestDB) blobType(size int) string {
|
||||||
|
switch db.dialect {
|
||||||
|
case SQLITE:
|
||||||
|
return fmt.Sprintf("blob[%d]", size)
|
||||||
|
case POSTGRESQL:
|
||||||
|
return "bytea"
|
||||||
|
case MYSQL:
|
||||||
|
return fmt.Sprintf("VARBINARY(%d)", size)
|
||||||
|
}
|
||||||
|
panic("unknown dialect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *TestDB) serialPK() string {
|
||||||
|
switch db.dialect {
|
||||||
|
case SQLITE:
|
||||||
|
return "integer primary key autoincrement"
|
||||||
|
case POSTGRESQL:
|
||||||
|
return "serial primary key"
|
||||||
|
case MYSQL:
|
||||||
|
return "integer primary key auto_increment"
|
||||||
|
}
|
||||||
|
panic("unknown dialect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *TestDB) now() string {
|
||||||
|
switch db.dialect {
|
||||||
|
case SQLITE:
|
||||||
|
return "datetime('now')"
|
||||||
|
case POSTGRESQL:
|
||||||
|
return "now()"
|
||||||
|
case MYSQL:
|
||||||
|
return "now()"
|
||||||
|
}
|
||||||
|
panic("unknown dialect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeBench() {
|
||||||
|
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer st.Close()
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testResult is test for result
|
||||||
|
func testResult(t *testing.T) {
|
||||||
|
db.tearDown()
|
||||||
|
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
|
||||||
|
|
||||||
|
for i := 1; i < 3; i++ {
|
||||||
|
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
|
||||||
|
n, err := r.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Errorf("got %v, want %v", n, 1)
|
||||||
|
}
|
||||||
|
n, err = r.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if n != int64(i) {
|
||||||
|
t.Errorf("got %v, want %v", n, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := db.Exec("error!"); err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testBlobs is test for blobs
|
||||||
|
func testBlobs(t *testing.T) {
|
||||||
|
db.tearDown()
|
||||||
|
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||||
|
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
|
||||||
|
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
|
||||||
|
|
||||||
|
want := fmt.Sprintf("%x", blob)
|
||||||
|
|
||||||
|
b := make([]byte, 16)
|
||||||
|
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
|
||||||
|
got := fmt.Sprintf("%x", b)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("[]byte scan: %v", err)
|
||||||
|
} else if got != want {
|
||||||
|
t.Errorf("for []byte, got %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
|
||||||
|
want = string(blob)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("string scan: %v", err)
|
||||||
|
} else if got != want {
|
||||||
|
t.Errorf("for string, got %q; want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testManyQueryRow is test for many query row
|
||||||
|
func testManyQueryRow(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Log("skipping in short mode")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
db.tearDown()
|
||||||
|
db.mustExec("create table foo (id integer primary key, name varchar(50))")
|
||||||
|
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||||
|
var name string
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
|
||||||
|
if err != nil || name != "bob" {
|
||||||
|
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testTxQuery is test for transactional query
|
||||||
|
func testTxQuery(t *testing.T) {
|
||||||
|
db.tearDown()
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
if !r.Next() {
|
||||||
|
if r.Err() != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Fatal("expected one rows")
|
||||||
|
}
|
||||||
|
|
||||||
|
var name string
|
||||||
|
err = r.Scan(&name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testPreparedStmt is test for prepared statement
|
||||||
|
func testPreparedStmt(t *testing.T) {
|
||||||
|
db.tearDown()
|
||||||
|
db.mustExec("CREATE TABLE t (count INT)")
|
||||||
|
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("prepare 1: %v", err)
|
||||||
|
}
|
||||||
|
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("prepare 2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 1; n <= 3; n++ {
|
||||||
|
if _, err := ins.Exec(n); err != nil {
|
||||||
|
t.Fatalf("insert(%d) = %v", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const nRuns = 10
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < nRuns; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 10; j++ {
|
||||||
|
count := 0
|
||||||
|
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
|
||||||
|
t.Errorf("Query: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := ins.Exec(rand.Intn(100)); err != nil {
|
||||||
|
t.Errorf("Insert: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmarks need to use panic() since b.Error errors are lost when
|
||||||
|
// running via testing.Benchmark() I would like to run these via go
|
||||||
|
// test -bench but calling Benchmark() from a benchmark test
|
||||||
|
// currently hangs go.
|
||||||
|
|
||||||
|
// benchmarkExec is benchmark for exec
|
||||||
|
func benchmarkExec(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := db.Exec("select 1"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// benchmarkQuery is benchmark for query
|
||||||
|
func benchmarkQuery(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
var n sql.NullString
|
||||||
|
var i int
|
||||||
|
var f float64
|
||||||
|
var s string
|
||||||
|
// var t time.Time
|
||||||
|
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// benchmarkParams is benchmark for params
|
||||||
|
func benchmarkParams(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
var n sql.NullString
|
||||||
|
var i int
|
||||||
|
var f float64
|
||||||
|
var s string
|
||||||
|
// var t time.Time
|
||||||
|
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// benchmarkStmt is benchmark for statement
|
||||||
|
func benchmarkStmt(b *testing.B) {
|
||||||
|
st, err := db.Prepare("select ?, ?, ?, ?")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer st.Close()
|
||||||
|
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
var n sql.NullString
|
||||||
|
var i int
|
||||||
|
var f float64
|
||||||
|
var s string
|
||||||
|
// var t time.Time
|
||||||
|
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// benchmarkRows is benchmark for rows
|
||||||
|
func benchmarkRows(b *testing.B) {
|
||||||
|
db.once.Do(makeBench)
|
||||||
|
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
var n sql.NullString
|
||||||
|
var i int
|
||||||
|
var f float64
|
||||||
|
var s string
|
||||||
|
var t time.Time
|
||||||
|
r, err := db.Query("select * from bench")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
for r.Next() {
|
||||||
|
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = r.Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// benchmarkStmtRows is benchmark for statement rows
|
||||||
|
func benchmarkStmtRows(b *testing.B) {
|
||||||
|
db.once.Do(makeBench)
|
||||||
|
|
||||||
|
st, err := db.Prepare("select * from bench")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer st.Close()
|
||||||
|
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
var n sql.NullString
|
||||||
|
var i int
|
||||||
|
var f float64
|
||||||
|
var s string
|
||||||
|
var t time.Time
|
||||||
|
r, err := st.Query()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
for r.Next() {
|
||||||
|
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = r.Err(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,423 +0,0 @@
|
||||||
package sqlite3_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Dialect is a type of dialect of databases.
|
|
||||||
type Dialect int
|
|
||||||
|
|
||||||
// Dialects for databases.
|
|
||||||
const (
|
|
||||||
SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
|
|
||||||
POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
|
|
||||||
MYSQL // MYSQL mean MySQL dialect
|
|
||||||
)
|
|
||||||
|
|
||||||
// DB provide context for the tests
|
|
||||||
type DB struct {
|
|
||||||
*testing.T
|
|
||||||
*sql.DB
|
|
||||||
dialect Dialect
|
|
||||||
once sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
var db *DB
|
|
||||||
|
|
||||||
// the following tables will be created and dropped during the test
|
|
||||||
var testTables = []string{"foo", "bar", "t", "bench"}
|
|
||||||
|
|
||||||
var tests = []testing.InternalTest{
|
|
||||||
{Name: "TestBlobs", F: TestBlobs},
|
|
||||||
{Name: "TestManyQueryRow", F: TestManyQueryRow},
|
|
||||||
{Name: "TestTxQuery", F: TestTxQuery},
|
|
||||||
{Name: "TestPreparedStmt", F: TestPreparedStmt},
|
|
||||||
}
|
|
||||||
|
|
||||||
var benchmarks = []testing.InternalBenchmark{
|
|
||||||
{Name: "BenchmarkExec", F: BenchmarkExec},
|
|
||||||
{Name: "BenchmarkQuery", F: BenchmarkQuery},
|
|
||||||
{Name: "BenchmarkParams", F: BenchmarkParams},
|
|
||||||
{Name: "BenchmarkStmt", F: BenchmarkStmt},
|
|
||||||
{Name: "BenchmarkRows", F: BenchmarkRows},
|
|
||||||
{Name: "BenchmarkStmtRows", F: BenchmarkStmtRows},
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunTests runs the SQL test suite
|
|
||||||
func RunTests(t *testing.T, d *sql.DB, dialect Dialect) {
|
|
||||||
db = &DB{t, d, dialect, sync.Once{}}
|
|
||||||
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
|
|
||||||
|
|
||||||
if !testing.Short() {
|
|
||||||
for _, b := range benchmarks {
|
|
||||||
fmt.Printf("%-20s", b.Name)
|
|
||||||
r := testing.Benchmark(b.F)
|
|
||||||
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
db.tearDown()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) mustExec(sql string, args ...interface{}) sql.Result {
|
|
||||||
res, err := db.Exec(sql, args...)
|
|
||||||
if err != nil {
|
|
||||||
db.Fatalf("Error running %q: %v", sql, err)
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) tearDown() {
|
|
||||||
for _, tbl := range testTables {
|
|
||||||
switch db.dialect {
|
|
||||||
case SQLITE:
|
|
||||||
db.mustExec("drop table if exists " + tbl)
|
|
||||||
case MYSQL, POSTGRESQL:
|
|
||||||
db.mustExec("drop table if exists " + tbl)
|
|
||||||
default:
|
|
||||||
db.Fatal("unknown dialect")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// q replaces ? parameters if needed
|
|
||||||
func (db *DB) q(sql string) string {
|
|
||||||
switch db.dialect {
|
|
||||||
case POSTGRESQL: // repace with $1, $2, ..
|
|
||||||
qrx := regexp.MustCompile(`\?`)
|
|
||||||
n := 0
|
|
||||||
return qrx.ReplaceAllStringFunc(sql, func(string) string {
|
|
||||||
n++
|
|
||||||
return "$" + strconv.Itoa(n)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return sql
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) blobType(size int) string {
|
|
||||||
switch db.dialect {
|
|
||||||
case SQLITE:
|
|
||||||
return fmt.Sprintf("blob[%d]", size)
|
|
||||||
case POSTGRESQL:
|
|
||||||
return "bytea"
|
|
||||||
case MYSQL:
|
|
||||||
return fmt.Sprintf("VARBINARY(%d)", size)
|
|
||||||
}
|
|
||||||
panic("unknown dialect")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) serialPK() string {
|
|
||||||
switch db.dialect {
|
|
||||||
case SQLITE:
|
|
||||||
return "integer primary key autoincrement"
|
|
||||||
case POSTGRESQL:
|
|
||||||
return "serial primary key"
|
|
||||||
case MYSQL:
|
|
||||||
return "integer primary key auto_increment"
|
|
||||||
}
|
|
||||||
panic("unknown dialect")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) now() string {
|
|
||||||
switch db.dialect {
|
|
||||||
case SQLITE:
|
|
||||||
return "datetime('now')"
|
|
||||||
case POSTGRESQL:
|
|
||||||
return "now()"
|
|
||||||
case MYSQL:
|
|
||||||
return "now()"
|
|
||||||
}
|
|
||||||
panic("unknown dialect")
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeBench() {
|
|
||||||
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
defer st.Close()
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestResult is test for result
|
|
||||||
func TestResult(t *testing.T) {
|
|
||||||
db.tearDown()
|
|
||||||
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
|
|
||||||
|
|
||||||
for i := 1; i < 3; i++ {
|
|
||||||
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
|
|
||||||
n, err := r.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if n != 1 {
|
|
||||||
t.Errorf("got %v, want %v", n, 1)
|
|
||||||
}
|
|
||||||
n, err = r.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if n != int64(i) {
|
|
||||||
t.Errorf("got %v, want %v", n, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, err := db.Exec("error!"); err == nil {
|
|
||||||
t.Fatalf("expected error")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBlobs is test for blobs
|
|
||||||
func TestBlobs(t *testing.T) {
|
|
||||||
db.tearDown()
|
|
||||||
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
|
||||||
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
|
|
||||||
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
|
|
||||||
|
|
||||||
want := fmt.Sprintf("%x", blob)
|
|
||||||
|
|
||||||
b := make([]byte, 16)
|
|
||||||
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
|
|
||||||
got := fmt.Sprintf("%x", b)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("[]byte scan: %v", err)
|
|
||||||
} else if got != want {
|
|
||||||
t.Errorf("for []byte, got %q; want %q", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
|
|
||||||
want = string(blob)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("string scan: %v", err)
|
|
||||||
} else if got != want {
|
|
||||||
t.Errorf("for string, got %q; want %q", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestManyQueryRow is test for many query row
|
|
||||||
func TestManyQueryRow(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Log("skipping in short mode")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
db.tearDown()
|
|
||||||
db.mustExec("create table foo (id integer primary key, name varchar(50))")
|
|
||||||
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
|
||||||
var name string
|
|
||||||
for i := 0; i < 10000; i++ {
|
|
||||||
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
|
|
||||||
if err != nil || name != "bob" {
|
|
||||||
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestTxQuery is test for transactional query
|
|
||||||
func TestTxQuery(t *testing.T) {
|
|
||||||
db.tearDown()
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer r.Close()
|
|
||||||
|
|
||||||
if !r.Next() {
|
|
||||||
if r.Err() != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
t.Fatal("expected one rows")
|
|
||||||
}
|
|
||||||
|
|
||||||
var name string
|
|
||||||
err = r.Scan(&name)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPreparedStmt is test for prepared statement
|
|
||||||
func TestPreparedStmt(t *testing.T) {
|
|
||||||
db.tearDown()
|
|
||||||
db.mustExec("CREATE TABLE t (count INT)")
|
|
||||||
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("prepare 1: %v", err)
|
|
||||||
}
|
|
||||||
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("prepare 2: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for n := 1; n <= 3; n++ {
|
|
||||||
if _, err := ins.Exec(n); err != nil {
|
|
||||||
t.Fatalf("insert(%d) = %v", n, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const nRuns = 10
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < nRuns; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
for j := 0; j < 10; j++ {
|
|
||||||
count := 0
|
|
||||||
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
|
|
||||||
t.Errorf("Query: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err := ins.Exec(rand.Intn(100)); err != nil {
|
|
||||||
t.Errorf("Insert: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Benchmarks need to use panic() since b.Error errors are lost when
|
|
||||||
// running via testing.Benchmark() I would like to run these via go
|
|
||||||
// test -bench but calling Benchmark() from a benchmark test
|
|
||||||
// currently hangs go.
|
|
||||||
|
|
||||||
// BenchmarkExec is benchmark for exec
|
|
||||||
func BenchmarkExec(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
if _, err := db.Exec("select 1"); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkQuery is benchmark for query
|
|
||||||
func BenchmarkQuery(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
var n sql.NullString
|
|
||||||
var i int
|
|
||||||
var f float64
|
|
||||||
var s string
|
|
||||||
// var t time.Time
|
|
||||||
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkParams is benchmark for params
|
|
||||||
func BenchmarkParams(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
var n sql.NullString
|
|
||||||
var i int
|
|
||||||
var f float64
|
|
||||||
var s string
|
|
||||||
// var t time.Time
|
|
||||||
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkStmt is benchmark for statement
|
|
||||||
func BenchmarkStmt(b *testing.B) {
|
|
||||||
st, err := db.Prepare("select ?, ?, ?, ?")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
defer st.Close()
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
var n sql.NullString
|
|
||||||
var i int
|
|
||||||
var f float64
|
|
||||||
var s string
|
|
||||||
// var t time.Time
|
|
||||||
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkRows is benchmark for rows
|
|
||||||
func BenchmarkRows(b *testing.B) {
|
|
||||||
db.once.Do(makeBench)
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
var n sql.NullString
|
|
||||||
var i int
|
|
||||||
var f float64
|
|
||||||
var s string
|
|
||||||
var t time.Time
|
|
||||||
r, err := db.Query("select * from bench")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
for r.Next() {
|
|
||||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err = r.Err(); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkStmtRows is benchmark for statement rows
|
|
||||||
func BenchmarkStmtRows(b *testing.B) {
|
|
||||||
db.once.Do(makeBench)
|
|
||||||
|
|
||||||
st, err := db.Prepare("select * from bench")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
defer st.Close()
|
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
|
||||||
var n sql.NullString
|
|
||||||
var i int
|
|
||||||
var f float64
|
|
||||||
var s string
|
|
||||||
var t time.Time
|
|
||||||
r, err := st.Query()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
for r.Next() {
|
|
||||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err = r.Err(); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue