From 308f5f1b2fae699cb38474136205cbab8f081019 Mon Sep 17 00:00:00 2001 From: Xu Xinran Date: Wed, 14 Jun 2017 19:55:09 +0800 Subject: [PATCH 01/22] Treat []byte{} as empty bytes instead of NULL. --- sqlite3.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index 33b9b9c..c3d5d98 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -772,11 +772,13 @@ func (s *SQLiteStmt) bind(args []namedValue) error { case float64: rv = C.sqlite3_bind_double(s.s, n, C.double(v)) case []byte: + var ptr *byte if len(v) == 0 { - rv = C._sqlite3_bind_blob(s.s, n, nil, 0) + ptr = &(make([]byte, 1)[0]) } else { - rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v))) + ptr = &v[0] } + rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(ptr), C.int(len(v))) case time.Time: b := []byte(v.Format(SQLiteTimestampFormats[0])) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) From 3fa7ed217682f0f66953e3c37d7a648bbc5d3e5c Mon Sep 17 00:00:00 2001 From: Xu Xinran Date: Wed, 14 Jun 2017 21:22:40 +0800 Subject: [PATCH 02/22] Use global variable for better performance. --- sqlite3.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index c3d5d98..56e55e2 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -734,6 +734,8 @@ type bindArg struct { v driver.Value } +var placeHolder byte = 0 + func (s *SQLiteStmt) bind(args []namedValue) error { rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { @@ -755,8 +757,7 @@ func (s *SQLiteStmt) bind(args []namedValue) error { rv = C.sqlite3_bind_null(s.s, n) case string: if len(v) == 0 { - b := []byte{0} - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0)) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder)), C.int(0)) } else { b := []byte(v) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) @@ -774,7 +775,7 @@ func (s *SQLiteStmt) bind(args []namedValue) error { case []byte: var ptr *byte if len(v) == 0 { - ptr = &(make([]byte, 1)[0]) + ptr = &placeHolder } else { ptr = &v[0] } From cd1cbf523a80ebe39ccb1b829d760852e298c128 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 17 Jun 2017 12:22:09 -0700 Subject: [PATCH 03/22] Sync database-close and statement-close Potential fix for issue #426. --- sqlite3.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sqlite3.go b/sqlite3.go index 56e55e2..0cd4666 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -113,6 +113,7 @@ import ( "runtime" "strconv" "strings" + "sync" "time" "unsafe" @@ -157,6 +158,7 @@ type SQLiteDriver struct { // SQLiteConn implement sql.Conn. type SQLiteConn struct { + dbMu sync.Mutex db *C.sqlite3 loc *time.Location txlock string @@ -679,11 +681,22 @@ func (c *SQLiteConn) Close() error { return c.lastError() } deleteHandles(c) + c.dbMu.Lock() c.db = nil + c.dbMu.Unlock() runtime.SetFinalizer(c, nil) return nil } +func (c *SQLiteConn) dbConnOpen() bool { + if c == nil { + return false + } + c.dbMu.Lock() + defer c.dbMu.Unlock() + return c.db != nil +} + // Prepare the query string. Return a new statement. func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { return c.prepare(context.Background(), query) @@ -713,7 +726,7 @@ func (s *SQLiteStmt) Close() error { return nil } s.closed = true - if s.c == nil || s.c.db == nil { + if !s.c.dbConnOpen() { return errors.New("sqlite statement with already closed database connection") } rv := C.sqlite3_finalize(s.s) From ef9f773b24d2ec1611476e580c032e98068e481a Mon Sep 17 00:00:00 2001 From: Evgeniy Makeev Date: Tue, 20 Jun 2017 17:36:44 -0700 Subject: [PATCH 04/22] Fix for cgo panic, issue #428: https://github.com/mattn/go-sqlite3/issues/428 --- sqlite3.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index 0cd4666..b34c3a5 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -747,7 +747,7 @@ type bindArg struct { v driver.Value } -var placeHolder byte = 0 +var placeHolder = []byte{0} func (s *SQLiteStmt) bind(args []namedValue) error { rv := C.sqlite3_reset(s.s) @@ -770,7 +770,7 @@ func (s *SQLiteStmt) bind(args []namedValue) error { rv = C.sqlite3_bind_null(s.s, n) case string: if len(v) == 0 { - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder)), C.int(0)) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0)) } else { b := []byte(v) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) @@ -786,13 +786,10 @@ func (s *SQLiteStmt) bind(args []namedValue) error { case float64: rv = C.sqlite3_bind_double(s.s, n, C.double(v)) case []byte: - var ptr *byte if len(v) == 0 { - ptr = &placeHolder - } else { - ptr = &v[0] + v = placeHolder } - rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(ptr), C.int(len(v))) + rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v))) case time.Time: b := []byte(v.Format(SQLiteTimestampFormats[0])) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) From 05123859bed77249c3d9ca8efe6adc3cce1e1bed Mon Sep 17 00:00:00 2001 From: deepilla Date: Fri, 30 Jun 2017 13:17:04 -0500 Subject: [PATCH 05/22] Don't convert Unix times to nanoseconds when querying datetime fields. Fixes #430. --- sqlite3.go | 5 +++-- sqlite3_test.go | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index b34c3a5..d3a6407 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -961,10 +961,11 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { // large to be a reasonable timestamp in seconds. if val > 1e12 || val < -1e12 { val *= int64(time.Millisecond) // convert ms to nsec + t = time.Unix(0, val) } else { - val *= int64(time.Second) // convert sec to nsec + t = time.Unix(val, 0) } - t = time.Unix(0, val).UTC() + t = t.UTC() if rc.s.c.loc != nil { t = t.In(rc.s.c.loc) } diff --git a/sqlite3_test.go b/sqlite3_test.go index 03b678d..e563479 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -403,6 +403,7 @@ func TestTimestamp(t *testing.T) { }{ {"nonsense", time.Time{}}, {"0000-00-00 00:00:00", time.Time{}}, + {time.Time{}.Unix(), time.Time{}}, {timestamp1, timestamp1}, {timestamp2.Unix(), timestamp2.Truncate(time.Second)}, {timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)}, From 59bd281a89883d39ef219699e4a46eab87b3cff9 Mon Sep 17 00:00:00 2001 From: Jason Abbott Date: Mon, 3 Jul 2017 12:51:48 -0600 Subject: [PATCH 06/22] Incorporate original PR 271 from https://github.com/brokensandals --- _example/hook/hook.go | 6 +++++ callback.go | 18 +++++++++++++ sqlite3.go | 54 ++++++++++++++++++++++++++++++++++++++ sqlite3_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 139 insertions(+) diff --git a/_example/hook/hook.go b/_example/hook/hook.go index 17bddeb..6023181 100644 --- a/_example/hook/hook.go +++ b/_example/hook/hook.go @@ -14,6 +14,12 @@ func main() { &sqlite3.SQLiteDriver{ ConnectHook: func(conn *sqlite3.SQLiteConn) error { sqlite3conn = append(sqlite3conn, conn) + conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) { + switch op { + case sqlite3.SQLITE_INSERT: + log.Println("Notified of insert on db", db, "table", table, "rowid", rowid) + } + }) return nil }, }) diff --git a/callback.go b/callback.go index 48fc63a..6a55964 100644 --- a/callback.go +++ b/callback.go @@ -53,6 +53,24 @@ func doneTrampoline(ctx *C.sqlite3_context) { ai.Done(ctx) } +//export commitHookTrampoline +func commitHookTrampoline(handle uintptr) int { + callback := lookupHandle(handle).(func() int) + return callback() +} + +//export rollbackHookTrampoline +func rollbackHookTrampoline(handle uintptr) { + callback := lookupHandle(handle).(func()) + callback() +} + +//export updateHookTrampoline +func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) { + callback := lookupHandle(handle).(func(int, string, string, int64)) + callback(op, C.GoString(db), C.GoString(table), rowid) +} + // Use handles to avoid passing Go pointers to C. type handleVal struct { diff --git a/sqlite3.go b/sqlite3.go index d3a6407..0217cce 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -100,6 +100,9 @@ int _sqlite3_create_function( } void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); +int commitHookTrampoline(void*); +void rollbackHookTrampoline(void*); +void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64); */ import "C" import ( @@ -150,6 +153,12 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) { return libVersion, libVersionNumber, sourceID } +const ( + SQLITE_DELETE = C.SQLITE_DELETE + SQLITE_INSERT = C.SQLITE_INSERT + SQLITE_UPDATE = C.SQLITE_UPDATE +) + // SQLiteDriver implement sql.Driver. type SQLiteDriver struct { Extensions []string @@ -315,6 +324,51 @@ func (tx *SQLiteTx) Rollback() error { return err } +// RegisterCommitHook sets the commit hook for a connection. +// +// If the callback returns non-zero the transaction will become a rollback. +// +// If there is an existing commit hook for this connection, it will be +// removed. If callback is nil the existing hook (if any) will be removed +// without creating a new one. +func (c *SQLiteConn) RegisterCommitHook(callback func() int) { + if callback == nil { + C.sqlite3_commit_hook(c.db, nil, nil) + } else { + C.sqlite3_commit_hook(c.db, (*[0]byte)(unsafe.Pointer(C.commitHookTrampoline)), unsafe.Pointer(newHandle(c, callback))) + } +} + +// RegisterRollbackHook sets the rollback hook for a connection. +// +// If there is an existing rollback hook for this connection, it will be +// removed. If callback is nil the existing hook (if any) will be removed +// without creating a new one. +func (c *SQLiteConn) RegisterRollbackHook(callback func()) { + if callback == nil { + C.sqlite3_rollback_hook(c.db, nil, nil) + } else { + C.sqlite3_rollback_hook(c.db, (*[0]byte)(unsafe.Pointer(C.rollbackHookTrampoline)), unsafe.Pointer(newHandle(c, callback))) + } +} + +// RegisterUpdateHook sets the update hook for a connection. +// +// The parameters to the callback are the operation (one of the constants +// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the +// table name, and the rowid. +// +// If there is an existing update hook for this connection, it will be +// removed. If callback is nil the existing hook (if any) will be removed +// without creating a new one. +func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) { + if callback == nil { + C.sqlite3_update_hook(c.db, nil, nil) + } else { + C.sqlite3_update_hook(c.db, (*[0]byte)(unsafe.Pointer(C.updateHookTrampoline)), unsafe.Pointer(newHandle(c, callback))) + } +} + // RegisterFunc makes a Go function available as a SQLite function. // // The Go function can have arguments of the following types: any diff --git a/sqlite3_test.go b/sqlite3_test.go index e563479..f11c349 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1265,6 +1265,67 @@ func TestPinger(t *testing.T) { } } +func TestUpdateAndTransactionHooks(t *testing.T) { + var events []string + var commitHookReturn = 0 + + sql.Register("sqlite3_UpdateHook", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + conn.RegisterCommitHook(func() int { + events = append(events, "commit") + return commitHookReturn + }) + conn.RegisterRollbackHook(func() { + events = append(events, "rollback") + }) + conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) { + events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid)) + }) + return nil + }, + }) + db, err := sql.Open("sqlite3_UpdateHook", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + statements := []string{ + "create table foo (id integer primary key)", + "insert into foo values (9)", + "update foo set id = 99 where id = 9", + "delete from foo where id = 99", + } + for _, statement := range statements { + _, err = db.Exec(statement) + if err != nil { + t.Fatalf("Unable to prepare test data [%v]: %v", statement, err) + } + } + + commitHookReturn = 1 + _, err = db.Exec("insert into foo values (5)") + if err == nil { + t.Error("Commit hook failed to rollback transaction") + } + + var expected = []string{ + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT), + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE), + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE), + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT), + "commit", + "rollback", + } + if !reflect.DeepEqual(events, expected) { + t.Errorf("Expected notifications %v but got %v", expected, events) + } +} + var customFunctionOnce sync.Once func BenchmarkCustomFunctions(b *testing.B) { From acfa60124032040b9f5a9406f5a772ee16fe845e Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 5 Jul 2017 17:25:03 +0900 Subject: [PATCH 07/22] SQLITE_THREADSAFE=1 fixes #274 --- sqlite3.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlite3.go b/sqlite3.go index d3a6407..bfb4d1f 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -7,7 +7,7 @@ package sqlite3 /* #cgo CFLAGS: -std=gnu99 -#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE +#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE=1 #cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61 #cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15 #cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC From 848386d7a20cdd7b4314334b97e958c933a361ba Mon Sep 17 00:00:00 2001 From: Ross Light Date: Sun, 9 Jul 2017 07:32:14 -0700 Subject: [PATCH 08/22] Add connection option for recursive triggers Similar to foreign keys, the recursive triggers PRAGMA affects the interpretation of all statements on a connection. --- sqlite3.go | 26 ++++++++++++++++++++++++++ sqlite3_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/sqlite3.go b/sqlite3.go index 2b7b8df..2ebf7e7 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -599,6 +599,8 @@ func errorString(err Error) string { // "deferred", "exclusive". // _foreign_keys=X // Enable or disable enforcement of foreign keys. X can be 1 or 0. +// _recursive_triggers=X +// Enable or disable recursive triggers. X can be 1 or 0. func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if C.sqlite3_threadsafe() == 0 { return nil, errors.New("sqlite library was not compiled for thread-safe operation") @@ -608,6 +610,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { txlock := "BEGIN" busyTimeout := 5000 foreignKeys := -1 + recursiveTriggers := -1 pos := strings.IndexRune(dsn, '?') if pos >= 1 { params, err := url.ParseQuery(dsn[pos+1:]) @@ -662,6 +665,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { } } + // _recursive_triggers + if val := params.Get("_recursive_triggers"); val != "" { + switch val { + case "1": + recursiveTriggers = 1 + case "0": + recursiveTriggers = 0 + default: + return nil, fmt.Errorf("Invalid _recursive_triggers: %v", val) + } + } + if !strings.HasPrefix(dsn, "file:") { dsn = dsn[:pos] } @@ -708,6 +723,17 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, err } } + if recursiveTriggers == 0 { + if err := exec("PRAGMA recursive_triggers = OFF;"); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + } else if recursiveTriggers == 1 { + if err := exec("PRAGMA recursive_triggers = ON;"); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + } conn := &SQLiteConn{db: db, loc: loc, txlock: txlock} diff --git a/sqlite3_test.go b/sqlite3_test.go index f11c349..a00e622 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -136,6 +136,35 @@ func TestForeignKeys(t *testing.T) { } } +func TestRecursiveTriggers(t *testing.T) { + cases := map[string]bool{ + "?_recursive_triggers=1": true, + "?_recursive_triggers=0": false, + } + for option, want := range cases { + fname := TempFilename(t) + uri := "file:" + fname + option + db, err := sql.Open("sqlite3", uri) + if err != nil { + os.Remove(fname) + t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err) + continue + } + var enabled bool + err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled) + db.Close() + os.Remove(fname) + if err != nil { + t.Errorf("query recursive_triggers for %s: %v", uri, err) + continue + } + if enabled != want { + t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want) + continue + } + } +} + func TestClose(t *testing.T) { tempFilename := TempFilename(t) defer os.Remove(tempFilename) From 569232dc083e6dbcaab291f39bd5f78af2aedf08 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 2 Aug 2017 00:06:18 +0900 Subject: [PATCH 09/22] fix possibly double Close. fixes #448 --- sqlite3.go | 14 +++++++++----- sqlite3_test/{sqltest.go => sqlite3_test.go} | 0 2 files changed, 9 insertions(+), 5 deletions(-) rename sqlite3_test/{sqltest.go => sqlite3_test.go} (100%) diff --git a/sqlite3.go b/sqlite3.go index 2ebf7e7..3fc7354 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -167,7 +167,7 @@ type SQLiteDriver struct { // SQLiteConn implement sql.Conn. type SQLiteConn struct { - dbMu sync.Mutex + mu sync.Mutex db *C.sqlite3 loc *time.Location txlock string @@ -197,6 +197,7 @@ type SQLiteResult struct { // SQLiteRows implement sql.Rows. type SQLiteRows struct { + mu sync.Mutex s *SQLiteStmt nc int cols []string @@ -761,9 +762,9 @@ func (c *SQLiteConn) Close() error { return c.lastError() } deleteHandles(c) - c.dbMu.Lock() + c.mu.Lock() c.db = nil - c.dbMu.Unlock() + c.mu.Unlock() runtime.SetFinalizer(c, nil) return nil } @@ -772,8 +773,8 @@ func (c *SQLiteConn) dbConnOpen() bool { if c == nil { return false } - c.dbMu.Lock() - defer c.dbMu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() return c.db != nil } @@ -980,7 +981,10 @@ func (rc *SQLiteRows) Close() error { return nil } if rc.done != nil { + rc.mu.Lock() close(rc.done) + rc.done = nil + rc.mu.Unlock() } if rc.cls { return rc.s.Close() diff --git a/sqlite3_test/sqltest.go b/sqlite3_test/sqlite3_test.go similarity index 100% rename from sqlite3_test/sqltest.go rename to sqlite3_test/sqlite3_test.go From 42a4d148c2b3dd62511663f2a53ff8a416501a0b Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 2 Aug 2017 01:26:57 +0900 Subject: [PATCH 10/22] fix tests on tip --- sqlite3_test.go | 434 +++++++++++++++++++++++++++++++++-- sqlite3_test/sqlite3_test.go | 423 ---------------------------------- 2 files changed, 420 insertions(+), 437 deletions(-) delete mode 100644 sqlite3_test/sqlite3_test.go diff --git a/sqlite3_test.go b/sqlite3_test.go index a00e622..5a94bd6 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -11,16 +11,16 @@ import ( "errors" "fmt" "io/ioutil" + "math/rand" "net/url" "os" "reflect" "regexp" + "strconv" "strings" "sync" "testing" "time" - - "github.com/mattn/go-sqlite3/sqlite3_test" ) func TempFilename(t *testing.T) string { @@ -870,18 +870,6 @@ func TestTimezoneConversion(t *testing.T) { } } -func TestSuite(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE) -} - // TODO: Execer & Queryer currently disabled // https://github.com/mattn/go-sqlite3/issues/82 func TestExecer(t *testing.T) { @@ -1389,3 +1377,421 @@ func BenchmarkCustomFunctions(b *testing.B) { } } } + +func TestSuite(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") + if err != nil { + t.Fatal(err) + } + defer d.Close() + + db = &TestDB{t, d, SQLITE, sync.Once{}} + testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests) + + if !testing.Short() { + for _, b := range benchmarks { + fmt.Printf("%-20s", b.Name) + r := testing.Benchmark(b.F) + fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds()) + } + } + db.tearDown() +} + +// Dialect is a type of dialect of databases. +type Dialect int + +// Dialects for databases. +const ( + SQLITE Dialect = iota // SQLITE mean SQLite3 dialect + POSTGRESQL // POSTGRESQL mean PostgreSQL dialect + MYSQL // MYSQL mean MySQL dialect +) + +// DB provide context for the tests +type TestDB struct { + *testing.T + *sql.DB + dialect Dialect + once sync.Once +} + +var db *TestDB + +// the following tables will be created and dropped during the test +var testTables = []string{"foo", "bar", "t", "bench"} + +var tests = []testing.InternalTest{ + {Name: "TestBlobs", F: TestBlobs}, + {Name: "TestManyQueryRow", F: TestManyQueryRow}, + {Name: "TestTxQuery", F: TestTxQuery}, + {Name: "TestPreparedStmt", F: TestPreparedStmt}, +} + +var benchmarks = []testing.InternalBenchmark{ + {Name: "BenchmarkExec", F: BenchmarkExec}, + {Name: "BenchmarkQuery", F: BenchmarkQuery}, + {Name: "BenchmarkParams", F: BenchmarkParams}, + {Name: "BenchmarkStmt", F: BenchmarkStmt}, + {Name: "BenchmarkRows", F: BenchmarkRows}, + {Name: "BenchmarkStmtRows", F: BenchmarkStmtRows}, +} + +func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result { + res, err := db.Exec(sql, args...) + if err != nil { + db.Fatalf("Error running %q: %v", sql, err) + } + return res +} + +func (db *TestDB) tearDown() { + for _, tbl := range testTables { + switch db.dialect { + case SQLITE: + db.mustExec("drop table if exists " + tbl) + case MYSQL, POSTGRESQL: + db.mustExec("drop table if exists " + tbl) + default: + db.Fatal("unknown dialect") + } + } +} + +// q replaces ? parameters if needed +func (db *TestDB) q(sql string) string { + switch db.dialect { + case POSTGRESQL: // repace with $1, $2, .. + qrx := regexp.MustCompile(`\?`) + n := 0 + return qrx.ReplaceAllStringFunc(sql, func(string) string { + n++ + return "$" + strconv.Itoa(n) + }) + } + return sql +} + +func (db *TestDB) blobType(size int) string { + switch db.dialect { + case SQLITE: + return fmt.Sprintf("blob[%d]", size) + case POSTGRESQL: + return "bytea" + case MYSQL: + return fmt.Sprintf("VARBINARY(%d)", size) + } + panic("unknown dialect") +} + +func (db *TestDB) serialPK() string { + switch db.dialect { + case SQLITE: + return "integer primary key autoincrement" + case POSTGRESQL: + return "serial primary key" + case MYSQL: + return "integer primary key auto_increment" + } + panic("unknown dialect") +} + +func (db *TestDB) now() string { + switch db.dialect { + case SQLITE: + return "datetime('now')" + case POSTGRESQL: + return "now()" + case MYSQL: + return "now()" + } + panic("unknown dialect") +} + +func makeBench() { + if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil { + panic(err) + } + st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)") + if err != nil { + panic(err) + } + defer st.Close() + for i := 0; i < 100; i++ { + if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil { + panic(err) + } + } +} + +// TestResult is test for result +func TestResult(t *testing.T) { + db.tearDown() + db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))") + + for i := 1; i < 3; i++ { + r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i)) + n, err := r.RowsAffected() + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("got %v, want %v", n, 1) + } + n, err = r.LastInsertId() + if err != nil { + t.Fatal(err) + } + if n != int64(i) { + t.Errorf("got %v, want %v", n, i) + } + } + if _, err := db.Exec("error!"); err == nil { + t.Fatalf("expected error") + } +} + +// TestBlobs is test for blobs +func TestBlobs(t *testing.T) { + db.tearDown() + var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")") + db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob) + + want := fmt.Sprintf("%x", blob) + + b := make([]byte, 16) + err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b) + got := fmt.Sprintf("%x", b) + if err != nil { + t.Errorf("[]byte scan: %v", err) + } else if got != want { + t.Errorf("for []byte, got %q; want %q", got, want) + } + + err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got) + want = string(blob) + if err != nil { + t.Errorf("string scan: %v", err) + } else if got != want { + t.Errorf("for string, got %q; want %q", got, want) + } +} + +// TestManyQueryRow is test for many query row +func TestManyQueryRow(t *testing.T) { + if testing.Short() { + t.Log("skipping in short mode") + return + } + db.tearDown() + db.mustExec("create table foo (id integer primary key, name varchar(50))") + db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") + var name string + for i := 0; i < 10000; i++ { + err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name) + if err != nil || name != "bob" { + t.Fatalf("on query %d: err=%v, name=%q", i, err, name) + } + } +} + +// TestTxQuery is test for transactional query +func TestTxQuery(t *testing.T) { + db.tearDown() + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))") + if err != nil { + t.Fatal(err) + } + + _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") + if err != nil { + t.Fatal(err) + } + + r, err := tx.Query(db.q("select name from foo where id = ?"), 1) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected one rows") + } + + var name string + err = r.Scan(&name) + if err != nil { + t.Fatal(err) + } +} + +// TestPreparedStmt is test for prepared statement +func TestPreparedStmt(t *testing.T) { + db.tearDown() + db.mustExec("CREATE TABLE t (count INT)") + sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC") + if err != nil { + t.Fatalf("prepare 1: %v", err) + } + ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)")) + if err != nil { + t.Fatalf("prepare 2: %v", err) + } + + for n := 1; n <= 3; n++ { + if _, err := ins.Exec(n); err != nil { + t.Fatalf("insert(%d) = %v", n, err) + } + } + + const nRuns = 10 + var wg sync.WaitGroup + for i := 0; i < nRuns; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + count := 0 + if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { + t.Errorf("Query: %v", err) + return + } + if _, err := ins.Exec(rand.Intn(100)); err != nil { + t.Errorf("Insert: %v", err) + return + } + } + }() + } + wg.Wait() +} + +// Benchmarks need to use panic() since b.Error errors are lost when +// running via testing.Benchmark() I would like to run these via go +// test -bench but calling Benchmark() from a benchmark test +// currently hangs go. + +// BenchmarkExec is benchmark for exec +func BenchmarkExec(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := db.Exec("select 1"); err != nil { + panic(err) + } + } +} + +// BenchmarkQuery is benchmark for query +func BenchmarkQuery(b *testing.B) { + for i := 0; i < b.N; i++ { + var n sql.NullString + var i int + var f float64 + var s string + // var t time.Time + if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { + panic(err) + } + } +} + +// BenchmarkParams is benchmark for params +func BenchmarkParams(b *testing.B) { + for i := 0; i < b.N; i++ { + var n sql.NullString + var i int + var f float64 + var s string + // var t time.Time + if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { + panic(err) + } + } +} + +// BenchmarkStmt is benchmark for statement +func BenchmarkStmt(b *testing.B) { + st, err := db.Prepare("select ?, ?, ?, ?") + if err != nil { + panic(err) + } + defer st.Close() + + for n := 0; n < b.N; n++ { + var n sql.NullString + var i int + var f float64 + var s string + // var t time.Time + if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { + panic(err) + } + } +} + +// BenchmarkRows is benchmark for rows +func BenchmarkRows(b *testing.B) { + db.once.Do(makeBench) + + for n := 0; n < b.N; n++ { + var n sql.NullString + var i int + var f float64 + var s string + var t time.Time + r, err := db.Query("select * from bench") + if err != nil { + panic(err) + } + for r.Next() { + if err = r.Scan(&n, &i, &f, &s, &t); err != nil { + panic(err) + } + } + if err = r.Err(); err != nil { + panic(err) + } + } +} + +// BenchmarkStmtRows is benchmark for statement rows +func BenchmarkStmtRows(b *testing.B) { + db.once.Do(makeBench) + + st, err := db.Prepare("select * from bench") + if err != nil { + panic(err) + } + defer st.Close() + + for n := 0; n < b.N; n++ { + var n sql.NullString + var i int + var f float64 + var s string + var t time.Time + r, err := st.Query() + if err != nil { + panic(err) + } + for r.Next() { + if err = r.Scan(&n, &i, &f, &s, &t); err != nil { + panic(err) + } + } + if err = r.Err(); err != nil { + panic(err) + } + } +} diff --git a/sqlite3_test/sqlite3_test.go b/sqlite3_test/sqlite3_test.go deleted file mode 100644 index 0ad9c3a..0000000 --- a/sqlite3_test/sqlite3_test.go +++ /dev/null @@ -1,423 +0,0 @@ -package sqlite3_test - -import ( - "database/sql" - "fmt" - "math/rand" - "regexp" - "strconv" - "sync" - "testing" - "time" -) - -// Dialect is a type of dialect of databases. -type Dialect int - -// Dialects for databases. -const ( - SQLITE Dialect = iota // SQLITE mean SQLite3 dialect - POSTGRESQL // POSTGRESQL mean PostgreSQL dialect - MYSQL // MYSQL mean MySQL dialect -) - -// DB provide context for the tests -type DB struct { - *testing.T - *sql.DB - dialect Dialect - once sync.Once -} - -var db *DB - -// the following tables will be created and dropped during the test -var testTables = []string{"foo", "bar", "t", "bench"} - -var tests = []testing.InternalTest{ - {Name: "TestBlobs", F: TestBlobs}, - {Name: "TestManyQueryRow", F: TestManyQueryRow}, - {Name: "TestTxQuery", F: TestTxQuery}, - {Name: "TestPreparedStmt", F: TestPreparedStmt}, -} - -var benchmarks = []testing.InternalBenchmark{ - {Name: "BenchmarkExec", F: BenchmarkExec}, - {Name: "BenchmarkQuery", F: BenchmarkQuery}, - {Name: "BenchmarkParams", F: BenchmarkParams}, - {Name: "BenchmarkStmt", F: BenchmarkStmt}, - {Name: "BenchmarkRows", F: BenchmarkRows}, - {Name: "BenchmarkStmtRows", F: BenchmarkStmtRows}, -} - -// RunTests runs the SQL test suite -func RunTests(t *testing.T, d *sql.DB, dialect Dialect) { - db = &DB{t, d, dialect, sync.Once{}} - testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests) - - if !testing.Short() { - for _, b := range benchmarks { - fmt.Printf("%-20s", b.Name) - r := testing.Benchmark(b.F) - fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds()) - } - } - db.tearDown() -} - -func (db *DB) mustExec(sql string, args ...interface{}) sql.Result { - res, err := db.Exec(sql, args...) - if err != nil { - db.Fatalf("Error running %q: %v", sql, err) - } - return res -} - -func (db *DB) tearDown() { - for _, tbl := range testTables { - switch db.dialect { - case SQLITE: - db.mustExec("drop table if exists " + tbl) - case MYSQL, POSTGRESQL: - db.mustExec("drop table if exists " + tbl) - default: - db.Fatal("unknown dialect") - } - } -} - -// q replaces ? parameters if needed -func (db *DB) q(sql string) string { - switch db.dialect { - case POSTGRESQL: // repace with $1, $2, .. - qrx := regexp.MustCompile(`\?`) - n := 0 - return qrx.ReplaceAllStringFunc(sql, func(string) string { - n++ - return "$" + strconv.Itoa(n) - }) - } - return sql -} - -func (db *DB) blobType(size int) string { - switch db.dialect { - case SQLITE: - return fmt.Sprintf("blob[%d]", size) - case POSTGRESQL: - return "bytea" - case MYSQL: - return fmt.Sprintf("VARBINARY(%d)", size) - } - panic("unknown dialect") -} - -func (db *DB) serialPK() string { - switch db.dialect { - case SQLITE: - return "integer primary key autoincrement" - case POSTGRESQL: - return "serial primary key" - case MYSQL: - return "integer primary key auto_increment" - } - panic("unknown dialect") -} - -func (db *DB) now() string { - switch db.dialect { - case SQLITE: - return "datetime('now')" - case POSTGRESQL: - return "now()" - case MYSQL: - return "now()" - } - panic("unknown dialect") -} - -func makeBench() { - if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil { - panic(err) - } - st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)") - if err != nil { - panic(err) - } - defer st.Close() - for i := 0; i < 100; i++ { - if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil { - panic(err) - } - } -} - -// TestResult is test for result -func TestResult(t *testing.T) { - db.tearDown() - db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))") - - for i := 1; i < 3; i++ { - r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i)) - n, err := r.RowsAffected() - if err != nil { - t.Fatal(err) - } - if n != 1 { - t.Errorf("got %v, want %v", n, 1) - } - n, err = r.LastInsertId() - if err != nil { - t.Fatal(err) - } - if n != int64(i) { - t.Errorf("got %v, want %v", n, i) - } - } - if _, err := db.Exec("error!"); err == nil { - t.Fatalf("expected error") - } -} - -// TestBlobs is test for blobs -func TestBlobs(t *testing.T) { - db.tearDown() - var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")") - db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob) - - want := fmt.Sprintf("%x", blob) - - b := make([]byte, 16) - err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b) - got := fmt.Sprintf("%x", b) - if err != nil { - t.Errorf("[]byte scan: %v", err) - } else if got != want { - t.Errorf("for []byte, got %q; want %q", got, want) - } - - err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got) - want = string(blob) - if err != nil { - t.Errorf("string scan: %v", err) - } else if got != want { - t.Errorf("for string, got %q; want %q", got, want) - } -} - -// TestManyQueryRow is test for many query row -func TestManyQueryRow(t *testing.T) { - if testing.Short() { - t.Log("skipping in short mode") - return - } - db.tearDown() - db.mustExec("create table foo (id integer primary key, name varchar(50))") - db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") - var name string - for i := 0; i < 10000; i++ { - err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name) - if err != nil || name != "bob" { - t.Fatalf("on query %d: err=%v, name=%q", i, err, name) - } - } -} - -// TestTxQuery is test for transactional query -func TestTxQuery(t *testing.T) { - db.tearDown() - tx, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer tx.Rollback() - - _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))") - if err != nil { - t.Fatal(err) - } - - _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") - if err != nil { - t.Fatal(err) - } - - r, err := tx.Query(db.q("select name from foo where id = ?"), 1) - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected one rows") - } - - var name string - err = r.Scan(&name) - if err != nil { - t.Fatal(err) - } -} - -// TestPreparedStmt is test for prepared statement -func TestPreparedStmt(t *testing.T) { - db.tearDown() - db.mustExec("CREATE TABLE t (count INT)") - sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC") - if err != nil { - t.Fatalf("prepare 1: %v", err) - } - ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)")) - if err != nil { - t.Fatalf("prepare 2: %v", err) - } - - for n := 1; n <= 3; n++ { - if _, err := ins.Exec(n); err != nil { - t.Fatalf("insert(%d) = %v", n, err) - } - } - - const nRuns = 10 - var wg sync.WaitGroup - for i := 0; i < nRuns; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 10; j++ { - count := 0 - if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { - t.Errorf("Query: %v", err) - return - } - if _, err := ins.Exec(rand.Intn(100)); err != nil { - t.Errorf("Insert: %v", err) - return - } - } - }() - } - wg.Wait() -} - -// Benchmarks need to use panic() since b.Error errors are lost when -// running via testing.Benchmark() I would like to run these via go -// test -bench but calling Benchmark() from a benchmark test -// currently hangs go. - -// BenchmarkExec is benchmark for exec -func BenchmarkExec(b *testing.B) { - for i := 0; i < b.N; i++ { - if _, err := db.Exec("select 1"); err != nil { - panic(err) - } - } -} - -// BenchmarkQuery is benchmark for query -func BenchmarkQuery(b *testing.B) { - for i := 0; i < b.N; i++ { - var n sql.NullString - var i int - var f float64 - var s string - // var t time.Time - if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { - panic(err) - } - } -} - -// BenchmarkParams is benchmark for params -func BenchmarkParams(b *testing.B) { - for i := 0; i < b.N; i++ { - var n sql.NullString - var i int - var f float64 - var s string - // var t time.Time - if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { - panic(err) - } - } -} - -// BenchmarkStmt is benchmark for statement -func BenchmarkStmt(b *testing.B) { - st, err := db.Prepare("select ?, ?, ?, ?") - if err != nil { - panic(err) - } - defer st.Close() - - for n := 0; n < b.N; n++ { - var n sql.NullString - var i int - var f float64 - var s string - // var t time.Time - if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { - panic(err) - } - } -} - -// BenchmarkRows is benchmark for rows -func BenchmarkRows(b *testing.B) { - db.once.Do(makeBench) - - for n := 0; n < b.N; n++ { - var n sql.NullString - var i int - var f float64 - var s string - var t time.Time - r, err := db.Query("select * from bench") - if err != nil { - panic(err) - } - for r.Next() { - if err = r.Scan(&n, &i, &f, &s, &t); err != nil { - panic(err) - } - } - if err = r.Err(); err != nil { - panic(err) - } - } -} - -// BenchmarkStmtRows is benchmark for statement rows -func BenchmarkStmtRows(b *testing.B) { - db.once.Do(makeBench) - - st, err := db.Prepare("select * from bench") - if err != nil { - panic(err) - } - defer st.Close() - - for n := 0; n < b.N; n++ { - var n sql.NullString - var i int - var f float64 - var s string - var t time.Time - r, err := st.Query() - if err != nil { - panic(err) - } - for r.Next() { - if err = r.Scan(&n, &i, &f, &s, &t); err != nil { - panic(err) - } - } - if err = r.Err(); err != nil { - panic(err) - } - } -} From 1828334c4a7937cf4d957e36e995b9d6ba4fc535 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 2 Aug 2017 01:43:14 +0900 Subject: [PATCH 11/22] remove mutex --- sqlite3.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index 3fc7354..d0327b8 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -197,12 +197,12 @@ type SQLiteResult struct { // SQLiteRows implement sql.Rows. type SQLiteRows struct { - mu sync.Mutex s *SQLiteStmt nc int cols []string decltype []string cls bool + closed bool done chan struct{} } @@ -905,6 +905,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows, cols: nil, decltype: nil, cls: s.cls, + closed: false, done: make(chan struct{}), } @@ -977,14 +978,12 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result // Close the rows. func (rc *SQLiteRows) Close() error { - if rc.s.closed { + if rc.s.closed || rc.closed { return nil } + rc.closed = true if rc.done != nil { - rc.mu.Lock() close(rc.done) - rc.done = nil - rc.mu.Unlock() } if rc.cls { return rc.s.Close() From 7133e5d7f58869b174d0e412322c68ac78b77cc8 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 2 Aug 2017 01:49:00 +0900 Subject: [PATCH 12/22] ignore errors in teardown --- sqlite3_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlite3_test.go b/sqlite3_test.go index 5a94bd6..ad9aba9 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1451,9 +1451,9 @@ func (db *TestDB) tearDown() { for _, tbl := range testTables { switch db.dialect { case SQLITE: - db.mustExec("drop table if exists " + tbl) + db.Exec("drop table if exists " + tbl) case MYSQL, POSTGRESQL: - db.mustExec("drop table if exists " + tbl) + db.Exec("drop table if exists " + tbl) default: db.Fatal("unknown dialect") } From 6654e412c3c7eabb310d920cf73a2102dbf8c632 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 2 Aug 2017 02:06:40 +0900 Subject: [PATCH 13/22] fix test --- sqlite3_test.go | 69 +++++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/sqlite3_test.go b/sqlite3_test.go index ad9aba9..7c545e1 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1424,19 +1424,20 @@ var db *TestDB var testTables = []string{"foo", "bar", "t", "bench"} var tests = []testing.InternalTest{ - {Name: "TestBlobs", F: TestBlobs}, - {Name: "TestManyQueryRow", F: TestManyQueryRow}, - {Name: "TestTxQuery", F: TestTxQuery}, - {Name: "TestPreparedStmt", F: TestPreparedStmt}, + {Name: "TestResult", F: testResult}, + {Name: "TestBlobs", F: testBlobs}, + {Name: "TestManyQueryRow", F: testManyQueryRow}, + {Name: "TestTxQuery", F: testTxQuery}, + {Name: "TestPreparedStmt", F: testPreparedStmt}, } var benchmarks = []testing.InternalBenchmark{ - {Name: "BenchmarkExec", F: BenchmarkExec}, - {Name: "BenchmarkQuery", F: BenchmarkQuery}, - {Name: "BenchmarkParams", F: BenchmarkParams}, - {Name: "BenchmarkStmt", F: BenchmarkStmt}, - {Name: "BenchmarkRows", F: BenchmarkRows}, - {Name: "BenchmarkStmtRows", F: BenchmarkStmtRows}, + {Name: "BenchmarkExec", F: benchmarkExec}, + {Name: "BenchmarkQuery", F: benchmarkQuery}, + {Name: "BenchmarkParams", F: benchmarkParams}, + {Name: "BenchmarkStmt", F: benchmarkStmt}, + {Name: "BenchmarkRows", F: benchmarkRows}, + {Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, } func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result { @@ -1451,9 +1452,9 @@ func (db *TestDB) tearDown() { for _, tbl := range testTables { switch db.dialect { case SQLITE: - db.Exec("drop table if exists " + tbl) + db.mustExec("drop table if exists " + tbl) case MYSQL, POSTGRESQL: - db.Exec("drop table if exists " + tbl) + db.mustExec("drop table if exists " + tbl) default: db.Fatal("unknown dialect") } @@ -1526,8 +1527,8 @@ func makeBench() { } } -// TestResult is test for result -func TestResult(t *testing.T) { +// testResult is test for result +func testResult(t *testing.T) { db.tearDown() db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))") @@ -1553,8 +1554,8 @@ func TestResult(t *testing.T) { } } -// TestBlobs is test for blobs -func TestBlobs(t *testing.T) { +// testBlobs is test for blobs +func testBlobs(t *testing.T) { db.tearDown() var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")") @@ -1580,8 +1581,8 @@ func TestBlobs(t *testing.T) { } } -// TestManyQueryRow is test for many query row -func TestManyQueryRow(t *testing.T) { +// testManyQueryRow is test for many query row +func testManyQueryRow(t *testing.T) { if testing.Short() { t.Log("skipping in short mode") return @@ -1598,8 +1599,8 @@ func TestManyQueryRow(t *testing.T) { } } -// TestTxQuery is test for transactional query -func TestTxQuery(t *testing.T) { +// testTxQuery is test for transactional query +func testTxQuery(t *testing.T) { db.tearDown() tx, err := db.Begin() if err != nil { @@ -1637,8 +1638,8 @@ func TestTxQuery(t *testing.T) { } } -// TestPreparedStmt is test for prepared statement -func TestPreparedStmt(t *testing.T) { +// testPreparedStmt is test for prepared statement +func testPreparedStmt(t *testing.T) { db.tearDown() db.mustExec("CREATE TABLE t (count INT)") sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC") @@ -1683,8 +1684,8 @@ func TestPreparedStmt(t *testing.T) { // test -bench but calling Benchmark() from a benchmark test // currently hangs go. -// BenchmarkExec is benchmark for exec -func BenchmarkExec(b *testing.B) { +// benchmarkExec is benchmark for exec +func benchmarkExec(b *testing.B) { for i := 0; i < b.N; i++ { if _, err := db.Exec("select 1"); err != nil { panic(err) @@ -1692,8 +1693,8 @@ func BenchmarkExec(b *testing.B) { } } -// BenchmarkQuery is benchmark for query -func BenchmarkQuery(b *testing.B) { +// benchmarkQuery is benchmark for query +func benchmarkQuery(b *testing.B) { for i := 0; i < b.N; i++ { var n sql.NullString var i int @@ -1706,8 +1707,8 @@ func BenchmarkQuery(b *testing.B) { } } -// BenchmarkParams is benchmark for params -func BenchmarkParams(b *testing.B) { +// benchmarkParams is benchmark for params +func benchmarkParams(b *testing.B) { for i := 0; i < b.N; i++ { var n sql.NullString var i int @@ -1720,8 +1721,8 @@ func BenchmarkParams(b *testing.B) { } } -// BenchmarkStmt is benchmark for statement -func BenchmarkStmt(b *testing.B) { +// benchmarkStmt is benchmark for statement +func benchmarkStmt(b *testing.B) { st, err := db.Prepare("select ?, ?, ?, ?") if err != nil { panic(err) @@ -1740,8 +1741,8 @@ func BenchmarkStmt(b *testing.B) { } } -// BenchmarkRows is benchmark for rows -func BenchmarkRows(b *testing.B) { +// benchmarkRows is benchmark for rows +func benchmarkRows(b *testing.B) { db.once.Do(makeBench) for n := 0; n < b.N; n++ { @@ -1765,8 +1766,8 @@ func BenchmarkRows(b *testing.B) { } } -// BenchmarkStmtRows is benchmark for statement rows -func BenchmarkStmtRows(b *testing.B) { +// benchmarkStmtRows is benchmark for statement rows +func benchmarkStmtRows(b *testing.B) { db.once.Do(makeBench) st, err := db.Prepare("select * from bench") From d1772f082687b34ca377224523c6e1b5b545425a Mon Sep 17 00:00:00 2001 From: Greg Holt Date: Mon, 21 Aug 2017 13:22:09 -0700 Subject: [PATCH 14/22] Added TestNilAndEmptyBytes --- sqlite3_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/sqlite3_test.go b/sqlite3_test.go index 7c545e1..8169f3d 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -6,6 +6,7 @@ package sqlite3 import ( + "bytes" "database/sql" "database/sql/driver" "errors" @@ -1343,6 +1344,59 @@ func TestUpdateAndTransactionHooks(t *testing.T) { } } +func TestNilAndEmptyBytes(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + actualNil := []byte("use this to use an actual nil not a reference to nil") + emptyBytes := []byte{} + for tsti, tst := range []struct { + name string + columnType string + insertBytes []byte + expectedBytes []byte + }{ + {"actual nil blob", "blob", actualNil, nil}, + {"referenced nil blob", "blob", nil, nil}, + {"empty blob", "blob", emptyBytes, emptyBytes}, + {"actual nil text", "text", actualNil, nil}, + {"referenced nil text", "text", nil, nil}, + {"empty text", "text", emptyBytes, emptyBytes}, + } { + if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil { + t.Fatal(tst.name, err) + } + if bytes.Equal(tst.insertBytes, actualNil) { + if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil { + t.Fatal(tst.name, err) + } + } else { + if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil { + t.Fatal(tst.name, err) + } + } + rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti)) + if err != nil { + t.Fatal(tst.name, err) + } + if !rows.Next() { + t.Fatal(tst.name, "no rows") + } + var scanBytes []byte + if err = rows.Scan(&scanBytes); err != nil { + t.Fatal(tst.name, err) + } + if err = rows.Err(); err != nil { + t.Fatal(tst.name, err) + } + if !bytes.Equal(scanBytes, tst.expectedBytes) { + t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) + } + } +} + var customFunctionOnce sync.Once func BenchmarkCustomFunctions(b *testing.B) { From 85e456ef27b5b5a64214bc5a74cbf8ec3114f5e5 Mon Sep 17 00:00:00 2001 From: Greg Holt Date: Mon, 21 Aug 2017 13:30:07 -0700 Subject: [PATCH 15/22] Fix to pass TestNilAndEmptyBytes --- sqlite3.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index d0327b8..bff6b7c 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -867,10 +867,11 @@ func (s *SQLiteStmt) bind(args []namedValue) error { case float64: rv = C.sqlite3_bind_double(s.s, n, C.double(v)) case []byte: - if len(v) == 0 { + ln := len(v) + if ln == 0 { v = placeHolder } - rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v))) + rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) case time.Time: b := []byte(v.Format(SQLiteTimestampFormats[0])) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) From b1c8062c18ee31834bdd0cd70c90af8590ce1f1a Mon Sep 17 00:00:00 2001 From: Greg Holt Date: Mon, 21 Aug 2017 13:45:34 -0700 Subject: [PATCH 16/22] Improved TestNilAndEmptyBytes I forgot that bytes.Equals treats nil and []byte{} as equal. --- sqlite3_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sqlite3_test.go b/sqlite3_test.go index 8169f3d..09e6727 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1391,7 +1391,9 @@ func TestNilAndEmptyBytes(t *testing.T) { if err = rows.Err(); err != nil { t.Fatal(tst.name, err) } - if !bytes.Equal(scanBytes, tst.expectedBytes) { + if tst.expectedBytes == nil && scanBytes != nil { + t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) + } else if !bytes.Equal(scanBytes, tst.expectedBytes) { t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) } } From a97e7bb12f7a08fabc155855e9861f983d6d7dec Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Sun, 27 Aug 2017 14:00:09 +0900 Subject: [PATCH 17/22] fix README.md close #456 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 01d28a2..ad00f10 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ FAQ * Want to get time.Time with current locale - Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`. + Use `_loc=auto` in SQLite3 filename schema like `file:foo.db?_loc=auto`. * Can I use this in multiple routines concurrently? From d40d4905434d5abfab8d29584adc850797897769 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Mon, 28 Aug 2017 18:58:02 +0900 Subject: [PATCH 18/22] fixes #458 --- sqlite3.go | 1 + sqlite3_go18_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/sqlite3.go b/sqlite3.go index bff6b7c..076c9bd 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -811,6 +811,7 @@ func (s *SQLiteStmt) Close() error { return errors.New("sqlite statement with already closed database connection") } rv := C.sqlite3_finalize(s.s) + s.s = nil if rv != C.SQLITE_OK { return s.c.lastError() } diff --git a/sqlite3_go18_test.go b/sqlite3_go18_test.go index f076b81..a5f4aae 100644 --- a/sqlite3_go18_test.go +++ b/sqlite3_go18_test.go @@ -8,9 +8,13 @@ package sqlite3 import ( + "context" "database/sql" + "fmt" + "math/rand" "os" "testing" + "time" ) func TestNamedParams(t *testing.T) { @@ -48,3 +52,91 @@ func TestNamedParams(t *testing.T) { t.Error("Failed to db.QueryRow: not matched results") } } + +var ( + testTableStatements = []string{ + `DROP TABLE IF EXISTS test_table`, + ` +CREATE TABLE IF NOT EXISTS test_table ( + key1 VARCHAR(64) PRIMARY KEY, + key_id VARCHAR(64) NOT NULL, + key2 VARCHAR(64) NOT NULL, + key3 VARCHAR(64) NOT NULL, + key4 VARCHAR(64) NOT NULL, + key5 VARCHAR(64) NOT NULL, + key6 VARCHAR(64) NOT NULL, + data BLOB NOT NULL +);`, + } + letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +) + +func randStringBytes(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} + +func initDatabase(t *testing.T, db *sql.DB, rowCount int64) { + t.Logf("Executing db initializing statements") + for _, query := range testTableStatements { + _, err := db.Exec(query) + if err != nil { + t.Fatal(err) + } + } + for i := int64(0); i < rowCount; i++ { + query := `INSERT INTO test_table + (key1, key_id, key2, key3, key4, key5, key6, data) + VALUES + (?, ?, ?, ?, ?, ?, ?, ?);` + args := []interface{}{ + randStringBytes(50), + fmt.Sprint(i), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(50), + randStringBytes(2048), + } + _, err := db.Exec(query, args...) + if err != nil { + t.Fatal(err) + } + } +} + +func TestShortTimeout(t *testing.T) { + db, err := sql.Open("sqlite3", "file::memory:?mode=memory&cache=shared") + if err != nil { + t.Fatal(err) + } + defer db.Close() + initDatabase(t, db, 10000) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Microsecond) + defer cancel() + query := `SELECT key1, key_id, key2, key3, key4, key5, key6, data + FROM test_table + ORDER BY key2 ASC` + rows, err := db.QueryContext(ctx, query) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + for rows.Next() { + var key1, keyid, key2, key3, key4, key5, key6 string + var data []byte + err = rows.Scan(&key1, &keyid, &key2, &key3, &key4, &key5, &key6, &data) + if err != nil { + break + } + } + if context.DeadlineExceeded != ctx.Err() { + t.Fatal(ctx.Err()) + } +} From ee720241fc5db8cf245ee588870940ad7a92863c Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 30 Aug 2017 13:19:01 +0900 Subject: [PATCH 19/22] support Solaris See #459 --- sqlite3_libsqlite3.go | 1 + sqlite3_other.go | 1 + 2 files changed, 2 insertions(+) diff --git a/sqlite3_libsqlite3.go b/sqlite3_libsqlite3.go index 135863e..e4557e6 100644 --- a/sqlite3_libsqlite3.go +++ b/sqlite3_libsqlite3.go @@ -10,5 +10,6 @@ package sqlite3 #cgo CFLAGS: -DUSE_LIBSQLITE3 #cgo linux LDFLAGS: -lsqlite3 #cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3 +#cgo solaris LDFLAGS: -lsqlite3 */ import "C" diff --git a/sqlite3_other.go b/sqlite3_other.go index a20d02c..f721b5e 100644 --- a/sqlite3_other.go +++ b/sqlite3_other.go @@ -9,5 +9,6 @@ package sqlite3 /* #cgo CFLAGS: -I. #cgo linux LDFLAGS: -ldl +#cgo solaris LDFLAGS: -lc */ import "C" From 911f1c4fa68b1823e76c909e777571f563bdb184 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 30 Aug 2017 13:29:47 +0900 Subject: [PATCH 20/22] fix lock --- sqlite3.go | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index 076c9bd..506a1b3 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -182,6 +182,7 @@ type SQLiteTx struct { // SQLiteStmt implement sql.Stmt. type SQLiteStmt struct { + mu sync.Mutex c *SQLiteConn s *C.sqlite3_stmt t string @@ -803,6 +804,8 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er // Close the statement. func (s *SQLiteStmt) Close() error { + s.mu.Lock() + defer s.mu.Unlock() if s.closed { return nil } @@ -980,7 +983,9 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result // Close the rows. func (rc *SQLiteRows) Close() error { + rc.s.mu.Lock() if rc.s.closed || rc.closed { + rc.s.mu.Unlock() return nil } rc.closed = true @@ -988,18 +993,23 @@ func (rc *SQLiteRows) Close() error { close(rc.done) } if rc.cls { + rc.s.mu.Unlock() return rc.s.Close() } rv := C.sqlite3_reset(rc.s.s) if rv != C.SQLITE_OK { + rc.s.mu.Unlock() return rc.s.c.lastError() } + rc.s.mu.Unlock() return nil } // Columns return column names. func (rc *SQLiteRows) Columns() []string { - if rc.nc != len(rc.cols) { + rc.s.mu.Lock() + defer rc.s.mu.Unlock() + if rc.s.s != nil && rc.nc != len(rc.cols) { rc.cols = make([]string, rc.nc) for i := 0; i < rc.nc; i++ { rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i))) @@ -1010,7 +1020,9 @@ func (rc *SQLiteRows) Columns() []string { // DeclTypes return column types. func (rc *SQLiteRows) DeclTypes() []string { - if rc.decltype == nil { + rc.s.mu.Lock() + defer rc.s.mu.Unlock() + if rc.s.s != nil && rc.decltype == nil { rc.decltype = make([]string, rc.nc) for i := 0; i < rc.nc; i++ { rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))) @@ -1021,20 +1033,30 @@ func (rc *SQLiteRows) DeclTypes() []string { // Next move cursor to next. func (rc *SQLiteRows) Next(dest []driver.Value) error { + if rc.s.closed { + return io.EOF + } + rc.s.mu.Lock() rv := C.sqlite3_step(rc.s.s) if rv == C.SQLITE_DONE { + rc.s.mu.Unlock() return io.EOF } if rv != C.SQLITE_ROW { + defer rc.s.mu.Unlock() rv = C.sqlite3_reset(rc.s.s) if rv != C.SQLITE_OK { return rc.s.c.lastError() } + rc.s.mu.Unlock() return nil } rc.DeclTypes() + rc.s.mu.Lock() + defer rc.s.mu.Unlock() + for i := range dest { switch C.sqlite3_column_type(rc.s.s, C.int(i)) { case C.SQLITE_INTEGER: From 58ed4a0810db4a6757547ca36adf4cf2ffc18fdd Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 30 Aug 2017 19:30:53 +0900 Subject: [PATCH 21/22] fix race --- sqlite3.go | 1 - 1 file changed, 1 deletion(-) diff --git a/sqlite3.go b/sqlite3.go index 506a1b3..c345727 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1048,7 +1048,6 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { if rv != C.SQLITE_OK { return rc.s.c.lastError() } - rc.s.mu.Unlock() return nil } From 8d81c2f1f8f4b00b487c668bb8c9df5daa112900 Mon Sep 17 00:00:00 2001 From: Yasuhiro Matsumoto Date: Wed, 30 Aug 2017 19:37:57 +0900 Subject: [PATCH 22/22] fix race --- sqlite3.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sqlite3.go b/sqlite3.go index c345727..42a2e18 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1018,10 +1018,7 @@ func (rc *SQLiteRows) Columns() []string { return rc.cols } -// DeclTypes return column types. -func (rc *SQLiteRows) DeclTypes() []string { - rc.s.mu.Lock() - defer rc.s.mu.Unlock() +func (rc *SQLiteRows) declTypes() []string { if rc.s.s != nil && rc.decltype == nil { rc.decltype = make([]string, rc.nc) for i := 0; i < rc.nc; i++ { @@ -1031,19 +1028,25 @@ func (rc *SQLiteRows) DeclTypes() []string { return rc.decltype } +// DeclTypes return column types. +func (rc *SQLiteRows) DeclTypes() []string { + rc.s.mu.Lock() + defer rc.s.mu.Unlock() + return rc.declTypes() +} + // Next move cursor to next. func (rc *SQLiteRows) Next(dest []driver.Value) error { if rc.s.closed { return io.EOF } rc.s.mu.Lock() + defer rc.s.mu.Unlock() rv := C.sqlite3_step(rc.s.s) if rv == C.SQLITE_DONE { - rc.s.mu.Unlock() return io.EOF } if rv != C.SQLITE_ROW { - defer rc.s.mu.Unlock() rv = C.sqlite3_reset(rc.s.s) if rv != C.SQLITE_OK { return rc.s.c.lastError() @@ -1051,10 +1054,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { return nil } - rc.DeclTypes() - - rc.s.mu.Lock() - defer rc.s.mu.Unlock() + rc.declTypes() for i := range dest { switch C.sqlite3_column_type(rc.s.s, C.int(i)) {