forked from mirror/go-sqlcipher
Upgrade to the latest sqlcipher.
This commit is contained in:
commit
04d43c043e
|
@ -1,3 +1,4 @@
|
|||
*.db
|
||||
*.exe
|
||||
*.dll
|
||||
*.o
|
||||
|
|
14
.travis.yml
14
.travis.yml
|
@ -1,13 +1,19 @@
|
|||
language: go
|
||||
sudo: required
|
||||
dist: trusty
|
||||
env:
|
||||
- GOTAGS=
|
||||
- GOTAGS=libsqlite3
|
||||
- GOTAGS=trace
|
||||
- GOTAGS=vtable
|
||||
go:
|
||||
- 1.5
|
||||
- 1.6
|
||||
- tip
|
||||
- 1.7.x
|
||||
- 1.8.x
|
||||
- 1.9.x
|
||||
- master
|
||||
before_install:
|
||||
- go get github.com/mattn/goveralls
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
script:
|
||||
- $HOME/gopath/bin/goveralls -repotoken 3qJVUE0iQwqnCbmNcDsjYu1nh4J4KIFXx
|
||||
- go test -v . -tags "libsqlite3"
|
||||
- go test -race -v . -tags "$GOTAGS"
|
||||
|
|
25
README.md
25
README.md
|
@ -6,10 +6,10 @@ SQLCipher driver conforming to the built-in database/sql interface and using the
|
|||
|
||||
|
||||
which is
|
||||
`3.8.8.3 2015-02-25 13:29:11 9d6c1880fb75660bbabd693175579529785f8a6b`
|
||||
`3.20.1`
|
||||
|
||||
Working with sqlcipher version which is
|
||||
`3.8.6 2014-08-15 11:46:33 9491ba7d738528f168657adb43a198238abde19e`
|
||||
`3.4.2`
|
||||
|
||||
It's wrapper with
|
||||
* [go-sqlite3](https://github.com/mattn/go-sqlite3) sqlite3 driver for go that using database/sql.
|
||||
|
@ -73,6 +73,8 @@ Here is some help from go-sqlite3 project.
|
|||
|
||||
Use `go build --tags "icu"`
|
||||
|
||||
Available extensions: `json1`, `fts5`, `icu`
|
||||
|
||||
* Can't build go-sqlite3 on windows 64bit.
|
||||
|
||||
> Probably, you are using go 1.0, go1.0 has a problem when it comes to compiling/linking on windows 64bit.
|
||||
|
@ -81,16 +83,29 @@ Here is some help from go-sqlite3 project.
|
|||
* Getting insert error while query is opened.
|
||||
|
||||
> You can pass some arguments into the connection string, for example, a URI.
|
||||
> See: https://github.com/mattn/go-sqlite3/issues/39
|
||||
> See: [#39](https://github.com/mattn/go-sqlite3/issues/39)
|
||||
|
||||
* Do you want to cross compile? mingw on Linux or Mac?
|
||||
|
||||
> See: https://github.com/mattn/go-sqlite3/issues/106
|
||||
> See: [#106](https://github.com/mattn/go-sqlite3/issues/106)
|
||||
> See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html
|
||||
|
||||
* Want to get time.Time with current locale
|
||||
|
||||
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
|
||||
Use `_loc=auto` in SQLite3 filename schema like `file:foo.db?_loc=auto`.
|
||||
|
||||
* Can I use this in multiple routines concurrently?
|
||||
|
||||
Yes for readonly. But, No for writable. See [#50](https://github.com/mattn/go-sqlite3/issues/50), [#51](https://github.com/mattn/go-sqlite3/issues/51), [#209](https://github.com/mattn/go-sqlite3/issues/209), [#274](https://github.com/mattn/go-sqlite3/issues/274).
|
||||
|
||||
* Why is it racy if I use a `sql.Open("sqlite3", ":memory:")` database?
|
||||
|
||||
Each connection to :memory: opens a brand new in-memory sql database, so if
|
||||
the stdlib's sql engine happens to open another connection and you've only
|
||||
specified ":memory:", that connection will see a brand new database. A
|
||||
workaround is to use "file::memory:?mode=memory&cache=shared". Every
|
||||
connection to this string will point to the same in-memory database. See
|
||||
[#204](https://github.com/mattn/go-sqlite3/issues/204) for more info.
|
||||
|
||||
* Print some waring messages like `warning: 'RAND_add' is deprecated: first deprecated in OS X 10.7`
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ package main
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/xeodou/go-sqlcipher"
|
||||
)
|
||||
|
||||
|
@ -24,6 +25,7 @@ func main() {
|
|||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
c := "CREATE TABLE IF NOT EXISTS `users` (`id` INTEGER PRIMARY KEY, `name` char, `password` chart, UNIQUE(`name`));"
|
||||
_, err = db.Exec(c)
|
||||
if err != nil {
|
||||
|
|
|
@ -2,9 +2,10 @@ package main
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -13,42 +14,48 @@ func main() {
|
|||
&sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||
sqlite3conn = append(sqlite3conn, conn)
|
||||
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
|
||||
switch op {
|
||||
case sqlite3.SQLITE_INSERT:
|
||||
log.Println("Notified of insert on db", db, "table", table, "rowid", rowid)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
},
|
||||
})
|
||||
os.Remove("./foo.db")
|
||||
os.Remove("./bar.db")
|
||||
|
||||
destDb, err := sql.Open("sqlite3_with_hook_example", "./foo.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer destDb.Close()
|
||||
destDb.Ping()
|
||||
|
||||
_, err = destDb.Exec("create table foo(id int, value text)")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
_, err = destDb.Exec("insert into foo values(1, 'foo')")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
_, err = destDb.Exec("insert into foo values(2, 'bar')")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
_, err = destDb.Query("select * from foo")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
srcDb, err := sql.Open("sqlite3_with_hook_example", "./bar.db")
|
||||
srcDb, err := sql.Open("sqlite3_with_hook_example", "./foo.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer srcDb.Close()
|
||||
srcDb.Ping()
|
||||
|
||||
_, err = srcDb.Exec("create table foo(id int, value text)")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
_, err = srcDb.Exec("insert into foo values(1, 'foo')")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
_, err = srcDb.Exec("insert into foo values(2, 'bar')")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
_, err = srcDb.Query("select * from foo")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
destDb, err := sql.Open("sqlite3_with_hook_example", "./bar.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer destDb.Close()
|
||||
destDb.Ping()
|
||||
|
||||
bk, err := sqlite3conn[1].Backup("main", sqlite3conn[0], "main")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func createBulkInsertQuery(n int, start int) (query string, args []interface{}) {
|
||||
values := make([]string, n)
|
||||
args = make([]interface{}, n*2)
|
||||
pos := 0
|
||||
for i := 0; i < n; i++ {
|
||||
values[i] = "(?, ?)"
|
||||
args[pos] = start + i
|
||||
args[pos+1] = fmt.Sprintf("こんにちわ世界%03d", i)
|
||||
pos += 2
|
||||
}
|
||||
query = fmt.Sprintf(
|
||||
"insert into foo(id, name) values %s",
|
||||
strings.Join(values, ", "),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func bukInsert(db *sql.DB, query string, args []interface{}) (err error) {
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(args...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func main() {
|
||||
var sqlite3conn *sqlite3.SQLiteConn
|
||||
sql.Register("sqlite3_with_limit", &sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||
sqlite3conn = conn
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
os.Remove("./foo.db")
|
||||
db, err := sql.Open("sqlite3_with_limit", "./foo.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
sqlStmt := `
|
||||
create table foo (id integer not null primary key, name text);
|
||||
delete from foo;
|
||||
`
|
||||
_, err = db.Exec(sqlStmt)
|
||||
if err != nil {
|
||||
log.Printf("%q: %s\n", err, sqlStmt)
|
||||
return
|
||||
}
|
||||
|
||||
if sqlite3conn == nil {
|
||||
log.Fatal("not set sqlite3 connection")
|
||||
}
|
||||
|
||||
limitVariableNumber := sqlite3conn.GetLimit(sqlite3.SQLITE_LIMIT_VARIABLE_NUMBER)
|
||||
log.Printf("default SQLITE_LIMIT_VARIABLE_NUMBER: %d", limitVariableNumber)
|
||||
|
||||
num := 400
|
||||
query, args := createBulkInsertQuery(num, 0)
|
||||
err = bukInsert(db, query, args)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
smallLimitVariableNumber := 100
|
||||
sqlite3conn.SetLimit(sqlite3.SQLITE_LIMIT_VARIABLE_NUMBER, smallLimitVariableNumber)
|
||||
|
||||
limitVariableNumber = sqlite3conn.GetLimit(sqlite3.SQLITE_LIMIT_VARIABLE_NUMBER)
|
||||
log.Printf("updated SQLITE_LIMIT_VARIABLE_NUMBER: %d", limitVariableNumber)
|
||||
|
||||
query, args = createBulkInsertQuery(num, num)
|
||||
err = bukInsert(db, query, args)
|
||||
if err != nil {
|
||||
if err != nil {
|
||||
log.Printf("expect failed since SQLITE_LIMIT_VARIABLE_NUMBER is too small: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
bigLimitVariableNumber := 999999
|
||||
sqlite3conn.SetLimit(sqlite3.SQLITE_LIMIT_VARIABLE_NUMBER, bigLimitVariableNumber)
|
||||
limitVariableNumber = sqlite3conn.GetLimit(sqlite3.SQLITE_LIMIT_VARIABLE_NUMBER)
|
||||
log.Printf("set SQLITE_LIMIT_VARIABLE_NUMBER: %d", bigLimitVariableNumber)
|
||||
log.Printf("updated SQLITE_LIMIT_VARIABLE_NUMBER: %d", limitVariableNumber)
|
||||
|
||||
query, args = createBulkInsertQuery(500, num+num)
|
||||
err = bukInsert(db, query, args)
|
||||
if err != nil {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Println("no error if SQLITE_LIMIT_VARIABLE_NUMBER > 999")
|
||||
}
|
|
@ -3,8 +3,9 @@ package main
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"log"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -29,8 +30,8 @@ func main() {
|
|||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var id, full_name, description, html_url string
|
||||
rows.Scan(&id, &full_name, &description, &html_url)
|
||||
fmt.Printf("%s: %s\n\t%s\n\t%s\n\n", id, full_name, description, html_url)
|
||||
var id, fullName, description, htmlURL string
|
||||
rows.Scan(&id, &fullName, &description, &htmlURL)
|
||||
fmt.Printf("%s: %s\n\t%s\n\t%s\n\n", id, fullName, description, htmlURL)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,264 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func traceCallback(info sqlite3.TraceInfo) int {
|
||||
// Not very readable but may be useful; uncomment next line in case of doubt:
|
||||
//fmt.Printf("Trace: %#v\n", info)
|
||||
|
||||
var dbErrText string
|
||||
if info.DBError.Code != 0 || info.DBError.ExtendedCode != 0 {
|
||||
dbErrText = fmt.Sprintf("; DB error: %#v", info.DBError)
|
||||
} else {
|
||||
dbErrText = "."
|
||||
}
|
||||
|
||||
// Show the Statement-or-Trigger text in curly braces ('{', '}')
|
||||
// since from the *paired* ASCII characters they are
|
||||
// the least used in SQL syntax, therefore better visual delimiters.
|
||||
// Maybe show 'ExpandedSQL' the same way as 'StmtOrTrigger'.
|
||||
//
|
||||
// A known use of curly braces (outside strings) is
|
||||
// for ODBC escape sequences. Not likely to appear here.
|
||||
//
|
||||
// Template languages, etc. don't matter, we should see their *result*
|
||||
// at *this* level.
|
||||
// Strange curly braces in SQL code that reached the database driver
|
||||
// suggest that there is a bug in the application.
|
||||
// The braces are likely to be either template syntax or
|
||||
// a programming language's string interpolation syntax.
|
||||
|
||||
var expandedText string
|
||||
if info.ExpandedSQL != "" {
|
||||
if info.ExpandedSQL == info.StmtOrTrigger {
|
||||
expandedText = " = exp"
|
||||
} else {
|
||||
expandedText = fmt.Sprintf(" expanded {%q}", info.ExpandedSQL)
|
||||
}
|
||||
} else {
|
||||
expandedText = ""
|
||||
}
|
||||
|
||||
// SQLite docs as of September 6, 2016: Tracing and Profiling Functions
|
||||
// https://www.sqlite.org/c3ref/profile.html
|
||||
//
|
||||
// The profile callback time is in units of nanoseconds, however
|
||||
// the current implementation is only capable of millisecond resolution
|
||||
// so the six least significant digits in the time are meaningless.
|
||||
// Future versions of SQLite might provide greater resolution on the profiler callback.
|
||||
|
||||
var runTimeText string
|
||||
if info.RunTimeNanosec == 0 {
|
||||
if info.EventCode == sqlite3.TraceProfile {
|
||||
//runTimeText = "; no time" // seems confusing
|
||||
runTimeText = "; time 0" // no measurement unit
|
||||
} else {
|
||||
//runTimeText = "; no time" // seems useless and confusing
|
||||
}
|
||||
} else {
|
||||
const nanosPerMillisec = 1000000
|
||||
if info.RunTimeNanosec%nanosPerMillisec == 0 {
|
||||
runTimeText = fmt.Sprintf("; time %d ms", info.RunTimeNanosec/nanosPerMillisec)
|
||||
} else {
|
||||
// unexpected: better than millisecond resolution
|
||||
runTimeText = fmt.Sprintf("; time %d ns!!!", info.RunTimeNanosec)
|
||||
}
|
||||
}
|
||||
|
||||
var modeText string
|
||||
if info.AutoCommit {
|
||||
modeText = "-AC-"
|
||||
} else {
|
||||
modeText = "+Tx+"
|
||||
}
|
||||
|
||||
fmt.Printf("Trace: ev %d %s conn 0x%x, stmt 0x%x {%q}%s%s%s\n",
|
||||
info.EventCode, modeText, info.ConnHandle, info.StmtHandle,
|
||||
info.StmtOrTrigger, expandedText,
|
||||
runTimeText,
|
||||
dbErrText)
|
||||
return 0
|
||||
}
|
||||
|
||||
func main() {
|
||||
eventMask := sqlite3.TraceStmt | sqlite3.TraceProfile | sqlite3.TraceRow | sqlite3.TraceClose
|
||||
|
||||
sql.Register("sqlite3_tracing",
|
||||
&sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||
err := conn.SetTrace(&sqlite3.TraceConfig{
|
||||
Callback: traceCallback,
|
||||
EventMask: eventMask,
|
||||
WantExpandedSQL: true,
|
||||
})
|
||||
return err
|
||||
},
|
||||
})
|
||||
|
||||
os.Exit(dbMain())
|
||||
}
|
||||
|
||||
// Harder to do DB work in main().
|
||||
// It's better with a separate function because
|
||||
// 'defer' and 'os.Exit' don't go well together.
|
||||
//
|
||||
// DO NOT use 'log.Fatal...' below: remember that it's equivalent to
|
||||
// Print() followed by a call to os.Exit(1) --- and
|
||||
// we want to avoid Exit() so 'defer' can do cleanup.
|
||||
// Use 'log.Panic...' instead.
|
||||
|
||||
func dbMain() int {
|
||||
db, err := sql.Open("sqlite3_tracing", ":memory:")
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to open database: %#+v\n", err)
|
||||
return 1
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
|
||||
dbSetup(db)
|
||||
|
||||
dbDoInsert(db)
|
||||
dbDoInsertPrepared(db)
|
||||
dbDoSelect(db)
|
||||
dbDoSelectPrepared(db)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// 'DDL' stands for "Data Definition Language":
|
||||
|
||||
// Note: "INTEGER PRIMARY KEY NOT NULL AUTOINCREMENT" causes the error
|
||||
// 'near "AUTOINCREMENT": syntax error'; without "NOT NULL" it works.
|
||||
const tableDDL = `CREATE TABLE t1 (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
note VARCHAR NOT NULL
|
||||
)`
|
||||
|
||||
// 'DML' stands for "Data Manipulation Language":
|
||||
|
||||
const insertDML = "INSERT INTO t1 (note) VALUES (?)"
|
||||
const selectDML = "SELECT id, note FROM t1 WHERE note LIKE ?"
|
||||
|
||||
const textPrefix = "bla-1234567890-"
|
||||
const noteTextPattern = "%Prep%"
|
||||
|
||||
const nGenRows = 4 // Number of Rows to Generate (for *each* approach tested)
|
||||
|
||||
func dbSetup(db *sql.DB) {
|
||||
var err error
|
||||
|
||||
_, err = db.Exec("DROP TABLE IF EXISTS t1")
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
_, err = db.Exec(tableDDL)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func dbDoInsert(db *sql.DB) {
|
||||
const Descr = "DB-Exec"
|
||||
for i := 0; i < nGenRows; i++ {
|
||||
result, err := db.Exec(insertDML, textPrefix+Descr)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
|
||||
resultDoCheck(result, Descr, i)
|
||||
}
|
||||
}
|
||||
|
||||
func dbDoInsertPrepared(db *sql.DB) {
|
||||
const Descr = "DB-Prepare"
|
||||
|
||||
stmt, err := db.Prepare(insertDML)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for i := 0; i < nGenRows; i++ {
|
||||
result, err := stmt.Exec(textPrefix + Descr)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
|
||||
resultDoCheck(result, Descr, i)
|
||||
}
|
||||
}
|
||||
|
||||
func resultDoCheck(result sql.Result, callerDescr string, callIndex int) {
|
||||
lastID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
nAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
|
||||
log.Printf("Exec result for %s (%d): ID = %d, affected = %d\n", callerDescr, callIndex, lastID, nAffected)
|
||||
}
|
||||
|
||||
func dbDoSelect(db *sql.DB) {
|
||||
const Descr = "DB-Query"
|
||||
|
||||
rows, err := db.Query(selectDML, noteTextPattern)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
rowsDoFetch(rows, Descr)
|
||||
}
|
||||
|
||||
func dbDoSelectPrepared(db *sql.DB) {
|
||||
const Descr = "DB-Prepare"
|
||||
|
||||
stmt, err := db.Prepare(selectDML)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.Query(noteTextPattern)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
rowsDoFetch(rows, Descr)
|
||||
}
|
||||
|
||||
func rowsDoFetch(rows *sql.Rows, callerDescr string) {
|
||||
var nRows int
|
||||
var id int64
|
||||
var note string
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&id, ¬e)
|
||||
if err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
log.Printf("Row for %s (%d): id=%d, note=%q\n",
|
||||
callerDescr, nRows, id, note)
|
||||
nRows++
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Panic(err)
|
||||
}
|
||||
log.Printf("Total %d rows for %s.\n", nRows, callerDescr)
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||
return conn.CreateModule("github", &githubModule{})
|
||||
},
|
||||
})
|
||||
db, err := sql.Open("sqlite3_with_extensions", ":memory:")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec("create virtual table repo using github(id, full_name, description, html_url)")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := db.Query("select id, full_name, description, html_url from repo")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var id, fullName, description, htmlURL string
|
||||
rows.Scan(&id, &fullName, &description, &htmlURL)
|
||||
fmt.Printf("%s: %s\n\t%s\n\t%s\n\n", id, fullName, description, htmlURL)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type githubRepo struct {
|
||||
ID int `json:"id"`
|
||||
FullName string `json:"full_name"`
|
||||
Description string `json:"description"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
}
|
||||
|
||||
type githubModule struct {
|
||||
}
|
||||
|
||||
func (m *githubModule) Create(c *sqlite3.SQLiteConn, args []string) (sqlite3.VTab, error) {
|
||||
err := c.DeclareVTab(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id INT,
|
||||
full_name TEXT,
|
||||
description TEXT,
|
||||
html_url TEXT
|
||||
)`, args[0]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ghRepoTable{}, nil
|
||||
}
|
||||
|
||||
func (m *githubModule) Connect(c *sqlite3.SQLiteConn, args []string) (sqlite3.VTab, error) {
|
||||
return m.Create(c, args)
|
||||
}
|
||||
|
||||
func (m *githubModule) DestroyModule() {}
|
||||
|
||||
type ghRepoTable struct {
|
||||
repos []githubRepo
|
||||
}
|
||||
|
||||
func (v *ghRepoTable) Open() (sqlite3.VTabCursor, error) {
|
||||
resp, err := http.Get("https://api.github.com/repositories")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var repos []githubRepo
|
||||
if err := json.Unmarshal(body, &repos); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ghRepoCursor{0, repos}, nil
|
||||
}
|
||||
|
||||
func (v *ghRepoTable) BestIndex(cst []sqlite3.InfoConstraint, ob []sqlite3.InfoOrderBy) (*sqlite3.IndexResult, error) {
|
||||
return &sqlite3.IndexResult{}, nil
|
||||
}
|
||||
|
||||
func (v *ghRepoTable) Disconnect() error { return nil }
|
||||
func (v *ghRepoTable) Destroy() error { return nil }
|
||||
|
||||
type ghRepoCursor struct {
|
||||
index int
|
||||
repos []githubRepo
|
||||
}
|
||||
|
||||
func (vc *ghRepoCursor) Column(c *sqlite3.SQLiteContext, col int) error {
|
||||
switch col {
|
||||
case 0:
|
||||
c.ResultInt(vc.repos[vc.index].ID)
|
||||
case 1:
|
||||
c.ResultText(vc.repos[vc.index].FullName)
|
||||
case 2:
|
||||
c.ResultText(vc.repos[vc.index].Description)
|
||||
case 3:
|
||||
c.ResultText(vc.repos[vc.index].HTMLURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vc *ghRepoCursor) Filter(idxNum int, idxStr string, vals []interface{}) error {
|
||||
vc.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vc *ghRepoCursor) Next() error {
|
||||
vc.index++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vc *ghRepoCursor) EOF() bool {
|
||||
return vc.index >= len(vc.repos)
|
||||
}
|
||||
|
||||
func (vc *ghRepoCursor) Rowid() (int64, error) {
|
||||
return int64(vc.index), nil
|
||||
}
|
||||
|
||||
func (vc *ghRepoCursor) Close() error {
|
||||
return nil
|
||||
}
|
23
backup.go
23
backup.go
|
@ -19,10 +19,12 @@ import (
|
|||
"unsafe"
|
||||
)
|
||||
|
||||
// SQLiteBackup implement interface of Backup.
|
||||
type SQLiteBackup struct {
|
||||
b *C.sqlite3_backup
|
||||
}
|
||||
|
||||
// Backup make backup from src to dest.
|
||||
func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*SQLiteBackup, error) {
|
||||
destptr := C.CString(dest)
|
||||
defer C.free(unsafe.Pointer(destptr))
|
||||
|
@ -37,10 +39,10 @@ func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*SQLiteB
|
|||
return nil, c.lastError()
|
||||
}
|
||||
|
||||
// Backs up for one step. Calls the underlying `sqlite3_backup_step` function.
|
||||
// This function returns a boolean indicating if the backup is done and
|
||||
// an error signalling any other error. Done is returned if the underlying C
|
||||
// function returns SQLITE_DONE (Code 101)
|
||||
// Step to backs up for one step. Calls the underlying `sqlite3_backup_step`
|
||||
// function. This function returns a boolean indicating if the backup is done
|
||||
// and an error signalling any other error. Done is returned if the underlying
|
||||
// C function returns SQLITE_DONE (Code 101)
|
||||
func (b *SQLiteBackup) Step(p int) (bool, error) {
|
||||
ret := C.sqlite3_backup_step(b.b, C.int(p))
|
||||
if ret == C.SQLITE_DONE {
|
||||
|
@ -51,24 +53,33 @@ func (b *SQLiteBackup) Step(p int) (bool, error) {
|
|||
return false, nil
|
||||
}
|
||||
|
||||
// Remaining return whether have the rest for backup.
|
||||
func (b *SQLiteBackup) Remaining() int {
|
||||
return int(C.sqlite3_backup_remaining(b.b))
|
||||
}
|
||||
|
||||
// PageCount return count of pages.
|
||||
func (b *SQLiteBackup) PageCount() int {
|
||||
return int(C.sqlite3_backup_pagecount(b.b))
|
||||
}
|
||||
|
||||
// Finish close backup.
|
||||
func (b *SQLiteBackup) Finish() error {
|
||||
return b.Close()
|
||||
}
|
||||
|
||||
// Close close backup.
|
||||
func (b *SQLiteBackup) Close() error {
|
||||
ret := C.sqlite3_backup_finish(b.b)
|
||||
|
||||
// sqlite3_backup_finish() never fails, it just returns the
|
||||
// error code from previous operations, so clean up before
|
||||
// checking and returning an error
|
||||
b.b = nil
|
||||
runtime.SetFinalizer(b, nil)
|
||||
|
||||
if ret != 0 {
|
||||
return Error{Code: ErrNo(ret)}
|
||||
}
|
||||
b.b = nil
|
||||
runtime.SetFinalizer(b, nil)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,290 @@
|
|||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// The number of rows of test data to create in the source database.
|
||||
// Can be used to control how many pages are available to be backed up.
|
||||
const testRowCount = 100
|
||||
|
||||
// The maximum number of seconds after which the page-by-page backup is considered to have taken too long.
|
||||
const usePagePerStepsTimeoutSeconds = 30
|
||||
|
||||
// Test the backup functionality.
|
||||
func testBackup(t *testing.T, testRowCount int, usePerPageSteps bool) {
|
||||
// This function will be called multiple times.
|
||||
// It uses sql.Register(), which requires the name parameter value to be unique.
|
||||
// There does not currently appear to be a way to unregister a registered driver, however.
|
||||
// So generate a database driver name that will likely be unique.
|
||||
var driverName = fmt.Sprintf("sqlite3_testBackup_%v_%v_%v", testRowCount, usePerPageSteps, time.Now().UnixNano())
|
||||
|
||||
// The driver's connection will be needed in order to perform the backup.
|
||||
driverConns := []*SQLiteConn{}
|
||||
sql.Register(driverName, &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
driverConns = append(driverConns, conn)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// Connect to the source database.
|
||||
srcTempFilename := TempFilename(t)
|
||||
defer os.Remove(srcTempFilename)
|
||||
srcDb, err := sql.Open(driverName, srcTempFilename)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open the source database:", err)
|
||||
}
|
||||
defer srcDb.Close()
|
||||
err = srcDb.Ping()
|
||||
if err != nil {
|
||||
t.Fatal("Failed to connect to the source database:", err)
|
||||
}
|
||||
|
||||
// Connect to the destination database.
|
||||
destTempFilename := TempFilename(t)
|
||||
defer os.Remove(destTempFilename)
|
||||
destDb, err := sql.Open(driverName, destTempFilename)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open the destination database:", err)
|
||||
}
|
||||
defer destDb.Close()
|
||||
err = destDb.Ping()
|
||||
if err != nil {
|
||||
t.Fatal("Failed to connect to the destination database:", err)
|
||||
}
|
||||
|
||||
// Check the driver connections.
|
||||
if len(driverConns) != 2 {
|
||||
t.Fatalf("Expected 2 driver connections, but found %v.", len(driverConns))
|
||||
}
|
||||
srcDbDriverConn := driverConns[0]
|
||||
if srcDbDriverConn == nil {
|
||||
t.Fatal("The source database driver connection is nil.")
|
||||
}
|
||||
destDbDriverConn := driverConns[1]
|
||||
if destDbDriverConn == nil {
|
||||
t.Fatal("The destination database driver connection is nil.")
|
||||
}
|
||||
|
||||
// Generate some test data for the given ID.
|
||||
var generateTestData = func(id int) string {
|
||||
return fmt.Sprintf("test-%v", id)
|
||||
}
|
||||
|
||||
// Populate the source database with a test table containing some test data.
|
||||
tx, err := srcDb.Begin()
|
||||
if err != nil {
|
||||
t.Fatal("Failed to begin a transaction when populating the source database:", err)
|
||||
}
|
||||
_, err = srcDb.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)")
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
t.Fatal("Failed to create the source database \"test\" table:", err)
|
||||
}
|
||||
for id := 0; id < testRowCount; id++ {
|
||||
_, err = srcDb.Exec("INSERT INTO test (id, value) VALUES (?, ?)", id, generateTestData(id))
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
t.Fatal("Failed to insert a row into the source database \"test\" table:", err)
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
t.Fatal("Failed to populate the source database:", err)
|
||||
}
|
||||
|
||||
// Confirm that the destination database is initially empty.
|
||||
var destTableCount int
|
||||
err = destDb.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'").Scan(&destTableCount)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to check the destination table count:", err)
|
||||
}
|
||||
if destTableCount != 0 {
|
||||
t.Fatalf("The destination database is not empty; %v table(s) found.", destTableCount)
|
||||
}
|
||||
|
||||
// Prepare to perform the backup.
|
||||
backup, err := destDbDriverConn.Backup("main", srcDbDriverConn, "main")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to initialize the backup:", err)
|
||||
}
|
||||
|
||||
// Allow the initial page count and remaining values to be retrieved.
|
||||
// According to <https://www.sqlite.org/c3ref/backup_finish.html>, the page count and remaining values are "... only updated by sqlite3_backup_step()."
|
||||
isDone, err := backup.Step(0)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to perform an initial 0-page backup step:", err)
|
||||
}
|
||||
if isDone {
|
||||
t.Fatal("Backup is unexpectedly done.")
|
||||
}
|
||||
|
||||
// Check that the page count and remaining values are reasonable.
|
||||
initialPageCount := backup.PageCount()
|
||||
if initialPageCount <= 0 {
|
||||
t.Fatalf("Unexpected initial page count value: %v", initialPageCount)
|
||||
}
|
||||
initialRemaining := backup.Remaining()
|
||||
if initialRemaining <= 0 {
|
||||
t.Fatalf("Unexpected initial remaining value: %v", initialRemaining)
|
||||
}
|
||||
if initialRemaining != initialPageCount {
|
||||
t.Fatalf("Initial remaining value differs from the initial page count value; remaining: %v; page count: %v", initialRemaining, initialPageCount)
|
||||
}
|
||||
|
||||
// Perform the backup.
|
||||
if usePerPageSteps {
|
||||
var startTime = time.Now().Unix()
|
||||
|
||||
// Test backing-up using a page-by-page approach.
|
||||
var latestRemaining = initialRemaining
|
||||
for {
|
||||
// Perform the backup step.
|
||||
isDone, err = backup.Step(1)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to perform a backup step:", err)
|
||||
}
|
||||
|
||||
// The page count should remain unchanged from its initial value.
|
||||
currentPageCount := backup.PageCount()
|
||||
if currentPageCount != initialPageCount {
|
||||
t.Fatalf("Current page count differs from the initial page count; initial page count: %v; current page count: %v", initialPageCount, currentPageCount)
|
||||
}
|
||||
|
||||
// There should now be one less page remaining.
|
||||
currentRemaining := backup.Remaining()
|
||||
expectedRemaining := latestRemaining - 1
|
||||
if currentRemaining != expectedRemaining {
|
||||
t.Fatalf("Unexpected remaining value; expected remaining value: %v; actual remaining value: %v", expectedRemaining, currentRemaining)
|
||||
}
|
||||
latestRemaining = currentRemaining
|
||||
|
||||
if isDone {
|
||||
break
|
||||
}
|
||||
|
||||
// Limit the runtime of the backup attempt.
|
||||
if (time.Now().Unix() - startTime) > usePagePerStepsTimeoutSeconds {
|
||||
t.Fatal("Backup is taking longer than expected.")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Test the copying of all remaining pages.
|
||||
isDone, err = backup.Step(-1)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to perform a backup step:", err)
|
||||
}
|
||||
if !isDone {
|
||||
t.Fatal("Backup is unexpectedly not done.")
|
||||
}
|
||||
}
|
||||
|
||||
// Check that the page count and remaining values are reasonable.
|
||||
finalPageCount := backup.PageCount()
|
||||
if finalPageCount != initialPageCount {
|
||||
t.Fatalf("Final page count differs from the initial page count; initial page count: %v; final page count: %v", initialPageCount, finalPageCount)
|
||||
}
|
||||
finalRemaining := backup.Remaining()
|
||||
if finalRemaining != 0 {
|
||||
t.Fatalf("Unexpected remaining value: %v", finalRemaining)
|
||||
}
|
||||
|
||||
// Finish the backup.
|
||||
err = backup.Finish()
|
||||
if err != nil {
|
||||
t.Fatal("Failed to finish backup:", err)
|
||||
}
|
||||
|
||||
// Confirm that the "test" table now exists in the destination database.
|
||||
var doesTestTableExist bool
|
||||
err = destDb.QueryRow("SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'test' LIMIT 1) AS test_table_exists").Scan(&doesTestTableExist)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to check if the \"test\" table exists in the destination database:", err)
|
||||
}
|
||||
if !doesTestTableExist {
|
||||
t.Fatal("The \"test\" table could not be found in the destination database.")
|
||||
}
|
||||
|
||||
// Confirm that the number of rows in the destination database's "test" table matches that of the source table.
|
||||
var actualTestTableRowCount int
|
||||
err = destDb.QueryRow("SELECT COUNT(*) FROM test").Scan(&actualTestTableRowCount)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to determine the rowcount of the \"test\" table in the destination database:", err)
|
||||
}
|
||||
if testRowCount != actualTestTableRowCount {
|
||||
t.Fatalf("Unexpected destination \"test\" table row count; expected: %v; found: %v", testRowCount, actualTestTableRowCount)
|
||||
}
|
||||
|
||||
// Check each of the rows in the destination database.
|
||||
for id := 0; id < testRowCount; id++ {
|
||||
var checkedValue string
|
||||
err = destDb.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&checkedValue)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to query the \"test\" table in the destination database:", err)
|
||||
}
|
||||
|
||||
var expectedValue = generateTestData(id)
|
||||
if checkedValue != expectedValue {
|
||||
t.Fatalf("Unexpected value in the \"test\" table in the destination database; expected value: %v; actual value: %v", expectedValue, checkedValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackupStepByStep(t *testing.T) {
|
||||
testBackup(t, testRowCount, true)
|
||||
}
|
||||
|
||||
func TestBackupAllRemainingPages(t *testing.T) {
|
||||
testBackup(t, testRowCount, false)
|
||||
}
|
||||
|
||||
// Test the error reporting when preparing to perform a backup.
|
||||
func TestBackupError(t *testing.T) {
|
||||
const driverName = "sqlite3_TestBackupError"
|
||||
|
||||
// The driver's connection will be needed in order to perform the backup.
|
||||
var dbDriverConn *SQLiteConn
|
||||
sql.Register(driverName, &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
dbDriverConn = conn
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// Connect to the database.
|
||||
dbTempFilename := TempFilename(t)
|
||||
defer os.Remove(dbTempFilename)
|
||||
db, err := sql.Open(driverName, dbTempFilename)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open the database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
db.Ping()
|
||||
|
||||
// Need the driver connection in order to perform the backup.
|
||||
if dbDriverConn == nil {
|
||||
t.Fatal("Failed to get the driver connection.")
|
||||
}
|
||||
|
||||
// Prepare to perform the backup.
|
||||
// Intentionally using the same connection for both the source and destination databases, to trigger an error result.
|
||||
backup, err := dbDriverConn.Backup("main", dbDriverConn, "main")
|
||||
if err == nil {
|
||||
t.Fatal("Failed to get the expected error result.")
|
||||
}
|
||||
const expectedError = "source and destination must be distinct"
|
||||
if err.Error() != expectedError {
|
||||
t.Fatalf("Unexpected error message; expected value: \"%v\"; actual value: \"%v\"", expectedError, err.Error())
|
||||
}
|
||||
if backup != nil {
|
||||
t.Fatal("Failed to get the expected nil backup result.")
|
||||
}
|
||||
}
|
32
callback.go
32
callback.go
|
@ -11,7 +11,11 @@ package sqlite3
|
|||
// code for SQLite custom functions is in here.
|
||||
|
||||
/*
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
|
||||
void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
|
||||
|
@ -36,8 +40,8 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value
|
|||
}
|
||||
|
||||
//export stepTrampoline
|
||||
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
|
||||
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||
func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
|
||||
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
|
||||
ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
|
||||
ai.Step(ctx, args)
|
||||
}
|
||||
|
@ -49,6 +53,30 @@ func doneTrampoline(ctx *C.sqlite3_context) {
|
|||
ai.Done(ctx)
|
||||
}
|
||||
|
||||
//export compareTrampoline
|
||||
func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.char) C.int {
|
||||
cmp := lookupHandle(handlePtr).(func(string, string) int)
|
||||
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
|
||||
}
|
||||
|
||||
//export commitHookTrampoline
|
||||
func commitHookTrampoline(handle uintptr) int {
|
||||
callback := lookupHandle(handle).(func() int)
|
||||
return callback()
|
||||
}
|
||||
|
||||
//export rollbackHookTrampoline
|
||||
func rollbackHookTrampoline(handle uintptr) {
|
||||
callback := lookupHandle(handle).(func())
|
||||
callback()
|
||||
}
|
||||
|
||||
//export updateHookTrampoline
|
||||
func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
|
||||
callback := lookupHandle(handle).(func(int, string, string, int64))
|
||||
callback(op, C.GoString(db), C.GoString(table), rowid)
|
||||
}
|
||||
|
||||
// Use handles to avoid passing Go pointers to C.
|
||||
|
||||
type handleVal struct {
|
||||
|
|
9
error.go
9
error.go
|
@ -7,12 +7,16 @@ package sqlite3
|
|||
|
||||
import "C"
|
||||
|
||||
// ErrNo inherit errno.
|
||||
type ErrNo int
|
||||
|
||||
// ErrNoMask is mask code.
|
||||
const ErrNoMask C.int = 0xff
|
||||
|
||||
// ErrNoExtended is extended errno.
|
||||
type ErrNoExtended int
|
||||
|
||||
// Error implement sqlite error code.
|
||||
type Error struct {
|
||||
Code ErrNo /* The error code returned by SQLite */
|
||||
ExtendedCode ErrNoExtended /* The extended error code returned by SQLite */
|
||||
|
@ -52,14 +56,17 @@ var (
|
|||
ErrWarning = ErrNo(28) /* Warnings from sqlite3_log() */
|
||||
)
|
||||
|
||||
// Error return error message from errno.
|
||||
func (err ErrNo) Error() string {
|
||||
return Error{Code: err}.Error()
|
||||
}
|
||||
|
||||
// Extend return extended errno.
|
||||
func (err ErrNo) Extend(by int) ErrNoExtended {
|
||||
return ErrNoExtended(int(err) | (by << 8))
|
||||
}
|
||||
|
||||
// Error return error message that is extended code.
|
||||
func (err ErrNoExtended) Error() string {
|
||||
return Error{Code: ErrNo(C.int(err) & ErrNoMask), ExtendedCode: err}.Error()
|
||||
}
|
||||
|
@ -121,7 +128,7 @@ var (
|
|||
ErrConstraintTrigger = ErrConstraint.Extend(7)
|
||||
ErrConstraintUnique = ErrConstraint.Extend(8)
|
||||
ErrConstraintVTab = ErrConstraint.Extend(9)
|
||||
ErrConstraintRowId = ErrConstraint.Extend(10)
|
||||
ErrConstraintRowID = ErrConstraint.Extend(10)
|
||||
ErrNoticeRecoverWAL = ErrNotice.Extend(1)
|
||||
ErrNoticeRecoverRollback = ErrNotice.Extend(2)
|
||||
ErrWarningAutoIndex = ErrWarning.Extend(1)
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
func TestSimpleError(t *testing.T) {
|
||||
e := ErrError.Error()
|
||||
if e != "SQL logic error or missing database" {
|
||||
if e != "SQL logic error or missing database" && e != "SQL logic error" {
|
||||
t.Error("wrong error code: " + e)
|
||||
}
|
||||
}
|
||||
|
|
38721
sqlite3-binding.c
38721
sqlite3-binding.c
File diff suppressed because it is too large
Load Diff
2752
sqlite3-binding.h
2752
sqlite3-binding.h
File diff suppressed because it is too large
Load Diff
586
sqlite3.go
586
sqlite3.go
|
@ -1,3 +1,5 @@
|
|||
// +build cgo
|
||||
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
|
@ -14,9 +16,13 @@ package sqlite3
|
|||
/*
|
||||
#cgo CFLAGS: -std=gnu99
|
||||
#cgo CFLAGS: -DSQLITE_HAS_CODEC
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
|
||||
#cgo LDFLAGS: -lcrypto
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE=1 -DHAVE_USLEEP=1
|
||||
#cgo linux,!android CFLAGS: -DHAVE_PREAD64=1 -DHAVE_PWRITE64=1
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
|
||||
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
|
||||
#cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC
|
||||
#cgo CFLAGS: -Wno-deprecated-declarations
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
|
@ -107,9 +113,41 @@ int _sqlite3_create_function(
|
|||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||
void doneTrampoline(sqlite3_context*);
|
||||
|
||||
int compareTrampoline(void*, int, char*, int, char*);
|
||||
int commitHookTrampoline(void*);
|
||||
void rollbackHookTrampoline(void*);
|
||||
void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
|
||||
|
||||
#ifdef SQLITE_LIMIT_WORKER_THREADS
|
||||
# define _SQLITE_HAS_LIMIT
|
||||
# define SQLITE_LIMIT_LENGTH 0
|
||||
# define SQLITE_LIMIT_SQL_LENGTH 1
|
||||
# define SQLITE_LIMIT_COLUMN 2
|
||||
# define SQLITE_LIMIT_EXPR_DEPTH 3
|
||||
# define SQLITE_LIMIT_COMPOUND_SELECT 4
|
||||
# define SQLITE_LIMIT_VDBE_OP 5
|
||||
# define SQLITE_LIMIT_FUNCTION_ARG 6
|
||||
# define SQLITE_LIMIT_ATTACHED 7
|
||||
# define SQLITE_LIMIT_LIKE_PATTERN_LENGTH 8
|
||||
# define SQLITE_LIMIT_VARIABLE_NUMBER 9
|
||||
# define SQLITE_LIMIT_TRIGGER_DEPTH 10
|
||||
# define SQLITE_LIMIT_WORKER_THREADS 11
|
||||
# else
|
||||
# define SQLITE_LIMIT_WORKER_THREADS 11
|
||||
#endif
|
||||
|
||||
static int _sqlite3_limit(sqlite3* db, int limitId, int newLimit) {
|
||||
#ifndef _SQLITE_HAS_LIMIT
|
||||
return -1;
|
||||
#else
|
||||
return sqlite3_limit(db, limitId, newLimit);
|
||||
#endif
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
|
@ -120,14 +158,15 @@ import (
|
|||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Timestamp formats understood by both this module and SQLite.
|
||||
// The first format in the slice will be used when saving time values
|
||||
// into the database. When parsing a string from a timestamp or
|
||||
// datetime column, the formats are tried in order.
|
||||
// SQLiteTimestampFormats is timestamp formats understood by both this module
|
||||
// and SQLite. The first format in the slice will be used when saving time
|
||||
// values into the database. When parsing a string from a timestamp or datetime
|
||||
// column, the formats are tried in order.
|
||||
var SQLiteTimestampFormats = []string{
|
||||
// By default, store timestamps with whatever timezone they come with.
|
||||
// When parsed, they will be returned with the same timezone.
|
||||
|
@ -142,26 +181,39 @@ var SQLiteTimestampFormats = []string{
|
|||
"2006-01-02",
|
||||
}
|
||||
|
||||
const (
|
||||
columnDate string = "date"
|
||||
columnDatetime string = "datetime"
|
||||
columnTimestamp string = "timestamp"
|
||||
)
|
||||
|
||||
func init() {
|
||||
sql.Register("sqlite3", &SQLiteDriver{})
|
||||
}
|
||||
|
||||
// Version returns SQLite library version information.
|
||||
func Version() (libVersion string, libVersionNumber int, sourceId string) {
|
||||
func Version() (libVersion string, libVersionNumber int, sourceID string) {
|
||||
libVersion = C.GoString(C.sqlite3_libversion())
|
||||
libVersionNumber = int(C.sqlite3_libversion_number())
|
||||
sourceId = C.GoString(C.sqlite3_sourceid())
|
||||
return libVersion, libVersionNumber, sourceId
|
||||
sourceID = C.GoString(C.sqlite3_sourceid())
|
||||
return libVersion, libVersionNumber, sourceID
|
||||
}
|
||||
|
||||
// Driver struct.
|
||||
const (
|
||||
SQLITE_DELETE = C.SQLITE_DELETE
|
||||
SQLITE_INSERT = C.SQLITE_INSERT
|
||||
SQLITE_UPDATE = C.SQLITE_UPDATE
|
||||
)
|
||||
|
||||
// SQLiteDriver implement sql.Driver.
|
||||
type SQLiteDriver struct {
|
||||
Extensions []string
|
||||
ConnectHook func(*SQLiteConn) error
|
||||
}
|
||||
|
||||
// Conn struct.
|
||||
// SQLiteConn implement sql.Conn.
|
||||
type SQLiteConn struct {
|
||||
mu sync.Mutex
|
||||
db *C.sqlite3
|
||||
loc *time.Location
|
||||
txlock string
|
||||
|
@ -169,35 +221,36 @@ type SQLiteConn struct {
|
|||
aggregators []*aggInfo
|
||||
}
|
||||
|
||||
// Tx struct.
|
||||
// SQLiteTx implemen sql.Tx.
|
||||
type SQLiteTx struct {
|
||||
c *SQLiteConn
|
||||
}
|
||||
|
||||
// Stmt struct.
|
||||
// SQLiteStmt implement sql.Stmt.
|
||||
type SQLiteStmt struct {
|
||||
mu sync.Mutex
|
||||
c *SQLiteConn
|
||||
s *C.sqlite3_stmt
|
||||
nv int
|
||||
nn []string
|
||||
t string
|
||||
closed bool
|
||||
cls bool
|
||||
}
|
||||
|
||||
// Result struct.
|
||||
// SQLiteResult implement sql.Result.
|
||||
type SQLiteResult struct {
|
||||
id int64
|
||||
changes int64
|
||||
}
|
||||
|
||||
// Rows struct.
|
||||
// SQLiteRows implement sql.Rows.
|
||||
type SQLiteRows struct {
|
||||
s *SQLiteStmt
|
||||
nc int
|
||||
cols []string
|
||||
decltype []string
|
||||
cls bool
|
||||
closed bool
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
type functionInfo struct {
|
||||
|
@ -303,16 +356,90 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
|
|||
|
||||
// Commit transaction.
|
||||
func (tx *SQLiteTx) Commit() error {
|
||||
_, err := tx.c.exec("COMMIT")
|
||||
_, err := tx.c.exec(context.Background(), "COMMIT", nil)
|
||||
if err != nil && err.(Error).Code == C.SQLITE_BUSY {
|
||||
// sqlite3 will leave the transaction open in this scenario.
|
||||
// However, database/sql considers the transaction complete once we
|
||||
// return from Commit() - we must clean up to honour its semantics.
|
||||
tx.c.exec(context.Background(), "ROLLBACK", nil)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Rollback transaction.
|
||||
func (tx *SQLiteTx) Rollback() error {
|
||||
_, err := tx.c.exec("ROLLBACK")
|
||||
_, err := tx.c.exec(context.Background(), "ROLLBACK", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// RegisterCollation makes a Go function available as a collation.
|
||||
//
|
||||
// cmp receives two UTF-8 strings, a and b. The result should be 0 if
|
||||
// a==b, -1 if a < b, and +1 if a > b.
|
||||
//
|
||||
// cmp must always return the same result given the same
|
||||
// inputs. Additionally, it must have the following properties for all
|
||||
// strings A, B and C: if A==B then B==A; if A==B and B==C then A==C;
|
||||
// if A<B then B>A; if A<B and B<C then A<C.
|
||||
//
|
||||
// If cmp does not obey these constraints, sqlite3's behavior is
|
||||
// undefined when the collation is used.
|
||||
func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int) error {
|
||||
handle := newHandle(c, cmp)
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, unsafe.Pointer(handle), (*[0]byte)(unsafe.Pointer(C.compareTrampoline)))
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterCommitHook sets the commit hook for a connection.
|
||||
//
|
||||
// If the callback returns non-zero the transaction will become a rollback.
|
||||
//
|
||||
// If there is an existing commit hook for this connection, it will be
|
||||
// removed. If callback is nil the existing hook (if any) will be removed
|
||||
// without creating a new one.
|
||||
func (c *SQLiteConn) RegisterCommitHook(callback func() int) {
|
||||
if callback == nil {
|
||||
C.sqlite3_commit_hook(c.db, nil, nil)
|
||||
} else {
|
||||
C.sqlite3_commit_hook(c.db, (*[0]byte)(C.commitHookTrampoline), unsafe.Pointer(newHandle(c, callback)))
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRollbackHook sets the rollback hook for a connection.
|
||||
//
|
||||
// If there is an existing rollback hook for this connection, it will be
|
||||
// removed. If callback is nil the existing hook (if any) will be removed
|
||||
// without creating a new one.
|
||||
func (c *SQLiteConn) RegisterRollbackHook(callback func()) {
|
||||
if callback == nil {
|
||||
C.sqlite3_rollback_hook(c.db, nil, nil)
|
||||
} else {
|
||||
C.sqlite3_rollback_hook(c.db, (*[0]byte)(C.rollbackHookTrampoline), unsafe.Pointer(newHandle(c, callback)))
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterUpdateHook sets the update hook for a connection.
|
||||
//
|
||||
// The parameters to the callback are the operation (one of the constants
|
||||
// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the
|
||||
// table name, and the rowid.
|
||||
//
|
||||
// If there is an existing update hook for this connection, it will be
|
||||
// removed. If callback is nil the existing hook (if any) will be removed
|
||||
// without creating a new one.
|
||||
func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) {
|
||||
if callback == nil {
|
||||
C.sqlite3_update_hook(c.db, nil, nil)
|
||||
} else {
|
||||
C.sqlite3_update_hook(c.db, (*[0]byte)(C.updateHookTrampoline), unsafe.Pointer(newHandle(c, callback)))
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterFunc makes a Go function available as a SQLite function.
|
||||
//
|
||||
// The Go function can have arguments of the following types: any
|
||||
|
@ -383,13 +510,17 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
|
|||
if pure {
|
||||
opts |= C.SQLITE_DETERMINISTIC
|
||||
}
|
||||
rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(newHandle(c, &fi)), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
|
||||
rv := sqlite3CreateFunction(c.db, cname, C.int(numArgs), C.int(opts), newHandle(c, &fi), C.callbackTrampoline, nil, nil)
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTextRep C.int, pApp uintptr, xFunc unsafe.Pointer, xStep unsafe.Pointer, xFinal unsafe.Pointer) C.int {
|
||||
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(pApp), (*[0]byte)(xFunc), (*[0]byte)(xStep), (*[0]byte)(xFinal))
|
||||
}
|
||||
|
||||
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
|
||||
//
|
||||
// Because aggregation is incremental, it's implemented in Go with a
|
||||
|
@ -508,7 +639,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool
|
|||
if pure {
|
||||
opts |= C.SQLITE_DETERMINISTIC
|
||||
}
|
||||
rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
|
||||
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
|
@ -520,22 +651,38 @@ func (c *SQLiteConn) AutoCommit() bool {
|
|||
return int(C.sqlite3_get_autocommit(c.db)) != 0
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) lastError() Error {
|
||||
func (c *SQLiteConn) lastError() error {
|
||||
return lastError(c.db)
|
||||
}
|
||||
|
||||
func lastError(db *C.sqlite3) error {
|
||||
rv := C.sqlite3_errcode(db)
|
||||
if rv == C.SQLITE_OK {
|
||||
return nil
|
||||
}
|
||||
return Error{
|
||||
Code: ErrNo(C.sqlite3_errcode(c.db)),
|
||||
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)),
|
||||
err: C.GoString(C.sqlite3_errmsg(c.db)),
|
||||
Code: ErrNo(rv),
|
||||
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(db)),
|
||||
err: C.GoString(C.sqlite3_errmsg(db)),
|
||||
}
|
||||
}
|
||||
|
||||
// Implements Execer
|
||||
// Exec implements Execer.
|
||||
func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||
if len(args) == 0 {
|
||||
return c.exec(query)
|
||||
list := make([]namedValue, len(args))
|
||||
for i, v := range args {
|
||||
list[i] = namedValue{
|
||||
Ordinal: i + 1,
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
return c.exec(context.Background(), query, list)
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) exec(ctx context.Context, query string, args []namedValue) (driver.Result, error) {
|
||||
start := 0
|
||||
for {
|
||||
s, err := c.Prepare(query)
|
||||
s, err := c.prepare(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -543,14 +690,19 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
|
|||
if s.(*SQLiteStmt).s != nil {
|
||||
na := s.NumInput()
|
||||
if len(args) < na {
|
||||
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
|
||||
s.Close()
|
||||
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
|
||||
}
|
||||
res, err = s.Exec(args[:na])
|
||||
for i := 0; i < na; i++ {
|
||||
args[i].Ordinal -= start
|
||||
}
|
||||
res, err = s.(*SQLiteStmt).exec(ctx, args[:na])
|
||||
if err != nil && err != driver.ErrSkip {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
args = args[na:]
|
||||
start += na
|
||||
}
|
||||
tail := s.(*SQLiteStmt).t
|
||||
s.Close()
|
||||
|
@ -561,24 +713,46 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
|
|||
}
|
||||
}
|
||||
|
||||
// Implements Queryer
|
||||
type namedValue struct {
|
||||
Name string
|
||||
Ordinal int
|
||||
Value driver.Value
|
||||
}
|
||||
|
||||
// Query implements Queryer.
|
||||
func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, v := range args {
|
||||
list[i] = namedValue{
|
||||
Ordinal: i + 1,
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
return c.query(context.Background(), query, list)
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
|
||||
start := 0
|
||||
for {
|
||||
s, err := c.Prepare(query)
|
||||
s, err := c.prepare(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.(*SQLiteStmt).cls = true
|
||||
na := s.NumInput()
|
||||
if len(args) < na {
|
||||
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
|
||||
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
|
||||
}
|
||||
rows, err := s.Query(args[:na])
|
||||
for i := 0; i < na; i++ {
|
||||
args[i].Ordinal -= start
|
||||
}
|
||||
rows, err := s.(*SQLiteStmt).query(ctx, args[:na])
|
||||
if err != nil && err != driver.ErrSkip {
|
||||
s.Close()
|
||||
return nil, err
|
||||
return rows, err
|
||||
}
|
||||
args = args[na:]
|
||||
start += na
|
||||
tail := s.(*SQLiteStmt).t
|
||||
if tail == "" {
|
||||
return rows, nil
|
||||
|
@ -589,21 +763,13 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
|
|||
}
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
|
||||
pcmd := C.CString(cmd)
|
||||
defer C.free(unsafe.Pointer(pcmd))
|
||||
|
||||
var rowid, changes C.longlong
|
||||
rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes)
|
||||
if rv != C.SQLITE_OK {
|
||||
return nil, c.lastError()
|
||||
}
|
||||
return &SQLiteResult{int64(rowid), int64(changes)}, nil
|
||||
}
|
||||
|
||||
// Begin transaction.
|
||||
func (c *SQLiteConn) Begin() (driver.Tx, error) {
|
||||
if _, err := c.exec(c.txlock); err != nil {
|
||||
return c.begin(context.Background())
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) {
|
||||
if _, err := c.exec(ctx, c.txlock, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SQLiteTx{c}, nil
|
||||
|
@ -627,6 +793,12 @@ func errorString(err Error) string {
|
|||
// _txlock=XXX
|
||||
// Specify locking behavior for transactions. XXX can be "immediate",
|
||||
// "deferred", "exclusive".
|
||||
// _foreign_keys=X
|
||||
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
|
||||
// _recursive_triggers=X
|
||||
// Enable or disable recursive triggers. X can be 1 or 0.
|
||||
// _mutex=XXX
|
||||
// Specify mutex mode. XXX can be "no", "full".
|
||||
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||
if C.sqlite3_threadsafe() == 0 {
|
||||
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
||||
|
@ -634,7 +806,10 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
|
||||
var loc *time.Location
|
||||
txlock := "BEGIN"
|
||||
busy_timeout := 5000
|
||||
busyTimeout := 5000
|
||||
foreignKeys := -1
|
||||
recursiveTriggers := -1
|
||||
mutex := C.int(C.SQLITE_OPEN_FULLMUTEX)
|
||||
pos := strings.IndexRune(dsn, '?')
|
||||
if pos >= 1 {
|
||||
params, err := url.ParseQuery(dsn[pos+1:])
|
||||
|
@ -660,7 +835,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("Invalid _busy_timeout: %v: %v", val, err)
|
||||
}
|
||||
busy_timeout = int(iv)
|
||||
busyTimeout = int(iv)
|
||||
}
|
||||
|
||||
// _txlock
|
||||
|
@ -677,6 +852,42 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// _foreign_keys
|
||||
if val := params.Get("_foreign_keys"); val != "" {
|
||||
switch val {
|
||||
case "1":
|
||||
foreignKeys = 1
|
||||
case "0":
|
||||
foreignKeys = 0
|
||||
default:
|
||||
return nil, fmt.Errorf("Invalid _foreign_keys: %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
// _recursive_triggers
|
||||
if val := params.Get("_recursive_triggers"); val != "" {
|
||||
switch val {
|
||||
case "1":
|
||||
recursiveTriggers = 1
|
||||
case "0":
|
||||
recursiveTriggers = 0
|
||||
default:
|
||||
return nil, fmt.Errorf("Invalid _recursive_triggers: %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
// _mutex
|
||||
if val := params.Get("_mutex"); val != "" {
|
||||
switch val {
|
||||
case "no":
|
||||
mutex = C.SQLITE_OPEN_NOMUTEX
|
||||
case "full":
|
||||
mutex = C.SQLITE_OPEN_FULLMUTEX
|
||||
default:
|
||||
return nil, fmt.Errorf("Invalid _mutex: %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(dsn, "file:") {
|
||||
dsn = dsn[:pos]
|
||||
}
|
||||
|
@ -686,9 +897,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
name := C.CString(dsn)
|
||||
defer C.free(unsafe.Pointer(name))
|
||||
rv := C._sqlite3_open_v2(name, &db,
|
||||
C.SQLITE_OPEN_FULLMUTEX|
|
||||
C.SQLITE_OPEN_READWRITE|
|
||||
C.SQLITE_OPEN_CREATE,
|
||||
mutex|C.SQLITE_OPEN_READWRITE|C.SQLITE_OPEN_CREATE,
|
||||
nil)
|
||||
if rv != 0 {
|
||||
return nil, Error{Code: ErrNo(rv)}
|
||||
|
@ -697,21 +906,56 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
return nil, errors.New("sqlite succeeded without returning a database")
|
||||
}
|
||||
|
||||
rv = C.sqlite3_busy_timeout(db, C.int(busy_timeout))
|
||||
rv = C.sqlite3_busy_timeout(db, C.int(busyTimeout))
|
||||
if rv != C.SQLITE_OK {
|
||||
C.sqlite3_close_v2(db)
|
||||
return nil, Error{Code: ErrNo(rv)}
|
||||
}
|
||||
|
||||
exec := func(s string) error {
|
||||
cs := C.CString(s)
|
||||
rv := C.sqlite3_exec(db, cs, nil, nil, nil)
|
||||
C.free(unsafe.Pointer(cs))
|
||||
if rv != C.SQLITE_OK {
|
||||
return lastError(db)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if foreignKeys == 0 {
|
||||
if err := exec("PRAGMA foreign_keys = OFF;"); err != nil {
|
||||
C.sqlite3_close_v2(db)
|
||||
return nil, err
|
||||
}
|
||||
} else if foreignKeys == 1 {
|
||||
if err := exec("PRAGMA foreign_keys = ON;"); err != nil {
|
||||
C.sqlite3_close_v2(db)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if recursiveTriggers == 0 {
|
||||
if err := exec("PRAGMA recursive_triggers = OFF;"); err != nil {
|
||||
C.sqlite3_close_v2(db)
|
||||
return nil, err
|
||||
}
|
||||
} else if recursiveTriggers == 1 {
|
||||
if err := exec("PRAGMA recursive_triggers = ON;"); err != nil {
|
||||
C.sqlite3_close_v2(db)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
|
||||
|
||||
if len(d.Extensions) > 0 {
|
||||
if err := conn.loadExtensions(d.Extensions); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if d.ConnectHook != nil {
|
||||
if err := d.ConnectHook(conn); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -721,18 +965,33 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
|
||||
// Close the connection.
|
||||
func (c *SQLiteConn) Close() error {
|
||||
deleteHandles(c)
|
||||
rv := C.sqlite3_close_v2(c.db)
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
deleteHandles(c)
|
||||
c.mu.Lock()
|
||||
c.db = nil
|
||||
c.mu.Unlock()
|
||||
runtime.SetFinalizer(c, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) dbConnOpen() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.db != nil
|
||||
}
|
||||
|
||||
// Prepare the query string. Return a new statement.
|
||||
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
|
||||
return c.prepare(context.Background(), query)
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
pquery := C.CString(query)
|
||||
defer C.free(unsafe.Pointer(pquery))
|
||||
var s *C.sqlite3_stmt
|
||||
|
@ -745,29 +1004,54 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
|
|||
if tail != nil && *tail != '\000' {
|
||||
t = strings.TrimSpace(C.GoString(tail))
|
||||
}
|
||||
nv := int(C.sqlite3_bind_parameter_count(s))
|
||||
var nn []string
|
||||
for i := 0; i < nv; i++ {
|
||||
pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))
|
||||
if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 {
|
||||
nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))))
|
||||
}
|
||||
}
|
||||
ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t}
|
||||
ss := &SQLiteStmt{c: c, s: s, t: t}
|
||||
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// Run-Time Limit Categories.
|
||||
// See: http://www.sqlite.org/c3ref/c_limit_attached.html
|
||||
const (
|
||||
SQLITE_LIMIT_LENGTH = C.SQLITE_LIMIT_LENGTH
|
||||
SQLITE_LIMIT_SQL_LENGTH = C.SQLITE_LIMIT_SQL_LENGTH
|
||||
SQLITE_LIMIT_COLUMN = C.SQLITE_LIMIT_COLUMN
|
||||
SQLITE_LIMIT_EXPR_DEPTH = C.SQLITE_LIMIT_EXPR_DEPTH
|
||||
SQLITE_LIMIT_COMPOUND_SELECT = C.SQLITE_LIMIT_COMPOUND_SELECT
|
||||
SQLITE_LIMIT_VDBE_OP = C.SQLITE_LIMIT_VDBE_OP
|
||||
SQLITE_LIMIT_FUNCTION_ARG = C.SQLITE_LIMIT_FUNCTION_ARG
|
||||
SQLITE_LIMIT_ATTACHED = C.SQLITE_LIMIT_ATTACHED
|
||||
SQLITE_LIMIT_LIKE_PATTERN_LENGTH = C.SQLITE_LIMIT_LIKE_PATTERN_LENGTH
|
||||
SQLITE_LIMIT_VARIABLE_NUMBER = C.SQLITE_LIMIT_VARIABLE_NUMBER
|
||||
SQLITE_LIMIT_TRIGGER_DEPTH = C.SQLITE_LIMIT_TRIGGER_DEPTH
|
||||
SQLITE_LIMIT_WORKER_THREADS = C.SQLITE_LIMIT_WORKER_THREADS
|
||||
)
|
||||
|
||||
// GetLimit returns the current value of a run-time limit.
|
||||
// See: sqlite3_limit, http://www.sqlite.org/c3ref/limit.html
|
||||
func (c *SQLiteConn) GetLimit(id int) int {
|
||||
return int(C._sqlite3_limit(c.db, C.int(id), -1))
|
||||
}
|
||||
|
||||
// SetLimit changes the value of a run-time limits.
|
||||
// Then this method returns the prior value of the limit.
|
||||
// See: sqlite3_limit, http://www.sqlite.org/c3ref/limit.html
|
||||
func (c *SQLiteConn) SetLimit(id int, newVal int) int {
|
||||
return int(C._sqlite3_limit(c.db, C.int(id), C.int(newVal)))
|
||||
}
|
||||
|
||||
// Close the statement.
|
||||
func (s *SQLiteStmt) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.closed {
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
if s.c == nil || s.c.db == nil {
|
||||
if !s.c.dbConnOpen() {
|
||||
return errors.New("sqlite statement with already closed database connection")
|
||||
}
|
||||
rv := C.sqlite3_finalize(s.s)
|
||||
s.s = nil
|
||||
if rv != C.SQLITE_OK {
|
||||
return s.c.lastError()
|
||||
}
|
||||
|
@ -775,9 +1059,9 @@ func (s *SQLiteStmt) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Return a number of parameters.
|
||||
// NumInput return a number of parameters.
|
||||
func (s *SQLiteStmt) NumInput() int {
|
||||
return s.nv
|
||||
return int(C.sqlite3_bind_parameter_count(s.s))
|
||||
}
|
||||
|
||||
type bindArg struct {
|
||||
|
@ -785,37 +1069,30 @@ type bindArg struct {
|
|||
v driver.Value
|
||||
}
|
||||
|
||||
func (s *SQLiteStmt) bind(args []driver.Value) error {
|
||||
var placeHolder = []byte{0}
|
||||
|
||||
func (s *SQLiteStmt) bind(args []namedValue) error {
|
||||
rv := C.sqlite3_reset(s.s)
|
||||
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
||||
return s.c.lastError()
|
||||
}
|
||||
|
||||
var vargs []bindArg
|
||||
narg := len(args)
|
||||
vargs = make([]bindArg, narg)
|
||||
if len(s.nn) > 0 {
|
||||
for i, v := range s.nn {
|
||||
if pi, err := strconv.Atoi(v[1:]); err == nil {
|
||||
vargs[i] = bindArg{pi, args[i]}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i, v := range args {
|
||||
vargs[i] = bindArg{i + 1, v}
|
||||
if v.Name != "" {
|
||||
cname := C.CString(":" + v.Name)
|
||||
args[i].Ordinal = int(C.sqlite3_bind_parameter_index(s.s, cname))
|
||||
C.free(unsafe.Pointer(cname))
|
||||
}
|
||||
}
|
||||
|
||||
for _, varg := range vargs {
|
||||
n := C.int(varg.n)
|
||||
v := varg.v
|
||||
switch v := v.(type) {
|
||||
for _, arg := range args {
|
||||
n := C.int(arg.Ordinal)
|
||||
switch v := arg.Value.(type) {
|
||||
case nil:
|
||||
rv = C.sqlite3_bind_null(s.s, n)
|
||||
case string:
|
||||
if len(v) == 0 {
|
||||
b := []byte{0}
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
|
||||
} else {
|
||||
b := []byte(v)
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||
|
@ -823,7 +1100,7 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
|
|||
case int64:
|
||||
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
|
||||
case bool:
|
||||
if bool(v) {
|
||||
if v {
|
||||
rv = C.sqlite3_bind_int(s.s, n, 1)
|
||||
} else {
|
||||
rv = C.sqlite3_bind_int(s.s, n, 0)
|
||||
|
@ -831,11 +1108,11 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
|
|||
case float64:
|
||||
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
rv = C._sqlite3_bind_blob(s.s, n, nil, 0)
|
||||
} else {
|
||||
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v)))
|
||||
ln := len(v)
|
||||
if ln == 0 {
|
||||
v = placeHolder
|
||||
}
|
||||
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
|
||||
case time.Time:
|
||||
b := []byte(v.Format(SQLiteTimestampFormats[0]))
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||
|
@ -849,29 +1126,94 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
|
|||
|
||||
// Query the statement with arguments. Return records.
|
||||
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, v := range args {
|
||||
list[i] = namedValue{
|
||||
Ordinal: i + 1,
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
return s.query(context.Background(), list)
|
||||
}
|
||||
|
||||
func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows, error) {
|
||||
if err := s.bind(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil, s.cls}, nil
|
||||
|
||||
rows := &SQLiteRows{
|
||||
s: s,
|
||||
nc: int(C.sqlite3_column_count(s.s)),
|
||||
cols: nil,
|
||||
decltype: nil,
|
||||
cls: s.cls,
|
||||
closed: false,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Return last inserted ID.
|
||||
if ctxdone := ctx.Done(); ctxdone != nil {
|
||||
go func(db *C.sqlite3) {
|
||||
select {
|
||||
case <-ctxdone:
|
||||
select {
|
||||
case <-rows.done:
|
||||
default:
|
||||
C.sqlite3_interrupt(db)
|
||||
rows.Close()
|
||||
}
|
||||
case <-rows.done:
|
||||
}
|
||||
}(s.c.db)
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// LastInsertId teturn last inserted ID.
|
||||
func (r *SQLiteResult) LastInsertId() (int64, error) {
|
||||
return r.id, nil
|
||||
}
|
||||
|
||||
// Return how many rows affected.
|
||||
// RowsAffected return how many rows affected.
|
||||
func (r *SQLiteResult) RowsAffected() (int64, error) {
|
||||
return r.changes, nil
|
||||
}
|
||||
|
||||
// Execute the statement with arguments. Return result object.
|
||||
// Exec execute the statement with arguments. Return result object.
|
||||
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, v := range args {
|
||||
list[i] = namedValue{
|
||||
Ordinal: i + 1,
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
return s.exec(context.Background(), list)
|
||||
}
|
||||
|
||||
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
|
||||
if err := s.bind(args); err != nil {
|
||||
C.sqlite3_reset(s.s)
|
||||
C.sqlite3_clear_bindings(s.s)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ctxdone := ctx.Done(); ctxdone != nil {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func(db *C.sqlite3) {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctxdone:
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
C.sqlite3_interrupt(db)
|
||||
}
|
||||
}
|
||||
}(s.c.db)
|
||||
}
|
||||
|
||||
var rowid, changes C.longlong
|
||||
rv := C._sqlite3_step(s.s, &rowid, &changes)
|
||||
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
|
||||
|
@ -880,27 +1222,39 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
|
|||
C.sqlite3_clear_bindings(s.s)
|
||||
return nil, err
|
||||
}
|
||||
return &SQLiteResult{int64(rowid), int64(changes)}, nil
|
||||
|
||||
return &SQLiteResult{id: int64(rowid), changes: int64(changes)}, nil
|
||||
}
|
||||
|
||||
// Close the rows.
|
||||
func (rc *SQLiteRows) Close() error {
|
||||
if rc.s.closed {
|
||||
rc.s.mu.Lock()
|
||||
if rc.s.closed || rc.closed {
|
||||
rc.s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
rc.closed = true
|
||||
if rc.done != nil {
|
||||
close(rc.done)
|
||||
}
|
||||
if rc.cls {
|
||||
rc.s.mu.Unlock()
|
||||
return rc.s.Close()
|
||||
}
|
||||
rv := C.sqlite3_reset(rc.s.s)
|
||||
if rv != C.SQLITE_OK {
|
||||
rc.s.mu.Unlock()
|
||||
return rc.s.c.lastError()
|
||||
}
|
||||
rc.s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return column names.
|
||||
// Columns return column names.
|
||||
func (rc *SQLiteRows) Columns() []string {
|
||||
if rc.nc != len(rc.cols) {
|
||||
rc.s.mu.Lock()
|
||||
defer rc.s.mu.Unlock()
|
||||
if rc.s.s != nil && rc.nc != len(rc.cols) {
|
||||
rc.cols = make([]string, rc.nc)
|
||||
for i := 0; i < rc.nc; i++ {
|
||||
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
|
||||
|
@ -909,9 +1263,8 @@ func (rc *SQLiteRows) Columns() []string {
|
|||
return rc.cols
|
||||
}
|
||||
|
||||
// Return column types.
|
||||
func (rc *SQLiteRows) DeclTypes() []string {
|
||||
if rc.decltype == nil {
|
||||
func (rc *SQLiteRows) declTypes() []string {
|
||||
if rc.s.s != nil && rc.decltype == nil {
|
||||
rc.decltype = make([]string, rc.nc)
|
||||
for i := 0; i < rc.nc; i++ {
|
||||
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
|
||||
|
@ -920,8 +1273,20 @@ func (rc *SQLiteRows) DeclTypes() []string {
|
|||
return rc.decltype
|
||||
}
|
||||
|
||||
// Move cursor to next.
|
||||
// DeclTypes return column types.
|
||||
func (rc *SQLiteRows) DeclTypes() []string {
|
||||
rc.s.mu.Lock()
|
||||
defer rc.s.mu.Unlock()
|
||||
return rc.declTypes()
|
||||
}
|
||||
|
||||
// Next move cursor to next.
|
||||
func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||
if rc.s.closed {
|
||||
return io.EOF
|
||||
}
|
||||
rc.s.mu.Lock()
|
||||
defer rc.s.mu.Unlock()
|
||||
rv := C.sqlite3_step(rc.s.s)
|
||||
if rv == C.SQLITE_DONE {
|
||||
return io.EOF
|
||||
|
@ -934,23 +1299,24 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
rc.DeclTypes()
|
||||
rc.declTypes()
|
||||
|
||||
for i := range dest {
|
||||
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
|
||||
case C.SQLITE_INTEGER:
|
||||
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
|
||||
switch rc.decltype[i] {
|
||||
case "timestamp", "datetime", "date":
|
||||
case columnTimestamp, columnDatetime, columnDate:
|
||||
var t time.Time
|
||||
// Assume a millisecond unix timestamp if it's 13 digits -- too
|
||||
// large to be a reasonable timestamp in seconds.
|
||||
if val > 1e12 || val < -1e12 {
|
||||
val *= int64(time.Millisecond) // convert ms to nsec
|
||||
t = time.Unix(0, val)
|
||||
} else {
|
||||
val *= int64(time.Second) // convert sec to nsec
|
||||
t = time.Unix(val, 0)
|
||||
}
|
||||
t = time.Unix(0, val).UTC()
|
||||
t = t.UTC()
|
||||
if rc.s.c.loc != nil {
|
||||
t = t.In(rc.s.c.loc)
|
||||
}
|
||||
|
@ -971,10 +1337,10 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
|||
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
|
||||
switch dest[i].(type) {
|
||||
case sql.RawBytes:
|
||||
dest[i] = (*[1 << 30]byte)(unsafe.Pointer(p))[0:n]
|
||||
dest[i] = (*[1 << 30]byte)(p)[0:n]
|
||||
default:
|
||||
slice := make([]byte, n)
|
||||
copy(slice[:], (*[1 << 30]byte)(unsafe.Pointer(p))[0:n])
|
||||
copy(slice[:], (*[1 << 30]byte)(p)[0:n])
|
||||
dest[i] = slice
|
||||
}
|
||||
case C.SQLITE_NULL:
|
||||
|
@ -987,7 +1353,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
|||
s := C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))), C.int(n))
|
||||
|
||||
switch rc.decltype[i] {
|
||||
case "timestamp", "datetime", "date":
|
||||
case columnTimestamp, columnDatetime, columnDate:
|
||||
var t time.Time
|
||||
s = strings.TrimSuffix(s, "Z")
|
||||
for _, format := range SQLiteTimestampFormats {
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
// These wrappers are necessary because SQLITE_TRANSIENT
|
||||
// is a pointer constant, and cgo doesn't translate them correctly.
|
||||
|
||||
static inline void my_result_text(sqlite3_context *ctx, char *p, int np) {
|
||||
sqlite3_result_text(ctx, p, np, SQLITE_TRANSIENT);
|
||||
}
|
||||
|
||||
static inline void my_result_blob(sqlite3_context *ctx, void *p, int np) {
|
||||
sqlite3_result_blob(ctx, p, np, SQLITE_TRANSIENT);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const i64 = unsafe.Sizeof(int(0)) > 4
|
||||
|
||||
// SQLiteContext behave sqlite3_context
|
||||
type SQLiteContext C.sqlite3_context
|
||||
|
||||
// ResultBool sets the result of an SQL function.
|
||||
func (c *SQLiteContext) ResultBool(b bool) {
|
||||
if b {
|
||||
c.ResultInt(1)
|
||||
} else {
|
||||
c.ResultInt(0)
|
||||
}
|
||||
}
|
||||
|
||||
// ResultBlob sets the result of an SQL function.
|
||||
// See: sqlite3_result_blob, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultBlob(b []byte) {
|
||||
if i64 && len(b) > math.MaxInt32 {
|
||||
C.sqlite3_result_error_toobig((*C.sqlite3_context)(c))
|
||||
return
|
||||
}
|
||||
var p *byte
|
||||
if len(b) > 0 {
|
||||
p = &b[0]
|
||||
}
|
||||
C.my_result_blob((*C.sqlite3_context)(c), unsafe.Pointer(p), C.int(len(b)))
|
||||
}
|
||||
|
||||
// ResultDouble sets the result of an SQL function.
|
||||
// See: sqlite3_result_double, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultDouble(d float64) {
|
||||
C.sqlite3_result_double((*C.sqlite3_context)(c), C.double(d))
|
||||
}
|
||||
|
||||
// ResultInt sets the result of an SQL function.
|
||||
// See: sqlite3_result_int, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultInt(i int) {
|
||||
if i64 && (i > math.MaxInt32 || i < math.MinInt32) {
|
||||
C.sqlite3_result_int64((*C.sqlite3_context)(c), C.sqlite3_int64(i))
|
||||
} else {
|
||||
C.sqlite3_result_int((*C.sqlite3_context)(c), C.int(i))
|
||||
}
|
||||
}
|
||||
|
||||
// ResultInt64 sets the result of an SQL function.
|
||||
// See: sqlite3_result_int64, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultInt64(i int64) {
|
||||
C.sqlite3_result_int64((*C.sqlite3_context)(c), C.sqlite3_int64(i))
|
||||
}
|
||||
|
||||
// ResultNull sets the result of an SQL function.
|
||||
// See: sqlite3_result_null, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultNull() {
|
||||
C.sqlite3_result_null((*C.sqlite3_context)(c))
|
||||
}
|
||||
|
||||
// ResultText sets the result of an SQL function.
|
||||
// See: sqlite3_result_text, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultText(s string) {
|
||||
h := (*reflect.StringHeader)(unsafe.Pointer(&s))
|
||||
cs, l := (*C.char)(unsafe.Pointer(h.Data)), C.int(h.Len)
|
||||
C.my_result_text((*C.sqlite3_context)(c), cs, l)
|
||||
}
|
||||
|
||||
// ResultZeroblob sets the result of an SQL function.
|
||||
// See: sqlite3_result_zeroblob, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultZeroblob(n int) {
|
||||
C.sqlite3_result_zeroblob((*C.sqlite3_context)(c), C.int(n))
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
// +build cgo
|
||||
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
|
||||
"context"
|
||||
)
|
||||
|
||||
// Ping implement Pinger.
|
||||
func (c *SQLiteConn) Ping(ctx context.Context) error {
|
||||
if c.db == nil {
|
||||
return errors.New("Connection was closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryContext implement QueryerContext.
|
||||
func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = namedValue(nv)
|
||||
}
|
||||
return c.query(ctx, query, list)
|
||||
}
|
||||
|
||||
// ExecContext implement ExecerContext.
|
||||
func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = namedValue(nv)
|
||||
}
|
||||
return c.exec(ctx, query, list)
|
||||
}
|
||||
|
||||
// PrepareContext implement ConnPrepareContext.
|
||||
func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
return c.prepare(ctx, query)
|
||||
}
|
||||
|
||||
// BeginTx implement ConnBeginTx.
|
||||
func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
return c.begin(ctx)
|
||||
}
|
||||
|
||||
// QueryContext implement QueryerContext.
|
||||
func (s *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = namedValue(nv)
|
||||
}
|
||||
return s.query(ctx, list)
|
||||
}
|
||||
|
||||
// ExecContext implement ExecerContext.
|
||||
func (s *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = namedValue(nv)
|
||||
}
|
||||
return s.exec(ctx, list)
|
||||
}
|
|
@ -0,0 +1,157 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNamedParams(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
db, err := sql.Open("sqlite3", tempFilename)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(`
|
||||
create table foo (id integer, name text, extra text);
|
||||
`)
|
||||
if err != nil {
|
||||
t.Error("Failed to call db.Query:", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`insert into foo(id, name, extra) values(:id, :name, :name)`, sql.Named("name", "foo"), sql.Named("id", 1))
|
||||
if err != nil {
|
||||
t.Error("Failed to call db.Exec:", err)
|
||||
}
|
||||
|
||||
row := db.QueryRow(`select id, extra from foo where id = :id and extra = :extra`, sql.Named("id", 1), sql.Named("extra", "foo"))
|
||||
if row == nil {
|
||||
t.Error("Failed to call db.QueryRow")
|
||||
}
|
||||
var id int
|
||||
var extra string
|
||||
err = row.Scan(&id, &extra)
|
||||
if err != nil {
|
||||
t.Error("Failed to db.Scan:", err)
|
||||
}
|
||||
if id != 1 || extra != "foo" {
|
||||
t.Error("Failed to db.QueryRow: not matched results")
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
testTableStatements = []string{
|
||||
`DROP TABLE IF EXISTS test_table`,
|
||||
`
|
||||
CREATE TABLE IF NOT EXISTS test_table (
|
||||
key1 VARCHAR(64) PRIMARY KEY,
|
||||
key_id VARCHAR(64) NOT NULL,
|
||||
key2 VARCHAR(64) NOT NULL,
|
||||
key3 VARCHAR(64) NOT NULL,
|
||||
key4 VARCHAR(64) NOT NULL,
|
||||
key5 VARCHAR(64) NOT NULL,
|
||||
key6 VARCHAR(64) NOT NULL,
|
||||
data BLOB NOT NULL
|
||||
);`,
|
||||
}
|
||||
letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
)
|
||||
|
||||
func randStringBytes(n int) string {
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func initDatabase(t *testing.T, db *sql.DB, rowCount int64) {
|
||||
t.Logf("Executing db initializing statements")
|
||||
for _, query := range testTableStatements {
|
||||
_, err := db.Exec(query)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
for i := int64(0); i < rowCount; i++ {
|
||||
query := `INSERT INTO test_table
|
||||
(key1, key_id, key2, key3, key4, key5, key6, data)
|
||||
VALUES
|
||||
(?, ?, ?, ?, ?, ?, ?, ?);`
|
||||
args := []interface{}{
|
||||
randStringBytes(50),
|
||||
fmt.Sprint(i),
|
||||
randStringBytes(50),
|
||||
randStringBytes(50),
|
||||
randStringBytes(50),
|
||||
randStringBytes(50),
|
||||
randStringBytes(50),
|
||||
randStringBytes(50),
|
||||
randStringBytes(2048),
|
||||
}
|
||||
_, err := db.Exec(query, args...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShortTimeout(t *testing.T) {
|
||||
srcTempFilename := TempFilename(t)
|
||||
defer os.Remove(srcTempFilename)
|
||||
|
||||
db, err := sql.Open("sqlite3", srcTempFilename)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
initDatabase(t, db, 100)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
|
||||
defer cancel()
|
||||
query := `SELECT key1, key_id, key2, key3, key4, key5, key6, data
|
||||
FROM test_table
|
||||
ORDER BY key2 ASC`
|
||||
_, err = db.QueryContext(ctx, query)
|
||||
if err != nil && err != context.DeadlineExceeded {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ctx.Err() != nil && ctx.Err() != context.DeadlineExceeded {
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecCancel(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if _, err = db.Exec("create table foo (id integer primary key)"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for n := 0; n < 100; n++ {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
_, err = db.ExecContext(ctx, "insert into foo (id) values (?)", n)
|
||||
cancel()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,5 +10,6 @@ package sqlite3
|
|||
#cgo CFLAGS: -DUSE_LIBSQLITE3
|
||||
#cgo linux LDFLAGS: -lsqlite3
|
||||
#cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3
|
||||
#cgo solaris LDFLAGS: -lsqlite3
|
||||
*/
|
||||
import "C"
|
||||
|
|
|
@ -7,7 +7,11 @@
|
|||
package sqlite3
|
||||
|
||||
/*
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
|
@ -27,6 +31,7 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error {
|
|||
defer C.free(unsafe.Pointer(cext))
|
||||
rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
|
||||
if rv != C.SQLITE_OK {
|
||||
C.sqlite3_enable_load_extension(c.db, 0)
|
||||
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
|
||||
}
|
||||
}
|
||||
|
@ -38,6 +43,7 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// LoadExtension load the sqlite3 extension.
|
||||
func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
|
||||
rv := C.sqlite3_enable_load_extension(c.db, 1)
|
||||
if rv != C.SQLITE_OK {
|
||||
|
|
|
@ -9,5 +9,6 @@ package sqlite3
|
|||
/*
|
||||
#cgo CFLAGS: -I.
|
||||
#cgo linux LDFLAGS: -ldl
|
||||
#cgo solaris LDFLAGS: -lc
|
||||
*/
|
||||
import "C"
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
// Copyright (C) 2018 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build solaris
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -D__EXTENSIONS__=1
|
||||
*/
|
||||
import "C"
|
792
sqlite3_test.go
792
sqlite3_test.go
|
@ -6,21 +6,22 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/xeodou/go-sqlcipher/sqlite3_test"
|
||||
)
|
||||
|
||||
func TempFilename(t *testing.T) string {
|
||||
|
@ -107,6 +108,64 @@ func TestReadonly(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestForeignKeys(t *testing.T) {
|
||||
cases := map[string]bool{
|
||||
"?_foreign_keys=1": true,
|
||||
"?_foreign_keys=0": false,
|
||||
}
|
||||
for option, want := range cases {
|
||||
fname := TempFilename(t)
|
||||
uri := "file:" + fname + option
|
||||
db, err := sql.Open("sqlite3", uri)
|
||||
if err != nil {
|
||||
os.Remove(fname)
|
||||
t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
|
||||
continue
|
||||
}
|
||||
var enabled bool
|
||||
err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled)
|
||||
db.Close()
|
||||
os.Remove(fname)
|
||||
if err != nil {
|
||||
t.Errorf("query foreign_keys for %s: %v", uri, err)
|
||||
continue
|
||||
}
|
||||
if enabled != want {
|
||||
t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecursiveTriggers(t *testing.T) {
|
||||
cases := map[string]bool{
|
||||
"?_recursive_triggers=1": true,
|
||||
"?_recursive_triggers=0": false,
|
||||
}
|
||||
for option, want := range cases {
|
||||
fname := TempFilename(t)
|
||||
uri := "file:" + fname + option
|
||||
db, err := sql.Open("sqlite3", uri)
|
||||
if err != nil {
|
||||
os.Remove(fname)
|
||||
t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
|
||||
continue
|
||||
}
|
||||
var enabled bool
|
||||
err = db.QueryRow("PRAGMA recursive_triggers;").Scan(&enabled)
|
||||
db.Close()
|
||||
os.Remove(fname)
|
||||
if err != nil {
|
||||
t.Errorf("query recursive_triggers for %s: %v", uri, err)
|
||||
continue
|
||||
}
|
||||
if enabled != want {
|
||||
t.Errorf("\"PRAGMA recursive_triggers;\" for %q = %t; want %t", uri, enabled, want)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
|
@ -374,6 +433,7 @@ func TestTimestamp(t *testing.T) {
|
|||
}{
|
||||
{"nonsense", time.Time{}},
|
||||
{"0000-00-00 00:00:00", time.Time{}},
|
||||
{time.Time{}.Unix(), time.Time{}},
|
||||
{timestamp1, timestamp1},
|
||||
{timestamp2.Unix(), timestamp2.Truncate(time.Second)},
|
||||
{timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)},
|
||||
|
@ -503,7 +563,7 @@ func TestBoolean(t *testing.T) {
|
|||
t.Fatalf("Expected 1 row but %v", counter)
|
||||
}
|
||||
|
||||
if id != 1 && fbool != true {
|
||||
if id != 1 && !fbool {
|
||||
t.Fatalf("Value for id 1 should be %v, not %v", bool1, fbool)
|
||||
}
|
||||
|
||||
|
@ -525,7 +585,7 @@ func TestBoolean(t *testing.T) {
|
|||
t.Fatalf("Expected 1 row but %v", counter)
|
||||
}
|
||||
|
||||
if id != 2 && fbool != false {
|
||||
if id != 2 && fbool {
|
||||
t.Fatalf("Value for id 2 should be %v, not %v", bool2, fbool)
|
||||
}
|
||||
|
||||
|
@ -811,18 +871,6 @@ func TestTimezoneConversion(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSuite(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
sqlite3_test.RunTests(t, db, sqlite3_test.SQLITE)
|
||||
}
|
||||
|
||||
// TODO: Execer & Queryer currently disabled
|
||||
// https://github.com/mattn/go-sqlite3/issues/82
|
||||
func TestExecer(t *testing.T) {
|
||||
|
@ -1185,12 +1233,12 @@ func TestDateTimeNow(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestFunctionRegistration(t *testing.T) {
|
||||
addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) }
|
||||
addi_64 := func(a, b int64) int64 { return a + b }
|
||||
addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) }
|
||||
addu_64 := func(a, b uint64) uint64 { return a + b }
|
||||
addi8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) }
|
||||
addi64 := func(a, b int64) int64 { return a + b }
|
||||
addu8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) }
|
||||
addu64 := func(a, b uint64) uint64 { return a + b }
|
||||
addiu := func(a int, b uint) int64 { return int64(a) + int64(b) }
|
||||
addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b }
|
||||
addf32_64 := func(a float32, b float64) float64 { return float64(a) + b }
|
||||
not := func(a bool) bool { return !a }
|
||||
regex := func(re, s string) (bool, error) {
|
||||
return regexp.MatchString(re, s)
|
||||
|
@ -1222,22 +1270,22 @@ func TestFunctionRegistration(t *testing.T) {
|
|||
|
||||
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil {
|
||||
if err := conn.RegisterFunc("addi8_16_32", addi8_16_32, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil {
|
||||
if err := conn.RegisterFunc("addi64", addi64, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil {
|
||||
if err := conn.RegisterFunc("addu8_16_32", addu8_16_32, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil {
|
||||
if err := conn.RegisterFunc("addu64", addu64, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.RegisterFunc("addiu", addiu, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil {
|
||||
if err := conn.RegisterFunc("addf32_64", addf32_64, true); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.RegisterFunc("not", not, true); err != nil {
|
||||
|
@ -1268,12 +1316,12 @@ func TestFunctionRegistration(t *testing.T) {
|
|||
query string
|
||||
expected interface{}
|
||||
}{
|
||||
{"SELECT addi_8_16_32(1,2)", int32(3)},
|
||||
{"SELECT addi_64(1,2)", int64(3)},
|
||||
{"SELECT addu_8_16_32(1,2)", uint32(3)},
|
||||
{"SELECT addu_64(1,2)", uint64(3)},
|
||||
{"SELECT addi8_16_32(1,2)", int32(3)},
|
||||
{"SELECT addi64(1,2)", int64(3)},
|
||||
{"SELECT addu8_16_32(1,2)", uint32(3)},
|
||||
{"SELECT addu64(1,2)", uint64(3)},
|
||||
{"SELECT addiu(1,2)", int64(3)},
|
||||
{"SELECT addf_32_64(1.5,1.5)", float64(3)},
|
||||
{"SELECT addf32_64(1.5,1.5)", float64(3)},
|
||||
{"SELECT not(1)", false},
|
||||
{"SELECT not(0)", true},
|
||||
{`SELECT regex("^foo.*", "foobar")`, true},
|
||||
|
@ -1331,7 +1379,8 @@ func TestAggregatorRegistration(t *testing.T) {
|
|||
|
||||
_, err = db.Exec("create table foo (department integer, profits integer)")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create table:", err)
|
||||
// trace feature is not implemented
|
||||
t.Skip("Failed to create table:", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
|
||||
|
@ -1358,6 +1407,127 @@ func TestAggregatorRegistration(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func rot13(r rune) rune {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
return 'A' + (r-'A'+13)%26
|
||||
case r >= 'a' && r <= 'z':
|
||||
return 'a' + (r-'a'+13)%26
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestCollationRegistration(t *testing.T) {
|
||||
collateRot13 := func(a, b string) int {
|
||||
ra, rb := strings.Map(rot13, a), strings.Map(rot13, b)
|
||||
return strings.Compare(ra, rb)
|
||||
}
|
||||
collateRot13Reverse := func(a, b string) int {
|
||||
return collateRot13(b, a)
|
||||
}
|
||||
|
||||
sql.Register("sqlite3_CollationRegistration", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
if err := conn.RegisterCollation("rot13", collateRot13); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.RegisterCollation("rot13reverse", collateRot13Reverse); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite3_CollationRegistration", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
populate := []string{
|
||||
`CREATE TABLE test (s TEXT)`,
|
||||
`INSERT INTO test VALUES ("aaaa")`,
|
||||
`INSERT INTO test VALUES ("ffff")`,
|
||||
`INSERT INTO test VALUES ("qqqq")`,
|
||||
`INSERT INTO test VALUES ("tttt")`,
|
||||
`INSERT INTO test VALUES ("zzzz")`,
|
||||
}
|
||||
for _, stmt := range populate {
|
||||
if _, err := db.Exec(stmt); err != nil {
|
||||
t.Fatal("Failed to populate test DB:", err)
|
||||
}
|
||||
}
|
||||
|
||||
ops := []struct {
|
||||
query string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13 ASC",
|
||||
[]string{
|
||||
"qqqq",
|
||||
"tttt",
|
||||
"zzzz",
|
||||
"aaaa",
|
||||
"ffff",
|
||||
},
|
||||
},
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13 DESC",
|
||||
[]string{
|
||||
"ffff",
|
||||
"aaaa",
|
||||
"zzzz",
|
||||
"tttt",
|
||||
"qqqq",
|
||||
},
|
||||
},
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13reverse ASC",
|
||||
[]string{
|
||||
"ffff",
|
||||
"aaaa",
|
||||
"zzzz",
|
||||
"tttt",
|
||||
"qqqq",
|
||||
},
|
||||
},
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13reverse DESC",
|
||||
[]string{
|
||||
"qqqq",
|
||||
"tttt",
|
||||
"zzzz",
|
||||
"aaaa",
|
||||
"ffff",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, op := range ops {
|
||||
rows, err := db.Query(op.query)
|
||||
if err != nil {
|
||||
t.Fatalf("Query %q failed: %s", op.query, err)
|
||||
}
|
||||
got := []string{}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var s string
|
||||
if err = rows.Scan(&s); err != nil {
|
||||
t.Fatalf("Reading row for %q: %s", op.query, err)
|
||||
}
|
||||
got = append(got, s)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("Reading rows for %q: %s", op.query, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, op.want) {
|
||||
t.Fatalf("Unexpected output from %q\ngot:\n%s\n\nwant:\n%s", op.query, strings.Join(got, "\n"), strings.Join(op.want, "\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeclTypes(t *testing.T) {
|
||||
|
||||
d := SQLiteDriver{}
|
||||
|
@ -1393,21 +1563,150 @@ func TestDeclTypes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPinger(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
db.Close()
|
||||
err = db.Ping()
|
||||
if err == nil {
|
||||
t.Fatal("Should be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAndTransactionHooks(t *testing.T) {
|
||||
var events []string
|
||||
var commitHookReturn = 0
|
||||
|
||||
sql.Register("sqlite3_UpdateHook", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
conn.RegisterCommitHook(func() int {
|
||||
events = append(events, "commit")
|
||||
return commitHookReturn
|
||||
})
|
||||
conn.RegisterRollbackHook(func() {
|
||||
events = append(events, "rollback")
|
||||
})
|
||||
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
|
||||
events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid))
|
||||
})
|
||||
return nil
|
||||
},
|
||||
})
|
||||
db, err := sql.Open("sqlite3_UpdateHook", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
statements := []string{
|
||||
"create table foo (id integer primary key)",
|
||||
"insert into foo values (9)",
|
||||
"update foo set id = 99 where id = 9",
|
||||
"delete from foo where id = 99",
|
||||
}
|
||||
for _, statement := range statements {
|
||||
_, err = db.Exec(statement)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to prepare test data [%v]: %v", statement, err)
|
||||
}
|
||||
}
|
||||
|
||||
commitHookReturn = 1
|
||||
_, err = db.Exec("insert into foo values (5)")
|
||||
if err == nil {
|
||||
t.Error("Commit hook failed to rollback transaction")
|
||||
}
|
||||
|
||||
var expected = []string{
|
||||
"commit",
|
||||
fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT),
|
||||
"commit",
|
||||
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE),
|
||||
"commit",
|
||||
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE),
|
||||
"commit",
|
||||
fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT),
|
||||
"commit",
|
||||
"rollback",
|
||||
}
|
||||
if !reflect.DeepEqual(events, expected) {
|
||||
t.Errorf("Expected notifications %v but got %v", expected, events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilAndEmptyBytes(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
actualNil := []byte("use this to use an actual nil not a reference to nil")
|
||||
emptyBytes := []byte{}
|
||||
for tsti, tst := range []struct {
|
||||
name string
|
||||
columnType string
|
||||
insertBytes []byte
|
||||
expectedBytes []byte
|
||||
}{
|
||||
{"actual nil blob", "blob", actualNil, nil},
|
||||
{"referenced nil blob", "blob", nil, nil},
|
||||
{"empty blob", "blob", emptyBytes, emptyBytes},
|
||||
{"actual nil text", "text", actualNil, nil},
|
||||
{"referenced nil text", "text", nil, nil},
|
||||
{"empty text", "text", emptyBytes, emptyBytes},
|
||||
} {
|
||||
if _, err = db.Exec(fmt.Sprintf("create table tbl%d (txt %s)", tsti, tst.columnType)); err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
if bytes.Equal(tst.insertBytes, actualNil) {
|
||||
if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), nil); err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
} else {
|
||||
if _, err = db.Exec(fmt.Sprintf("insert into tbl%d (txt) values (?)", tsti), &tst.insertBytes); err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
}
|
||||
rows, err := db.Query(fmt.Sprintf("select txt from tbl%d", tsti))
|
||||
if err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
if !rows.Next() {
|
||||
t.Fatal(tst.name, "no rows")
|
||||
}
|
||||
var scanBytes []byte
|
||||
if err = rows.Scan(&scanBytes); err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatal(tst.name, err)
|
||||
}
|
||||
if tst.expectedBytes == nil && scanBytes != nil {
|
||||
t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
|
||||
} else if !bytes.Equal(scanBytes, tst.expectedBytes) {
|
||||
t.Errorf("%s: %#v != %#v", tst.name, scanBytes, tst.expectedBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var customFunctionOnce sync.Once
|
||||
|
||||
func BenchmarkCustomFunctions(b *testing.B) {
|
||||
customFunctionOnce.Do(func() {
|
||||
custom_add := func(a, b int64) int64 {
|
||||
customAdd := func(a, b int64) int64 {
|
||||
return a + b
|
||||
}
|
||||
|
||||
sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
// Impure function to force sqlite to reexecute it each time.
|
||||
if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return conn.RegisterFunc("custom_add", customAdd, false)
|
||||
},
|
||||
})
|
||||
})
|
||||
|
@ -1427,3 +1726,422 @@ func BenchmarkCustomFunctions(b *testing.B) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuite(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
db = &TestDB{t, d, SQLITE, sync.Once{}}
|
||||
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
|
||||
|
||||
if !testing.Short() {
|
||||
for _, b := range benchmarks {
|
||||
fmt.Printf("%-20s", b.Name)
|
||||
r := testing.Benchmark(b.F)
|
||||
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
|
||||
}
|
||||
}
|
||||
db.tearDown()
|
||||
}
|
||||
|
||||
// Dialect is a type of dialect of databases.
|
||||
type Dialect int
|
||||
|
||||
// Dialects for databases.
|
||||
const (
|
||||
SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
|
||||
POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
|
||||
MYSQL // MYSQL mean MySQL dialect
|
||||
)
|
||||
|
||||
// DB provide context for the tests
|
||||
type TestDB struct {
|
||||
*testing.T
|
||||
*sql.DB
|
||||
dialect Dialect
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
var db *TestDB
|
||||
|
||||
// the following tables will be created and dropped during the test
|
||||
var testTables = []string{"foo", "bar", "t", "bench"}
|
||||
|
||||
var tests = []testing.InternalTest{
|
||||
{Name: "TestResult", F: testResult},
|
||||
{Name: "TestBlobs", F: testBlobs},
|
||||
{Name: "TestManyQueryRow", F: testManyQueryRow},
|
||||
{Name: "TestTxQuery", F: testTxQuery},
|
||||
{Name: "TestPreparedStmt", F: testPreparedStmt},
|
||||
}
|
||||
|
||||
var benchmarks = []testing.InternalBenchmark{
|
||||
{Name: "BenchmarkExec", F: benchmarkExec},
|
||||
{Name: "BenchmarkQuery", F: benchmarkQuery},
|
||||
{Name: "BenchmarkParams", F: benchmarkParams},
|
||||
{Name: "BenchmarkStmt", F: benchmarkStmt},
|
||||
{Name: "BenchmarkRows", F: benchmarkRows},
|
||||
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
|
||||
}
|
||||
|
||||
func (db *TestDB) mustExec(sql string, args ...interface{}) sql.Result {
|
||||
res, err := db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
db.Fatalf("Error running %q: %v", sql, err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (db *TestDB) tearDown() {
|
||||
for _, tbl := range testTables {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
db.mustExec("drop table if exists " + tbl)
|
||||
case MYSQL, POSTGRESQL:
|
||||
db.mustExec("drop table if exists " + tbl)
|
||||
default:
|
||||
db.Fatal("unknown dialect")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// q replaces ? parameters if needed
|
||||
func (db *TestDB) q(sql string) string {
|
||||
switch db.dialect {
|
||||
case POSTGRESQL: // replace with $1, $2, ..
|
||||
qrx := regexp.MustCompile(`\?`)
|
||||
n := 0
|
||||
return qrx.ReplaceAllStringFunc(sql, func(string) string {
|
||||
n++
|
||||
return "$" + strconv.Itoa(n)
|
||||
})
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (db *TestDB) blobType(size int) string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return fmt.Sprintf("blob[%d]", size)
|
||||
case POSTGRESQL:
|
||||
return "bytea"
|
||||
case MYSQL:
|
||||
return fmt.Sprintf("VARBINARY(%d)", size)
|
||||
}
|
||||
panic("unknown dialect")
|
||||
}
|
||||
|
||||
func (db *TestDB) serialPK() string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return "integer primary key autoincrement"
|
||||
case POSTGRESQL:
|
||||
return "serial primary key"
|
||||
case MYSQL:
|
||||
return "integer primary key auto_increment"
|
||||
}
|
||||
panic("unknown dialect")
|
||||
}
|
||||
|
||||
func (db *TestDB) now() string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return "datetime('now')"
|
||||
case POSTGRESQL:
|
||||
return "now()"
|
||||
case MYSQL:
|
||||
return "now()"
|
||||
}
|
||||
panic("unknown dialect")
|
||||
}
|
||||
|
||||
func makeBench() {
|
||||
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
for i := 0; i < 100; i++ {
|
||||
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testResult is test for result
|
||||
func testResult(t *testing.T) {
|
||||
db.tearDown()
|
||||
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
|
||||
|
||||
for i := 1; i < 3; i++ {
|
||||
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
|
||||
n, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("got %v, want %v", n, 1)
|
||||
}
|
||||
n, err = r.LastInsertId()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != int64(i) {
|
||||
t.Errorf("got %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec("error!"); err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
// testBlobs is test for blobs
|
||||
func testBlobs(t *testing.T) {
|
||||
db.tearDown()
|
||||
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
|
||||
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
|
||||
|
||||
want := fmt.Sprintf("%x", blob)
|
||||
|
||||
b := make([]byte, 16)
|
||||
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
|
||||
got := fmt.Sprintf("%x", b)
|
||||
if err != nil {
|
||||
t.Errorf("[]byte scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for []byte, got %q; want %q", got, want)
|
||||
}
|
||||
|
||||
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
|
||||
want = string(blob)
|
||||
if err != nil {
|
||||
t.Errorf("string scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for string, got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// testManyQueryRow is test for many query row
|
||||
func testManyQueryRow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Log("skipping in short mode")
|
||||
return
|
||||
}
|
||||
db.tearDown()
|
||||
db.mustExec("create table foo (id integer primary key, name varchar(50))")
|
||||
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||
var name string
|
||||
for i := 0; i < 10000; i++ {
|
||||
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
|
||||
if err != nil || name != "bob" {
|
||||
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testTxQuery is test for transactional query
|
||||
func testTxQuery(t *testing.T) {
|
||||
db.tearDown()
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Fatal("expected one rows")
|
||||
}
|
||||
|
||||
var name string
|
||||
err = r.Scan(&name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// testPreparedStmt is test for prepared statement
|
||||
func testPreparedStmt(t *testing.T) {
|
||||
db.tearDown()
|
||||
db.mustExec("CREATE TABLE t (count INT)")
|
||||
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 1: %v", err)
|
||||
}
|
||||
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 2: %v", err)
|
||||
}
|
||||
|
||||
for n := 1; n <= 3; n++ {
|
||||
if _, err := ins.Exec(n); err != nil {
|
||||
t.Fatalf("insert(%d) = %v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
const nRuns = 10
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nRuns; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
count := 0
|
||||
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
|
||||
t.Errorf("Query: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := ins.Exec(rand.Intn(100)); err != nil {
|
||||
t.Errorf("Insert: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Benchmarks need to use panic() since b.Error errors are lost when
|
||||
// running via testing.Benchmark() I would like to run these via go
|
||||
// test -bench but calling Benchmark() from a benchmark test
|
||||
// currently hangs go.
|
||||
|
||||
// benchmarkExec is benchmark for exec
|
||||
func benchmarkExec(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := db.Exec("select 1"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkQuery is benchmark for query
|
||||
func benchmarkQuery(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkParams is benchmark for params
|
||||
func benchmarkParams(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkStmt is benchmark for statement
|
||||
func benchmarkStmt(b *testing.B) {
|
||||
st, err := db.Prepare("select ?, ?, ?, ?")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkRows is benchmark for rows
|
||||
func benchmarkRows(b *testing.B) {
|
||||
db.once.Do(makeBench)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
var t time.Time
|
||||
r, err := db.Query("select * from bench")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for r.Next() {
|
||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err = r.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// benchmarkStmtRows is benchmark for statement rows
|
||||
func benchmarkStmtRows(b *testing.B) {
|
||||
db.once.Do(makeBench)
|
||||
|
||||
st, err := db.Prepare("select * from bench")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
var t time.Time
|
||||
r, err := st.Query()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for r.Next() {
|
||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err = r.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,409 +0,0 @@
|
|||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Dialect int
|
||||
|
||||
const (
|
||||
SQLITE Dialect = iota
|
||||
POSTGRESQL
|
||||
MYSQL
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
*testing.T
|
||||
*sql.DB
|
||||
dialect Dialect
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
var db *DB
|
||||
|
||||
// the following tables will be created and dropped during the test
|
||||
var testTables = []string{"foo", "bar", "t", "bench"}
|
||||
|
||||
var tests = []testing.InternalTest{
|
||||
{"TestBlobs", TestBlobs},
|
||||
{"TestManyQueryRow", TestManyQueryRow},
|
||||
{"TestTxQuery", TestTxQuery},
|
||||
{"TestPreparedStmt", TestPreparedStmt},
|
||||
}
|
||||
|
||||
var benchmarks = []testing.InternalBenchmark{
|
||||
{"BenchmarkExec", BenchmarkExec},
|
||||
{"BenchmarkQuery", BenchmarkQuery},
|
||||
{"BenchmarkParams", BenchmarkParams},
|
||||
{"BenchmarkStmt", BenchmarkStmt},
|
||||
{"BenchmarkRows", BenchmarkRows},
|
||||
{"BenchmarkStmtRows", BenchmarkStmtRows},
|
||||
}
|
||||
|
||||
// RunTests runs the SQL test suite
|
||||
func RunTests(t *testing.T, d *sql.DB, dialect Dialect) {
|
||||
db = &DB{t, d, dialect, sync.Once{}}
|
||||
testing.RunTests(func(string, string) (bool, error) { return true, nil }, tests)
|
||||
|
||||
if !testing.Short() {
|
||||
for _, b := range benchmarks {
|
||||
fmt.Printf("%-20s", b.Name)
|
||||
r := testing.Benchmark(b.F)
|
||||
fmt.Printf("%10d %10.0f req/s\n", r.N, float64(r.N)/r.T.Seconds())
|
||||
}
|
||||
}
|
||||
db.tearDown()
|
||||
}
|
||||
|
||||
func (db *DB) mustExec(sql string, args ...interface{}) sql.Result {
|
||||
res, err := db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
db.Fatalf("Error running %q: %v", sql, err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (db *DB) tearDown() {
|
||||
for _, tbl := range testTables {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
db.mustExec("drop table if exists " + tbl)
|
||||
case MYSQL, POSTGRESQL:
|
||||
db.mustExec("drop table if exists " + tbl)
|
||||
default:
|
||||
db.Fatal("unkown dialect")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// q replaces ? parameters if needed
|
||||
func (db *DB) q(sql string) string {
|
||||
switch db.dialect {
|
||||
case POSTGRESQL: // repace with $1, $2, ..
|
||||
qrx := regexp.MustCompile(`\?`)
|
||||
n := 0
|
||||
return qrx.ReplaceAllStringFunc(sql, func(string) string {
|
||||
n++
|
||||
return "$" + strconv.Itoa(n)
|
||||
})
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (db *DB) blobType(size int) string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return fmt.Sprintf("blob[%d]", size)
|
||||
case POSTGRESQL:
|
||||
return "bytea"
|
||||
case MYSQL:
|
||||
return fmt.Sprintf("VARBINARY(%d)", size)
|
||||
}
|
||||
panic("unkown dialect")
|
||||
}
|
||||
|
||||
func (db *DB) serialPK() string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return "integer primary key autoincrement"
|
||||
case POSTGRESQL:
|
||||
return "serial primary key"
|
||||
case MYSQL:
|
||||
return "integer primary key auto_increment"
|
||||
}
|
||||
panic("unkown dialect")
|
||||
}
|
||||
|
||||
func (db *DB) now() string {
|
||||
switch db.dialect {
|
||||
case SQLITE:
|
||||
return "datetime('now')"
|
||||
case POSTGRESQL:
|
||||
return "now()"
|
||||
case MYSQL:
|
||||
return "now()"
|
||||
}
|
||||
panic("unkown dialect")
|
||||
}
|
||||
|
||||
func makeBench() {
|
||||
if _, err := db.Exec("create table bench (n varchar(32), i integer, d double, s varchar(32), t datetime)"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
st, err := db.Prepare("insert into bench values (?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
for i := 0; i < 100; i++ {
|
||||
if _, err = st.Exec(nil, i, float64(i), fmt.Sprintf("%d", i), time.Now()); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResult(t *testing.T) {
|
||||
db.tearDown()
|
||||
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
|
||||
|
||||
for i := 1; i < 3; i++ {
|
||||
r := db.mustExec(db.q("insert into test (name) values (?)"), fmt.Sprintf("row %d", i))
|
||||
n, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("got %v, want %v", n, 1)
|
||||
}
|
||||
n, err = r.LastInsertId()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != int64(i) {
|
||||
t.Errorf("got %v, want %v", n, i)
|
||||
}
|
||||
}
|
||||
if _, err := db.Exec("error!"); err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBlobs(t *testing.T) {
|
||||
db.tearDown()
|
||||
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
db.mustExec("create table foo (id integer primary key, bar " + db.blobType(16) + ")")
|
||||
db.mustExec(db.q("insert into foo (id, bar) values(?,?)"), 0, blob)
|
||||
|
||||
want := fmt.Sprintf("%x", blob)
|
||||
|
||||
b := make([]byte, 16)
|
||||
err := db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&b)
|
||||
got := fmt.Sprintf("%x", b)
|
||||
if err != nil {
|
||||
t.Errorf("[]byte scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for []byte, got %q; want %q", got, want)
|
||||
}
|
||||
|
||||
err = db.QueryRow(db.q("select bar from foo where id = ?"), 0).Scan(&got)
|
||||
want = string(blob)
|
||||
if err != nil {
|
||||
t.Errorf("string scan: %v", err)
|
||||
} else if got != want {
|
||||
t.Errorf("for string, got %q; want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManyQueryRow(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Log("skipping in short mode")
|
||||
return
|
||||
}
|
||||
db.tearDown()
|
||||
db.mustExec("create table foo (id integer primary key, name varchar(50))")
|
||||
db.mustExec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||
var name string
|
||||
for i := 0; i < 10000; i++ {
|
||||
err := db.QueryRow(db.q("select name from foo where id = ?"), 1).Scan(&name)
|
||||
if err != nil || name != "bob" {
|
||||
t.Fatalf("on query %d: err=%v, name=%q", i, err, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTxQuery(t *testing.T) {
|
||||
db.tearDown()
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec("create table foo (id integer primary key, name varchar(50))")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(db.q("insert into foo (id, name) values(?,?)"), 1, "bob")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r, err := tx.Query(db.q("select name from foo where id = ?"), 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
if !r.Next() {
|
||||
if r.Err() != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Fatal("expected one rows")
|
||||
}
|
||||
|
||||
var name string
|
||||
err = r.Scan(&name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreparedStmt(t *testing.T) {
|
||||
db.tearDown()
|
||||
db.mustExec("CREATE TABLE t (count INT)")
|
||||
sel, err := db.Prepare("SELECT count FROM t ORDER BY count DESC")
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 1: %v", err)
|
||||
}
|
||||
ins, err := db.Prepare(db.q("INSERT INTO t (count) VALUES (?)"))
|
||||
if err != nil {
|
||||
t.Fatalf("prepare 2: %v", err)
|
||||
}
|
||||
|
||||
for n := 1; n <= 3; n++ {
|
||||
if _, err := ins.Exec(n); err != nil {
|
||||
t.Fatalf("insert(%d) = %v", n, err)
|
||||
}
|
||||
}
|
||||
|
||||
const nRuns = 10
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < nRuns; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
count := 0
|
||||
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
|
||||
t.Errorf("Query: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := ins.Exec(rand.Intn(100)); err != nil {
|
||||
t.Errorf("Insert: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Benchmarks need to use panic() since b.Error errors are lost when
|
||||
// running via testing.Benchmark() I would like to run these via go
|
||||
// test -bench but calling Benchmark() from a benchmark test
|
||||
// currently hangs go.
|
||||
|
||||
func BenchmarkExec(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := db.Exec("select 1"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkQuery(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParams(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStmt(b *testing.B) {
|
||||
st, err := db.Prepare("select ?, ?, ?, ?")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
// var t time.Time
|
||||
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRows(b *testing.B) {
|
||||
db.once.Do(makeBench)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
var t time.Time
|
||||
r, err := db.Query("select * from bench")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for r.Next() {
|
||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err = r.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStmtRows(b *testing.B) {
|
||||
db.once.Do(makeBench)
|
||||
|
||||
st, err := db.Prepare("select * from bench")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer st.Close()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
var n sql.NullString
|
||||
var i int
|
||||
var f float64
|
||||
var s string
|
||||
var t time.Time
|
||||
r, err := st.Query()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for r.Next() {
|
||||
if err = r.Scan(&n, &i, &f, &s, &t); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if err = r.Err(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,287 @@
|
|||
// Copyright (C) 2016 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build trace
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
|
||||
int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Trace... constants identify the possible events causing callback invocation.
|
||||
// Values are same as the corresponding SQLite Trace Event Codes.
|
||||
const (
|
||||
TraceStmt = uint32(C.SQLITE_TRACE_STMT)
|
||||
TraceProfile = uint32(C.SQLITE_TRACE_PROFILE)
|
||||
TraceRow = uint32(C.SQLITE_TRACE_ROW)
|
||||
TraceClose = uint32(C.SQLITE_TRACE_CLOSE)
|
||||
)
|
||||
|
||||
type TraceInfo struct {
|
||||
// Pack together the shorter fields, to keep the struct smaller.
|
||||
// On a 64-bit machine there would be padding
|
||||
// between EventCode and ConnHandle; having AutoCommit here is "free":
|
||||
EventCode uint32
|
||||
AutoCommit bool
|
||||
ConnHandle uintptr
|
||||
|
||||
// Usually filled, unless EventCode = TraceClose = SQLITE_TRACE_CLOSE:
|
||||
// identifier for a prepared statement:
|
||||
StmtHandle uintptr
|
||||
|
||||
// Two strings filled when EventCode = TraceStmt = SQLITE_TRACE_STMT:
|
||||
// (1) either the unexpanded SQL text of the prepared statement, or
|
||||
// an SQL comment that indicates the invocation of a trigger;
|
||||
// (2) expanded SQL, if requested and if (1) is not an SQL comment.
|
||||
StmtOrTrigger string
|
||||
ExpandedSQL string // only if requested (TraceConfig.WantExpandedSQL = true)
|
||||
|
||||
// filled when EventCode = TraceProfile = SQLITE_TRACE_PROFILE:
|
||||
// estimated number of nanoseconds that the prepared statement took to run:
|
||||
RunTimeNanosec int64
|
||||
|
||||
DBError Error
|
||||
}
|
||||
|
||||
// TraceUserCallback gives the signature for a trace function
|
||||
// provided by the user (Go application programmer).
|
||||
// SQLite 3.14 documentation (as of September 2, 2016)
|
||||
// for SQL Trace Hook = sqlite3_trace_v2():
|
||||
// The integer return value from the callback is currently ignored,
|
||||
// though this may change in future releases. Callback implementations
|
||||
// should return zero to ensure future compatibility.
|
||||
type TraceUserCallback func(TraceInfo) int
|
||||
|
||||
type TraceConfig struct {
|
||||
Callback TraceUserCallback
|
||||
EventMask uint32
|
||||
WantExpandedSQL bool
|
||||
}
|
||||
|
||||
func fillDBError(dbErr *Error, db *C.sqlite3) {
|
||||
// See SQLiteConn.lastError(), in file 'sqlite3.go' at the time of writing (Sept 5, 2016)
|
||||
dbErr.Code = ErrNo(C.sqlite3_errcode(db))
|
||||
dbErr.ExtendedCode = ErrNoExtended(C.sqlite3_extended_errcode(db))
|
||||
dbErr.err = C.GoString(C.sqlite3_errmsg(db))
|
||||
}
|
||||
|
||||
func fillExpandedSQL(info *TraceInfo, db *C.sqlite3, pStmt unsafe.Pointer) {
|
||||
if pStmt == nil {
|
||||
panic("No SQLite statement pointer in P arg of trace_v2 callback")
|
||||
}
|
||||
|
||||
expSQLiteCStr := C.sqlite3_expanded_sql((*C.sqlite3_stmt)(pStmt))
|
||||
if expSQLiteCStr == nil {
|
||||
fillDBError(&info.DBError, db)
|
||||
return
|
||||
}
|
||||
info.ExpandedSQL = C.GoString(expSQLiteCStr)
|
||||
}
|
||||
|
||||
//export traceCallbackTrampoline
|
||||
func traceCallbackTrampoline(
|
||||
traceEventCode C.uint,
|
||||
// Parameter named 'C' in SQLite docs = Context given at registration:
|
||||
ctx unsafe.Pointer,
|
||||
// Parameter named 'P' in SQLite docs (Primary event data?):
|
||||
p unsafe.Pointer,
|
||||
// Parameter named 'X' in SQLite docs (eXtra event data?):
|
||||
xValue unsafe.Pointer) C.int {
|
||||
|
||||
eventCode := uint32(traceEventCode)
|
||||
|
||||
if ctx == nil {
|
||||
panic(fmt.Sprintf("No context (ev 0x%x)", traceEventCode))
|
||||
}
|
||||
|
||||
contextDB := (*C.sqlite3)(ctx)
|
||||
connHandle := uintptr(ctx)
|
||||
|
||||
var traceConf TraceConfig
|
||||
var found bool
|
||||
if eventCode == TraceClose {
|
||||
// clean up traceMap: 'pop' means get and delete
|
||||
traceConf, found = popTraceMapping(connHandle)
|
||||
} else {
|
||||
traceConf, found = lookupTraceMapping(connHandle)
|
||||
}
|
||||
|
||||
if !found {
|
||||
panic(fmt.Sprintf("Mapping not found for handle 0x%x (ev 0x%x)",
|
||||
connHandle, eventCode))
|
||||
}
|
||||
|
||||
var info TraceInfo
|
||||
|
||||
info.EventCode = eventCode
|
||||
info.AutoCommit = (int(C.sqlite3_get_autocommit(contextDB)) != 0)
|
||||
info.ConnHandle = connHandle
|
||||
|
||||
switch eventCode {
|
||||
case TraceStmt:
|
||||
info.StmtHandle = uintptr(p)
|
||||
|
||||
var xStr string
|
||||
if xValue != nil {
|
||||
xStr = C.GoString((*C.char)(xValue))
|
||||
}
|
||||
info.StmtOrTrigger = xStr
|
||||
if !strings.HasPrefix(xStr, "--") {
|
||||
// Not SQL comment, therefore the current event
|
||||
// is not related to a trigger.
|
||||
// The user might want to receive the expanded SQL;
|
||||
// let's check:
|
||||
if traceConf.WantExpandedSQL {
|
||||
fillExpandedSQL(&info, contextDB, p)
|
||||
}
|
||||
}
|
||||
|
||||
case TraceProfile:
|
||||
info.StmtHandle = uintptr(p)
|
||||
|
||||
if xValue == nil {
|
||||
panic("NULL pointer in X arg of trace_v2 callback for SQLITE_TRACE_PROFILE event")
|
||||
}
|
||||
|
||||
info.RunTimeNanosec = *(*int64)(xValue)
|
||||
|
||||
// sample the error //TODO: is it safe? is it useful?
|
||||
fillDBError(&info.DBError, contextDB)
|
||||
|
||||
case TraceRow:
|
||||
info.StmtHandle = uintptr(p)
|
||||
|
||||
case TraceClose:
|
||||
handle := uintptr(p)
|
||||
if handle != info.ConnHandle {
|
||||
panic(fmt.Sprintf("Different conn handle 0x%x (expected 0x%x) in SQLITE_TRACE_CLOSE event.",
|
||||
handle, info.ConnHandle))
|
||||
}
|
||||
|
||||
default:
|
||||
// Pass unsupported events to the user callback (if configured);
|
||||
// let the user callback decide whether to panic or ignore them.
|
||||
}
|
||||
|
||||
// Do not execute user callback when the event was not requested by user!
|
||||
// Remember that the Close event is always selected when
|
||||
// registering this callback trampoline with SQLite --- for cleanup.
|
||||
// In the future there may be more events forced to "selected" in SQLite
|
||||
// for the driver's needs.
|
||||
if traceConf.EventMask&eventCode == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
r := 0
|
||||
if traceConf.Callback != nil {
|
||||
r = traceConf.Callback(info)
|
||||
}
|
||||
return C.int(r)
|
||||
}
|
||||
|
||||
type traceMapEntry struct {
|
||||
config TraceConfig
|
||||
}
|
||||
|
||||
var traceMapLock sync.Mutex
|
||||
var traceMap = make(map[uintptr]traceMapEntry)
|
||||
|
||||
func addTraceMapping(connHandle uintptr, traceConf TraceConfig) {
|
||||
traceMapLock.Lock()
|
||||
defer traceMapLock.Unlock()
|
||||
|
||||
oldEntryCopy, found := traceMap[connHandle]
|
||||
if found {
|
||||
panic(fmt.Sprintf("Adding trace config %v: handle 0x%x already registered (%v).",
|
||||
traceConf, connHandle, oldEntryCopy.config))
|
||||
}
|
||||
traceMap[connHandle] = traceMapEntry{config: traceConf}
|
||||
fmt.Printf("Added trace config %v: handle 0x%x.\n", traceConf, connHandle)
|
||||
}
|
||||
|
||||
func lookupTraceMapping(connHandle uintptr) (TraceConfig, bool) {
|
||||
traceMapLock.Lock()
|
||||
defer traceMapLock.Unlock()
|
||||
|
||||
entryCopy, found := traceMap[connHandle]
|
||||
return entryCopy.config, found
|
||||
}
|
||||
|
||||
// 'pop' = get and delete from map before returning the value to the caller
|
||||
func popTraceMapping(connHandle uintptr) (TraceConfig, bool) {
|
||||
traceMapLock.Lock()
|
||||
defer traceMapLock.Unlock()
|
||||
|
||||
entryCopy, found := traceMap[connHandle]
|
||||
if found {
|
||||
delete(traceMap, connHandle)
|
||||
fmt.Printf("Pop handle 0x%x: deleted trace config %v.\n", connHandle, entryCopy.config)
|
||||
}
|
||||
return entryCopy.config, found
|
||||
}
|
||||
|
||||
// SetTrace installs or removes the trace callback for the given database connection.
|
||||
// It's not named 'RegisterTrace' because only one callback can be kept and called.
|
||||
// Calling SetTrace a second time on same database connection
|
||||
// overrides (cancels) any prior callback and all its settings:
|
||||
// event mask, etc.
|
||||
func (c *SQLiteConn) SetTrace(requested *TraceConfig) error {
|
||||
connHandle := uintptr(unsafe.Pointer(c.db))
|
||||
|
||||
_, _ = popTraceMapping(connHandle)
|
||||
|
||||
if requested == nil {
|
||||
// The traceMap entry was deleted already by popTraceMapping():
|
||||
// can disable all events now, no need to watch for TraceClose.
|
||||
err := c.setSQLiteTrace(0)
|
||||
return err
|
||||
}
|
||||
|
||||
reqCopy := *requested
|
||||
|
||||
// Disable potentially expensive operations
|
||||
// if their result will not be used. We are doing this
|
||||
// just in case the caller provided nonsensical input.
|
||||
if reqCopy.EventMask&TraceStmt == 0 {
|
||||
reqCopy.WantExpandedSQL = false
|
||||
}
|
||||
|
||||
addTraceMapping(connHandle, reqCopy)
|
||||
|
||||
// The callback trampoline function does cleanup on Close event,
|
||||
// regardless of the presence or absence of the user callback.
|
||||
// Therefore it needs the Close event to be selected:
|
||||
actualEventMask := uint(reqCopy.EventMask | TraceClose)
|
||||
err := c.setSQLiteTrace(actualEventMask)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) setSQLiteTrace(sqliteEventMask uint) error {
|
||||
rv := C.sqlite3_trace_v2(c.db,
|
||||
C.uint(sqliteEventMask),
|
||||
(*[0]byte)(unsafe.Pointer(C.traceCallbackTrampoline)),
|
||||
unsafe.Pointer(c.db)) // Fourth arg is same as first: we are
|
||||
// passing the database connection handle as callback context.
|
||||
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package sqlite3
|
||||
|
||||
/*
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ColumnTypeDatabaseTypeName implement RowsColumnTypeDatabaseTypeName.
|
||||
func (rc *SQLiteRows) ColumnTypeDatabaseTypeName(i int) string {
|
||||
return C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))
|
||||
}
|
||||
|
||||
/*
|
||||
func (rc *SQLiteRows) ColumnTypeLength(index int) (length int64, ok bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (rc *SQLiteRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
|
||||
return 0, 0, false
|
||||
}
|
||||
*/
|
||||
|
||||
// ColumnTypeNullable implement RowsColumnTypeNullable.
|
||||
func (rc *SQLiteRows) ColumnTypeNullable(i int) (nullable, ok bool) {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// ColumnTypeScanType implement RowsColumnTypeScanType.
|
||||
func (rc *SQLiteRows) ColumnTypeScanType(i int) reflect.Type {
|
||||
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
|
||||
case C.SQLITE_INTEGER:
|
||||
switch C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))) {
|
||||
case "timestamp", "datetime", "date":
|
||||
return reflect.TypeOf(time.Time{})
|
||||
case "boolean":
|
||||
return reflect.TypeOf(false)
|
||||
}
|
||||
return reflect.TypeOf(int64(0))
|
||||
case C.SQLITE_FLOAT:
|
||||
return reflect.TypeOf(float64(0))
|
||||
case C.SQLITE_BLOB:
|
||||
return reflect.SliceOf(reflect.TypeOf(byte(0)))
|
||||
case C.SQLITE_NULL:
|
||||
return reflect.TypeOf(nil)
|
||||
case C.SQLITE_TEXT:
|
||||
return reflect.TypeOf("")
|
||||
}
|
||||
return reflect.SliceOf(reflect.TypeOf(byte(0)))
|
||||
}
|
|
@ -0,0 +1,646 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build vtable
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -std=gnu99
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
|
||||
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_COLUMN_METADATA=1
|
||||
#cgo CFLAGS: -Wno-deprecated-declarations
|
||||
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <memory.h>
|
||||
|
||||
static inline char *_sqlite3_mprintf(char *zFormat, char *arg) {
|
||||
return sqlite3_mprintf(zFormat, arg);
|
||||
}
|
||||
|
||||
typedef struct goVTab goVTab;
|
||||
|
||||
struct goVTab {
|
||||
sqlite3_vtab base;
|
||||
void *vTab;
|
||||
};
|
||||
|
||||
uintptr_t goMInit(void *db, void *pAux, int argc, char **argv, char **pzErr, int isCreate);
|
||||
|
||||
static int cXInit(sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr, int isCreate) {
|
||||
void *vTab = (void *)goMInit(db, pAux, argc, (char**)argv, pzErr, isCreate);
|
||||
if (!vTab || *pzErr) {
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
goVTab *pvTab = (goVTab *)sqlite3_malloc(sizeof(goVTab));
|
||||
if (!pvTab) {
|
||||
*pzErr = sqlite3_mprintf("%s", "Out of memory");
|
||||
return SQLITE_NOMEM;
|
||||
}
|
||||
memset(pvTab, 0, sizeof(goVTab));
|
||||
pvTab->vTab = vTab;
|
||||
|
||||
*ppVTab = (sqlite3_vtab *)pvTab;
|
||||
*pzErr = 0;
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static inline int cXCreate(sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr) {
|
||||
return cXInit(db, pAux, argc, argv, ppVTab, pzErr, 1);
|
||||
}
|
||||
static inline int cXConnect(sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr) {
|
||||
return cXInit(db, pAux, argc, argv, ppVTab, pzErr, 0);
|
||||
}
|
||||
|
||||
char* goVBestIndex(void *pVTab, void *icp);
|
||||
|
||||
static inline int cXBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *info) {
|
||||
char *pzErr = goVBestIndex(((goVTab*)pVTab)->vTab, info);
|
||||
if (pzErr) {
|
||||
if (pVTab->zErrMsg)
|
||||
sqlite3_free(pVTab->zErrMsg);
|
||||
pVTab->zErrMsg = pzErr;
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
char* goVRelease(void *pVTab, int isDestroy);
|
||||
|
||||
static int cXRelease(sqlite3_vtab *pVTab, int isDestroy) {
|
||||
char *pzErr = goVRelease(((goVTab*)pVTab)->vTab, isDestroy);
|
||||
if (pzErr) {
|
||||
if (pVTab->zErrMsg)
|
||||
sqlite3_free(pVTab->zErrMsg);
|
||||
pVTab->zErrMsg = pzErr;
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
if (pVTab->zErrMsg)
|
||||
sqlite3_free(pVTab->zErrMsg);
|
||||
sqlite3_free(pVTab);
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static inline int cXDisconnect(sqlite3_vtab *pVTab) {
|
||||
return cXRelease(pVTab, 0);
|
||||
}
|
||||
static inline int cXDestroy(sqlite3_vtab *pVTab) {
|
||||
return cXRelease(pVTab, 1);
|
||||
}
|
||||
|
||||
typedef struct goVTabCursor goVTabCursor;
|
||||
|
||||
struct goVTabCursor {
|
||||
sqlite3_vtab_cursor base;
|
||||
void *vTabCursor;
|
||||
};
|
||||
|
||||
uintptr_t goVOpen(void *pVTab, char **pzErr);
|
||||
|
||||
static int cXOpen(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor) {
|
||||
void *vTabCursor = (void *)goVOpen(((goVTab*)pVTab)->vTab, &(pVTab->zErrMsg));
|
||||
goVTabCursor *pCursor = (goVTabCursor *)sqlite3_malloc(sizeof(goVTabCursor));
|
||||
if (!pCursor) {
|
||||
return SQLITE_NOMEM;
|
||||
}
|
||||
memset(pCursor, 0, sizeof(goVTabCursor));
|
||||
pCursor->vTabCursor = vTabCursor;
|
||||
*ppCursor = (sqlite3_vtab_cursor *)pCursor;
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static int setErrMsg(sqlite3_vtab_cursor *pCursor, char *pzErr) {
|
||||
if (pCursor->pVtab->zErrMsg)
|
||||
sqlite3_free(pCursor->pVtab->zErrMsg);
|
||||
pCursor->pVtab->zErrMsg = pzErr;
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
|
||||
char* goVClose(void *pCursor);
|
||||
|
||||
static int cXClose(sqlite3_vtab_cursor *pCursor) {
|
||||
char *pzErr = goVClose(((goVTabCursor*)pCursor)->vTabCursor);
|
||||
if (pzErr) {
|
||||
return setErrMsg(pCursor, pzErr);
|
||||
}
|
||||
sqlite3_free(pCursor);
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
char* goVFilter(void *pCursor, int idxNum, char* idxName, int argc, sqlite3_value **argv);
|
||||
|
||||
static int cXFilter(sqlite3_vtab_cursor *pCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) {
|
||||
char *pzErr = goVFilter(((goVTabCursor*)pCursor)->vTabCursor, idxNum, (char*)idxStr, argc, argv);
|
||||
if (pzErr) {
|
||||
return setErrMsg(pCursor, pzErr);
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
char* goVNext(void *pCursor);
|
||||
|
||||
static int cXNext(sqlite3_vtab_cursor *pCursor) {
|
||||
char *pzErr = goVNext(((goVTabCursor*)pCursor)->vTabCursor);
|
||||
if (pzErr) {
|
||||
return setErrMsg(pCursor, pzErr);
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
int goVEof(void *pCursor);
|
||||
|
||||
static inline int cXEof(sqlite3_vtab_cursor *pCursor) {
|
||||
return goVEof(((goVTabCursor*)pCursor)->vTabCursor);
|
||||
}
|
||||
|
||||
char* goVColumn(void *pCursor, void *cp, int col);
|
||||
|
||||
static int cXColumn(sqlite3_vtab_cursor *pCursor, sqlite3_context *ctx, int i) {
|
||||
char *pzErr = goVColumn(((goVTabCursor*)pCursor)->vTabCursor, ctx, i);
|
||||
if (pzErr) {
|
||||
return setErrMsg(pCursor, pzErr);
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
char* goVRowid(void *pCursor, sqlite3_int64 *pRowid);
|
||||
|
||||
static int cXRowid(sqlite3_vtab_cursor *pCursor, sqlite3_int64 *pRowid) {
|
||||
char *pzErr = goVRowid(((goVTabCursor*)pCursor)->vTabCursor, pRowid);
|
||||
if (pzErr) {
|
||||
return setErrMsg(pCursor, pzErr);
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
char* goVUpdate(void *pVTab, int argc, sqlite3_value **argv, sqlite3_int64 *pRowid);
|
||||
|
||||
static int cXUpdate(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, sqlite3_int64 *pRowid) {
|
||||
char *pzErr = goVUpdate(((goVTab*)pVTab)->vTab, argc, argv, pRowid);
|
||||
if (pzErr) {
|
||||
if (pVTab->zErrMsg)
|
||||
sqlite3_free(pVTab->zErrMsg);
|
||||
pVTab->zErrMsg = pzErr;
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static sqlite3_module goModule = {
|
||||
0, // iVersion
|
||||
cXCreate, // xCreate - create a table
|
||||
cXConnect, // xConnect - connect to an existing table
|
||||
cXBestIndex, // xBestIndex - Determine search strategy
|
||||
cXDisconnect, // xDisconnect - Disconnect from a table
|
||||
cXDestroy, // xDestroy - Drop a table
|
||||
cXOpen, // xOpen - open a cursor
|
||||
cXClose, // xClose - close a cursor
|
||||
cXFilter, // xFilter - configure scan constraints
|
||||
cXNext, // xNext - advance a cursor
|
||||
cXEof, // xEof
|
||||
cXColumn, // xColumn - read data
|
||||
cXRowid, // xRowid - read data
|
||||
cXUpdate, // xUpdate - write data
|
||||
// Not implemented
|
||||
0, // xBegin - begin transaction
|
||||
0, // xSync - sync transaction
|
||||
0, // xCommit - commit transaction
|
||||
0, // xRollback - rollback transaction
|
||||
0, // xFindFunction - function overloading
|
||||
0, // xRename - rename the table
|
||||
0, // xSavepoint
|
||||
0, // xRelease
|
||||
0 // xRollbackTo
|
||||
};
|
||||
|
||||
void goMDestroy(void*);
|
||||
|
||||
static int _sqlite3_create_module(sqlite3 *db, const char *zName, uintptr_t pClientData) {
|
||||
return sqlite3_create_module_v2(db, zName, &goModule, (void*) pClientData, goMDestroy);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type sqliteModule struct {
|
||||
c *SQLiteConn
|
||||
name string
|
||||
module Module
|
||||
}
|
||||
|
||||
type sqliteVTab struct {
|
||||
module *sqliteModule
|
||||
vTab VTab
|
||||
}
|
||||
|
||||
type sqliteVTabCursor struct {
|
||||
vTab *sqliteVTab
|
||||
vTabCursor VTabCursor
|
||||
}
|
||||
|
||||
// Op is type of operations.
|
||||
type Op uint8
|
||||
|
||||
// Op mean identity of operations.
|
||||
const (
|
||||
OpEQ Op = 2
|
||||
OpGT = 4
|
||||
OpLE = 8
|
||||
OpLT = 16
|
||||
OpGE = 32
|
||||
OpMATCH = 64
|
||||
OpLIKE = 65 /* 3.10.0 and later only */
|
||||
OpGLOB = 66 /* 3.10.0 and later only */
|
||||
OpREGEXP = 67 /* 3.10.0 and later only */
|
||||
OpScanUnique = 1 /* Scan visits at most 1 row */
|
||||
)
|
||||
|
||||
// InfoConstraint give information of constraint.
|
||||
type InfoConstraint struct {
|
||||
Column int
|
||||
Op Op
|
||||
Usable bool
|
||||
}
|
||||
|
||||
// InfoOrderBy give information of order-by.
|
||||
type InfoOrderBy struct {
|
||||
Column int
|
||||
Desc bool
|
||||
}
|
||||
|
||||
func constraints(info *C.sqlite3_index_info) []InfoConstraint {
|
||||
l := info.nConstraint
|
||||
slice := (*[1 << 30]C.struct_sqlite3_index_constraint)(unsafe.Pointer(info.aConstraint))[:l:l]
|
||||
|
||||
cst := make([]InfoConstraint, 0, l)
|
||||
for _, c := range slice {
|
||||
var usable bool
|
||||
if c.usable > 0 {
|
||||
usable = true
|
||||
}
|
||||
cst = append(cst, InfoConstraint{
|
||||
Column: int(c.iColumn),
|
||||
Op: Op(c.op),
|
||||
Usable: usable,
|
||||
})
|
||||
}
|
||||
return cst
|
||||
}
|
||||
|
||||
func orderBys(info *C.sqlite3_index_info) []InfoOrderBy {
|
||||
l := info.nOrderBy
|
||||
slice := (*[1 << 30]C.struct_sqlite3_index_orderby)(unsafe.Pointer(info.aOrderBy))[:l:l]
|
||||
|
||||
ob := make([]InfoOrderBy, 0, l)
|
||||
for _, c := range slice {
|
||||
var desc bool
|
||||
if c.desc > 0 {
|
||||
desc = true
|
||||
}
|
||||
ob = append(ob, InfoOrderBy{
|
||||
Column: int(c.iColumn),
|
||||
Desc: desc,
|
||||
})
|
||||
}
|
||||
return ob
|
||||
}
|
||||
|
||||
// IndexResult is a Go struct representation of what eventually ends up in the
|
||||
// output fields for `sqlite3_index_info`
|
||||
// See: https://www.sqlite.org/c3ref/index_info.html
|
||||
type IndexResult struct {
|
||||
Used []bool // aConstraintUsage
|
||||
IdxNum int
|
||||
IdxStr string
|
||||
AlreadyOrdered bool // orderByConsumed
|
||||
EstimatedCost float64
|
||||
EstimatedRows float64
|
||||
}
|
||||
|
||||
// mPrintf is a utility wrapper around sqlite3_mprintf
|
||||
func mPrintf(format, arg string) *C.char {
|
||||
cf := C.CString(format)
|
||||
defer C.free(unsafe.Pointer(cf))
|
||||
ca := C.CString(arg)
|
||||
defer C.free(unsafe.Pointer(ca))
|
||||
return C._sqlite3_mprintf(cf, ca)
|
||||
}
|
||||
|
||||
//export goMInit
|
||||
func goMInit(db, pClientData unsafe.Pointer, argc C.int, argv **C.char, pzErr **C.char, isCreate C.int) C.uintptr_t {
|
||||
m := lookupHandle(uintptr(pClientData)).(*sqliteModule)
|
||||
if m.c.db != (*C.sqlite3)(db) {
|
||||
*pzErr = mPrintf("%s", "Inconsistent db handles")
|
||||
return 0
|
||||
}
|
||||
args := make([]string, argc)
|
||||
var A []*C.char
|
||||
slice := reflect.SliceHeader{Data: uintptr(unsafe.Pointer(argv)), Len: int(argc), Cap: int(argc)}
|
||||
a := reflect.NewAt(reflect.TypeOf(A), unsafe.Pointer(&slice)).Elem().Interface()
|
||||
for i, s := range a.([]*C.char) {
|
||||
args[i] = C.GoString(s)
|
||||
}
|
||||
var vTab VTab
|
||||
var err error
|
||||
if isCreate == 1 {
|
||||
vTab, err = m.module.Create(m.c, args)
|
||||
} else {
|
||||
vTab, err = m.module.Connect(m.c, args)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
*pzErr = mPrintf("%s", err.Error())
|
||||
return 0
|
||||
}
|
||||
vt := sqliteVTab{m, vTab}
|
||||
*pzErr = nil
|
||||
return C.uintptr_t(newHandle(m.c, &vt))
|
||||
}
|
||||
|
||||
//export goVRelease
|
||||
func goVRelease(pVTab unsafe.Pointer, isDestroy C.int) *C.char {
|
||||
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
|
||||
var err error
|
||||
if isDestroy == 1 {
|
||||
err = vt.vTab.Destroy()
|
||||
} else {
|
||||
err = vt.vTab.Disconnect()
|
||||
}
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVOpen
|
||||
func goVOpen(pVTab unsafe.Pointer, pzErr **C.char) C.uintptr_t {
|
||||
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
|
||||
vTabCursor, err := vt.vTab.Open()
|
||||
if err != nil {
|
||||
*pzErr = mPrintf("%s", err.Error())
|
||||
return 0
|
||||
}
|
||||
vtc := sqliteVTabCursor{vt, vTabCursor}
|
||||
*pzErr = nil
|
||||
return C.uintptr_t(newHandle(vt.module.c, &vtc))
|
||||
}
|
||||
|
||||
//export goVBestIndex
|
||||
func goVBestIndex(pVTab unsafe.Pointer, icp unsafe.Pointer) *C.char {
|
||||
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
|
||||
info := (*C.sqlite3_index_info)(icp)
|
||||
csts := constraints(info)
|
||||
res, err := vt.vTab.BestIndex(csts, orderBys(info))
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
if len(res.Used) != len(csts) {
|
||||
return mPrintf("Result.Used != expected value", "")
|
||||
}
|
||||
|
||||
// Get a pointer to constraint_usage struct so we can update in place.
|
||||
l := info.nConstraint
|
||||
s := (*[1 << 30]C.struct_sqlite3_index_constraint_usage)(unsafe.Pointer(info.aConstraintUsage))[:l:l]
|
||||
index := 1
|
||||
for i := C.int(0); i < info.nConstraint; i++ {
|
||||
if res.Used[i] {
|
||||
s[i].argvIndex = C.int(index)
|
||||
s[i].omit = C.uchar(1)
|
||||
index++
|
||||
}
|
||||
}
|
||||
|
||||
info.idxNum = C.int(res.IdxNum)
|
||||
idxStr := C.CString(res.IdxStr)
|
||||
defer C.free(unsafe.Pointer(idxStr))
|
||||
info.idxStr = idxStr
|
||||
info.needToFreeIdxStr = C.int(0)
|
||||
if res.AlreadyOrdered {
|
||||
info.orderByConsumed = C.int(1)
|
||||
}
|
||||
info.estimatedCost = C.double(res.EstimatedCost)
|
||||
info.estimatedRows = C.sqlite3_int64(res.EstimatedRows)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVClose
|
||||
func goVClose(pCursor unsafe.Pointer) *C.char {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
err := vtc.vTabCursor.Close()
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goMDestroy
|
||||
func goMDestroy(pClientData unsafe.Pointer) {
|
||||
m := lookupHandle(uintptr(pClientData)).(*sqliteModule)
|
||||
m.module.DestroyModule()
|
||||
}
|
||||
|
||||
//export goVFilter
|
||||
func goVFilter(pCursor unsafe.Pointer, idxNum C.int, idxName *C.char, argc C.int, argv **C.sqlite3_value) *C.char {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||
vals := make([]interface{}, 0, argc)
|
||||
for _, v := range args {
|
||||
conv, err := callbackArgGeneric(v)
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
vals = append(vals, conv.Interface())
|
||||
}
|
||||
err := vtc.vTabCursor.Filter(int(idxNum), C.GoString(idxName), vals)
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVNext
|
||||
func goVNext(pCursor unsafe.Pointer) *C.char {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
err := vtc.vTabCursor.Next()
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVEof
|
||||
func goVEof(pCursor unsafe.Pointer) C.int {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
err := vtc.vTabCursor.EOF()
|
||||
if err {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
//export goVColumn
|
||||
func goVColumn(pCursor, cp unsafe.Pointer, col C.int) *C.char {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
c := (*SQLiteContext)(cp)
|
||||
err := vtc.vTabCursor.Column(c, int(col))
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVRowid
|
||||
func goVRowid(pCursor unsafe.Pointer, pRowid *C.sqlite3_int64) *C.char {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
rowid, err := vtc.vTabCursor.Rowid()
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
*pRowid = C.sqlite3_int64(rowid)
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVUpdate
|
||||
func goVUpdate(pVTab unsafe.Pointer, argc C.int, argv **C.sqlite3_value, pRowid *C.sqlite3_int64) *C.char {
|
||||
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
|
||||
|
||||
var tname string
|
||||
if n, ok := vt.vTab.(interface {
|
||||
TableName() string
|
||||
}); ok {
|
||||
tname = n.TableName() + " "
|
||||
}
|
||||
|
||||
err := fmt.Errorf("virtual %s table %sis read-only", vt.module.name, tname)
|
||||
if v, ok := vt.vTab.(VTabUpdater); ok {
|
||||
// convert argv
|
||||
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||
vals := make([]interface{}, 0, argc)
|
||||
for _, v := range args {
|
||||
conv, err := callbackArgGeneric(v)
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
|
||||
// work around for SQLITE_NULL
|
||||
x := conv.Interface()
|
||||
if z, ok := x.([]byte); ok && z == nil {
|
||||
x = nil
|
||||
}
|
||||
|
||||
vals = append(vals, x)
|
||||
}
|
||||
|
||||
switch {
|
||||
case argc == 1:
|
||||
err = v.Delete(vals[0])
|
||||
|
||||
case argc > 1 && vals[0] == nil:
|
||||
var id int64
|
||||
id, err = v.Insert(vals[1], vals[2:])
|
||||
if err == nil {
|
||||
*pRowid = C.sqlite3_int64(id)
|
||||
}
|
||||
|
||||
case argc > 1:
|
||||
err = v.Update(vals[1], vals[2:])
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Module is a "virtual table module", it defines the implementation of a
|
||||
// virtual tables. See: http://sqlite.org/c3ref/module.html
|
||||
type Module interface {
|
||||
// http://sqlite.org/vtab.html#xcreate
|
||||
Create(c *SQLiteConn, args []string) (VTab, error)
|
||||
// http://sqlite.org/vtab.html#xconnect
|
||||
Connect(c *SQLiteConn, args []string) (VTab, error)
|
||||
// http://sqlite.org/c3ref/create_module.html
|
||||
DestroyModule()
|
||||
}
|
||||
|
||||
// VTab describes a particular instance of the virtual table.
|
||||
// See: http://sqlite.org/c3ref/vtab.html
|
||||
type VTab interface {
|
||||
// http://sqlite.org/vtab.html#xbestindex
|
||||
BestIndex([]InfoConstraint, []InfoOrderBy) (*IndexResult, error)
|
||||
// http://sqlite.org/vtab.html#xdisconnect
|
||||
Disconnect() error
|
||||
// http://sqlite.org/vtab.html#sqlite3_module.xDestroy
|
||||
Destroy() error
|
||||
// http://sqlite.org/vtab.html#xopen
|
||||
Open() (VTabCursor, error)
|
||||
}
|
||||
|
||||
// VTabUpdater is a type that allows a VTab to be inserted, updated, or
|
||||
// deleted.
|
||||
// See: https://sqlite.org/vtab.html#xupdate
|
||||
type VTabUpdater interface {
|
||||
Delete(interface{}) error
|
||||
Insert(interface{}, []interface{}) (int64, error)
|
||||
Update(interface{}, []interface{}) error
|
||||
}
|
||||
|
||||
// VTabCursor describes cursors that point into the virtual table and are used
|
||||
// to loop through the virtual table. See: http://sqlite.org/c3ref/vtab_cursor.html
|
||||
type VTabCursor interface {
|
||||
// http://sqlite.org/vtab.html#xclose
|
||||
Close() error
|
||||
// http://sqlite.org/vtab.html#xfilter
|
||||
Filter(idxNum int, idxStr string, vals []interface{}) error
|
||||
// http://sqlite.org/vtab.html#xnext
|
||||
Next() error
|
||||
// http://sqlite.org/vtab.html#xeof
|
||||
EOF() bool
|
||||
// http://sqlite.org/vtab.html#xcolumn
|
||||
Column(c *SQLiteContext, col int) error
|
||||
// http://sqlite.org/vtab.html#xrowid
|
||||
Rowid() (int64, error)
|
||||
}
|
||||
|
||||
// DeclareVTab declares the Schema of a virtual table.
|
||||
// See: http://sqlite.org/c3ref/declare_vtab.html
|
||||
func (c *SQLiteConn) DeclareVTab(sql string) error {
|
||||
zSQL := C.CString(sql)
|
||||
defer C.free(unsafe.Pointer(zSQL))
|
||||
rv := C.sqlite3_declare_vtab(c.db, zSQL)
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateModule registers a virtual table implementation.
|
||||
// See: http://sqlite.org/c3ref/create_module.html
|
||||
func (c *SQLiteConn) CreateModule(moduleName string, module Module) error {
|
||||
mname := C.CString(moduleName)
|
||||
defer C.free(unsafe.Pointer(mname))
|
||||
udm := sqliteModule{c, moduleName, module}
|
||||
rv := C._sqlite3_create_module(c.db, mname, C.uintptr_t(newHandle(c, &udm)))
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,485 @@
|
|||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build vtable
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testModule struct {
|
||||
t *testing.T
|
||||
intarray []int
|
||||
}
|
||||
|
||||
type testVTab struct {
|
||||
intarray []int
|
||||
}
|
||||
|
||||
type testVTabCursor struct {
|
||||
vTab *testVTab
|
||||
index int
|
||||
}
|
||||
|
||||
func (m testModule) Create(c *SQLiteConn, args []string) (VTab, error) {
|
||||
if len(args) != 6 {
|
||||
m.t.Fatal("six arguments expected")
|
||||
}
|
||||
if args[0] != "test" {
|
||||
m.t.Fatal("module name")
|
||||
}
|
||||
if args[1] != "main" {
|
||||
m.t.Fatal("db name")
|
||||
}
|
||||
if args[2] != "vtab" {
|
||||
m.t.Fatal("table name")
|
||||
}
|
||||
if args[3] != "'1'" {
|
||||
m.t.Fatal("first arg")
|
||||
}
|
||||
if args[4] != "2" {
|
||||
m.t.Fatal("second arg")
|
||||
}
|
||||
if args[5] != "three" {
|
||||
m.t.Fatal("third argsecond arg")
|
||||
}
|
||||
err := c.DeclareVTab("CREATE TABLE x(test TEXT)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testVTab{m.intarray}, nil
|
||||
}
|
||||
|
||||
func (m testModule) Connect(c *SQLiteConn, args []string) (VTab, error) {
|
||||
return m.Create(c, args)
|
||||
}
|
||||
|
||||
func (m testModule) DestroyModule() {}
|
||||
|
||||
func (v *testVTab) BestIndex(cst []InfoConstraint, ob []InfoOrderBy) (*IndexResult, error) {
|
||||
used := make([]bool, 0, len(cst))
|
||||
for range cst {
|
||||
used = append(used, false)
|
||||
}
|
||||
return &IndexResult{
|
||||
Used: used,
|
||||
IdxNum: 0,
|
||||
IdxStr: "test-index",
|
||||
AlreadyOrdered: true,
|
||||
EstimatedCost: 100,
|
||||
EstimatedRows: 200,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *testVTab) Disconnect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *testVTab) Destroy() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *testVTab) Open() (VTabCursor, error) {
|
||||
return &testVTabCursor{v, 0}, nil
|
||||
}
|
||||
|
||||
func (vc *testVTabCursor) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vc *testVTabCursor) Filter(idxNum int, idxStr string, vals []interface{}) error {
|
||||
vc.index = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vc *testVTabCursor) Next() error {
|
||||
vc.index++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vc *testVTabCursor) EOF() bool {
|
||||
return vc.index >= len(vc.vTab.intarray)
|
||||
}
|
||||
|
||||
func (vc *testVTabCursor) Column(c *SQLiteContext, col int) error {
|
||||
if col != 0 {
|
||||
return fmt.Errorf("column index out of bounds: %d", col)
|
||||
}
|
||||
c.ResultInt(vc.vTab.intarray[vc.index])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vc *testVTabCursor) Rowid() (int64, error) {
|
||||
return int64(vc.index), nil
|
||||
}
|
||||
|
||||
func TestCreateModule(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
intarray := []int{1, 2, 3}
|
||||
sql.Register("sqlite3_TestCreateModule", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
return conn.CreateModule("test", testModule{t, intarray})
|
||||
},
|
||||
})
|
||||
db, err := sql.Open("sqlite3_TestCreateModule", tempFilename)
|
||||
if err != nil {
|
||||
t.Fatalf("could not open db: %v", err)
|
||||
}
|
||||
_, err = db.Exec("CREATE VIRTUAL TABLE vtab USING test('1', 2, three)")
|
||||
if err != nil {
|
||||
t.Fatalf("could not create vtable: %v", err)
|
||||
}
|
||||
|
||||
var i, value int
|
||||
rows, err := db.Query("SELECT rowid, * FROM vtab WHERE test = '3'")
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't select from virtual table: %v", err)
|
||||
}
|
||||
for rows.Next() {
|
||||
rows.Scan(&i, &value)
|
||||
if intarray[i] != value {
|
||||
t.Fatalf("want %v but %v", intarray[i], value)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = db.Exec("DROP TABLE vtab")
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't drop virtual table: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVUpdate(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
|
||||
// create module
|
||||
updateMod := &vtabUpdateModule{t, make(map[string]*vtabUpdateTable)}
|
||||
|
||||
// register module
|
||||
sql.Register("sqlite3_TestVUpdate", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
return conn.CreateModule("updatetest", updateMod)
|
||||
},
|
||||
})
|
||||
|
||||
// connect
|
||||
db, err := sql.Open("sqlite3_TestVUpdate", tempFilename)
|
||||
if err != nil {
|
||||
t.Fatalf("could not open db: %v", err)
|
||||
}
|
||||
|
||||
// create test table
|
||||
_, err = db.Exec(`CREATE VIRTUAL TABLE vt USING updatetest(f1 integer, f2 text, f3 text)`)
|
||||
if err != nil {
|
||||
t.Fatalf("could not create updatetest vtable vt, got: %v", err)
|
||||
}
|
||||
|
||||
// check that table is defined properly
|
||||
if len(updateMod.tables) != 1 {
|
||||
t.Fatalf("expected exactly 1 table to exist, got: %d", len(updateMod.tables))
|
||||
}
|
||||
if _, ok := updateMod.tables["vt"]; !ok {
|
||||
t.Fatalf("expected table `vt` to exist in tables")
|
||||
}
|
||||
|
||||
// check nothing in updatetest
|
||||
rows, err := db.Query(`select * from vt`)
|
||||
if err != nil {
|
||||
t.Fatalf("could not query vt, got: %v", err)
|
||||
}
|
||||
i, err := getRowCount(rows)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
if i != 0 {
|
||||
t.Fatalf("expected no rows in vt, got: %d", i)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`delete from vt where f1 = 'yes'`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected error on delete, got nil")
|
||||
}
|
||||
|
||||
// test bad column name
|
||||
_, err = db.Exec(`insert into vt (f4) values('a')`)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error on insert, got nil")
|
||||
}
|
||||
|
||||
// insert to vt
|
||||
res, err := db.Exec(`insert into vt (f1, f2, f3) values (115, 'b', 'c'), (116, 'd', 'e')`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error on insert, got: %v", err)
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
if n != 2 {
|
||||
t.Fatalf("expected 1 row affected, got: %d", n)
|
||||
}
|
||||
|
||||
// check vt table
|
||||
vt := updateMod.tables["vt"]
|
||||
if len(vt.data) != 2 {
|
||||
t.Fatalf("expected table vt to have exactly 2 rows, got: %d", len(vt.data))
|
||||
}
|
||||
if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
|
||||
t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
|
||||
}
|
||||
if !reflect.DeepEqual(vt.data[1], []interface{}{int64(116), "d", "e"}) {
|
||||
t.Fatalf("expected table vt entry 1 to be [116 d e], instead: %v", vt.data[1])
|
||||
}
|
||||
|
||||
// query vt
|
||||
var f1 int
|
||||
var f2, f3 string
|
||||
err = db.QueryRow(`select * from vt where f1 = 115`).Scan(&f1, &f2, &f3)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error on vt query, got: %v", err)
|
||||
}
|
||||
|
||||
// check column values
|
||||
if f1 != 115 || f2 != "b" || f3 != "c" {
|
||||
t.Errorf("expected f1==115, f2==b, f3==c, got: %d, %q, %q", f1, f2, f3)
|
||||
}
|
||||
|
||||
// update vt
|
||||
res, err = db.Exec(`update vt set f1=117, f2='f' where f3='e'`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
n, err = res.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatalf("expected exactly one row updated, got: %d", n)
|
||||
}
|
||||
|
||||
// check vt table
|
||||
if len(vt.data) != 2 {
|
||||
t.Fatalf("expected table vt to have exactly 2 rows, got: %d", len(vt.data))
|
||||
}
|
||||
if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
|
||||
t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
|
||||
}
|
||||
if !reflect.DeepEqual(vt.data[1], []interface{}{int64(117), "f", "e"}) {
|
||||
t.Fatalf("expected table vt entry 1 to be [117 f e], instead: %v", vt.data[1])
|
||||
}
|
||||
|
||||
// delete from vt
|
||||
res, err = db.Exec(`delete from vt where f1 = 117`)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
n, err = res.RowsAffected()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatalf("expected exactly one row deleted, got: %d", n)
|
||||
}
|
||||
|
||||
// check vt table
|
||||
if len(vt.data) != 1 {
|
||||
t.Fatalf("expected table vt to have exactly 1 row, got: %d", len(vt.data))
|
||||
}
|
||||
if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
|
||||
t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
|
||||
}
|
||||
|
||||
// check updatetest has 1 result
|
||||
rows, err = db.Query(`select * from vt`)
|
||||
if err != nil {
|
||||
t.Fatalf("could not query vt, got: %v", err)
|
||||
}
|
||||
i, err = getRowCount(rows)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got: %v", err)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Fatalf("expected 1 row in vt, got: %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func getRowCount(rows *sql.Rows) (int, error) {
|
||||
var i int
|
||||
for rows.Next() {
|
||||
i++
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
type vtabUpdateModule struct {
|
||||
t *testing.T
|
||||
tables map[string]*vtabUpdateTable
|
||||
}
|
||||
|
||||
func (m *vtabUpdateModule) Create(c *SQLiteConn, args []string) (VTab, error) {
|
||||
if len(args) < 2 {
|
||||
return nil, errors.New("must declare at least one column")
|
||||
}
|
||||
|
||||
// get database name, table name, and column declarations ...
|
||||
dbname, tname, decls := args[1], args[2], args[3:]
|
||||
|
||||
// extract column names + types from parameters declarations
|
||||
cols, typs := make([]string, len(decls)), make([]string, len(decls))
|
||||
for i := 0; i < len(decls); i++ {
|
||||
n, typ := decls[i], ""
|
||||
if j := strings.IndexAny(n, " \t\n"); j != -1 {
|
||||
typ, n = strings.TrimSpace(n[j+1:]), n[:j]
|
||||
}
|
||||
cols[i], typs[i] = n, typ
|
||||
}
|
||||
|
||||
// declare table
|
||||
err := c.DeclareVTab(fmt.Sprintf(`CREATE TABLE "%s"."%s" (%s)`, dbname, tname, strings.Join(decls, ",")))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create table
|
||||
vtab := &vtabUpdateTable{m.t, dbname, tname, cols, typs, make([][]interface{}, 0)}
|
||||
m.tables[tname] = vtab
|
||||
return vtab, nil
|
||||
}
|
||||
|
||||
func (m *vtabUpdateModule) Connect(c *SQLiteConn, args []string) (VTab, error) {
|
||||
return m.Create(c, args)
|
||||
}
|
||||
|
||||
func (m *vtabUpdateModule) DestroyModule() {}
|
||||
|
||||
type vtabUpdateTable struct {
|
||||
t *testing.T
|
||||
db string
|
||||
name string
|
||||
cols []string
|
||||
typs []string
|
||||
data [][]interface{}
|
||||
}
|
||||
|
||||
func (t *vtabUpdateTable) Open() (VTabCursor, error) {
|
||||
return &vtabUpdateCursor{t, 0}, nil
|
||||
}
|
||||
|
||||
func (t *vtabUpdateTable) BestIndex(cst []InfoConstraint, ob []InfoOrderBy) (*IndexResult, error) {
|
||||
return &IndexResult{Used: make([]bool, len(cst))}, nil
|
||||
}
|
||||
|
||||
func (t *vtabUpdateTable) Disconnect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *vtabUpdateTable) Destroy() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *vtabUpdateTable) Insert(id interface{}, vals []interface{}) (int64, error) {
|
||||
var i int64
|
||||
if id == nil {
|
||||
i, t.data = int64(len(t.data)), append(t.data, vals)
|
||||
return i, nil
|
||||
}
|
||||
|
||||
var ok bool
|
||||
i, ok = id.(int64)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("id is invalid type: %T", id)
|
||||
}
|
||||
|
||||
t.data[i] = vals
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (t *vtabUpdateTable) Update(id interface{}, vals []interface{}) error {
|
||||
i, ok := id.(int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("id is invalid type: %T", id)
|
||||
}
|
||||
|
||||
if int(i) >= len(t.data) || i < 0 {
|
||||
return fmt.Errorf("invalid row id %d", i)
|
||||
}
|
||||
|
||||
t.data[int(i)] = vals
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *vtabUpdateTable) Delete(id interface{}) error {
|
||||
i, ok := id.(int64)
|
||||
if !ok {
|
||||
return fmt.Errorf("id is invalid type: %T", id)
|
||||
}
|
||||
|
||||
if int(i) >= len(t.data) || i < 0 {
|
||||
return fmt.Errorf("invalid row id %d", i)
|
||||
}
|
||||
|
||||
t.data = append(t.data[:i], t.data[i+1:]...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type vtabUpdateCursor struct {
|
||||
t *vtabUpdateTable
|
||||
i int
|
||||
}
|
||||
|
||||
func (c *vtabUpdateCursor) Column(ctxt *SQLiteContext, col int) error {
|
||||
switch x := c.t.data[c.i][col].(type) {
|
||||
case []byte:
|
||||
ctxt.ResultBlob(x)
|
||||
case bool:
|
||||
ctxt.ResultBool(x)
|
||||
case float64:
|
||||
ctxt.ResultDouble(x)
|
||||
case int:
|
||||
ctxt.ResultInt(x)
|
||||
case int64:
|
||||
ctxt.ResultInt64(x)
|
||||
case nil:
|
||||
ctxt.ResultNull()
|
||||
case string:
|
||||
ctxt.ResultText(x)
|
||||
default:
|
||||
ctxt.ResultText(fmt.Sprintf("%v", x))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *vtabUpdateCursor) Filter(ixNum int, ixName string, vals []interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *vtabUpdateCursor) Next() error {
|
||||
c.i++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *vtabUpdateCursor) EOF() bool {
|
||||
return c.i >= len(c.t.data)
|
||||
}
|
||||
|
||||
func (c *vtabUpdateCursor) Rowid() (int64, error) {
|
||||
return int64(c.i), nil
|
||||
}
|
||||
|
||||
func (c *vtabUpdateCursor) Close() error {
|
||||
return nil
|
||||
}
|
49
sqlite3ext.h
49
sqlite3ext.h
|
@ -1,3 +1,4 @@
|
|||
#ifndef USE_LIBSQLITE3
|
||||
/*
|
||||
** 2006 June 7
|
||||
**
|
||||
|
@ -15,11 +16,9 @@
|
|||
** as extensions by SQLite should #include this file instead of
|
||||
** sqlite3.h.
|
||||
*/
|
||||
#ifndef _SQLITE3EXT_H_
|
||||
#define _SQLITE3EXT_H_
|
||||
#include "sqlite3-binding.h"
|
||||
|
||||
typedef struct sqlite3_api_routines sqlite3_api_routines;
|
||||
#ifndef SQLITE3EXT_H
|
||||
#define SQLITE3EXT_H
|
||||
#include "sqlite3.h"
|
||||
|
||||
/*
|
||||
** The following structure holds pointers to all of the SQLite API
|
||||
|
@ -281,8 +280,31 @@ struct sqlite3_api_routines {
|
|||
int (*db_cacheflush)(sqlite3*);
|
||||
/* Version 3.12.0 and later */
|
||||
int (*system_errno)(sqlite3*);
|
||||
/* Version 3.14.0 and later */
|
||||
int (*trace_v2)(sqlite3*,unsigned,int(*)(unsigned,void*,void*,void*),void*);
|
||||
char *(*expanded_sql)(sqlite3_stmt*);
|
||||
/* Version 3.18.0 and later */
|
||||
void (*set_last_insert_rowid)(sqlite3*,sqlite3_int64);
|
||||
/* Version 3.20.0 and later */
|
||||
int (*prepare_v3)(sqlite3*,const char*,int,unsigned int,
|
||||
sqlite3_stmt**,const char**);
|
||||
int (*prepare16_v3)(sqlite3*,const void*,int,unsigned int,
|
||||
sqlite3_stmt**,const void**);
|
||||
int (*bind_pointer)(sqlite3_stmt*,int,void*,const char*,void(*)(void*));
|
||||
void (*result_pointer)(sqlite3_context*,void*,const char*,void(*)(void*));
|
||||
void *(*value_pointer)(sqlite3_value*,const char*);
|
||||
};
|
||||
|
||||
/*
|
||||
** This is the function signature used for all extension entry points. It
|
||||
** is also defined in the file "loadext.c".
|
||||
*/
|
||||
typedef int (*sqlite3_loadext_entry)(
|
||||
sqlite3 *db, /* Handle to the database. */
|
||||
char **pzErrMsg, /* Used to set error string on failure. */
|
||||
const sqlite3_api_routines *pThunk /* Extension API function pointers. */
|
||||
);
|
||||
|
||||
/*
|
||||
** The following macros redefine the API routines so that they are
|
||||
** redirected through the global sqlite3_api structure.
|
||||
|
@ -526,6 +548,17 @@ struct sqlite3_api_routines {
|
|||
#define sqlite3_db_cacheflush sqlite3_api->db_cacheflush
|
||||
/* Version 3.12.0 and later */
|
||||
#define sqlite3_system_errno sqlite3_api->system_errno
|
||||
/* Version 3.14.0 and later */
|
||||
#define sqlite3_trace_v2 sqlite3_api->trace_v2
|
||||
#define sqlite3_expanded_sql sqlite3_api->expanded_sql
|
||||
/* Version 3.18.0 and later */
|
||||
#define sqlite3_set_last_insert_rowid sqlite3_api->set_last_insert_rowid
|
||||
/* Version 3.20.0 and later */
|
||||
#define sqlite3_prepare_v3 sqlite3_api->prepare_v3
|
||||
#define sqlite3_prepare16_v3 sqlite3_api->prepare16_v3
|
||||
#define sqlite3_bind_pointer sqlite3_api->bind_pointer
|
||||
#define sqlite3_result_pointer sqlite3_api->result_pointer
|
||||
#define sqlite3_value_pointer sqlite3_api->value_pointer
|
||||
#endif /* !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION) */
|
||||
|
||||
#if !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION)
|
||||
|
@ -543,4 +576,8 @@ struct sqlite3_api_routines {
|
|||
# define SQLITE_EXTENSION_INIT3 /*no-op*/
|
||||
#endif
|
||||
|
||||
#endif /* _SQLITE3EXT_H_ */
|
||||
#endif /* SQLITE3EXT_H */
|
||||
|
||||
// If users really want to link against the system sqlite3 we
|
||||
// need to make this file a noop.
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
// +build !cgo
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
sql.Register("sqlite3", &SQLiteDriverMock{})
|
||||
}
|
||||
|
||||
type SQLiteDriverMock struct{}
|
||||
|
||||
var errorMsg = errors.New("Binary was compiled with 'CGO_ENABLED=0', go-sqlite3 requires cgo to work. This is a stub")
|
||||
|
||||
func (SQLiteDriverMock) Open(s string) (driver.Conn, error) {
|
||||
return nil, errorMsg
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
)
|
||||
|
||||
func main() {
|
||||
site := "https://www.sqlite.org/download.html"
|
||||
fmt.Printf("scraping %v\n", site)
|
||||
doc, err := goquery.NewDocument(site)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
var url string
|
||||
doc.Find("a").Each(func(_ int, s *goquery.Selection) {
|
||||
if url == "" && strings.HasPrefix(s.Text(), "sqlite-amalgamation-") {
|
||||
url = "https://www.sqlite.org/2018/" + s.Text()
|
||||
}
|
||||
})
|
||||
if url == "" {
|
||||
return
|
||||
}
|
||||
fmt.Printf("downloading %v\n", url)
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("extracting %v\n", path.Base(url))
|
||||
r, err := zip.NewReader(bytes.NewReader(b), resp.ContentLength)
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
for _, zf := range r.File {
|
||||
var f *os.File
|
||||
switch path.Base(zf.Name) {
|
||||
case "sqlite3.c":
|
||||
f, err = os.Create("sqlite3-binding.c")
|
||||
case "sqlite3.h":
|
||||
f, err = os.Create("sqlite3-binding.h")
|
||||
case "sqlite3ext.h":
|
||||
f, err = os.Create("sqlite3ext.h")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
zr, err := zf.Open()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = io.WriteString(f, "#ifndef USE_LIBSQLITE3\n")
|
||||
if err != nil {
|
||||
zr.Close()
|
||||
f.Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
scanner := bufio.NewScanner(zr)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
if text == `#include "sqlite3.h"` {
|
||||
text = `#include "sqlite3-binding.h"`
|
||||
}
|
||||
_, err = fmt.Fprintln(f, text)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
err = scanner.Err()
|
||||
if err != nil {
|
||||
zr.Close()
|
||||
f.Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
_, err = io.WriteString(f, "#else // USE_LIBSQLITE3\n // If users really want to link against the system sqlite3 we\n// need to make this file a noop.\n #endif")
|
||||
if err != nil {
|
||||
zr.Close()
|
||||
f.Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
zr.Close()
|
||||
f.Close()
|
||||
fmt.Printf("extracted %v\n", filepath.Base(f.Name()))
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue