From 7c7d9c8ce53d422b3c29b63e4fe397531e24c8bd Mon Sep 17 00:00:00 2001 From: Gert-Jan Timmer Date: Sun, 24 Jun 2018 22:57:57 +0200 Subject: [PATCH] Update * Increase coverage of backup tests * FormatDSN * Rewrite C.in to types for Stringer implementation * Add PRAGMA func * Fix TestPinger * Driver, DriverContext => Fix Extensions, ConnectHook * Add Stringer to CryptEncoder, CryptSaltedEncoder * Fix OpenConnector * Renamed context to vtable_context and include build tag --- _driver/driver_test.go | 2067 ++++++++++++++++++++++ {driver => _driver}/error_test.go | 0 driver/backup.go | 2 +- driver/backup_test.go | 78 +- driver/config.go | 687 +++++-- driver/config_test.go | 627 +++++++ driver/connection.go | 14 + driver/connection_go18.go | 1 + driver/connection_go18_test.go | 24 +- driver/connector.go | 4 +- driver/connector_test.go | 39 + driver/crypt.go | 36 +- driver/driver.go | 41 +- driver/driver.goconvey | 4 + driver/driver_go110.go | 20 +- driver/driver_test.go | 2049 +-------------------- driver/error.go | 4 - driver/opt_stat4_test.go | 2 +- driver/opt_userauth_test.go | 10 +- driver/pragma.go | 23 + driver/{context.go => vtable_context.go} | 1 + 21 files changed, 3522 insertions(+), 2211 deletions(-) create mode 100644 _driver/driver_test.go rename {driver => _driver}/error_test.go (100%) create mode 100644 driver/config_test.go create mode 100644 driver/connector_test.go create mode 100644 driver/driver.goconvey create mode 100644 driver/pragma.go rename driver/{context.go => vtable_context.go} (99%) diff --git a/_driver/driver_test.go b/_driver/driver_test.go new file mode 100644 index 0000000..bbd46ba --- /dev/null +++ b/_driver/driver_test.go @@ -0,0 +1,2067 @@ +// Copyright (C) 2018 The Go-SQLite3 Authors. +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package sqlite3 + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io/ioutil" + "math/rand" + "net/url" + "os" + "reflect" + "regexp" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +func TempFilename(t *testing.T) string { + f, err := ioutil.TempFile("", "go-sqlite3-test-") + if err != nil { + t.Fatal(err) + } + f.Close() + return f.Name() +} + +func doTestOpen(t *testing.T, option string) (string, error) { + var url string + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + if option != "" { + url = tempFilename + option + } else { + url = tempFilename + } + db, err := sql.Open("sqlite3", url) + if err != nil { + return "Failed to open database:", err + } + defer os.Remove(tempFilename) + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer)") + if err != nil { + return "Failed to create table:", err + } + + if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() { + return "Failed to create ./foo.db", nil + } + + return "", nil +} + +func TestOpen(t *testing.T) { + cases := map[string]bool{ + "": true, + "?txlock=immediate": true, + "?txlock=deferred": true, + "?txlock=exclusive": true, + "?txlock=bogus": false, + } + for option, expectedPass := range cases { + result, err := doTestOpen(t, option) + if result == "" { + if !expectedPass { + errmsg := fmt.Sprintf("_txlock error not caught at dbOpen with option: %s", option) + t.Fatal(errmsg) + } + } else if expectedPass { + if err == nil { + t.Fatal(result) + } else { + t.Fatal(result, err) + } + } + } +} + +func TestReadonly(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + + db1, err := sql.Open("sqlite3", "file:"+tempFilename) + if err != nil { + t.Fatal(err) + } + db1.Exec("CREATE TABLE test (x int, y float)") + + db2, err := sql.Open("sqlite3", "file:"+tempFilename+"?mode=ro") + if err != nil { + t.Fatal(err) + } + _ = db2 + _, err = db2.Exec("INSERT INTO test VALUES (1, 3.14)") + if err == nil { + t.Fatal("didn't expect INSERT into read-only database to work") + } +} + +func TestForeignKeys(t *testing.T) { + cases := map[string]bool{ + "?foreign_keys=1": true, + "?foreign_keys=0": false, + } + for option, want := range cases { + fname := TempFilename(t) + uri := "file:" + fname + option + db, err := sql.Open("sqlite3", uri) + if err != nil { + os.Remove(fname) + t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err) + continue + } + var enabled bool + err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled) + db.Close() + os.Remove(fname) + if err != nil { + t.Errorf("query foreign_keys for %s: %v", uri, err) + continue + } + if enabled != want { + t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want) + continue + } + } +} + +func TestRecursiveTriggers(t *testing.T) { + cases := map[string]bool{ + "?_recursive_triggers=1": true, + "?_recursive_triggers=0": false, + } + for option, want := range cases { + fname := TempFilename(t) + uri := "file:" + fname + option + db, err := sql.Open("sqlite3", uri) + if err != nil { + os.Remove(fname) + t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err) + continue + } + var enabled bool + err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled) + db.Close() + os.Remove(fname) + if err != nil { + t.Errorf("query recursive_triggers for %s: %v", uri, err) + continue + } + if enabled != want { + t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want) + continue + } + } +} + +func TestClose(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + stmt, err := db.Prepare("select id from foo where id = ?") + if err != nil { + t.Fatal("Failed to select records:", err) + } + + db.Close() + _, err = stmt.Exec(1) + if err == nil { + t.Fatal("Failed to operate closed statement") + } +} + +func TestInsert(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + res, err := db.Exec("insert into foo(id) values(123)") + if err != nil { + t.Fatal("Failed to insert record:", err) + } + affected, _ := res.RowsAffected() + if affected != 1 { + t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) + } + + rows, err := db.Query("select id from foo") + if err != nil { + t.Fatal("Failed to select records:", err) + } + defer rows.Close() + + rows.Next() + + var result int + rows.Scan(&result) + if result != 123 { + t.Errorf("Expected %d for fetched result, but %d:", 123, result) + } +} + +func TestUpsert(t *testing.T) { + _, n, _ := Version() + if !(n >= 3024000) { + t.Skip("UPSERT requires sqlite3 => 3.24.0") + } + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (name string primary key, counter integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + for i := 0; i < 10; i++ { + res, err := db.Exec("insert into foo(name, counter) values('key', 1) on conflict (name) do update set counter=counter+1") + if err != nil { + t.Fatal("Failed to upsert record:", err) + } + affected, _ := res.RowsAffected() + if affected != 1 { + t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) + } + } + rows, err := db.Query("select name, counter from foo") + if err != nil { + t.Fatal("Failed to select records:", err) + } + defer rows.Close() + + rows.Next() + + var resultName string + var resultCounter int + rows.Scan(&resultName, &resultCounter) + if resultName != "key" { + t.Errorf("Expected %s for fetched result, but %s:", "key", resultName) + } + if resultCounter != 10 { + t.Errorf("Expected %d for fetched result, but %d:", 10, resultCounter) + } + +} + +func TestUpdate(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + res, err := db.Exec("insert into foo(id) values(123)") + if err != nil { + t.Fatal("Failed to insert record:", err) + } + expected, err := res.LastInsertId() + if err != nil { + t.Fatal("Failed to get LastInsertId:", err) + } + affected, _ := res.RowsAffected() + if err != nil { + t.Fatal("Failed to get RowsAffected:", err) + } + if affected != 1 { + t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) + } + + res, err = db.Exec("update foo set id = 234") + if err != nil { + t.Fatal("Failed to update record:", err) + } + lastID, err := res.LastInsertId() + if err != nil { + t.Fatal("Failed to get LastInsertId:", err) + } + if expected != lastID { + t.Errorf("Expected %q for last Id, but %q:", expected, lastID) + } + affected, _ = res.RowsAffected() + if err != nil { + t.Fatal("Failed to get RowsAffected:", err) + } + if affected != 1 { + t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) + } + + rows, err := db.Query("select id from foo") + if err != nil { + t.Fatal("Failed to select records:", err) + } + defer rows.Close() + + rows.Next() + + var result int + rows.Scan(&result) + if result != 234 { + t.Errorf("Expected %d for fetched result, but %d:", 234, result) + } +} + +func TestDelete(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("drop table foo") + _, err = db.Exec("create table foo (id integer)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + res, err := db.Exec("insert into foo(id) values(123)") + if err != nil { + t.Fatal("Failed to insert record:", err) + } + expected, err := res.LastInsertId() + if err != nil { + t.Fatal("Failed to get LastInsertId:", err) + } + affected, err := res.RowsAffected() + if err != nil { + t.Fatal("Failed to get RowsAffected:", err) + } + if affected != 1 { + t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected) + } + + res, err = db.Exec("delete from foo where id = 123") + if err != nil { + t.Fatal("Failed to delete record:", err) + } + lastID, err := res.LastInsertId() + if err != nil { + t.Fatal("Failed to get LastInsertId:", err) + } + if expected != lastID { + t.Errorf("Expected %q for last Id, but %q:", expected, lastID) + } + affected, err = res.RowsAffected() + if err != nil { + t.Fatal("Failed to get RowsAffected:", err) + } + if affected != 1 { + t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected) + } + + rows, err := db.Query("select id from foo") + if err != nil { + t.Fatal("Failed to select records:", err) + } + defer rows.Close() + + if rows.Next() { + t.Error("Fetched row but expected not rows") + } +} + +func TestBooleanRoundtrip(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE TABLE foo(id INTEGER, value BOOL)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("INSERT INTO foo(id, value) VALUES(1, ?)", true) + if err != nil { + t.Fatal("Failed to insert true value:", err) + } + + _, err = db.Exec("INSERT INTO foo(id, value) VALUES(2, ?)", false) + if err != nil { + t.Fatal("Failed to insert false value:", err) + } + + rows, err := db.Query("SELECT id, value FROM foo") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + for rows.Next() { + var id int + var value bool + + if err := rows.Scan(&id, &value); err != nil { + t.Error("Unable to scan results:", err) + continue + } + + if id == 1 && !value { + t.Error("Value for id 1 should be true, not false") + + } else if id == 2 && value { + t.Error("Value for id 2 should be false, not true") + } + } +} + +func timezone(t time.Time) string { return t.Format("-07:00") } + +func TestTimestamp(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts timeSTAMP, dt DATETIME)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) + timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) + timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) + tzTest := time.FixedZone("TEST", -9*3600-13*60) + tests := []struct { + value interface{} + expected time.Time + }{ + {"nonsense", time.Time{}}, + {"0000-00-00 00:00:00", time.Time{}}, + {time.Time{}.Unix(), time.Time{}}, + {timestamp1, timestamp1}, + {timestamp2.Unix(), timestamp2.Truncate(time.Second)}, + {timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)}, + {timestamp1.In(tzTest), timestamp1.In(tzTest)}, + {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1}, + {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1}, + {timestamp1.Format("2006-01-02 15:04:05"), timestamp1}, + {timestamp1.Format("2006-01-02T15:04:05"), timestamp1}, + {timestamp2, timestamp2}, + {"2006-01-02 15:04:05.123456789", timestamp2}, + {"2006-01-02T15:04:05.123456789", timestamp2}, + {"2006-01-02T05:51:05.123456789-09:13", timestamp2.In(tzTest)}, + {"2012-11-04", timestamp3}, + {"2012-11-04 00:00", timestamp3}, + {"2012-11-04 00:00:00", timestamp3}, + {"2012-11-04 00:00:00.000", timestamp3}, + {"2012-11-04T00:00", timestamp3}, + {"2012-11-04T00:00:00", timestamp3}, + {"2012-11-04T00:00:00.000", timestamp3}, + {"2006-01-02T15:04:05.123456789Z", timestamp2}, + {"2012-11-04Z", timestamp3}, + {"2012-11-04 00:00Z", timestamp3}, + {"2012-11-04 00:00:00Z", timestamp3}, + {"2012-11-04 00:00:00.000Z", timestamp3}, + {"2012-11-04T00:00Z", timestamp3}, + {"2012-11-04T00:00:00Z", timestamp3}, + {"2012-11-04T00:00:00.000Z", timestamp3}, + } + for i := range tests { + _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) + if err != nil { + t.Fatal("Failed to insert timestamp:", err) + } + } + + rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + seen := 0 + for rows.Next() { + var id int + var ts, dt time.Time + + if err := rows.Scan(&id, &ts, &dt); err != nil { + t.Error("Unable to scan results:", err) + continue + } + if id < 0 || id >= len(tests) { + t.Error("Bad row id: ", id) + continue + } + seen++ + if !tests[id].expected.Equal(ts) { + t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) + } + if !tests[id].expected.Equal(dt) { + t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) + } + if timezone(tests[id].expected) != timezone(ts) { + t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value, + timezone(tests[id].expected), timezone(ts)) + } + if timezone(tests[id].expected) != timezone(dt) { + t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value, + timezone(tests[id].expected), timezone(dt)) + } + } + + if seen != len(tests) { + t.Errorf("Expected to see %d rows", len(tests)) + } +} + +func TestBoolean(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + + defer db.Close() + + _, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + bool1 := true + _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(1, ?)", bool1) + if err != nil { + t.Fatal("Failed to insert boolean:", err) + } + + bool2 := false + _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(2, ?)", bool2) + if err != nil { + t.Fatal("Failed to insert boolean:", err) + } + + bool3 := "nonsense" + _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(3, ?)", bool3) + if err != nil { + t.Fatal("Failed to insert nonsense:", err) + } + + rows, err := db.Query("SELECT id, fbool FROM foo where fbool = ?", bool1) + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + counter := 0 + + var id int + var fbool bool + + for rows.Next() { + if err := rows.Scan(&id, &fbool); err != nil { + t.Fatal("Unable to scan results:", err) + } + counter++ + } + + if counter != 1 { + t.Fatalf("Expected 1 row but %v", counter) + } + + if id != 1 && !fbool { + t.Fatalf("Value for id 1 should be %v, not %v", bool1, fbool) + } + + rows, err = db.Query("SELECT id, fbool FROM foo where fbool = ?", bool2) + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + counter = 0 + + for rows.Next() { + if err := rows.Scan(&id, &fbool); err != nil { + t.Fatal("Unable to scan results:", err) + } + counter++ + } + + if counter != 1 { + t.Fatalf("Expected 1 row but %v", counter) + } + + if id != 2 && fbool { + t.Fatalf("Value for id 2 should be %v, not %v", bool2, fbool) + } + + // make sure "nonsense" triggered an error + rows, err = db.Query("SELECT id, fbool FROM foo where id=?;", 3) + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + rows.Next() + err = rows.Scan(&id, &fbool) + if err == nil { + t.Error("Expected error from \"nonsense\" bool") + } +} + +func TestFloat32(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE foo(id INTEGER)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = db.Exec("INSERT INTO foo(id) VALUES(null)") + if err != nil { + t.Fatal("Failed to insert null:", err) + } + + rows, err := db.Query("SELECT id FROM foo") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + if !rows.Next() { + t.Fatal("Unable to query results:", err) + } + + var id interface{} + if err := rows.Scan(&id); err != nil { + t.Fatal("Unable to scan results:", err) + } + if id != nil { + t.Error("Expected nil but not") + } +} + +func TestNull(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + rows, err := db.Query("SELECT 3.141592") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + + if !rows.Next() { + t.Fatal("Unable to query results:", err) + } + + var v interface{} + if err := rows.Scan(&v); err != nil { + t.Fatal("Unable to scan results:", err) + } + f, ok := v.(float64) + if !ok { + t.Error("Expected float but not") + } + if f != 3.141592 { + t.Error("Expected 3.141592 but not") + } +} + +func TestWAL(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + if _, err = db.Exec("PRAGMA journal_mode=WAL;"); err != nil { + t.Fatal("Failed to Exec PRAGMA journal_mode:", err) + } + if _, err = db.Exec("PRAGMA locking_mode=EXCLUSIVE;"); err != nil { + t.Fatal("Failed to Exec PRAGMA locking_mode:", err) + } + if _, err = db.Exec("CREATE TABLE test (id SERIAL, user TEXT NOT NULL, name TEXT NOT NULL);"); err != nil { + t.Fatal("Failed to Exec CREATE TABLE:", err) + } + if _, err = db.Exec("INSERT INTO test (user, name) VALUES ('user','name');"); err != nil { + t.Fatal("Failed to Exec INSERT:", err) + } + + trans, err := db.Begin() + if err != nil { + t.Fatal("Failed to Begin:", err) + } + s, err := trans.Prepare("INSERT INTO test (user, name) VALUES (?, ?);") + if err != nil { + t.Fatal("Failed to Prepare:", err) + } + + var count int + if err = trans.QueryRow("SELECT count(user) FROM test;").Scan(&count); err != nil { + t.Fatal("Failed to QueryRow:", err) + } + if _, err = s.Exec("bbbb", "aaaa"); err != nil { + t.Fatal("Failed to Exec prepared statement:", err) + } + if err = s.Close(); err != nil { + t.Fatal("Failed to Close prepared statement:", err) + } + if err = trans.Commit(); err != nil { + t.Fatal("Failed to Commit:", err) + } +} + +func TestTimezoneConversion(t *testing.T) { + zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + for _, tz := range zones { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz)) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)") + if err != nil { + t.Fatal("Failed to create table:", err) + } + + loc, err := time.LoadLocation(tz) + if err != nil { + t.Fatal("Failed to load location:", err) + } + + timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) + timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) + timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) + tests := []struct { + value interface{} + expected time.Time + }{ + {"nonsense", time.Time{}.In(loc)}, + {"0000-00-00 00:00:00", time.Time{}.In(loc)}, + {timestamp1, timestamp1.In(loc)}, + {timestamp1.Unix(), timestamp1.In(loc)}, + {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)}, + {timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)}, + {timestamp2, timestamp2.In(loc)}, + {"2006-01-02 15:04:05.123456789", timestamp2.In(loc)}, + {"2006-01-02T15:04:05.123456789", timestamp2.In(loc)}, + {"2012-11-04", timestamp3.In(loc)}, + {"2012-11-04 00:00", timestamp3.In(loc)}, + {"2012-11-04 00:00:00", timestamp3.In(loc)}, + {"2012-11-04 00:00:00.000", timestamp3.In(loc)}, + {"2012-11-04T00:00", timestamp3.In(loc)}, + {"2012-11-04T00:00:00", timestamp3.In(loc)}, + {"2012-11-04T00:00:00.000", timestamp3.In(loc)}, + } + for i := range tests { + _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) + if err != nil { + t.Fatal("Failed to insert timestamp:", err) + } + } + + rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") + if err != nil { + t.Fatal("Unable to query foo table:", err) + } + defer rows.Close() + + seen := 0 + for rows.Next() { + var id int + var ts, dt time.Time + + if err := rows.Scan(&id, &ts, &dt); err != nil { + t.Error("Unable to scan results:", err) + continue + } + if id < 0 || id >= len(tests) { + t.Error("Bad row id: ", id) + continue + } + seen++ + if !tests[id].expected.Equal(ts) { + t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts) + } + if !tests[id].expected.Equal(dt) { + t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) + } + if tests[id].expected.Location().String() != ts.Location().String() { + t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), ts.Location().String()) + } + if tests[id].expected.Location().String() != dt.Location().String() { + t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), dt.Location().String()) + } + } + + if seen != len(tests) { + t.Errorf("Expected to see %d rows", len(tests)) + } + } +} + +func TestExecer(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer); -- one comment + insert into foo(id) values(?); + insert into foo(id) values(?); + insert into foo(id) values(?); -- another comment + `, 1, 2, 3) + if err != nil { + t.Error("Failed to call db.Exec:", err) + } +} + +func TestQueryer(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer); + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + + rows, err := db.Query(` + insert into foo(id) values(?); + insert into foo(id) values(?); + insert into foo(id) values(?); + select id from foo order by id; + `, 3, 2, 1) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + defer rows.Close() + n := 1 + if rows != nil { + for rows.Next() { + var id int + err = rows.Scan(&id) + if err != nil { + t.Error("Failed to db.Query:", err) + } + if id != n { + t.Error("Failed to db.Query: not matched results") + } + } + } +} + +func TestStress(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + db.Exec("CREATE TABLE foo (id int);") + db.Exec("INSERT INTO foo VALUES(1);") + db.Exec("INSERT INTO foo VALUES(2);") + db.Close() + + for i := 0; i < 10000; i++ { + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + + for j := 0; j < 3; j++ { + rows, err := db.Query("select * from foo where id=1;") + if err != nil { + t.Error("Failed to call db.Query:", err) + } + for rows.Next() { + var i int + if err := rows.Scan(&i); err != nil { + t.Errorf("Scan failed: %v\n", err) + } + } + if err := rows.Err(); err != nil { + t.Errorf("Post-scan failed: %v\n", err) + } + rows.Close() + } + db.Close() + } +} + +func TestDateTimeLocal(t *testing.T) { + zone := "Asia/Tokyo" + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone) + if err != nil { + t.Fatal("Failed to open database:", err) + } + db.Exec("CREATE TABLE foo (dt datetime);") + db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');") + + row := db.QueryRow("select * from foo") + var d time.Time + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.Hour() == 15 || !strings.Contains(d.String(), "JST") { + t.Fatal("Result should have timezone", d) + } + db.Close() + + db, err = sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + + row = db.QueryRow("select * from foo") + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.UTC().Hour() != 15 || !strings.Contains(d.String(), "UTC") { + t.Fatalf("Result should not have timezone %v %v", zone, d.String()) + } + + _, err = db.Exec("DELETE FROM foo") + if err != nil { + t.Fatal("Failed to delete table:", err) + } + dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST") + if err != nil { + t.Fatal("Failed to parse datetime:", err) + } + db.Exec("INSERT INTO foo VALUES(?);", dt) + + db.Close() + db, err = sql.Open("sqlite3", tempFilename+"?_loc="+zone) + if err != nil { + t.Fatal("Failed to open database:", err) + } + + row = db.QueryRow("select * from foo") + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.Hour() != 15 || !strings.Contains(d.String(), "JST") { + t.Fatalf("Result should have timezone %v %v", zone, d.String()) + } +} + +func TestStringContainingZero(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer, name, extra text); + `) + if err != nil { + t.Error("Failed to call db.Query:", err) + } + + const text = "foo\x00bar" + + _, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, text) + if err != nil { + t.Error("Failed to call db.Exec:", err) + } + + row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, text) + if row == nil { + t.Error("Failed to call db.QueryRow") + } + + var id int + var extra string + err = row.Scan(&id, &extra) + if err != nil { + t.Error("Failed to db.Scan:", err) + } + if id != 1 || extra != text { + t.Error("Failed to db.QueryRow: not matched results") + } +} + +const CurrentTimeStamp = "2006-01-02 15:04:05" + +type TimeStamp struct{ *time.Time } + +func (t TimeStamp) Scan(value interface{}) error { + var err error + switch v := value.(type) { + case string: + *t.Time, err = time.Parse(CurrentTimeStamp, v) + case []byte: + *t.Time, err = time.Parse(CurrentTimeStamp, string(v)) + default: + err = errors.New("invalid type for current_timestamp") + } + return err +} + +func (t TimeStamp) Value() (driver.Value, error) { + return t.Time.Format(CurrentTimeStamp), nil +} + +func TestDateTimeNow(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + var d time.Time + err = db.QueryRow("SELECT datetime('now')").Scan(TimeStamp{&d}) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } +} + +func TestFunctionRegistration(t *testing.T) { + addi8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) } + addi64 := func(a, b int64) int64 { return a + b } + addu8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) } + addu64 := func(a, b uint64) uint64 { return a + b } + addiu := func(a int, b uint) int64 { return int64(a) + int64(b) } + addf32_64 := func(a float32, b float64) float64 { return float64(a) + b } + not := func(a bool) bool { return !a } + regex := func(re, s string) (bool, error) { + return regexp.MatchString(re, s) + } + generic := func(a interface{}) int64 { + switch a.(type) { + case int64: + return 1 + case float64: + return 2 + case []byte: + return 3 + case string: + return 4 + default: + panic("unreachable") + } + } + variadic := func(a, b int64, c ...int64) int64 { + ret := a + b + for _, d := range c { + ret += d + } + return ret + } + variadicGeneric := func(a ...interface{}) int64 { + return int64(len(a)) + } + + sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterFunc("addi8_16_32", addi8_16_32, true); err != nil { + return err + } + if err := conn.RegisterFunc("addi64", addi64, true); err != nil { + return err + } + if err := conn.RegisterFunc("addu8_16_32", addu8_16_32, true); err != nil { + return err + } + if err := conn.RegisterFunc("addu64", addu64, true); err != nil { + return err + } + if err := conn.RegisterFunc("addiu", addiu, true); err != nil { + return err + } + if err := conn.RegisterFunc("addf32_64", addf32_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("not", not, true); err != nil { + return err + } + if err := conn.RegisterFunc("regex", regex, true); err != nil { + return err + } + if err := conn.RegisterFunc("generic", generic, true); err != nil { + return err + } + if err := conn.RegisterFunc("variadic", variadic, true); err != nil { + return err + } + if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + ops := []struct { + query string + expected interface{} + }{ + {"SELECT addi8_16_32(1,2)", int32(3)}, + {"SELECT addi64(1,2)", int64(3)}, + {"SELECT addu8_16_32(1,2)", uint32(3)}, + {"SELECT addu64(1,2)", uint64(3)}, + {"SELECT addiu(1,2)", int64(3)}, + {"SELECT addf32_64(1.5,1.5)", float64(3)}, + {"SELECT not(1)", false}, + {"SELECT not(0)", true}, + {`SELECT regex("^foo.*", "foobar")`, true}, + {`SELECT regex("^foo.*", "barfoobar")`, false}, + {"SELECT generic(1)", int64(1)}, + {"SELECT generic(1.1)", int64(2)}, + {`SELECT generic(NULL)`, int64(3)}, + {`SELECT generic("foo")`, int64(4)}, + {"SELECT variadic(1,2)", int64(3)}, + {"SELECT variadic(1,2,3,4)", int64(10)}, + {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)}, + {`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)}, + } + + for _, op := range ops { + ret := reflect.New(reflect.TypeOf(op.expected)) + err = db.QueryRow(op.query).Scan(ret.Interface()) + if err != nil { + t.Errorf("Query %q failed: %s", op.query, err) + } else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) { + t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected) + } + } +} + +type sumAggregator int64 + +func (s *sumAggregator) Step(x int64) { + *s += sumAggregator(x) +} + +func (s *sumAggregator) Done() int64 { + return int64(*s) +} + +func TestAggregatorRegistration(t *testing.T) { + customSum := func() *sumAggregator { + var ret sumAggregator + return &ret + } + + sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterAggregator("customSum", customSum, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + // trace feature is not implemented + t.Skip("Failed to create table:", err) + } + + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + tests := []struct { + dept, sum int64 + }{ + {1, 30}, + {2, 42}, + } + + for _, test := range tests { + var ret int64 + err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) + if err != nil { + t.Fatal("Query failed:", err) + } + if ret != test.sum { + t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) + } + } +} + +func rot13(r rune) rune { + switch { + case r >= 'A' && r <= 'Z': + return 'A' + (r-'A'+13)%26 + case r >= 'a' && r <= 'z': + return 'a' + (r-'a'+13)%26 + } + return r +} + +func TestCollationRegistration(t *testing.T) { + collateRot13 := func(a, b string) int { + ra, rb := strings.Map(rot13, a), strings.Map(rot13, b) + return strings.Compare(ra, rb) + } + collateRot13Reverse := func(a, b string) int { + return collateRot13(b, a) + } + + sql.Register("sqlite3_CollationRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterCollation("rot13", collateRot13); err != nil { + return err + } + if err := conn.RegisterCollation("rot13reverse", collateRot13Reverse); err != nil { + return err + } + return nil + }, + }) + + db, err := sql.Open("sqlite3_CollationRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + populate := []string{ + `CREATE TABLE test (s TEXT)`, + `INSERT INTO test VALUES ("aaaa")`, + `INSERT INTO test VALUES ("ffff")`, + `INSERT INTO test VALUES ("qqqq")`, + `INSERT INTO test VALUES ("tttt")`, + `INSERT INTO test VALUES ("zzzz")`, + } + for _, stmt := range populate { + if _, err := db.Exec(stmt); err != nil { + t.Fatal("Failed to populate test DB:", err) + } + } + + ops := []struct { + query string + want []string + }{ + { + "SELECT * FROM test ORDER BY s COLLATE rot13 ASC", + []string{ + "qqqq", + "tttt", + "zzzz", + "aaaa", + "ffff", + }, + }, + { + "SELECT * FROM test ORDER BY s COLLATE rot13 DESC", + []string{ + "ffff", + "aaaa", + "zzzz", + "tttt", + "qqqq", + }, + }, + { + "SELECT * FROM test ORDER BY s COLLATE rot13reverse ASC", + []string{ + "ffff", + "aaaa", + "zzzz", + "tttt", + "qqqq", + }, + }, + { + "SELECT * FROM test ORDER BY s COLLATE rot13reverse DESC", + []string{ + "qqqq", + "tttt", + "zzzz", + "aaaa", + "ffff", + }, + }, + } + + for _, op := range ops { + rows, err := db.Query(op.query) + if err != nil { + t.Fatalf("Query %q failed: %s", op.query, err) + } + got := []string{} + defer rows.Close() + for rows.Next() { + var s string + if err = rows.Scan(&s); err != nil { + t.Fatalf("Reading row for %q: %s", op.query, err) + } + got = append(got, s) + } + if err = rows.Err(); err != nil { + t.Fatalf("Reading rows for %q: %s", op.query, err) + } + + if !reflect.DeepEqual(got, op.want) { + t.Fatalf("Unexpected output from %q\ngot:\n%s\n\nwant:\n%s", op.query, strings.Join(got, "\n"), strings.Join(op.want, "\n")) + } + } +} + +func TestDeclTypes(t *testing.T) { + + d := SQLiteDriver{} + + conn, err := d.Open(":memory:") + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + defer conn.Close() + + sqlite3conn := conn.(*SQLiteConn) + + _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text)", nil) + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = sqlite3conn.Exec("insert into foo(name) values(\"bar\")", nil) + if err != nil { + t.Fatal("Failed to insert:", err) + } + + rs, err := sqlite3conn.Query("select * from foo", nil) + if err != nil { + t.Fatal("Failed to select:", err) + } + defer rs.Close() + + declTypes := rs.(*SQLiteRows).DeclTypes() + + if !reflect.DeepEqual(declTypes, []string{"integer", "text"}) { + t.Fatal("Unexpected declTypes:", declTypes) + } +} + +func TestUpdateAndTransactionHooks(t *testing.T) { + var events []string + var commitHookReturn = 0 + + sql.Register("sqlite3_UpdateHook", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + conn.RegisterCommitHook(func() int { + events = append(events, "commit") + return commitHookReturn + }) + conn.RegisterRollbackHook(func() { + events = append(events, "rollback") + }) + conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) { + events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid)) + }) + return nil + }, + }) + db, err := sql.Open("sqlite3_UpdateHook", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + statements := []string{ + "create table foo (id integer primary key)", + "insert into foo values (9)", + "update foo set id = 99 where id = 9", + "delete from foo where id = 99", + } + for _, statement := range statements { + _, err = db.Exec(statement) + if err != nil { + t.Fatalf("Unable to prepare test data [%v]: %v", statement, err) + } + } + + commitHookReturn = 1 + _, err = db.Exec("insert into foo values (5)") + if err == nil { + t.Error("Commit hook failed to rollback transaction") + } + + var expected = []string{ + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT), + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE), + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE), + "commit", + fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT), + "commit", + "rollback", + } + if !reflect.DeepEqual(events, expected) { + t.Errorf("Expected notifications %v but got %v", expected, events) + } +} + +func TestNilAndEmptyBytes(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + actualNil := []byte("use this to use an actual nil not a reference to nil") + emptyBytes := []byte{} + for tsti, tst := range []struct { + name string + columnType string + insertBytes []byte + expectedBytes []byte + }{ + {"actual nil blob", "blob", actualNil, nil}, + {"referenced nil blob", "blob", nil, nil}, + {"empty blob", "blob", emptyBytes, emptyBytes}, + {"actual nil text", "text", actualNil, nil}, + {"referenced nil text", "text", nil, nil}, + {"empty text", "text", emptyBytes, emptyBytes}, + } { + if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil { + t.Fatal(tst.name, err) + } + if bytes.Equal(tst.insertBytes, actualNil) { + if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil { + t.Fatal(tst.name, err) + } + } else { + if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil { + t.Fatal(tst.name, err) + } + } + rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti)) + if err != nil { + t.Fatal(tst.name, err) + } + if !rows.Next() { + t.Fatal(tst.name, "no rows") + } + var scanBytes []byte + if err = rows.Scan(&scanBytes); err != nil { + t.Fatal(tst.name, err) + } + if err = rows.Err(); err != nil { + t.Fatal(tst.name, err) + } + if tst.expectedBytes == nil && scanBytes != nil { + t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) + } else if !bytes.Equal(scanBytes, tst.expectedBytes) { + t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) + } + } +} + +func TestInsertNilByteSlice(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + if _, err := db.Exec("create table blob_not_null (b blob not null)"); err != nil { + t.Fatal(err) + } + var nilSlice []byte + if _, err := db.Exec("insert into blob_not_null (b) values (?)", nilSlice); err == nil { + t.Fatal("didn't expect INSERT to 'not null' column with a nil []byte slice to work") + } + zeroLenSlice := []byte{} + if _, err := db.Exec("insert into blob_not_null (b) values (?)", zeroLenSlice); err != nil { + t.Fatal("failed to insert zero-length slice") + } +} + +var customFunctionOnce sync.Once + +func BenchmarkCustomFunctions(b *testing.B) { + customFunctionOnce.Do(func() { + customAdd := func(a, b int64) int64 { + return a + b + } + + sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + // Impure function to force sqlite to reexecute it each time. + return conn.RegisterFunc("custom_add", customAdd, false) + }, + }) + }) + + db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:") + if err != nil { + b.Fatal("Failed to open database:", err) + } + defer db.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var i int64 + err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i) + if err != nil { + b.Fatal("Failed to run custom add:", err) + } + } +} + +func TestSuite(t *testing.T) { + tempFilename := TempFilename(t) + defer os.Remove(tempFilename) + d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") + if err != nil { + t.Fatal(err) + } + defer d.Close() + + db = &TestDB{t, d, SQLITE, sync.Once{}} + testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests) + + if !testing.Short() { + for _, b := range benchmarks { + fmt.Printf("%-20s", b.Name) + r := testing.Benchmark(b.F) + fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds()) + } + } + db.tearDown() +} + +// Dialect is a type of dialect of databases. +type Dialect int + +// Dialects for databases. +const ( + SQLITE Dialect = iota // SQLITE mean SQLite3 dialect + POSTGRESQL // POSTGRESQL mean PostgreSQL dialect + MYSQL // MYSQL mean MySQL dialect +) + +// DB provide context for the tests +type TestDB struct { + *testing.T + *sql.DB + dialect Dialect + once sync.Once +} + +var db *TestDB + +// the following tables will be created and dropped during the test +var testTables = []string{"foo", "bar", "t", "bench"} + +var tests = []testing.InternalTest{ + {Name: "TestResult", F: testResult}, + {Name: "TestBlobs", F: testBlobs}, + {Name: "TestMultiBlobs", F: testMultiBlobs}, + {Name: "TestManyQueryRow", F: testManyQueryRow}, + {Name: "TestTxQuery", F: testTxQuery}, + {Name: "TestPreparedStmt", F: testPreparedStmt}, +} + +var benchmarks = []testing.InternalBenchmark{ + {Name: "BenchmarkExec", F: benchmarkExec}, + {Name: "BenchmarkQuery", F: benchmarkQuery}, + {Name: "BenchmarkParams", F: benchmarkParams}, + {Name: "BenchmarkStmt", F: benchmarkStmt}, + {Name: "BenchmarkRows", F: benchmarkRows}, + {Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, +} + +func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result { + res, err := db.Exec(sql, args...) + if err != nil { + db.Fatalf("Error running %q: %v", sql, err) + } + return res +} + +func (db *TestDB) tearDown() { + for _, tbl := range testTables { + switch db.dialect { + case SQLITE: + db.mustExec("drop table if exists " + tbl) + case MYSQL, POSTGRESQL: + db.mustExec("drop table if exists " + tbl) + default: + db.Fatal("unknown dialect") + } + } +} + +// q replaces ? parameters if needed +func (db *TestDB) q(sql string) string { + switch db.dialect { + case POSTGRESQL: // replace with $1, $2, .. + qrx := regexp.MustCompile(`\?`) + n := 0 + return qrx.ReplaceAllStringFunc(sql, func(string) string { + n++ + return "$" + strconv.Itoa(n) + }) + } + return sql +} + +func (db *TestDB) blobType(size int) string { + switch db.dialect { + case SQLITE: + return fmt.Sprintf("blob[%d]", size) + case POSTGRESQL: + return "bytea" + case MYSQL: + return fmt.Sprintf("VARBINARY(%d)", size) + } + panic("unknown dialect") +} + +func (db *TestDB) serialPK() string { + switch db.dialect { + case SQLITE: + return "integer primary key autoincrement" + case POSTGRESQL: + return "serial primary key" + case MYSQL: + return "integer primary key auto_increment" + } + panic("unknown dialect") +} + +func (db *TestDB) now() string { + switch db.dialect { + case SQLITE: + return "datetime('now')" + case POSTGRESQL: + return "now()" + case MYSQL: + return "now()" + } + panic("unknown dialect") +} + +func makeBench() { + if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil { + panic(err) + } + st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)") + if err != nil { + panic(err) + } + defer st.Close() + for i := 0; i < 100; i++ { + if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil { + panic(err) + } + } +} + +// testResult is test for result +func testResult(t *testing.T) { + db.tearDown() + db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))") + + for i := 1; i < 3; i++ { + r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i)) + n, err := r.RowsAffected() + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("got %v, want %v", n, 1) + } + n, err = r.LastInsertId() + if err != nil { + t.Fatal(err) + } + if n != int64(i) { + t.Errorf("got %v, want %v", n, i) + } + } + if _, err := db.Exec("error!"); err == nil { + t.Fatalf("expected error") + } +} + +// testBlobs is test for blobs +func testBlobs(t *testing.T) { + db.tearDown() + var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")") + db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob) + + want := fmt.Sprintf("%x", blob) + + b := make([]byte, 16) + err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b) + got := fmt.Sprintf("%x", b) + if err != nil { + t.Errorf("[]byte scan: %v", err) + } else if got != want { + t.Errorf("for []byte, got %q; want %q", got, want) + } + + err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got) + want = string(blob) + if err != nil { + t.Errorf("string scan: %v", err) + } else if got != want { + t.Errorf("for string, got %q; want %q", got, want) + } +} + +func testMultiBlobs(t *testing.T) { + db.tearDown() + db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")") + var blob0 = []byte{0, 1, 2, 3, 4, 5, 6, 7} + db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob0) + var blob1 = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 1, blob1) + + r, err := db.Query(db.q("select bar from foo order by id")) + if err != nil { + t.Fatal(err) + } + defer r.Close() + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected one rows") + } + + want0 := fmt.Sprintf("%x", blob0) + b0 := make([]byte, 8) + err = r.Scan(&b0) + if err != nil { + t.Fatal(err) + } + got0 := fmt.Sprintf("%x", b0) + + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected one rows") + } + + want1 := fmt.Sprintf("%x", blob1) + b1 := make([]byte, 16) + err = r.Scan(&b1) + if err != nil { + t.Fatal(err) + } + got1 := fmt.Sprintf("%x", b1) + if got0 != want0 { + t.Errorf("for []byte, got %q; want %q", got0, want0) + } + if got1 != want1 { + t.Errorf("for []byte, got %q; want %q", got1, want1) + } +} + +// testManyQueryRow is test for many query row +func testManyQueryRow(t *testing.T) { + if testing.Short() { + t.Log("skipping in short mode") + return + } + db.tearDown() + db.mustExec("create table foo (id integer primary key, name varchar(50))") + db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") + var name string + for i := 0; i < 10000; i++ { + err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name) + if err != nil || name != "bob" { + t.Fatalf("on query %d: err=%v, name=%q", i, err, name) + } + } +} + +// testTxQuery is test for transactional query +func testTxQuery(t *testing.T) { + db.tearDown() + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))") + if err != nil { + t.Fatal(err) + } + + _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") + if err != nil { + t.Fatal(err) + } + + r, err := tx.Query(db.q("select name from foo where id = ?"), 1) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected one rows") + } + + var name string + err = r.Scan(&name) + if err != nil { + t.Fatal(err) + } +} + +// testPreparedStmt is test for prepared statement +func testPreparedStmt(t *testing.T) { + db.tearDown() + db.mustExec("CREATE TABLE t (count INT)") + sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC") + if err != nil { + t.Fatalf("prepare 1: %v", err) + } + ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)")) + if err != nil { + t.Fatalf("prepare 2: %v", err) + } + + for n := 1; n <= 3; n++ { + if _, err := ins.Exec(n); err != nil { + t.Fatalf("insert(%d) = %v", n, err) + } + } + + const nRuns = 10 + var wg sync.WaitGroup + for i := 0; i < nRuns; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + count := 0 + if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { + t.Errorf("Query: %v", err) + return + } + if _, err := ins.Exec(rand.Intn(100)); err != nil { + t.Errorf("Insert: %v", err) + return + } + } + }() + } + wg.Wait() +} + +// Benchmarks need to use panic() since b.Error errors are lost when +// running via testing.Benchmark() I would like to run these via go +// test -bench but calling Benchmark() from a benchmark test +// currently hangs go. + +// benchmarkExec is benchmark for exec +func benchmarkExec(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := db.Exec("select 1"); err != nil { + panic(err) + } + } +} + +// benchmarkQuery is benchmark for query +func benchmarkQuery(b *testing.B) { + for i := 0; i < b.N; i++ { + var n sql.NullString + var i int + var f float64 + var s string + // var t time.Time + if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { + panic(err) + } + } +} + +// benchmarkParams is benchmark for params +func benchmarkParams(b *testing.B) { + for i := 0; i < b.N; i++ { + var n sql.NullString + var i int + var f float64 + var s string + // var t time.Time + if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { + panic(err) + } + } +} + +// benchmarkStmt is benchmark for statement +func benchmarkStmt(b *testing.B) { + st, err := db.Prepare("select ?, ?, ?, ?") + if err != nil { + panic(err) + } + defer st.Close() + + for n := 0; n < b.N; n++ { + var n sql.NullString + var i int + var f float64 + var s string + // var t time.Time + if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { + panic(err) + } + } +} + +// benchmarkRows is benchmark for rows +func benchmarkRows(b *testing.B) { + db.once.Do(makeBench) + + for n := 0; n < b.N; n++ { + var n sql.NullString + var i int + var f float64 + var s string + var t time.Time + r, err := db.Query("select * from bench") + if err != nil { + panic(err) + } + for r.Next() { + if err = r.Scan(&n, &i, &f, &s, &t); err != nil { + panic(err) + } + } + if err = r.Err(); err != nil { + panic(err) + } + } +} + +// benchmarkStmtRows is benchmark for statement rows +func benchmarkStmtRows(b *testing.B) { + db.once.Do(makeBench) + + st, err := db.Prepare("select * from bench") + if err != nil { + panic(err) + } + defer st.Close() + + for n := 0; n < b.N; n++ { + var n sql.NullString + var i int + var f float64 + var s string + var t time.Time + r, err := st.Query() + if err != nil { + panic(err) + } + for r.Next() { + if err = r.Scan(&n, &i, &f, &s, &t); err != nil { + panic(err) + } + } + if err = r.Err(); err != nil { + panic(err) + } + } +} diff --git a/driver/error_test.go b/_driver/error_test.go similarity index 100% rename from driver/error_test.go rename to _driver/error_test.go diff --git a/driver/backup.go b/driver/backup.go index c6690b4..92c4535 100644 --- a/driver/backup.go +++ b/driver/backup.go @@ -80,7 +80,7 @@ func (b *SQLiteBackup) Close() error { b.b = nil runtime.SetFinalizer(b, nil) - if ret != 0 { + if ret != C.SQLITE_OK { return Error{Code: ErrNo(ret)} } return nil diff --git a/driver/backup_test.go b/driver/backup_test.go index 53b6976..abf32b6 100644 --- a/driver/backup_test.go +++ b/driver/backup_test.go @@ -4,6 +4,7 @@ // license that can be found in the LICENSE file. // +build cgo +// +build go1.8 package sqlite3 @@ -255,10 +256,10 @@ func TestBackupError(t *testing.T) { const driverName = "sqlite3_TestBackupError" // The driver's connection will be needed in order to perform the backup. - var dbDriverConn *SQLiteConn + var dbDriverConn []*SQLiteConn sql.Register(driverName, &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { - dbDriverConn = conn + dbDriverConn = append(dbDriverConn, conn) return nil }, }) @@ -266,21 +267,23 @@ func TestBackupError(t *testing.T) { // Connect to the database. dbTempFilename := TempFilename(t) defer os.Remove(dbTempFilename) - db, err := sql.Open(driverName, dbTempFilename) + srcDb, err := sql.Open(driverName, dbTempFilename) if err != nil { t.Fatal("Failed to open the database:", err) } - defer db.Close() - db.Ping() + defer srcDb.Close() + srcDb.Ping() + + srcDriverConn := dbDriverConn[0] // Need the driver connection in order to perform the backup. - if dbDriverConn == nil { + if srcDriverConn == nil { t.Fatal("Failed to get the driver connection.") } // Prepare to perform the backup. // Intentionally using the same connection for both the source and destination databases, to trigger an error result. - backup, err := dbDriverConn.Backup("main", dbDriverConn, "main") + backup, err := srcDriverConn.Backup("main", srcDriverConn, "main") if err == nil { t.Fatal("Failed to get the expected error result.") } @@ -291,4 +294,65 @@ func TestBackupError(t *testing.T) { if backup != nil { t.Fatal("Failed to get the expected nil backup result.") } + + // Generate some test data for the given ID. + var generateTestData = func(id int) string { + return fmt.Sprintf("test-%v", id) + } + + // Populate the source database with a test table containing some test data. + tx, err := srcDb.Begin() + if err != nil { + t.Fatal("Failed to begin a transaction when populating the source database:", err) + } + _, err = srcDb.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)") + if err != nil { + tx.Rollback() + t.Fatal("Failed to create the source database \"test\" table:", err) + } + for id := 0; id < testRowCount; id++ { + _, err = srcDb.Exec("INSERT INTO test (id, value) VALUES (?, ?)", id, generateTestData(id)) + if err != nil { + tx.Rollback() + t.Fatal("Failed to insert a row into the source database \"test\" table:", err) + } + } + err = tx.Commit() + if err != nil { + t.Fatal("Failed to populate the source database:", err) + } + + destTempFilename := TempFilename(t) + defer os.Remove(destTempFilename) + + // TODO: Rewrite for Golang:1.8 + // Current part of test uses golang:1.10 + destCfg := NewConfig() + destCfg.Database = destTempFilename + destDB := sql.OpenDB(destCfg) // Needs to be rewritten + destDB.Close() + + // Reconfigure to open READ-ONLY + destCfg.Mode = ModeReadOnly + var destConn *SQLiteConn + destCfg.ConnectHook = func(conn *SQLiteConn) error { + destConn = conn + return nil + } + + // OpenDB with Config + destDB = sql.OpenDB(destCfg) + destDB.Ping() + defer destDB.Close() + + backup, err = destConn.Backup("main", srcDriverConn, "main") + _, err = backup.Step(0) + if err == nil || err.(Error).Code != ErrReadonly { + t.Fatalf("Expected read-only error; received: (%d) %s", err.(Error).Code, err.(Error).Error()) + } + + err = backup.Close() + if err == nil || err.(Error).Code != ErrReadonly { + t.Fatalf("Expected read-only error; received: (%d) %s", err.(Error).Code, err.(Error).Error()) + } } diff --git a/driver/config.go b/driver/config.go index cf1a6cb..f08330a 100644 --- a/driver/config.go +++ b/driver/config.go @@ -13,33 +13,159 @@ package sqlite3 #else #include #endif + +#include +#include + +#ifndef SQLITE_OPEN_READWRITE +# define SQLITE_OPEN_READWRITE 0 +#endif + +#ifndef SQLITE_OPEN_FULLMUTEX +# define SQLITE_OPEN_FULLMUTEX 0 +#endif + +#ifndef SQLITE_DETERMINISTIC +# define SQLITE_DETERMINISTIC 0 +#endif + +static int +_sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) { +#ifdef SQLITE_OPEN_URI + return sqlite3_open_v2(filename, ppDb, flags | SQLITE_OPEN_URI, zVfs); +#else + return sqlite3_open_v2(filename, ppDb, flags, zVfs); +#endif +} */ import "C" import ( + "bytes" "database/sql/driver" + "errors" "fmt" "net/url" + "runtime" "strconv" "strings" "time" + "unsafe" ) +// Mode represents the open open for the database connection +type Mode C.int + +func (m Mode) String() string { + switch m { + case ModeReadOnly: + return "ro" + case ModeReadWrite: + return "rw" + case ModeReadWriteCreate: + return "rwc" + case ModeMemory: + return "memory" + default: + return "" + } +} + +// C returns the C.int of Mode +func (m Mode) C() C.int { + return C.int(m) +} + const ( - // SQLITE_OPEN_MUTEX_NO will force the database connection opens + // ModeReadOnly defines SQLITE_OPEN_READONLY for the database connection. + ModeReadOnly = Mode(C.SQLITE_OPEN_READONLY) + + // ModeReadWrite defines SQLITE_OPEN_READWRITE for the database connection. + ModeReadWrite = Mode(C.SQLITE_OPEN_READWRITE) + + // ModeReadWriteCreate defines SQLITE_OPEN_READWRITE and SQLITE_OPEN_CREATE. + ModeReadWriteCreate = Mode(C.SQLITE_OPEN_READWRITE | C.SQLITE_OPEN_CREATE) + + // ModeMemory defines mode=memory which will + // create a pure in-memory database that never reads or writes from disk + ModeMemory = Mode(C.SQLITE_OPEN_MEMORY) +) + +// CacheMode represents the current CacheMode +type CacheMode C.int + +func (c CacheMode) String() string { + switch c { + case CacheModeShared: + return "shared" + case CacheModePrivate: + return "private" + default: + return "" + } +} + +// C returns the C.int of CacheMode +func (c CacheMode) C() C.int { + return C.int(c) +} + +const ( + // CacheModeShared sets the cache mode of SQLite to 'shared' + CacheModeShared = CacheMode(C.SQLITE_OPEN_SHAREDCACHE) + + // CacheModePrivate sets the cache mode of SQLite to 'private' + CacheModePrivate = CacheMode(C.SQLITE_OPEN_PRIVATECACHE) +) + +// Mutex represents how the database opens connections within +// single or multi-threading +type Mutex C.int + +func (m Mutex) String() string { + switch m { + case MutexNo: + return "no" + case MutexFull: + return "full" + default: + return "" + } +} + +// C returns the C.int of Mutex +func (m Mutex) C() C.int { + return C.int(m) +} + +const ( + // MutexNo will force the database connection opens // in the multi-thread threading mode as long as the // single-thread mode has not been set at compile-time or start-time. - SQLITE_OPEN_MUTEX_NO = C.SQLITE_OPEN_NOMUTEX + MutexNo = Mutex(C.SQLITE_OPEN_NOMUTEX) - // SQLITE_OPEN_MUTEX_FULL will force the database connection opens + // MutexFull will force the database connection opens // in the serialized threading mode unless single-thread // was previously selected at compile-time or start-time. - SQLITE_OPEN_MUTEX_FULL = C.SQLITE_OPEN_FULLMUTEX + MutexFull = Mutex(C.SQLITE_OPEN_FULLMUTEX) ) // TxLock defines the Transaction Lock Behaviour. type TxLock string func (tx TxLock) String() string { + switch tx { + case TxLockDeferred: + return "deferred" + case TxLockImmediate: + return "immediate" + case TxLockExclusive: + return "exclusive" + default: + return "" + } +} + +func (tx TxLock) Value() string { return string(tx) } @@ -116,21 +242,25 @@ const ( // If the locking mode is NORMAL when first entering WAL journal mode, //then the locking mode can be changed between NORMAL and EXCLUSIVE // and back again at any time and without needing to exit WAL journal mode. -type LockingMode uint8 +type LockingMode string + +func (l LockingMode) String() string { + return strings.ToLower(string(l)) +} const ( // LockingModeNormal In NORMAL locking-mode // (the default unless overridden at compile-time using SQLITE_DEFAULT_LOCKING_MODE), // a database connection unlocks the database file at the conclusion // of each read or write transaction. - LockingModeNormal LockingMode = iota + LockingModeNormal = LockingMode("NORMAL") // LockingModeExclusive When the locking-mode is set to EXCLUSIVE, // the database connection never releases file-locks. // The first time the database is read in EXCLUSIVE mode, // a shared lock is obtained and held. // The first time the database is written, an exclusive lock is obtained and held. - LockingModeExclusive + LockingModeExclusive = LockingMode("EXCLUSIVE") ) // AutoVacuum defines the auto vacuum status of the database. @@ -142,7 +272,11 @@ const ( // that allows each database page to be traced backwards to its referrer. // Therefore, auto-vacuuming must be turned on before any tables are created. // It is not possible to enable or disable auto-vacuum after a table has been created. -type AutoVacuum uint8 +type AutoVacuum string + +func (av AutoVacuum) String() string { + return strings.ToLower(string(av)) +} const ( // AutoVacuumNone setting means that auto-vacuum is disabled. @@ -164,7 +298,7 @@ const ( // then invoke the VACUUM command to reorganize the entire database file. // To change from "full" or "incremental" back to "none" // always requires running VACUUM even on an empty database. - AutoVacuumNone = AutoVacuum(0) + AutoVacuumNone = AutoVacuum("NONE") // AutoVacuumFull sets auto vacuum of the database to FULL. // @@ -176,7 +310,7 @@ const ( // the way that the VACUUM command does. // In fact, because it moves pages around within the file, // auto-vacuum can actually make fragmentation worse. - AutoVacuumFull = AutoVacuum(1) + AutoVacuumFull = AutoVacuum("FULL") // AutoVacuumIncremental sets the auto vacuum of the database to INCREMENTAL. // @@ -186,7 +320,7 @@ const ( // at each commit as it does with auto_vacuum=full. // In incremental mode, the separate incremental_vacuum pragma must be invoked // to cause the auto-vacuum to occur. - AutoVacuumIncremental = AutoVacuum(2) + AutoVacuumIncremental = AutoVacuum("INCREMENTAL") ) // JournalMode defines the journal mode associated with the current database connection. @@ -198,19 +332,23 @@ const ( // Note also that the journal_mode cannot be changed while a transaction is active. type JournalMode string +func (j JournalMode) String() string { + return strings.ToLower(string(j)) +} + const ( // JournalModeDelete is the normal behavior. // In the DELETE mode, the rollback journal is deleted at the conclusion // of each transaction. // Indeed, the delete operation is the action that causes the transaction to commit. // (See the document titled Atomic Commit In SQLite for additional detail.) - JournalModeDelete JournalMode = "DELETE" + JournalModeDelete = JournalMode("DELETE") // JournalModeTruncate commits transactions by truncating the rollback journal // to zero-length instead of deleting it. // On many systems, truncating a file is much faster // than deleting the file since the containing directory does not need to be changed. - JournalModeTruncate JournalMode = "TRUNCATE" + JournalModeTruncate = JournalMode("TRUNCATE") // JournalModePersist prevents the rollback journal from being deleted // at the end of each transaction. @@ -220,26 +358,26 @@ const ( // where deleting or truncating a file is much more expensive // than overwriting the first block of a file with zeros. // See also: PRAGMA journal_size_limit and SQLITE_DEFAULT_JOURNAL_SIZE_LIMIT. - JournalModePersist JournalMode = "PERSIST" + JournalModePersist = JournalMode("PERSIST") // JournalModeMemory stores the rollback journal in volatile RAM. // This saves disk I/O but at the expense of database safety and integrity. // If the application using SQLite crashes in the middle of a transaction // when the MEMORY journaling mode is set, // then the database file will very likely go corrupt. - JournalModeMemory JournalMode = "MEMORY" + JournalModeMemory = JournalMode("MEMORY") // JournalModeWAL uses a write-ahead log instead of a rollback journal // to implement transactions. // The WAL journaling mode is persistent; // after being set it stays in effect across multiple database connections // and after closing and reopening the database. - JournalModeWAL JournalMode = "WAL" + JournalModeWAL = JournalMode("WAL") - // JournalModeDisabled disables the rollback journal completely. + // JournalModeOff disables the rollback journal completely. // No rollback journal is ever created and hence there is never a rollback journal to delete. // The OFF journaling mode disables the atomic commit and rollback capabilities of SQLite. The ROLLBACK command no longer works; it behaves in an undefined way. Applications must avoid using the ROLLBACK command when the journal mode is OFF. If the application crashes in the middle of a transaction when the OFF journaling mode is set, then the database file will very likely go corrupt. - JournalModeDisabled JournalMode = "OFF" + JournalModeOff = JournalMode("OFF") ) // SecureDelete defines the secure-delete setting. @@ -254,6 +392,10 @@ const ( // or else run VACUUM after the delete or update. type SecureDelete string +func (sd SecureDelete) String() string { + return strings.ToLower(string(sd)) +} + const ( // SecureDeleteOff disables secure deletion of content. SecureDeleteOff = SecureDelete("OFF") @@ -272,7 +414,11 @@ const ( ) // Synchronous sync setting of the database connection. -type Synchronous uint8 +type Synchronous string + +func (s Synchronous) String() string { + return strings.ToLower(string(s)) +} const ( // SynchronousOff sets synchronous to OFF (0), @@ -281,7 +427,7 @@ const ( // but the database might become corrupted if the operating system crashes // or the computer loses power before that data has been written to the disk surface. // On the other hand, commits can be orders of magnitude faster with synchronous OFF. - SynchronousOff = Synchronous(0) + SynchronousOff = Synchronous("OFF") // SynchronousNormal sets synchronous to NORMAL (1), // the SQLite database engine will still sync at the most critical moments, @@ -297,7 +443,7 @@ const ( // Transactions are durable across application crashes regardless // of the synchronous setting or journal mode. // The synchronous=NORMAL setting is a good choice for most applications running in WAL mode. - SynchronousNormal = Synchronous(1) + SynchronousNormal = Synchronous("NORMAL") // SynchronousFull sets synchronous to FULL (2), // the SQLite database engine will use the xSync method of the VFS @@ -306,42 +452,13 @@ const ( // will not corrupt the database. FULL synchronous is very safe, // but it is also slower. ///FULL is the most commonly used synchronous setting when not in WAL mode. - SynchronousFull = Synchronous(2) + SynchronousFull = Synchronous("FULL") // SynchronousExtra is like FULL with the addition that the directory containing // a rollback journal is synced after that journal is unlinked to commit a transaction // in DELETE mode. EXTRA provides additional durability if the commit // is followed closely by a power loss. - SynchronousExtra = Synchronous(3) -) - -// CacheMode defines the shared-cache mode of SQLite. -type CacheMode string - -const ( - // CacheModeShared sets the cache mode of SQLite to 'shared' - CacheModeShared = CacheMode("SHARED") - - // CacheModePrivate sets the cache mode of SQLite to 'private' - CacheModePrivate = CacheMode("PRIVATE") -) - -// Mode defines the open mode of the SQLite database. -type Mode string - -const ( - // ModeReadOnly defines SQLITE_OPEN_READONLY for the database connection. - ModeReadOnly = Mode("RO") - - // ModeReadWrite defines SQLITE_OPEN_READWRITE for the database connection. - ModeReadWrite = Mode("RW") - - // ModeReadWriteCreate defines SQLITE_OPEN_READWRITE and SQLITE_OPEN_CREATE. - ModeReadWriteCreate = Mode("RWC") - - // ModeMemory defines mode=memory which will - // create a pure in-memory database that never reads or writes from disk - ModeMemory = Mode("MEMORY") + SynchronousExtra = Synchronous("EXTRA") ) // Config is configuration parsed from a DSN string. @@ -349,12 +466,19 @@ const ( // the NewConfig function should be used, which sets default values. // Manual usage is allowed type Config struct { + // Database + Database string + // Mode of the SQLite database Mode Mode // CacheMode of the SQLite Connection Cache CacheMode + // Mutex flag SQLITE_OPEN_MUTEX_NO, SQLITE_OPEN_MUTEX_FULL + // Defaults to SQLITE_OPEN_MUTEX_FULL + Mutex Mutex + // The immutable parameter is a boolean query parameter that indicates // that the database file is stored on read-only media. When immutable is set, // SQLite assumes that the database file cannot be changed, @@ -364,10 +488,6 @@ type Config struct { // does in fact change can result in incorrect query results and/or SQLITE_CORRUPT errors. Immutable bool - // Mutex flag SQLITE_OPEN_MUTEX_NO, SQLITE_OPEN_MUTEX_FULL - // Defaults to SQLITE_OPEN_MUTEX_FULL - Mutex int - // TimeZone location TimeZone *time.Location @@ -430,6 +550,12 @@ type Config struct { // WriteableSchema enables of disables the ability to using UPDATE, INSERT, DELETE // Warning: misuse of this pragma can easily result in a corrupt database file. WriteableSchema bool + + // Extensions + Extensions []string + + // ConnectHook + ConnectHook func(*SQLiteConn) error } // Auth holds the authentication configuration for the SQLite UserAuth module. @@ -450,10 +576,15 @@ type Auth struct { // NewConfig creates a new Config and sets default values. func NewConfig() *Config { return &Config{ + // This is the behavior that is always used + // for sqlite3_open() and sqlite3_open16(). + // This is way it is set as default. + Mode: ModeReadWriteCreate, + + Database: ":memory:", Cache: CacheModePrivate, Immutable: false, - Authentication: &Auth{}, - Mutex: SQLITE_OPEN_MUTEX_FULL, + Mutex: MutexFull, TransactionLock: TxLockDeferred, LockingMode: LockingModeNormal, AutoVacuum: AutoVacuumNone, @@ -467,57 +598,386 @@ func NewConfig() *Config { SecureDelete: SecureDeleteOff, Synchronous: SynchronousNormal, WriteableSchema: false, + Authentication: &Auth{ + Encoder: NewSHA1Encoder(), + }, } } // FormatDSN formats the given Config into a DSN string which can be passed to // the driver. func (cfg *Config) FormatDSN() string { - // TODO: FormatDSN - return "" + var buf bytes.Buffer + + params := url.Values{} + params.Set("mode", cfg.Mode.String()) + params.Set("cache", cfg.Cache.String()) + params.Set("mutex", cfg.Mutex.String()) + + if cfg.Immutable { + params.Set("immutable", "true") + } + + if cfg.TimeZone != nil { + if cfg.TimeZone == time.Local { + params.Set("tz", "auto") + } else { + params.Set("tz", cfg.TimeZone.String()) + } + } + + if cfg.TransactionLock != TxLockDeferred { + params.Set("txlock", cfg.TransactionLock.String()) + } + + if cfg.LockingMode != LockingModeNormal { + params.Set("lock", cfg.LockingMode.String()) + } + + if cfg.AutoVacuum != AutoVacuumNone { + params.Set("vacuum", cfg.AutoVacuum.String()) + } + + if cfg.CaseSensitiveLike { + params.Set("cslike", "true") + } + + if cfg.DeferForeignKeys { + params.Set("defer_fk", "true") + } + + if cfg.ForeignKeyConstraints { + params.Set("fk", "true") + } + + if cfg.IgnoreCheckConstraints { + params.Set("ignore_check_contraints", "true") + } + + if cfg.JournalMode != JournalModeDelete { + params.Set("journal", cfg.JournalMode.String()) + } + + if cfg.QueryOnly { + params.Set("query_only", "true") + } + + if cfg.RecursiveTriggers { + params.Set("recursive_triggers", "true") + } + + if cfg.SecureDelete != SecureDeleteOff { + params.Set("secure_delete", cfg.SecureDelete.String()) + } + + if cfg.Synchronous != SynchronousNormal { + params.Set("syn", cfg.Synchronous.String()) + } + + if cfg.WriteableSchema { + params.Set("writable_schema", "true") + } + + if len(cfg.Authentication.Username) > 0 && len(cfg.Authentication.Password) > 0 { + params.Set("user", cfg.Authentication.Username) + params.Set("pass", cfg.Authentication.Password) + + if len(cfg.Authentication.Salt) > 0 { + params.Set("salt", cfg.Authentication.Salt) + } + + if cfg.Authentication.Encoder != nil { + params.Set("crypt", cfg.Authentication.Encoder.String()) + } + } + + if !strings.HasPrefix(cfg.Database, "file:") { + buf.WriteString("file:") + } + buf.WriteString(cfg.Database) + + // Append Options + buf.WriteRune('?') + buf.WriteString(params.Encode()) + + return buf.String() } // Create connection from Configuration func (cfg *Config) createConnection() (driver.Conn, error) { - //var db *C.sqlite3 + if C.sqlite3_threadsafe() == 0 { + return nil, errors.New("sqlite library was not compiled for thread-safe operation") + } + if len(cfg.Database) == 0 { + return nil, fmt.Errorf("No database configured") + } - // name := C.CString(dsn) - // defer C.free(unsafe.Pointer(name)) - // rv := C._sqlite3_open_v2(name, &db, - // mutex|C.SQLITE_OPEN_READWRITE|C.SQLITE_OPEN_CREATE, - // nil) - // if rv != 0 { - // return nil, Error{Code: ErrNo(rv)} - // } - // if db == nil { - // return nil, errors.New("sqlite succeeded without returning a database") - // } + var db *C.sqlite3 - // rv = C.sqlite3_busy_timeout(db, C.int(busyTimeout)) - // if rv != C.SQLITE_OK { - // C.sqlite3_close_v2(db) - // return nil, Error{Code: ErrNo(rv)} - // } + // Configure Database URI + // Because we are adding the 'immutable' flag to the database file + // We are required to conform the database path to an URI + // The immutable flag is an query parameter which means that the URI needs + // to start with 'file:'. Regardless if it is an in-memory database or not. + uri := cfg.Database + if !strings.HasPrefix(uri, "file:") { + uri = fmt.Sprintf("file:%s", uri) + } + if !strings.Contains(uri, ":memory:") { + uri = fmt.Sprintf("%s?immutable=%t", uri, cfg.Immutable) + } + database := C.CString(uri) - // exec := func(s string) error { - // cs := C.CString(s) - // rv := C.sqlite3_exec(db, cs, nil, nil, nil) - // C.free(unsafe.Pointer(cs)) - // if rv != C.SQLITE_OK { - // fmt.Printf("-Open-Exec() %d\n", rv) - // return lastError(db) - // } - // return nil - // } + // Free CString on return + defer C.free(unsafe.Pointer(database)) - // &SQLiteConn{ - // db: db, - // tz: cfg.TimeZone, - // txlock: cfg.TransactionLock.String(), - // } + // Open the database + // https://www.sqlite.org/c3ref/open.html + rv := C._sqlite3_open_v2( + database, + &db, + cfg.Mutex.C()|cfg.Cache.C()|cfg.Mode.C(), + nil) - return nil, nil + // Check if the database was opened succesful. + if rv != C.SQLITE_OK { + fmt.Println(Error{Code: ErrNo(rv)}) + return nil, Error{Code: ErrNo(rv)} + } + + // Verify we have a database pointer + if db == nil { + return nil, errors.New("sqlite succeeded without returning a database") + } + + // Set SQLITE Busy Timeout Handler + rv = C.sqlite3_busy_timeout(db, C.int(cfg.BusyTimeout)) + if rv != C.SQLITE_OK { + // Failed to set busy timeout + // close the database and return the error + C.sqlite3_close_v2(db) + + return nil, Error{Code: ErrNo(rv)} + } + + // Create basic connection + conn := &SQLiteConn{ + db: db, + tz: cfg.TimeZone, + txlock: cfg.TransactionLock.Value(), + } + + // At this point we have the following + // - database pointer + // - basic connection + // + // Now we need to configure the connection according to the *Config + + // USER AUTHENTICATION + // + // User Authentication is always performed even when + // sqlite_userauth is not compiled in, because without user authentication + // the authentication is a no-op. + // + // Workflow + // - Authenticate + // ON::SUCCESS => Continue + // ON::SQLITE_AUTH => Return error and exit Open(...) + // + // - Activate User Authentication + // Check if the user wants to activate User Authentication. + // If so then first create a temporary AuthConn to the database + // This is possible because we are already succesfully authenticated. + // + // - Check if `sqlite_user`` table exists + // YES => Add the provided user from DSN as Admin User and + // activate user authentication. + // NO => Continue + // + // + // Because we need to perform authentication we need to register + // the required functions on the connection. + + // Register sqlite_crypt function with the CryptEncoder provided + // within *Config.Authentication + if err := conn.RegisterFunc("sqlite_crypt", cfg.Authentication.Encoder.Encode, true); err != nil { + return nil, fmt.Errorf("CryptEncoderSHA1: %s", err) + } + + // Register: authenticate + // Authenticate will perform an authentication of the provided username + // and password against the database. + // + // If a database contains the SQLITE_USER table, then the + // call to Authenticate must be invoked with an + // appropriate username and password prior to enable read and write + // access to the database. + // + // Return SQLITE_OK on success or SQLITE_ERROR if the username/password + // combination is incorrect or unknown. + // + // If the SQLITE_USER table is not present in the database file, then + // this interface is a harmless no-op returnning SQLITE_OK. + if err := conn.RegisterFunc("authenticate", conn.authenticate, true); err != nil { + return nil, err + } + + // Register: auth_user_add + // auth_user_add can be used (by an admin user only) + // to create a new user. When called on a no-authentication-required + // database, this routine converts the database into an authentication- + // required database, automatically makes the added user an + // administrator, and logs in the current connection as that user. + // The AuthUserAdd only works for the "main" database, not + // for any ATTACH-ed databases. Any call to AuthUserAdd by a + // non-admin user results in an error. + if err := conn.RegisterFunc("auth_user_add", conn.authUserAdd, true); err != nil { + return nil, err + } + + // Register: auth_user_change + // auth_user_change can be used to change a users + // login credentials or admin privilege. Any user can change their own + // login credentials. Only an admin user can change another users login + // credentials or admin privilege setting. No user may change their own + // admin privilege setting. + if err := conn.RegisterFunc("auth_user_change", conn.authUserChange, true); err != nil { + return nil, err + } + + // Register: auth_user_delete + // auth_user_delete can be used (by an admin user only) + // to delete a user. The currently logged-in user cannot be deleted, + // which guarantees that there is always an admin user and hence that + // the database cannot be converted into a no-authentication-required + // database. + if err := conn.RegisterFunc("auth_user_delete", conn.authUserDelete, true); err != nil { + return nil, err + } + + // Register: auth_enabled + // auth_enabled can be used to check if user authentication is enabled + if err := conn.RegisterFunc("auth_enabled", conn.authEnabled, true); err != nil { + return nil, err + } + + // Preform Authentication only if username and password are provided + // If authentication is not enabled on the database + // this call is a NO-OP. + // Only call this when Username and Password are provided in the *Config + if len(cfg.Authentication.Username) > 0 && len(cfg.Authentication.Password) > 0 { + if err := conn.Authenticate(cfg.Authentication.Username, cfg.Authentication.Password); err != nil { + return nil, err + } + } + + // AUTO VACUUM + // The user preference for auto_vacuum needs to be implemented directly after + // the authentication and before the sqlite_user table gets created if the user + // decides to activate User Authentication because + // auto_vacuum needs to be set before any tables are created + // and activating user authentication creates the internal table `sqlite_user`. + if err := conn.PRAGMA(PRAGMA_AUTO_VACUUM, cfg.AutoVacuum.String()); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Check if authentication is enabled + // This can now be succesfully checked because we are successfully connected. + // Only issue this when Username and Password are configured. + // If this is an unauthenticated database + // the provided user will be created as Admin. + if len(cfg.Authentication.Username) > 0 && len(cfg.Authentication.Password) > 0 { + authExists := conn.AuthEnabled() + if !authExists { + if err := conn.AuthUserAdd(cfg.Authentication.Username, cfg.Authentication.Password, true); err != nil { + return nil, err + } + } + } + + // Case Sensitive LIKE + if err := conn.PRAGMA(PRAGMA_CASE_SENSITIVE_LIKE, strconv.FormatBool(cfg.CaseSensitiveLike)); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Defer Foreign Keys + if err := conn.PRAGMA(PRAGMA_DEFER_FOREIGN_KEYS, strconv.FormatBool(cfg.DeferForeignKeys)); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Ignore CHECK constraints + if err := conn.PRAGMA(PRAGMA_IGNORE_CHECK_CONTRAINTS, strconv.FormatBool(cfg.IgnoreCheckConstraints)); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Journal Mode + if err := conn.PRAGMA(PRAGMA_JOURNAL_MODE, cfg.JournalMode.String()); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Locking Mode + if err := conn.PRAGMA(PRAGMA_LOCKING_MODE, cfg.LockingMode.String()); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Query Only + if err := conn.PRAGMA(PRAGMA_QUERY_ONLY, strconv.FormatBool(cfg.QueryOnly)); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Recursive Triggers + if err := conn.PRAGMA(PRAGMA_RECURSIVE_TRIGGERS, strconv.FormatBool(cfg.RecursiveTriggers)); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Secure Delete + if err := conn.PRAGMA(PRAGMA_SECURE_DELETE, cfg.SecureDelete.String()); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Synchronous + if err := conn.PRAGMA(PRAGMA_SYNCHRONOUS, cfg.SecureDelete.String()); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Writable Schema + if err := conn.PRAGMA(PRAGMA_WRITABLE_SCHEMA, strconv.FormatBool(cfg.WriteableSchema)); err != nil { + C.sqlite3_close_v2(db) + return nil, err + } + + // Load Extensions + if len(cfg.Extensions) > 0 { + if err := conn.loadExtensions(cfg.Extensions); err != nil { + //fmt.Println("Error while loading Extensions") + conn.Close() + return nil, err + } + } + + // Configure Connect Hooks + if cfg.ConnectHook != nil { + if err := cfg.ConnectHook(conn); err != nil { + conn.Close() + return nil, err + } + } + + // Configure Finalizer + runtime.SetFinalizer(conn, (*SQLiteConn).Close) + + return conn, nil } // ParseDSN parses the DSN string to a Config @@ -525,8 +985,13 @@ func ParseDSN(dsn string) (cfg *Config, err error) { // New default with default values cfg = NewConfig() + cfg.Database = dsn + pos := strings.IndexRune(dsn, '?') if pos >= 1 { + // Update DatabaseURI + cfg.Database = dsn[0:pos] + // Parse Options params, err := url.ParseQuery(dsn[pos+1:]) if err != nil { @@ -592,13 +1057,32 @@ func ParseDSN(dsn string) (cfg *Config, err error) { if k == "mode" { val := params.Get(k) switch strings.ToUpper(val) { - case "RO", "RW", "RWC", "MEMORY": - cfg.Mode = Mode(strings.ToUpper(val)) + case "RO": + cfg.Mode = ModeReadOnly + case "RW": + cfg.Mode = ModeReadWrite + case "RWC": + cfg.Mode = ModeReadWriteCreate + case "MEMORY": + cfg.Mode = ModeMemory default: return nil, fmt.Errorf("Unknown mode: %v, expecting value of 'ro, rw, rwc, memory'", val) } } + // Mutex + if k == "mutex" { + val := params.Get(k) + switch strings.ToLower(val) { + case "no": + cfg.Mutex = MutexNo + case "full": + cfg.Mutex = MutexFull + default: + return nil, fmt.Errorf("Invalid mutex: %v, expecting value of 'no, full", val) + } + } + // Timezone if k == "tz" || k == "timezone" || k == "loc" { val := params.Get(k) @@ -613,19 +1097,6 @@ func ParseDSN(dsn string) (cfg *Config, err error) { } } - // Mutex - if k == "mutex" { - val := params.Get(k) - switch strings.ToLower(val) { - case "no": - cfg.Mutex = SQLITE_OPEN_MUTEX_NO - case "full": - cfg.Mutex = SQLITE_OPEN_MUTEX_FULL - default: - return nil, fmt.Errorf("Invalid mutex: %v, expecting value of 'no, full", val) - } - } - // Transaction Lock if k == "txlock" || k == "transaction_lock" { val := params.Get(k) diff --git a/driver/config_test.go b/driver/config_test.go new file mode 100644 index 0000000..976c2dc --- /dev/null +++ b/driver/config_test.go @@ -0,0 +1,627 @@ +// Copyright (C) 2018 The Go-SQLite3 Authors. +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// +build cgo + +package sqlite3 + +import ( + "reflect" + "testing" + "time" +) + +func TestConfig(t *testing.T) { + +} + +func TestParseDSN(t *testing.T) { + // URI + uriCases := map[string]*Config{ + "file:test.db": &Config{ + Database: "file:test.db", + }, + "file::memory:": &Config{ + Database: "file::memory:", + }, + "test.db": &Config{ + Database: "test.db", + }, + ":memory:": &Config{ + Database: ":memory:", + }, + "test.db?%35%2%%43?test=false": nil, + } + + for dsn, c := range uriCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal(err) + } + if c != nil { + if cfg.Database != c.Database { + t.Fatalf("Failed to parse database uri; expected: %s, got: %s", c.Database, cfg.Database) + } + } + } + + // Auth + authCases := map[string]*Config{ + "test.db?user=admin&pass=admin&salt=test": &Config{ + Authentication: &Auth{ + Username: "admin", + Password: "admin", + Salt: "test", + }, + }, + } + + for dsn, c := range authCases { + cfg, err := ParseDSN(dsn) + if err != nil { + t.Fatal(err) + } + if cfg.Authentication.Username != c.Authentication.Username { + t.Fatalf("Failed to parse 'user'; expected: %s, got %s", c.Authentication.Username, cfg.Authentication.Username) + } + if cfg.Authentication.Password != c.Authentication.Password { + t.Fatalf("Failed to parse 'pass'; expected: %s, got %s", c.Authentication.Password, cfg.Authentication.Password) + } + if cfg.Authentication.Salt != c.Authentication.Salt { + t.Fatalf("Failed to parse 'salt'; expected: %s, got %s", c.Authentication.Salt, cfg.Authentication.Salt) + } + } + + // Crypt + cryptCases := map[string]*Config{ + "test.db?crypt=auto": nil, + "test.db?crypt=sha1": &Config{ + Authentication: &Auth{ + Encoder: NewSHA1Encoder(), + }, + }, + "test.db?crypt=ssha1": nil, + "test.db?crypt=ssha1&salt=salt": &Config{ + Authentication: &Auth{ + Salt: "salt", + Encoder: NewSSHA1Encoder("salt"), + }, + }, + "test.db?crypt=sha256": &Config{ + Authentication: &Auth{ + Encoder: NewSHA256Encoder(), + }, + }, + "test.db?crypt=ssha256": nil, + "test.db?crypt=ssha256&salt=salt": &Config{ + Authentication: &Auth{ + Salt: "salt", + Encoder: NewSSHA256Encoder("salt"), + }, + }, + "test.db?crypt=sha384": &Config{ + Authentication: &Auth{ + Encoder: NewSHA384Encoder(), + }, + }, + "test.db?crypt=ssha384": nil, + "test.db?crypt=ssha384&salt=salt": &Config{ + Authentication: &Auth{ + Salt: "salt", + Encoder: NewSSHA384Encoder("salt"), + }, + }, + "test.db?crypt=sha512": &Config{ + Authentication: &Auth{ + Encoder: NewSHA512Encoder(), + }, + }, + "test.db?crypt=ssha512": nil, + "test.db?crypt=ssha512&salt=salt": &Config{ + Authentication: &Auth{ + Salt: "salt", + Encoder: NewSSHA512Encoder("salt"), + }, + }, + } + + for dsn, c := range cryptCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'crypt'") + } + if c != nil { + if reflect.TypeOf(cfg.Authentication.Encoder).String() != reflect.TypeOf(c.Authentication.Encoder).String() { + t.Fatal("Failed to parse 'crypt'") + } + if len(cfg.Authentication.Salt) > 0 { + if cfg.Authentication.Salt != c.Authentication.Salt { + t.Fatal("Failed to parse: 'salt'") + } + } + } + } + + // Cache + cacheCases := map[string]*Config{ + "test.db?cache=shared": &Config{ + Cache: CacheModeShared, + }, + "test.db?cache=private": &Config{ + Cache: CacheModePrivate, + }, + "test.db?cache=bogus": nil, + } + + for dsn, c := range cacheCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'cache'") + } + if c != nil { + if cfg.Cache != c.Cache { + t.Fatalf("Failed to parse 'cache'; expected: %d, got: %d", c.Cache, cfg.Cache) + } + } + } + + // Immutable + immutableCases := map[string]*Config{ + "test.db?immutable=false": &Config{ + Immutable: false, + }, + "test.db?immutable=true": &Config{ + Immutable: true, + }, + "test.db?immutable=active": nil, + } + + for dsn, c := range immutableCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'immutable'") + } + if c != nil { + if cfg.Immutable != c.Immutable { + t.Fatalf("Failed to parse 'immutable'; expected: %t, got: %t", c.Immutable, cfg.Immutable) + } + } + } + + // Mode + modeCases := map[string]*Config{ + "test.db?mode=ro": &Config{ + Mode: ModeReadOnly, + }, + "test.db?mode=rw": &Config{ + Mode: ModeReadWrite, + }, + "test.db?mode=rwc": &Config{ + Mode: ModeReadWriteCreate, + }, + "test.db?mode=memory": &Config{ + Mode: ModeMemory, + }, + "test.db?mode=full": nil, + } + + for dsn, c := range modeCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'mode'") + } + if c != nil { + if cfg.Mode != c.Mode { + t.Fatalf("Failed to parse 'mode'; expected: %d, got: %d", c.Mode, cfg.Mode) + } + } + } + + // Mutex + mutexCases := map[string]*Config{ + "test.db?mutex=no": &Config{ + Mutex: MutexNo, + }, + "test.db?mutex=full": &Config{ + Mutex: MutexFull, + }, + "test.db?mutex=bogus": nil, + } + + for dsn, c := range mutexCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal(err) + } + if c != nil { + if cfg.Mutex != c.Mutex { + t.Fatalf("Failed to parse 'mutex'; expected: %d, got: %d", c.Mutex, cfg.Mutex) + } + } + } + + // Timezone + ams, _ := time.LoadLocation("Europe/Amsterdam") + tzCases := map[string]*Config{ + "test.db?tz=auto": &Config{ + TimeZone: time.Local, + }, + "test.db?tz=Europe/Amsterdam": &Config{ + TimeZone: ams, + }, + "test.db?tz=Atlantis": nil, + } + + for dsn, c := range tzCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal(err) + } + if c != nil { + if cfg.TimeZone.String() != c.TimeZone.String() { + t.Fatal("Failed to parse timezone") + } + } + } + + // Transaction Lock + txLockCases := map[string]*Config{ + "test.db?txlock=deferred": &Config{ + TransactionLock: TxLockDeferred, + }, + "test.db?txlock=immediate": &Config{ + TransactionLock: TxLockImmediate, + }, + ":memory:?txlock=exclusive": &Config{ + TransactionLock: TxLockExclusive, + }, + "test.db?txlock=bogus": nil, + } + + for dsn, c := range txLockCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal(err) + } + if c != nil { + if cfg.TransactionLock != c.TransactionLock { + t.Fatalf("Failed to parse txlock; expected: %s, got: %s", c.TransactionLock, cfg.Database) + } + } + } + + // Auto Vacuum + vacuumCases := map[string]*Config{ + "test.db?vacuum=none": &Config{ + AutoVacuum: AutoVacuumNone, + }, + "test.db?vacuum=full": &Config{ + AutoVacuum: AutoVacuumFull, + }, + "test.db?vacuum=incremental": &Config{ + AutoVacuum: AutoVacuumIncremental, + }, + "test.db?vacuum=bogus": nil, + } + + for dsn, c := range vacuumCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'autovacuum, vacuum'") + } + if c != nil { + if cfg.AutoVacuum != c.AutoVacuum { + t.Fatalf("Failed to parse 'autovacuum'; expected: %s, got: %s", c.AutoVacuum, cfg.AutoVacuum) + } + } + } + + // Busy Timeout + timeoutCases := map[string]*Config{ + "test.db?timeout=5000": &Config{ + BusyTimeout: 5000 * time.Millisecond, + }, + "test.db?timeout=never": nil, + } + + for dsn, c := range timeoutCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'timeout'") + } + if c != nil { + if cfg.BusyTimeout != c.BusyTimeout { + t.Fatalf("Failed to parse 'timeout'; expected: %d, got: %d", c.BusyTimeout, cfg.BusyTimeout) + } + } + } + + // Case sensitive LIKE + cslikeCases := map[string]*Config{ + "test.db?cslike=false": &Config{ + CaseSensitiveLike: false, + }, + "test.db?cslike=true": &Config{ + CaseSensitiveLike: true, + }, + "test.db?cslike=active": nil, + } + + for dsn, c := range cslikeCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'cslike'") + } + if c != nil { + if cfg.CaseSensitiveLike != c.CaseSensitiveLike { + t.Fatalf("Failed to parse 'cslike'; expected: %t, got: %t", c.CaseSensitiveLike, cfg.CaseSensitiveLike) + } + } + } + + // Defer Foreign Keys + dfkCases := map[string]*Config{ + "test.db?defer_fk=false": &Config{ + DeferForeignKeys: false, + }, + "test.db?defer_fk=true": &Config{ + DeferForeignKeys: true, + }, + "test.db?defer_fk=active": nil, + } + + for dsn, c := range dfkCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'defer_fk'") + } + if c != nil { + if cfg.DeferForeignKeys != c.DeferForeignKeys { + t.Fatalf("Failed to parse 'defer_fk'; expected: %t, got: %t", c.DeferForeignKeys, cfg.DeferForeignKeys) + } + } + } + + // Foreign Key + fkCases := map[string]*Config{ + "test.db?fk=false": &Config{ + ForeignKeyConstraints: false, + }, + "test.db?fk=true": &Config{ + ForeignKeyConstraints: true, + }, + "test.db?fk=active": nil, + } + + for dsn, c := range fkCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'fk'") + } + if c != nil { + if cfg.ForeignKeyConstraints != c.ForeignKeyConstraints { + t.Fatalf("Failed to parse 'fk'; expected: %t, got: %t", c.ForeignKeyConstraints, cfg.ForeignKeyConstraints) + } + } + } + + // Ignore CHECK constraints + iCases := map[string]*Config{ + "test.db?ignore_check_constraints=false": &Config{ + IgnoreCheckConstraints: false, + }, + "test.db?ignore_check_constraints=true": &Config{ + IgnoreCheckConstraints: true, + }, + "test.db?ignore_check_constraints=active": nil, + } + + for dsn, c := range iCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'ignore_check_constraints'") + } + if c != nil { + if cfg.IgnoreCheckConstraints != c.IgnoreCheckConstraints { + t.Fatalf("Failed to parse 'ignore_check_constraints'; expected: %t, got: %t", c.IgnoreCheckConstraints, cfg.IgnoreCheckConstraints) + } + } + } + + // Synchronous + syncCases := map[string]*Config{ + "test.db?sync=off": &Config{ + Synchronous: SynchronousOff, + }, + "test.db?sync=normal": &Config{ + Synchronous: SynchronousNormal, + }, + "test.db?sync=full": &Config{ + Synchronous: SynchronousFull, + }, + "test.db?sync=extra": &Config{ + Synchronous: SynchronousExtra, + }, + "test.db?sync=bogus": nil, + } + + for dsn, c := range syncCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'sync'") + } + if c != nil { + if cfg.Synchronous != c.Synchronous { + t.Fatalf("Failed to parse 'sync'; expected: %s, got: %s", c.Synchronous, cfg.Synchronous) + } + } + } + + // Journal Mode + journalCases := map[string]*Config{ + "test.db?journal=delete": &Config{ + JournalMode: JournalModeDelete, + }, + "test.db?journal=truncate": &Config{ + JournalMode: JournalModeTruncate, + }, + "test.db?journal=persist": &Config{ + JournalMode: JournalModePersist, + }, + "test.db?journal=memory": &Config{ + JournalMode: JournalModeMemory, + }, + "test.db?journal=off": &Config{ + JournalMode: JournalModeOff, + }, + "test.db?journal=wal": &Config{ + JournalMode: JournalModeWAL, + Synchronous: SynchronousNormal, + }, + "test.db?journal=auto": nil, + } + + for dsn, c := range journalCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'journal'") + } + if c != nil { + if cfg.JournalMode != c.JournalMode { + t.Fatalf("Failed to parse 'journal'; expected: %s, got: %s", c.JournalMode, cfg.JournalMode) + } else { + if c.JournalMode == JournalModeWAL { + if cfg.Synchronous != c.Synchronous { + t.Fatal("Failed to auto adjust Synchronous mode to normal") + } + } + } + } + } + + // Locking Mode + lockingModeCases := map[string]*Config{ + "test.db?lock=normal": &Config{ + LockingMode: LockingModeNormal, + }, + "test.db?lock=exclusive": &Config{ + LockingMode: LockingModeExclusive, + }, + "test.db?lock=auto": nil, + } + + for dsn, c := range lockingModeCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'locking_mode'") + } + if c != nil { + if cfg.LockingMode != c.LockingMode { + t.Fatalf("Failed to parse 'locking_mode'; expected: %s, got: %s", c.LockingMode, cfg.LockingMode) + } + } + } + + // Query Only + qCases := map[string]*Config{ + "test.db?query_only=false": &Config{ + QueryOnly: false, + }, + "test.db?query_only=true": &Config{ + QueryOnly: true, + }, + "test.db?query_only=active": nil, + } + + for dsn, c := range qCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'query_only'") + } + if c != nil { + if cfg.QueryOnly != c.QueryOnly { + t.Fatalf("Failed to parse 'query_only'; expected: %t, got: %t", c.QueryOnly, cfg.QueryOnly) + } + } + } + + // Recursive Triggers + rtCases := map[string]*Config{ + "test.db?rt=false": &Config{ + RecursiveTriggers: false, + }, + "test.db?rt=true": &Config{ + RecursiveTriggers: true, + }, + "test.db?rt=active": nil, + } + + for dsn, c := range rtCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'recursive_triggers'") + } + if c != nil { + if cfg.RecursiveTriggers != c.RecursiveTriggers { + t.Fatalf("Failed to parse 'recursive_triggers'; expected: %t, got: %t", c.RecursiveTriggers, cfg.RecursiveTriggers) + } + } + } + + // Secure Delete + scCases := map[string]*Config{ + "test.db?secure_delete=off": &Config{ + SecureDelete: SecureDeleteOff, + }, + "test.db?secure_delete=on": &Config{ + SecureDelete: SecureDeleteOn, + }, + "test.db?secure_delete=fast": &Config{ + SecureDelete: SecureDeleteFast, + }, + "test.db?secure_delete=auto": nil, + } + + for dsn, c := range scCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'secure_delete'") + } + if c != nil { + if cfg.SecureDelete != c.SecureDelete { + t.Fatalf("Failed to parse 'secure_delete'; expected: %s, got: %s", c.SecureDelete, cfg.SecureDelete) + } + } + } + + // Writable Schema + wsCases := map[string]*Config{ + "test.db?writable_schema=false": &Config{ + WriteableSchema: false, + }, + "test.db?writable_schema=true": &Config{ + WriteableSchema: true, + }, + "test.db?writable_schema=active": nil, + } + + for dsn, c := range wsCases { + cfg, err := ParseDSN(dsn) + if err != nil && c != nil { + t.Fatal("Failed to parse 'writable_schema'") + } + if c != nil { + if cfg.WriteableSchema != c.WriteableSchema { + t.Fatalf("Failed to parse 'writable_schema'; expected: %t, got: %t", c.WriteableSchema, cfg.WriteableSchema) + } + } + } +} + +func TestFormatDSN(t *testing.T) { + // TODO: TestFormatDSN + cfg := NewConfig() + cfg.FormatDSN() +} diff --git a/driver/connection.go b/driver/connection.go index a8351b3..86a8fc6 100644 --- a/driver/connection.go +++ b/driver/connection.go @@ -45,6 +45,20 @@ type SQLiteConn struct { aggregators []*aggInfo } +func (c *SQLiteConn) PRAGMA(name, value string) error { + stmt := fmt.Sprintf("PRAGMA %s = %s;", name, value) + + cs := C.CString(stmt) + rv := C.sqlite3_exec(c.db, cs, nil, nil, nil) + C.free(unsafe.Pointer(cs)) + + if rv != C.SQLITE_OK { + return lastError(c.db) + } + + return nil +} + type functionInfo struct { f reflect.Value argConverters []callbackArgConverter diff --git a/driver/connection_go18.go b/driver/connection_go18.go index a21e487..6dc84fb 100644 --- a/driver/connection_go18.go +++ b/driver/connection_go18.go @@ -29,6 +29,7 @@ func (c *SQLiteConn) Ping(ctx context.Context) error { if c.db == nil { return errors.New("Connection was closed") } + return nil } diff --git a/driver/connection_go18_test.go b/driver/connection_go18_test.go index 66cad43..a04de93 100644 --- a/driver/connection_go18_test.go +++ b/driver/connection_go18_test.go @@ -156,17 +156,37 @@ func TestExecCancel(t *testing.T) { } func TestPinger(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") + driverName := "sqlite3_pinger" + + var dbDriverConn []*SQLiteConn + sql.Register(driverName, &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + dbDriverConn = append(dbDriverConn, conn) + return nil + }, + }) + + db, err := sql.Open(driverName, ":memory:") if err != nil { t.Fatal(err) } + + // Ping Database err = db.Ping() if err != nil { t.Fatal(err) } db.Close() + + // Ping database + // response should be: database closed err = db.Ping() - fmt.Println(err) + if err == nil { + t.Fatal("Should be closed") + } + + // Ping Database through connection + err = dbDriverConn[0].Ping(context.Background()) if err == nil { t.Fatal("Should be closed") } diff --git a/driver/connector.go b/driver/connector.go index 5576e1b..fcf5969 100644 --- a/driver/connector.go +++ b/driver/connector.go @@ -27,6 +27,8 @@ func (c *Config) Connect(ctx context.Context) (driver.Conn, error) { // Driver returns &SQLiteDriver{}. func (c *Config) Driver() driver.Driver { return &SQLiteDriver{ - Config: c, + Config: c, + Extensions: c.Extensions, + ConnectHook: c.ConnectHook, } } diff --git a/driver/connector_test.go b/driver/connector_test.go new file mode 100644 index 0000000..84fb5b0 --- /dev/null +++ b/driver/connector_test.go @@ -0,0 +1,39 @@ +// Copyright (C) 2018 The Go-SQLite3 Authors. +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// +build cgo +// +build go1.10 + +package sqlite3 + +import ( + "context" + "testing" +) + +func TestConnectorDriver(t *testing.T) { + // Create default Config + cfg := NewConfig() + cfg.ConnectHook = func(conn *SQLiteConn) error { + return nil + } + + // Create Driver from Config + drv := cfg.Driver() + if drv.(*SQLiteDriver).ConnectHook == nil { + t.Fatal("Failed to created Driver from Config") + } +} + +func TestConnectorConnect(t *testing.T) { + // Create default Config + cfg := NewConfig() + + // Create Connection to database from Config + conn, err := cfg.Connect(context.Background()) + if err != nil || conn == nil { + t.Fatal("Failed to create connection from Config") + } +} diff --git a/driver/crypt.go b/driver/crypt.go index 79b472a..8f2be38 100644 --- a/driver/crypt.go +++ b/driver/crypt.go @@ -46,6 +46,7 @@ import ( "crypto/sha1" "crypto/sha256" "crypto/sha512" + "fmt" ) // Force Implementation @@ -63,6 +64,7 @@ var ( // CryptEncoder provides the interface for implementing // a sqlite_crypt encoder. type CryptEncoder interface { + fmt.Stringer Encode(pass []byte, hash interface{}) []byte } @@ -80,6 +82,10 @@ func (e *sha1Encoder) Encode(pass []byte, hash interface{}) []byte { return h[:] } +func (e *sha1Encoder) String() string { + return "sha1" +} + // NewSHA1Encoder returns a new SHA1 Encoder. func NewSHA1Encoder() CryptEncoder { return &sha1Encoder{} @@ -100,6 +106,10 @@ func (e *ssha1Encoder) Salt() string { return e.salt } +func (e *ssha1Encoder) String() string { + return "ssha1" +} + // NewSSHA1Encoder returns a new salted SHA1 Encoder. func NewSSHA1Encoder(salt string) CryptSaltedEncoder { return &ssha1Encoder{ @@ -114,6 +124,10 @@ func (e *sha256Encoder) Encode(pass []byte, hash interface{}) []byte { return h[:] } +func (e *sha256Encoder) String() string { + return "sha256" +} + // NewSHA256Encoder returns a new SHA256 Encoder. func NewSHA256Encoder() CryptEncoder { return &sha256Encoder{} @@ -134,6 +148,10 @@ func (e *ssha256Encoder) Salt() string { return e.salt } +func (e *ssha256Encoder) String() string { + return "ssha256" +} + // NewSSHA256Encoder returns a new salted SHA256 Encoder. func NewSSHA256Encoder(salt string) CryptSaltedEncoder { return &ssha256Encoder{ @@ -148,6 +166,10 @@ func (e *sha384Encoder) Encode(pass []byte, hash interface{}) []byte { return h[:] } +func (e *sha384Encoder) String() string { + return "sha384" +} + // NewSHA384Encoder returns a new SHA384 Encoder. func NewSHA384Encoder() CryptEncoder { return &sha384Encoder{} @@ -168,6 +190,10 @@ func (e *ssha384Encoder) Salt() string { return e.salt } +func (e *ssha384Encoder) String() string { + return "ssha384" +} + // NewSSHA384Encoder returns a new salted SHA384 Encoder. func NewSSHA384Encoder(salt string) CryptSaltedEncoder { return &ssha384Encoder{ @@ -182,6 +208,10 @@ func (e *sha512Encoder) Encode(pass []byte, hash interface{}) []byte { return h[:] } +func (e *sha512Encoder) String() string { + return "sha512" +} + // NewSHA512Encoder returns a new SHA512 Encoder. func NewSHA512Encoder() CryptEncoder { return &sha512Encoder{} @@ -202,7 +232,11 @@ func (e *ssha512Encoder) Salt() string { return e.salt } -// NewSSHA384Encoder returns a new salted SHA512 Encoder. +func (e *ssha512Encoder) String() string { + return "ssha512" +} + +// NewSSHA512Encoder returns a new salted SHA512 Encoder. func NewSSHA512Encoder(salt string) CryptSaltedEncoder { return &ssha512Encoder{ salt: salt, diff --git a/driver/driver.go b/driver/driver.go index ed1f774..854272e 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -28,39 +28,11 @@ package sqlite3 #else #include #endif -#include -#include - -#ifdef __CYGWIN__ -# include -#endif - -#ifndef SQLITE_OPEN_READWRITE -# define SQLITE_OPEN_READWRITE 0 -#endif - -#ifndef SQLITE_OPEN_FULLMUTEX -# define SQLITE_OPEN_FULLMUTEX 0 -#endif - -#ifndef SQLITE_DETERMINISTIC -# define SQLITE_DETERMINISTIC 0 -#endif - -static int -_sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) { -#ifdef SQLITE_OPEN_URI - return sqlite3_open_v2(filename, ppDb, flags | SQLITE_OPEN_URI, zVfs); -#else - return sqlite3_open_v2(filename, ppDb, flags, zVfs); -#endif -} */ import "C" import ( "database/sql" "database/sql/driver" - "errors" ) var ( @@ -80,14 +52,19 @@ type SQLiteDriver struct { // Open database and return a new connection. func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { - if C.sqlite3_threadsafe() == 0 { - return nil, errors.New("sqlite library was not compiled for thread-safe operation") - } - cfg, err := ParseDSN(dsn) if err != nil { return nil, err } + // Configure Extensions + cfg.Extensions = d.Extensions + + // Configure ConnectHook + cfg.ConnectHook = d.ConnectHook + + // Set Configuration + d.Config = cfg + return cfg.createConnection() } diff --git a/driver/driver.goconvey b/driver/driver.goconvey new file mode 100644 index 0000000..697e095 --- /dev/null +++ b/driver/driver.goconvey @@ -0,0 +1,4 @@ + +-tags=sqlite_userauth + +-cover diff --git a/driver/driver_go110.go b/driver/driver_go110.go index 17d7195..f9748fa 100644 --- a/driver/driver_go110.go +++ b/driver/driver_go110.go @@ -8,7 +8,9 @@ package sqlite3 -import "database/sql/driver" +import ( + "database/sql/driver" +) var ( _ driver.DriverContext = (*SQLiteDriver)(nil) @@ -20,5 +22,19 @@ var ( // The two-step sequence allows drivers to parse the name just once and also provides // access to per-Conn contexts. func (d *SQLiteDriver) OpenConnector(dsn string) (driver.Connector, error) { - return ParseDSN(dsn) + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + + // Configure Extensions + cfg.Extensions = d.Extensions + + // Configure ConnectHook + cfg.ConnectHook = d.ConnectHook + + // Set Configuration + d.Config = cfg + + return cfg, nil } diff --git a/driver/driver_test.go b/driver/driver_test.go index b661c61..725bb71 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -3,25 +3,13 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. +// +build cgo + package sqlite3 import ( - "bytes" - "database/sql" - "database/sql/driver" - "errors" - "fmt" "io/ioutil" - "math/rand" - "net/url" - "os" - "reflect" - "regexp" - "strconv" - "strings" - "sync" "testing" - "time" ) func TempFilename(t *testing.T) string { @@ -32,2036 +20,3 @@ func TempFilename(t *testing.T) string { f.Close() return f.Name() } - -func doTestOpen(t *testing.T, option string) (string, error) { - var url string - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - if option != "" { - url = tempFilename + option - } else { - url = tempFilename - } - db, err := sql.Open("sqlite3", url) - if err != nil { - return "Failed to open database:", err - } - defer os.Remove(tempFilename) - defer db.Close() - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer)") - if err != nil { - return "Failed to create table:", err - } - - if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() { - return "Failed to create ./foo.db", nil - } - - return "", nil -} - -func TestOpen(t *testing.T) { - cases := map[string]bool{ - "": true, - "?_txlock=immediate": true, - "?_txlock=deferred": true, - "?_txlock=exclusive": true, - "?_txlock=bogus": false, - } - for option, expectedPass := range cases { - result, err := doTestOpen(t, option) - if result == "" { - if !expectedPass { - errmsg := fmt.Sprintf("_txlock error not caught at dbOpen with option: %s", option) - t.Fatal(errmsg) - } - } else if expectedPass { - if err == nil { - t.Fatal(result) - } else { - t.Fatal(result, err) - } - } - } -} - -func TestReadonly(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - - db1, err := sql.Open("sqlite3", "file:"+tempFilename) - if err != nil { - t.Fatal(err) - } - db1.Exec("CREATE TABLE test (x int, y float)") - - db2, err := sql.Open("sqlite3", "file:"+tempFilename+"?mode=ro") - if err != nil { - t.Fatal(err) - } - _ = db2 - _, err = db2.Exec("INSERT INTO test VALUES (1, 3.14)") - if err == nil { - t.Fatal("didn't expect INSERT into read-only database to work") - } -} - -func TestForeignKeys(t *testing.T) { - cases := map[string]bool{ - "?_foreign_keys=1": true, - "?_foreign_keys=0": false, - } - for option, want := range cases { - fname := TempFilename(t) - uri := "file:" + fname + option - db, err := sql.Open("sqlite3", uri) - if err != nil { - os.Remove(fname) - t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err) - continue - } - var enabled bool - err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled) - db.Close() - os.Remove(fname) - if err != nil { - t.Errorf("query foreign_keys for %s: %v", uri, err) - continue - } - if enabled != want { - t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want) - continue - } - } -} - -func TestRecursiveTriggers(t *testing.T) { - cases := map[string]bool{ - "?_recursive_triggers=1": true, - "?_recursive_triggers=0": false, - } - for option, want := range cases { - fname := TempFilename(t) - uri := "file:" + fname + option - db, err := sql.Open("sqlite3", uri) - if err != nil { - os.Remove(fname) - t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err) - continue - } - var enabled bool - err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled) - db.Close() - os.Remove(fname) - if err != nil { - t.Errorf("query recursive_triggers for %s: %v", uri, err) - continue - } - if enabled != want { - t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want) - continue - } - } -} - -func TestClose(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - stmt, err := db.Prepare("select id from foo where id = ?") - if err != nil { - t.Fatal("Failed to select records:", err) - } - - db.Close() - _, err = stmt.Exec(1) - if err == nil { - t.Fatal("Failed to operate closed statement") - } -} - -func TestInsert(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - res, err := db.Exec("insert into foo(id) values(123)") - if err != nil { - t.Fatal("Failed to insert record:", err) - } - affected, _ := res.RowsAffected() - if affected != 1 { - t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) - } - - rows, err := db.Query("select id from foo") - if err != nil { - t.Fatal("Failed to select records:", err) - } - defer rows.Close() - - rows.Next() - - var result int - rows.Scan(&result) - if result != 123 { - t.Errorf("Expected %d for fetched result, but %d:", 123, result) - } -} - -func TestUpsert(t *testing.T) { - _, n, _ := Version() - if !(n >= 3024000) { - t.Skip("UPSERT requires sqlite3 => 3.24.0") - } - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (name string primary key, counter integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - for i := 0; i < 10; i++ { - res, err := db.Exec("insert into foo(name, counter) values('key', 1) on conflict (name) do update set counter=counter+1") - if err != nil { - t.Fatal("Failed to upsert record:", err) - } - affected, _ := res.RowsAffected() - if affected != 1 { - t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) - } - } - rows, err := db.Query("select name, counter from foo") - if err != nil { - t.Fatal("Failed to select records:", err) - } - defer rows.Close() - - rows.Next() - - var resultName string - var resultCounter int - rows.Scan(&resultName, &resultCounter) - if resultName != "key" { - t.Errorf("Expected %s for fetched result, but %s:", "key", resultName) - } - if resultCounter != 10 { - t.Errorf("Expected %d for fetched result, but %d:", 10, resultCounter) - } - -} - -func TestUpdate(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - res, err := db.Exec("insert into foo(id) values(123)") - if err != nil { - t.Fatal("Failed to insert record:", err) - } - expected, err := res.LastInsertId() - if err != nil { - t.Fatal("Failed to get LastInsertId:", err) - } - affected, _ := res.RowsAffected() - if err != nil { - t.Fatal("Failed to get RowsAffected:", err) - } - if affected != 1 { - t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) - } - - res, err = db.Exec("update foo set id = 234") - if err != nil { - t.Fatal("Failed to update record:", err) - } - lastID, err := res.LastInsertId() - if err != nil { - t.Fatal("Failed to get LastInsertId:", err) - } - if expected != lastID { - t.Errorf("Expected %q for last Id, but %q:", expected, lastID) - } - affected, _ = res.RowsAffected() - if err != nil { - t.Fatal("Failed to get RowsAffected:", err) - } - if affected != 1 { - t.Fatalf("Expected %d for affected rows, but %d:", 1, affected) - } - - rows, err := db.Query("select id from foo") - if err != nil { - t.Fatal("Failed to select records:", err) - } - defer rows.Close() - - rows.Next() - - var result int - rows.Scan(&result) - if result != 234 { - t.Errorf("Expected %d for fetched result, but %d:", 234, result) - } -} - -func TestDelete(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("drop table foo") - _, err = db.Exec("create table foo (id integer)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - res, err := db.Exec("insert into foo(id) values(123)") - if err != nil { - t.Fatal("Failed to insert record:", err) - } - expected, err := res.LastInsertId() - if err != nil { - t.Fatal("Failed to get LastInsertId:", err) - } - affected, err := res.RowsAffected() - if err != nil { - t.Fatal("Failed to get RowsAffected:", err) - } - if affected != 1 { - t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected) - } - - res, err = db.Exec("delete from foo where id = 123") - if err != nil { - t.Fatal("Failed to delete record:", err) - } - lastID, err := res.LastInsertId() - if err != nil { - t.Fatal("Failed to get LastInsertId:", err) - } - if expected != lastID { - t.Errorf("Expected %q for last Id, but %q:", expected, lastID) - } - affected, err = res.RowsAffected() - if err != nil { - t.Fatal("Failed to get RowsAffected:", err) - } - if affected != 1 { - t.Errorf("Expected %d for cout of affected rows, but %q:", 1, affected) - } - - rows, err := db.Query("select id from foo") - if err != nil { - t.Fatal("Failed to select records:", err) - } - defer rows.Close() - - if rows.Next() { - t.Error("Fetched row but expected not rows") - } -} - -func TestBooleanRoundtrip(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("DROP TABLE foo") - _, err = db.Exec("CREATE TABLE foo(id INTEGER, value BOOL)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - _, err = db.Exec("INSERT INTO foo(id, value) VALUES(1, ?)", true) - if err != nil { - t.Fatal("Failed to insert true value:", err) - } - - _, err = db.Exec("INSERT INTO foo(id, value) VALUES(2, ?)", false) - if err != nil { - t.Fatal("Failed to insert false value:", err) - } - - rows, err := db.Query("SELECT id, value FROM foo") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - defer rows.Close() - - for rows.Next() { - var id int - var value bool - - if err := rows.Scan(&id, &value); err != nil { - t.Error("Unable to scan results:", err) - continue - } - - if id == 1 && !value { - t.Error("Value for id 1 should be true, not false") - - } else if id == 2 && value { - t.Error("Value for id 2 should be false, not true") - } - } -} - -func timezone(t time.Time) string { return t.Format("-07:00") } - -func TestTimestamp(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("DROP TABLE foo") - _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts timeSTAMP, dt DATETIME)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) - timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) - timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) - tzTest := time.FixedZone("TEST", -9*3600-13*60) - tests := []struct { - value interface{} - expected time.Time - }{ - {"nonsense", time.Time{}}, - {"0000-00-00 00:00:00", time.Time{}}, - {time.Time{}.Unix(), time.Time{}}, - {timestamp1, timestamp1}, - {timestamp2.Unix(), timestamp2.Truncate(time.Second)}, - {timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)}, - {timestamp1.In(tzTest), timestamp1.In(tzTest)}, - {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1}, - {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1}, - {timestamp1.Format("2006-01-02 15:04:05"), timestamp1}, - {timestamp1.Format("2006-01-02T15:04:05"), timestamp1}, - {timestamp2, timestamp2}, - {"2006-01-02 15:04:05.123456789", timestamp2}, - {"2006-01-02T15:04:05.123456789", timestamp2}, - {"2006-01-02T05:51:05.123456789-09:13", timestamp2.In(tzTest)}, - {"2012-11-04", timestamp3}, - {"2012-11-04 00:00", timestamp3}, - {"2012-11-04 00:00:00", timestamp3}, - {"2012-11-04 00:00:00.000", timestamp3}, - {"2012-11-04T00:00", timestamp3}, - {"2012-11-04T00:00:00", timestamp3}, - {"2012-11-04T00:00:00.000", timestamp3}, - {"2006-01-02T15:04:05.123456789Z", timestamp2}, - {"2012-11-04Z", timestamp3}, - {"2012-11-04 00:00Z", timestamp3}, - {"2012-11-04 00:00:00Z", timestamp3}, - {"2012-11-04 00:00:00.000Z", timestamp3}, - {"2012-11-04T00:00Z", timestamp3}, - {"2012-11-04T00:00:00Z", timestamp3}, - {"2012-11-04T00:00:00.000Z", timestamp3}, - } - for i := range tests { - _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) - if err != nil { - t.Fatal("Failed to insert timestamp:", err) - } - } - - rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - defer rows.Close() - - seen := 0 - for rows.Next() { - var id int - var ts, dt time.Time - - if err := rows.Scan(&id, &ts, &dt); err != nil { - t.Error("Unable to scan results:", err) - continue - } - if id < 0 || id >= len(tests) { - t.Error("Bad row id: ", id) - continue - } - seen++ - if !tests[id].expected.Equal(ts) { - t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) - } - if !tests[id].expected.Equal(dt) { - t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) - } - if timezone(tests[id].expected) != timezone(ts) { - t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value, - timezone(tests[id].expected), timezone(ts)) - } - if timezone(tests[id].expected) != timezone(dt) { - t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value, - timezone(tests[id].expected), timezone(dt)) - } - } - - if seen != len(tests) { - t.Errorf("Expected to see %d rows", len(tests)) - } -} - -func TestBoolean(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - - defer db.Close() - - _, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - bool1 := true - _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(1, ?)", bool1) - if err != nil { - t.Fatal("Failed to insert boolean:", err) - } - - bool2 := false - _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(2, ?)", bool2) - if err != nil { - t.Fatal("Failed to insert boolean:", err) - } - - bool3 := "nonsense" - _, err = db.Exec("INSERT INTO foo(id, fbool) VALUES(3, ?)", bool3) - if err != nil { - t.Fatal("Failed to insert nonsense:", err) - } - - rows, err := db.Query("SELECT id, fbool FROM foo where fbool = ?", bool1) - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - counter := 0 - - var id int - var fbool bool - - for rows.Next() { - if err := rows.Scan(&id, &fbool); err != nil { - t.Fatal("Unable to scan results:", err) - } - counter++ - } - - if counter != 1 { - t.Fatalf("Expected 1 row but %v", counter) - } - - if id != 1 && !fbool { - t.Fatalf("Value for id 1 should be %v, not %v", bool1, fbool) - } - - rows, err = db.Query("SELECT id, fbool FROM foo where fbool = ?", bool2) - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - counter = 0 - - for rows.Next() { - if err := rows.Scan(&id, &fbool); err != nil { - t.Fatal("Unable to scan results:", err) - } - counter++ - } - - if counter != 1 { - t.Fatalf("Expected 1 row but %v", counter) - } - - if id != 2 && fbool { - t.Fatalf("Value for id 2 should be %v, not %v", bool2, fbool) - } - - // make sure "nonsense" triggered an error - rows, err = db.Query("SELECT id, fbool FROM foo where id=?;", 3) - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - rows.Next() - err = rows.Scan(&id, &fbool) - if err == nil { - t.Error("Expected error from \"nonsense\" bool") - } -} - -func TestFloat32(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("CREATE TABLE foo(id INTEGER)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - _, err = db.Exec("INSERT INTO foo(id) VALUES(null)") - if err != nil { - t.Fatal("Failed to insert null:", err) - } - - rows, err := db.Query("SELECT id FROM foo") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - if !rows.Next() { - t.Fatal("Unable to query results:", err) - } - - var id interface{} - if err := rows.Scan(&id); err != nil { - t.Fatal("Unable to scan results:", err) - } - if id != nil { - t.Error("Expected nil but not") - } -} - -func TestNull(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - rows, err := db.Query("SELECT 3.141592") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - - if !rows.Next() { - t.Fatal("Unable to query results:", err) - } - - var v interface{} - if err := rows.Scan(&v); err != nil { - t.Fatal("Unable to scan results:", err) - } - f, ok := v.(float64) - if !ok { - t.Error("Expected float but not") - } - if f != 3.141592 { - t.Error("Expected 3.141592 but not") - } -} - -func TestWAL(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - if _, err = db.Exec("PRAGMA journal_mode=WAL;"); err != nil { - t.Fatal("Failed to Exec PRAGMA journal_mode:", err) - } - if _, err = db.Exec("PRAGMA locking_mode=EXCLUSIVE;"); err != nil { - t.Fatal("Failed to Exec PRAGMA locking_mode:", err) - } - if _, err = db.Exec("CREATE TABLE test (id SERIAL, user TEXT NOT NULL, name TEXT NOT NULL);"); err != nil { - t.Fatal("Failed to Exec CREATE TABLE:", err) - } - if _, err = db.Exec("INSERT INTO test (user, name) VALUES ('user','name');"); err != nil { - t.Fatal("Failed to Exec INSERT:", err) - } - - trans, err := db.Begin() - if err != nil { - t.Fatal("Failed to Begin:", err) - } - s, err := trans.Prepare("INSERT INTO test (user, name) VALUES (?, ?);") - if err != nil { - t.Fatal("Failed to Prepare:", err) - } - - var count int - if err = trans.QueryRow("SELECT count(user) FROM test;").Scan(&count); err != nil { - t.Fatal("Failed to QueryRow:", err) - } - if _, err = s.Exec("bbbb", "aaaa"); err != nil { - t.Fatal("Failed to Exec prepared statement:", err) - } - if err = s.Close(); err != nil { - t.Fatal("Failed to Close prepared statement:", err) - } - if err = trans.Commit(); err != nil { - t.Fatal("Failed to Commit:", err) - } -} - -func TestTimezoneConversion(t *testing.T) { - zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} - for _, tz := range zones { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz)) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("DROP TABLE foo") - _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)") - if err != nil { - t.Fatal("Failed to create table:", err) - } - - loc, err := time.LoadLocation(tz) - if err != nil { - t.Fatal("Failed to load location:", err) - } - - timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) - timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) - timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) - tests := []struct { - value interface{} - expected time.Time - }{ - {"nonsense", time.Time{}.In(loc)}, - {"0000-00-00 00:00:00", time.Time{}.In(loc)}, - {timestamp1, timestamp1.In(loc)}, - {timestamp1.Unix(), timestamp1.In(loc)}, - {timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)}, - {timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)}, - {timestamp2, timestamp2.In(loc)}, - {"2006-01-02 15:04:05.123456789", timestamp2.In(loc)}, - {"2006-01-02T15:04:05.123456789", timestamp2.In(loc)}, - {"2012-11-04", timestamp3.In(loc)}, - {"2012-11-04 00:00", timestamp3.In(loc)}, - {"2012-11-04 00:00:00", timestamp3.In(loc)}, - {"2012-11-04 00:00:00.000", timestamp3.In(loc)}, - {"2012-11-04T00:00", timestamp3.In(loc)}, - {"2012-11-04T00:00:00", timestamp3.In(loc)}, - {"2012-11-04T00:00:00.000", timestamp3.In(loc)}, - } - for i := range tests { - _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) - if err != nil { - t.Fatal("Failed to insert timestamp:", err) - } - } - - rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC") - if err != nil { - t.Fatal("Unable to query foo table:", err) - } - defer rows.Close() - - seen := 0 - for rows.Next() { - var id int - var ts, dt time.Time - - if err := rows.Scan(&id, &ts, &dt); err != nil { - t.Error("Unable to scan results:", err) - continue - } - if id < 0 || id >= len(tests) { - t.Error("Bad row id: ", id) - continue - } - seen++ - if !tests[id].expected.Equal(ts) { - t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts) - } - if !tests[id].expected.Equal(dt) { - t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) - } - if tests[id].expected.Location().String() != ts.Location().String() { - t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), ts.Location().String()) - } - if tests[id].expected.Location().String() != dt.Location().String() { - t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), dt.Location().String()) - } - } - - if seen != len(tests) { - t.Errorf("Expected to see %d rows", len(tests)) - } - } -} - -func TestExecer(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec(` - create table foo (id integer); -- one comment - insert into foo(id) values(?); - insert into foo(id) values(?); - insert into foo(id) values(?); -- another comment - `, 1, 2, 3) - if err != nil { - t.Error("Failed to call db.Exec:", err) - } -} - -func TestQueryer(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec(` - create table foo (id integer); - `) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - - rows, err := db.Query(` - insert into foo(id) values(?); - insert into foo(id) values(?); - insert into foo(id) values(?); - select id from foo order by id; - `, 3, 2, 1) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - defer rows.Close() - n := 1 - if rows != nil { - for rows.Next() { - var id int - err = rows.Scan(&id) - if err != nil { - t.Error("Failed to db.Query:", err) - } - if id != n { - t.Error("Failed to db.Query: not matched results") - } - } - } -} - -func TestStress(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - db.Exec("CREATE TABLE foo (id int);") - db.Exec("INSERT INTO foo VALUES(1);") - db.Exec("INSERT INTO foo VALUES(2);") - db.Close() - - for i := 0; i < 10000; i++ { - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - - for j := 0; j < 3; j++ { - rows, err := db.Query("select * from foo where id=1;") - if err != nil { - t.Error("Failed to call db.Query:", err) - } - for rows.Next() { - var i int - if err := rows.Scan(&i); err != nil { - t.Errorf("Scan failed: %v\n", err) - } - } - if err := rows.Err(); err != nil { - t.Errorf("Post-scan failed: %v\n", err) - } - rows.Close() - } - db.Close() - } -} - -func TestDateTimeLocal(t *testing.T) { - zone := "Asia/Tokyo" - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone) - if err != nil { - t.Fatal("Failed to open database:", err) - } - db.Exec("CREATE TABLE foo (dt datetime);") - db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');") - - row := db.QueryRow("select * from foo") - var d time.Time - err = row.Scan(&d) - if err != nil { - t.Fatal("Failed to scan datetime:", err) - } - if d.Hour() == 15 || !strings.Contains(d.String(), "JST") { - t.Fatal("Result should have timezone", d) - } - db.Close() - - db, err = sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - - row = db.QueryRow("select * from foo") - err = row.Scan(&d) - if err != nil { - t.Fatal("Failed to scan datetime:", err) - } - if d.UTC().Hour() != 15 || !strings.Contains(d.String(), "UTC") { - t.Fatalf("Result should not have timezone %v %v", zone, d.String()) - } - - _, err = db.Exec("DELETE FROM foo") - if err != nil { - t.Fatal("Failed to delete table:", err) - } - dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST") - if err != nil { - t.Fatal("Failed to parse datetime:", err) - } - db.Exec("INSERT INTO foo VALUES(?);", dt) - - db.Close() - db, err = sql.Open("sqlite3", tempFilename+"?_loc="+zone) - if err != nil { - t.Fatal("Failed to open database:", err) - } - - row = db.QueryRow("select * from foo") - err = row.Scan(&d) - if err != nil { - t.Fatal("Failed to scan datetime:", err) - } - if d.Hour() != 15 || !strings.Contains(d.String(), "JST") { - t.Fatalf("Result should have timezone %v %v", zone, d.String()) - } -} - -func TestStringContainingZero(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec(` - create table foo (id integer, name, extra text); - `) - if err != nil { - t.Error("Failed to call db.Query:", err) - } - - const text = "foo\x00bar" - - _, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, text) - if err != nil { - t.Error("Failed to call db.Exec:", err) - } - - row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, text) - if row == nil { - t.Error("Failed to call db.QueryRow") - } - - var id int - var extra string - err = row.Scan(&id, &extra) - if err != nil { - t.Error("Failed to db.Scan:", err) - } - if id != 1 || extra != text { - t.Error("Failed to db.QueryRow: not matched results") - } -} - -const CurrentTimeStamp = "2006-01-02 15:04:05" - -type TimeStamp struct{ *time.Time } - -func (t TimeStamp) Scan(value interface{}) error { - var err error - switch v := value.(type) { - case string: - *t.Time, err = time.Parse(CurrentTimeStamp, v) - case []byte: - *t.Time, err = time.Parse(CurrentTimeStamp, string(v)) - default: - err = errors.New("invalid type for current_timestamp") - } - return err -} - -func (t TimeStamp) Value() (driver.Value, error) { - return t.Time.Format(CurrentTimeStamp), nil -} - -func TestDateTimeNow(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - db, err := sql.Open("sqlite3", tempFilename) - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - var d time.Time - err = db.QueryRow("SELECT datetime('now')").Scan(TimeStamp{&d}) - if err != nil { - t.Fatal("Failed to scan datetime:", err) - } -} - -func TestFunctionRegistration(t *testing.T) { - addi8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) } - addi64 := func(a, b int64) int64 { return a + b } - addu8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) } - addu64 := func(a, b uint64) uint64 { return a + b } - addiu := func(a int, b uint) int64 { return int64(a) + int64(b) } - addf32_64 := func(a float32, b float64) float64 { return float64(a) + b } - not := func(a bool) bool { return !a } - regex := func(re, s string) (bool, error) { - return regexp.MatchString(re, s) - } - generic := func(a interface{}) int64 { - switch a.(type) { - case int64: - return 1 - case float64: - return 2 - case []byte: - return 3 - case string: - return 4 - default: - panic("unreachable") - } - } - variadic := func(a, b int64, c ...int64) int64 { - ret := a + b - for _, d := range c { - ret += d - } - return ret - } - variadicGeneric := func(a ...interface{}) int64 { - return int64(len(a)) - } - - sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - if err := conn.RegisterFunc("addi8_16_32", addi8_16_32, true); err != nil { - return err - } - if err := conn.RegisterFunc("addi64", addi64, true); err != nil { - return err - } - if err := conn.RegisterFunc("addu8_16_32", addu8_16_32, true); err != nil { - return err - } - if err := conn.RegisterFunc("addu64", addu64, true); err != nil { - return err - } - if err := conn.RegisterFunc("addiu", addiu, true); err != nil { - return err - } - if err := conn.RegisterFunc("addf32_64", addf32_64, true); err != nil { - return err - } - if err := conn.RegisterFunc("not", not, true); err != nil { - return err - } - if err := conn.RegisterFunc("regex", regex, true); err != nil { - return err - } - if err := conn.RegisterFunc("generic", generic, true); err != nil { - return err - } - if err := conn.RegisterFunc("variadic", variadic, true); err != nil { - return err - } - if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil { - return err - } - return nil - }, - }) - db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - ops := []struct { - query string - expected interface{} - }{ - {"SELECT addi8_16_32(1,2)", int32(3)}, - {"SELECT addi64(1,2)", int64(3)}, - {"SELECT addu8_16_32(1,2)", uint32(3)}, - {"SELECT addu64(1,2)", uint64(3)}, - {"SELECT addiu(1,2)", int64(3)}, - {"SELECT addf32_64(1.5,1.5)", float64(3)}, - {"SELECT not(1)", false}, - {"SELECT not(0)", true}, - {`SELECT regex("^foo.*", "foobar")`, true}, - {`SELECT regex("^foo.*", "barfoobar")`, false}, - {"SELECT generic(1)", int64(1)}, - {"SELECT generic(1.1)", int64(2)}, - {`SELECT generic(NULL)`, int64(3)}, - {`SELECT generic("foo")`, int64(4)}, - {"SELECT variadic(1,2)", int64(3)}, - {"SELECT variadic(1,2,3,4)", int64(10)}, - {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)}, - {`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)}, - } - - for _, op := range ops { - ret := reflect.New(reflect.TypeOf(op.expected)) - err = db.QueryRow(op.query).Scan(ret.Interface()) - if err != nil { - t.Errorf("Query %q failed: %s", op.query, err) - } else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) { - t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected) - } - } -} - -type sumAggregator int64 - -func (s *sumAggregator) Step(x int64) { - *s += sumAggregator(x) -} - -func (s *sumAggregator) Done() int64 { - return int64(*s) -} - -func TestAggregatorRegistration(t *testing.T) { - customSum := func() *sumAggregator { - var ret sumAggregator - return &ret - } - - sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - if err := conn.RegisterAggregator("customSum", customSum, true); err != nil { - return err - } - return nil - }, - }) - db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("create table foo (department integer, profits integer)") - if err != nil { - // trace feature is not implemented - t.Skip("Failed to create table:", err) - } - - _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") - if err != nil { - t.Fatal("Failed to insert records:", err) - } - - tests := []struct { - dept, sum int64 - }{ - {1, 30}, - {2, 42}, - } - - for _, test := range tests { - var ret int64 - err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) - if err != nil { - t.Fatal("Query failed:", err) - } - if ret != test.sum { - t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) - } - } -} - -func rot13(r rune) rune { - switch { - case r >= 'A' && r <= 'Z': - return 'A' + (r-'A'+13)%26 - case r >= 'a' && r <= 'z': - return 'a' + (r-'a'+13)%26 - } - return r -} - -func TestCollationRegistration(t *testing.T) { - collateRot13 := func(a, b string) int { - ra, rb := strings.Map(rot13, a), strings.Map(rot13, b) - return strings.Compare(ra, rb) - } - collateRot13Reverse := func(a, b string) int { - return collateRot13(b, a) - } - - sql.Register("sqlite3_CollationRegistration", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - if err := conn.RegisterCollation("rot13", collateRot13); err != nil { - return err - } - if err := conn.RegisterCollation("rot13reverse", collateRot13Reverse); err != nil { - return err - } - return nil - }, - }) - - db, err := sql.Open("sqlite3_CollationRegistration", ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - populate := []string{ - `CREATE TABLE test (s TEXT)`, - `INSERT INTO test VALUES ("aaaa")`, - `INSERT INTO test VALUES ("ffff")`, - `INSERT INTO test VALUES ("qqqq")`, - `INSERT INTO test VALUES ("tttt")`, - `INSERT INTO test VALUES ("zzzz")`, - } - for _, stmt := range populate { - if _, err := db.Exec(stmt); err != nil { - t.Fatal("Failed to populate test DB:", err) - } - } - - ops := []struct { - query string - want []string - }{ - { - "SELECT * FROM test ORDER BY s COLLATE rot13 ASC", - []string{ - "qqqq", - "tttt", - "zzzz", - "aaaa", - "ffff", - }, - }, - { - "SELECT * FROM test ORDER BY s COLLATE rot13 DESC", - []string{ - "ffff", - "aaaa", - "zzzz", - "tttt", - "qqqq", - }, - }, - { - "SELECT * FROM test ORDER BY s COLLATE rot13reverse ASC", - []string{ - "ffff", - "aaaa", - "zzzz", - "tttt", - "qqqq", - }, - }, - { - "SELECT * FROM test ORDER BY s COLLATE rot13reverse DESC", - []string{ - "qqqq", - "tttt", - "zzzz", - "aaaa", - "ffff", - }, - }, - } - - for _, op := range ops { - rows, err := db.Query(op.query) - if err != nil { - t.Fatalf("Query %q failed: %s", op.query, err) - } - got := []string{} - defer rows.Close() - for rows.Next() { - var s string - if err = rows.Scan(&s); err != nil { - t.Fatalf("Reading row for %q: %s", op.query, err) - } - got = append(got, s) - } - if err = rows.Err(); err != nil { - t.Fatalf("Reading rows for %q: %s", op.query, err) - } - - if !reflect.DeepEqual(got, op.want) { - t.Fatalf("Unexpected output from %q\ngot:\n%s\n\nwant:\n%s", op.query, strings.Join(got, "\n"), strings.Join(op.want, "\n")) - } - } -} - -func TestDeclTypes(t *testing.T) { - - d := SQLiteDriver{} - - conn, err := d.Open(":memory:") - if err != nil { - t.Fatal("Failed to begin transaction:", err) - } - defer conn.Close() - - sqlite3conn := conn.(*SQLiteConn) - - _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text)", nil) - if err != nil { - t.Fatal("Failed to create table:", err) - } - - _, err = sqlite3conn.Exec("insert into foo(name) values(\"bar\")", nil) - if err != nil { - t.Fatal("Failed to insert:", err) - } - - rs, err := sqlite3conn.Query("select * from foo", nil) - if err != nil { - t.Fatal("Failed to select:", err) - } - defer rs.Close() - - declTypes := rs.(*SQLiteRows).DeclTypes() - - if !reflect.DeepEqual(declTypes, []string{"integer", "text"}) { - t.Fatal("Unexpected declTypes:", declTypes) - } -} - -func TestUpdateAndTransactionHooks(t *testing.T) { - var events []string - var commitHookReturn = 0 - - sql.Register("sqlite3_UpdateHook", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - conn.RegisterCommitHook(func() int { - events = append(events, "commit") - return commitHookReturn - }) - conn.RegisterRollbackHook(func() { - events = append(events, "rollback") - }) - conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) { - events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid)) - }) - return nil - }, - }) - db, err := sql.Open("sqlite3_UpdateHook", ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - statements := []string{ - "create table foo (id integer primary key)", - "insert into foo values (9)", - "update foo set id = 99 where id = 9", - "delete from foo where id = 99", - } - for _, statement := range statements { - _, err = db.Exec(statement) - if err != nil { - t.Fatalf("Unable to prepare test data [%v]: %v", statement, err) - } - } - - commitHookReturn = 1 - _, err = db.Exec("insert into foo values (5)") - if err == nil { - t.Error("Commit hook failed to rollback transaction") - } - - var expected = []string{ - "commit", - fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT), - "commit", - fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE), - "commit", - fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE), - "commit", - fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT), - "commit", - "rollback", - } - if !reflect.DeepEqual(events, expected) { - t.Errorf("Expected notifications %v but got %v", expected, events) - } -} - -func TestNilAndEmptyBytes(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - actualNil := []byte("use this to use an actual nil not a reference to nil") - emptyBytes := []byte{} - for tsti, tst := range []struct { - name string - columnType string - insertBytes []byte - expectedBytes []byte - }{ - {"actual nil blob", "blob", actualNil, nil}, - {"referenced nil blob", "blob", nil, nil}, - {"empty blob", "blob", emptyBytes, emptyBytes}, - {"actual nil text", "text", actualNil, nil}, - {"referenced nil text", "text", nil, nil}, - {"empty text", "text", emptyBytes, emptyBytes}, - } { - if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil { - t.Fatal(tst.name, err) - } - if bytes.Equal(tst.insertBytes, actualNil) { - if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil { - t.Fatal(tst.name, err) - } - } else { - if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil { - t.Fatal(tst.name, err) - } - } - rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti)) - if err != nil { - t.Fatal(tst.name, err) - } - if !rows.Next() { - t.Fatal(tst.name, "no rows") - } - var scanBytes []byte - if err = rows.Scan(&scanBytes); err != nil { - t.Fatal(tst.name, err) - } - if err = rows.Err(); err != nil { - t.Fatal(tst.name, err) - } - if tst.expectedBytes == nil && scanBytes != nil { - t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) - } else if !bytes.Equal(scanBytes, tst.expectedBytes) { - t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes) - } - } -} - -func TestInsertNilByteSlice(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - if _, err := db.Exec("create table blob_not_null (b blob not null)"); err != nil { - t.Fatal(err) - } - var nilSlice []byte - if _, err := db.Exec("insert into blob_not_null (b) values (?)", nilSlice); err == nil { - t.Fatal("didn't expect INSERT to 'not null' column with a nil []byte slice to work") - } - zeroLenSlice := []byte{} - if _, err := db.Exec("insert into blob_not_null (b) values (?)", zeroLenSlice); err != nil { - t.Fatal("failed to insert zero-length slice") - } -} - -var customFunctionOnce sync.Once - -func BenchmarkCustomFunctions(b *testing.B) { - customFunctionOnce.Do(func() { - customAdd := func(a, b int64) int64 { - return a + b - } - - sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - // Impure function to force sqlite to reexecute it each time. - return conn.RegisterFunc("custom_add", customAdd, false) - }, - }) - }) - - db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:") - if err != nil { - b.Fatal("Failed to open database:", err) - } - defer db.Close() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var i int64 - err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i) - if err != nil { - b.Fatal("Failed to run custom add:", err) - } - } -} - -func TestSuite(t *testing.T) { - tempFilename := TempFilename(t) - defer os.Remove(tempFilename) - d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999") - if err != nil { - t.Fatal(err) - } - defer d.Close() - - db = &TestDB{t, d, SQLITE, sync.Once{}} - testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests) - - if !testing.Short() { - for _, b := range benchmarks { - fmt.Printf("%-20s", b.Name) - r := testing.Benchmark(b.F) - fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds()) - } - } - db.tearDown() -} - -// Dialect is a type of dialect of databases. -type Dialect int - -// Dialects for databases. -const ( - SQLITE Dialect = iota // SQLITE mean SQLite3 dialect - POSTGRESQL // POSTGRESQL mean PostgreSQL dialect - MYSQL // MYSQL mean MySQL dialect -) - -// DB provide context for the tests -type TestDB struct { - *testing.T - *sql.DB - dialect Dialect - once sync.Once -} - -var db *TestDB - -// the following tables will be created and dropped during the test -var testTables = []string{"foo", "bar", "t", "bench"} - -var tests = []testing.InternalTest{ - {Name: "TestResult", F: testResult}, - {Name: "TestBlobs", F: testBlobs}, - {Name: "TestMultiBlobs", F: testMultiBlobs}, - {Name: "TestManyQueryRow", F: testManyQueryRow}, - {Name: "TestTxQuery", F: testTxQuery}, - {Name: "TestPreparedStmt", F: testPreparedStmt}, -} - -var benchmarks = []testing.InternalBenchmark{ - {Name: "BenchmarkExec", F: benchmarkExec}, - {Name: "BenchmarkQuery", F: benchmarkQuery}, - {Name: "BenchmarkParams", F: benchmarkParams}, - {Name: "BenchmarkStmt", F: benchmarkStmt}, - {Name: "BenchmarkRows", F: benchmarkRows}, - {Name: "BenchmarkStmtRows", F: benchmarkStmtRows}, -} - -func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result { - res, err := db.Exec(sql, args...) - if err != nil { - db.Fatalf("Error running %q: %v", sql, err) - } - return res -} - -func (db *TestDB) tearDown() { - for _, tbl := range testTables { - switch db.dialect { - case SQLITE: - db.mustExec("drop table if exists " + tbl) - case MYSQL, POSTGRESQL: - db.mustExec("drop table if exists " + tbl) - default: - db.Fatal("unknown dialect") - } - } -} - -// q replaces ? parameters if needed -func (db *TestDB) q(sql string) string { - switch db.dialect { - case POSTGRESQL: // replace with $1, $2, .. - qrx := regexp.MustCompile(`\?`) - n := 0 - return qrx.ReplaceAllStringFunc(sql, func(string) string { - n++ - return "$" + strconv.Itoa(n) - }) - } - return sql -} - -func (db *TestDB) blobType(size int) string { - switch db.dialect { - case SQLITE: - return fmt.Sprintf("blob[%d]", size) - case POSTGRESQL: - return "bytea" - case MYSQL: - return fmt.Sprintf("VARBINARY(%d)", size) - } - panic("unknown dialect") -} - -func (db *TestDB) serialPK() string { - switch db.dialect { - case SQLITE: - return "integer primary key autoincrement" - case POSTGRESQL: - return "serial primary key" - case MYSQL: - return "integer primary key auto_increment" - } - panic("unknown dialect") -} - -func (db *TestDB) now() string { - switch db.dialect { - case SQLITE: - return "datetime('now')" - case POSTGRESQL: - return "now()" - case MYSQL: - return "now()" - } - panic("unknown dialect") -} - -func makeBench() { - if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil { - panic(err) - } - st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)") - if err != nil { - panic(err) - } - defer st.Close() - for i := 0; i < 100; i++ { - if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil { - panic(err) - } - } -} - -// testResult is test for result -func testResult(t *testing.T) { - db.tearDown() - db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))") - - for i := 1; i < 3; i++ { - r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i)) - n, err := r.RowsAffected() - if err != nil { - t.Fatal(err) - } - if n != 1 { - t.Errorf("got %v, want %v", n, 1) - } - n, err = r.LastInsertId() - if err != nil { - t.Fatal(err) - } - if n != int64(i) { - t.Errorf("got %v, want %v", n, i) - } - } - if _, err := db.Exec("error!"); err == nil { - t.Fatalf("expected error") - } -} - -// testBlobs is test for blobs -func testBlobs(t *testing.T) { - db.tearDown() - var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")") - db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob) - - want := fmt.Sprintf("%x", blob) - - b := make([]byte, 16) - err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b) - got := fmt.Sprintf("%x", b) - if err != nil { - t.Errorf("[]byte scan: %v", err) - } else if got != want { - t.Errorf("for []byte, got %q; want %q", got, want) - } - - err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got) - want = string(blob) - if err != nil { - t.Errorf("string scan: %v", err) - } else if got != want { - t.Errorf("for string, got %q; want %q", got, want) - } -} - -func testMultiBlobs(t *testing.T) { - db.tearDown() - db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")") - var blob0 = []byte{0, 1, 2, 3, 4, 5, 6, 7} - db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob0) - var blob1 = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} - db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 1, blob1) - - r, err := db.Query(db.q("select bar from foo order by id")) - if err != nil { - t.Fatal(err) - } - defer r.Close() - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected one rows") - } - - want0 := fmt.Sprintf("%x", blob0) - b0 := make([]byte, 8) - err = r.Scan(&b0) - if err != nil { - t.Fatal(err) - } - got0 := fmt.Sprintf("%x", b0) - - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected one rows") - } - - want1 := fmt.Sprintf("%x", blob1) - b1 := make([]byte, 16) - err = r.Scan(&b1) - if err != nil { - t.Fatal(err) - } - got1 := fmt.Sprintf("%x", b1) - if got0 != want0 { - t.Errorf("for []byte, got %q; want %q", got0, want0) - } - if got1 != want1 { - t.Errorf("for []byte, got %q; want %q", got1, want1) - } -} - -// testManyQueryRow is test for many query row -func testManyQueryRow(t *testing.T) { - if testing.Short() { - t.Log("skipping in short mode") - return - } - db.tearDown() - db.mustExec("create table foo (id integer primary key, name varchar(50))") - db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") - var name string - for i := 0; i < 10000; i++ { - err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name) - if err != nil || name != "bob" { - t.Fatalf("on query %d: err=%v, name=%q", i, err, name) - } - } -} - -// testTxQuery is test for transactional query -func testTxQuery(t *testing.T) { - db.tearDown() - tx, err := db.Begin() - if err != nil { - t.Fatal(err) - } - defer tx.Rollback() - - _, err = tx.Exec("create table foo (id integer primary key, name varchar(50))") - if err != nil { - t.Fatal(err) - } - - _, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob") - if err != nil { - t.Fatal(err) - } - - r, err := tx.Query(db.q("select name from foo where id = ?"), 1) - if err != nil { - t.Fatal(err) - } - defer r.Close() - - if !r.Next() { - if r.Err() != nil { - t.Fatal(err) - } - t.Fatal("expected one rows") - } - - var name string - err = r.Scan(&name) - if err != nil { - t.Fatal(err) - } -} - -// testPreparedStmt is test for prepared statement -func testPreparedStmt(t *testing.T) { - db.tearDown() - db.mustExec("CREATE TABLE t (count INT)") - sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC") - if err != nil { - t.Fatalf("prepare 1: %v", err) - } - ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)")) - if err != nil { - t.Fatalf("prepare 2: %v", err) - } - - for n := 1; n <= 3; n++ { - if _, err := ins.Exec(n); err != nil { - t.Fatalf("insert(%d) = %v", n, err) - } - } - - const nRuns = 10 - var wg sync.WaitGroup - for i := 0; i < nRuns; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 10; j++ { - count := 0 - if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { - t.Errorf("Query: %v", err) - return - } - if _, err := ins.Exec(rand.Intn(100)); err != nil { - t.Errorf("Insert: %v", err) - return - } - } - }() - } - wg.Wait() -} - -// Benchmarks need to use panic() since b.Error errors are lost when -// running via testing.Benchmark() I would like to run these via go -// test -bench but calling Benchmark() from a benchmark test -// currently hangs go. - -// benchmarkExec is benchmark for exec -func benchmarkExec(b *testing.B) { - for i := 0; i < b.N; i++ { - if _, err := db.Exec("select 1"); err != nil { - panic(err) - } - } -} - -// benchmarkQuery is benchmark for query -func benchmarkQuery(b *testing.B) { - for i := 0; i < b.N; i++ { - var n sql.NullString - var i int - var f float64 - var s string - // var t time.Time - if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { - panic(err) - } - } -} - -// benchmarkParams is benchmark for params -func benchmarkParams(b *testing.B) { - for i := 0; i < b.N; i++ { - var n sql.NullString - var i int - var f float64 - var s string - // var t time.Time - if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { - panic(err) - } - } -} - -// benchmarkStmt is benchmark for statement -func benchmarkStmt(b *testing.B) { - st, err := db.Prepare("select ?, ?, ?, ?") - if err != nil { - panic(err) - } - defer st.Close() - - for n := 0; n < b.N; n++ { - var n sql.NullString - var i int - var f float64 - var s string - // var t time.Time - if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { - panic(err) - } - } -} - -// benchmarkRows is benchmark for rows -func benchmarkRows(b *testing.B) { - db.once.Do(makeBench) - - for n := 0; n < b.N; n++ { - var n sql.NullString - var i int - var f float64 - var s string - var t time.Time - r, err := db.Query("select * from bench") - if err != nil { - panic(err) - } - for r.Next() { - if err = r.Scan(&n, &i, &f, &s, &t); err != nil { - panic(err) - } - } - if err = r.Err(); err != nil { - panic(err) - } - } -} - -// benchmarkStmtRows is benchmark for statement rows -func benchmarkStmtRows(b *testing.B) { - db.once.Do(makeBench) - - st, err := db.Prepare("select * from bench") - if err != nil { - panic(err) - } - defer st.Close() - - for n := 0; n < b.N; n++ { - var n sql.NullString - var i int - var f float64 - var s string - var t time.Time - r, err := st.Query() - if err != nil { - panic(err) - } - for r.Next() { - if err = r.Scan(&n, &i, &f, &s, &t); err != nil { - panic(err) - } - } - if err = r.Err(); err != nil { - panic(err) - } - } -} diff --git a/driver/error.go b/driver/error.go index 22885fd..dc6a1dc 100644 --- a/driver/error.go +++ b/driver/error.go @@ -21,9 +21,6 @@ package sqlite3 #endif */ import "C" -import ( - "fmt" -) // ErrNo inherit errno. type ErrNo int @@ -165,7 +162,6 @@ func lastError(db *C.sqlite3) error { } func errorString(err Error) string { - fmt.Println("errorString") return C.GoString(C.sqlite3_errstr(C.int(err.Code))) } diff --git a/driver/opt_stat4_test.go b/driver/opt_stat4_test.go index 19a2492..658067b 100644 --- a/driver/opt_stat4_test.go +++ b/driver/opt_stat4_test.go @@ -47,6 +47,6 @@ func TestStat4(t *testing.T) { } if exists != 1 { - t.Fatal("Failed to enabled STAT4") + t.Fatal("Failed to enable STAT4") } } diff --git a/driver/opt_userauth_test.go b/driver/opt_userauth_test.go index 7d75ee2..4b03107 100644 --- a/driver/opt_userauth_test.go +++ b/driver/opt_userauth_test.go @@ -61,7 +61,7 @@ func init() { file = TempFilename(t) } - db, err = sql.Open("sqlite3_with_conn", "file:"+file+fmt.Sprintf("?_auth&_auth_user=%s&_auth_pass=%s", username, password)) + db, err = sql.Open("sqlite3_with_conn", "file:"+file+fmt.Sprintf("?user=%s&pass=%s", username, password)) if err != nil { defer os.Remove(file) return file, nil, nil, err @@ -87,7 +87,7 @@ func init() { file = TempFilename(t) } - db, err = sql.Open("sqlite3_with_conn", "file:"+file+fmt.Sprintf("?_auth&_auth_user=%s&_auth_pass=%s&_auth_crypt=%s&_auth_salt=%s", username, password, crypt, salt)) + db, err = sql.Open("sqlite3_with_conn", "file:"+file+fmt.Sprintf("?user=%s&pass=%s&crypt=%s&salt=%s", username, password, crypt, salt)) if err != nil { defer os.Remove(file) return file, nil, nil, err @@ -270,7 +270,7 @@ func TestUserAuthAddAdmin(t *testing.T) { func TestUserAuthAddUser(t *testing.T) { f1, db1, c, err := connect(t, "", "admin", "admin") - if err != nil && c == nil && db == nil { + if err != nil && c == nil && db1 == nil { t.Fatal(err) } defer os.Remove(f1) @@ -365,7 +365,7 @@ func TestUserAuthAddUser(t *testing.T) { func TestUserAuthModifyUser(t *testing.T) { f1, db1, c1, err := connect(t, "", "admin", "admin") - if err != nil && c1 == nil && db == nil { + if err != nil && c1 == nil && db1 == nil { t.Fatal(err) } defer os.Remove(f1) @@ -464,7 +464,7 @@ func TestUserAuthModifyUser(t *testing.T) { func TestUserAuthDeleteUser(t *testing.T) { f1, db1, c, err := connect(t, "", "admin", "admin") - if err != nil && c == nil && db == nil { + if err != nil && c == nil && db1 == nil { t.Fatal(err) } defer os.Remove(f1) diff --git a/driver/pragma.go b/driver/pragma.go new file mode 100644 index 0000000..f4afd5c --- /dev/null +++ b/driver/pragma.go @@ -0,0 +1,23 @@ +// Copyright (C) 2018 The Go-SQLite3 Authors. +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// +build cgo + +package sqlite3 + +const ( + PRAGMA_AUTO_VACUUM = "auto_vacuum" + PRAGMA_CASE_SENSITIVE_LIKE = "case_sensitive_like" + PRAGMA_DEFER_FOREIGN_KEYS = "defer_foreign_keys" + PRAGMA_FOREIGN_KEYS = "foreign_keys" + PRAGMA_IGNORE_CHECK_CONTRAINTS = "ignore_check_constraints" + PRAGMA_JOURNAL_MODE = "journal_mode" + PRAGMA_LOCKING_MODE = "locking_mode" + PRAGMA_QUERY_ONLY = "query_only" + PRAGMA_RECURSIVE_TRIGGERS = "recursive_triggers" + PRAGMA_SECURE_DELETE = "secure_delete" + PRAGMA_SYNCHRONOUS = "synchronous" + PRAGMA_WRITABLE_SCHEMA = "writable_schema" +) diff --git a/driver/context.go b/driver/vtable_context.go similarity index 99% rename from driver/context.go rename to driver/vtable_context.go index 74ec6c3..0ccc7f0 100644 --- a/driver/context.go +++ b/driver/vtable_context.go @@ -4,6 +4,7 @@ // license that can be found in the LICENSE file. // +build cgo +// +build sqlite_vtable package sqlite3