Merge pull request from mattn/master

Merge upstream
This commit is contained in:
Philip O'Toole 2017-06-17 12:02:47 -07:00 committed by GitHub
commit b951516ea0
38 changed files with 215174 additions and 195088 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
*.db *.db
*.exe *.exe
*.dll *.dll
*.o

View File

@ -1,9 +1,18 @@
language: go language: go
sudo: required
dist: trusty
env:
- GOTAGS=
- GOTAGS=libsqlite3
- GOTAGS=trace
- GOTAGS=vtable
go: go:
- 1.7
- 1.8
- tip - tip
before_install: before_install:
- go get github.com/axw/gocov/gocov
- go get github.com/mattn/goveralls - go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover - go get golang.org/x/tools/cmd/cover
script: script:
- $HOME/gopath/bin/goveralls -repotoken 3qJVUE0iQwqnCbmNcDsjYu1nh4J4KIFXx - $HOME/gopath/bin/goveralls -repotoken 3qJVUE0iQwqnCbmNcDsjYu1nh4J4KIFXx
- go test -race -v . -tags "$GOTAGS"

View File

@ -1,9 +1,10 @@
go-sqlite3 go-sqlite3
========== ==========
[![GoDoc Reference](https://godoc.org/github.com/mattn/go-sqlite3?status.svg)](http://godoc.org/github.com/mattn/go-sqlite3)
[![Build Status](https://travis-ci.org/mattn/go-sqlite3.svg?branch=master)](https://travis-ci.org/mattn/go-sqlite3) [![Build Status](https://travis-ci.org/mattn/go-sqlite3.svg?branch=master)](https://travis-ci.org/mattn/go-sqlite3)
[![Coverage Status](https://coveralls.io/repos/mattn/go-sqlite3/badge.svg?branch=master)](https://coveralls.io/r/mattn/go-sqlite3?branch=master) [![Coverage Status](https://coveralls.io/repos/mattn/go-sqlite3/badge.svg?branch=master)](https://coveralls.io/r/mattn/go-sqlite3?branch=master)
[![GoDoc](https://godoc.org/github.com/mattn/go-sqlite3?status.svg)](http://godoc.org/github.com/mattn/go-sqlite3) [![Go Report Card](https://goreportcard.com/badge/github.com/mattn/go-sqlite3)](https://goreportcard.com/report/github.com/mattn/go-sqlite3)
Description Description
----------- -----------
@ -35,29 +36,50 @@ FAQ
Use `go build --tags "libsqlite3 linux"` Use `go build --tags "libsqlite3 linux"`
* Want to build go-sqlite3 with libsqlite3 on OS X.
Install sqlite3 from homebrew: `brew install sqlite3`
Use `go build --tags "libsqlite3 darwin"`
* Want to build go-sqlite3 with icu extension. * Want to build go-sqlite3 with icu extension.
Use `go build --tags "icu"` Use `go build --tags "icu"`
Available extensions: `json1`, `fts5`, `icu`
* Can't build go-sqlite3 on windows 64bit. * Can't build go-sqlite3 on windows 64bit.
> Probably, you are using go 1.0, go1.0 has a problem when it comes to compiling/linking on windows 64bit. > Probably, you are using go 1.0, go1.0 has a problem when it comes to compiling/linking on windows 64bit.
> See: https://github.com/mattn/go-sqlite3/issues/27 > See: [#27](https://github.com/mattn/go-sqlite3/issues/27)
* Getting insert error while query is opened. * Getting insert error while query is opened.
> You can pass some arguments into the connection string, for example, a URI. > You can pass some arguments into the connection string, for example, a URI.
> See: https://github.com/mattn/go-sqlite3/issues/39 > See: [#39](https://github.com/mattn/go-sqlite3/issues/39)
* Do you want cross compiling? mingw on Linux or Mac? * Do you want to cross compile? mingw on Linux or Mac?
> See: https://github.com/mattn/go-sqlite3/issues/106 > See: [#106](https://github.com/mattn/go-sqlite3/issues/106)
> See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html > See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html
* Want to get time.Time with current locale * Want to get time.Time with current locale
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`. Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
* Can I use this in multiple routines concurrently?
Yes for readonly. But, No for writable. See [#50](https://github.com/mattn/go-sqlite3/issues/50), [#51](https://github.com/mattn/go-sqlite3/issues/51), [#209](https://github.com/mattn/go-sqlite3/issues/209).
* Why is it racy if I use a `sql.Open("sqlite3", ":memory:")` database?
Each connection to :memory: opens a brand new in-memory sql database, so if
the stdlib's sql engine happens to open another connection and you've only
specified ":memory:", that connection will see a brand new database. A
workaround is to use "file::memory:?mode=memory&cache=shared". Every
connection to this string will point to the same in-memory database. See
[#204](https://github.com/mattn/go-sqlite3/issues/204) for more info.
License License
------- -------
@ -67,7 +89,7 @@ sqlite3-binding.c, sqlite3-binding.h, sqlite3ext.h
The -binding suffix was added to avoid build failures under gccgo. The -binding suffix was added to avoid build failures under gccgo.
In this repository, those files are amalgamation code that copied from SQLite3. The license of those codes are depend on the license of SQLite3. In this repository, those files are an amalgamation of code that was copied from SQLite3. The license of that code is the same as the license of SQLite3.
Author Author
------ ------

View File

@ -2,9 +2,10 @@ package main
import ( import (
"database/sql" "database/sql"
"github.com/mattn/go-sqlite3"
"log" "log"
"os" "os"
"github.com/mattn/go-sqlite3"
) )
func main() { func main() {
@ -19,36 +20,36 @@ func main() {
os.Remove("./foo.db") os.Remove("./foo.db")
os.Remove("./bar.db") os.Remove("./bar.db")
destDb, err := sql.Open("sqlite3_with_hook_example", "./foo.db") srcDb, err := sql.Open("sqlite3_with_hook_example", "./foo.db")
if err != nil {
log.Fatal(err)
}
defer destDb.Close()
destDb.Ping()
_, err = destDb.Exec("create table foo(id int, value text)")
if err != nil {
log.Fatal(err)
}
_, err = destDb.Exec("insert into foo values(1, 'foo')")
if err != nil {
log.Fatal(err)
}
_, err = destDb.Exec("insert into foo values(2, 'bar')")
if err != nil {
log.Fatal(err)
}
_, err = destDb.Query("select * from foo")
if err != nil {
log.Fatal(err)
}
srcDb, err := sql.Open("sqlite3_with_hook_example", "./bar.db")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer srcDb.Close() defer srcDb.Close()
srcDb.Ping() srcDb.Ping()
_, err = srcDb.Exec("create table foo(id int, value text)")
if err != nil {
log.Fatal(err)
}
_, err = srcDb.Exec("insert into foo values(1, 'foo')")
if err != nil {
log.Fatal(err)
}
_, err = srcDb.Exec("insert into foo values(2, 'bar')")
if err != nil {
log.Fatal(err)
}
_, err = srcDb.Query("select * from foo")
if err != nil {
log.Fatal(err)
}
destDb, err := sql.Open("sqlite3_with_hook_example", "./bar.db")
if err != nil {
log.Fatal(err)
}
defer destDb.Close()
destDb.Ping()
bk, err := sqlite3conn[1].Backup("main", sqlite3conn[0], "main") bk, err := sqlite3conn[1].Backup("main", sqlite3conn[0], "main")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)

View File

@ -3,8 +3,9 @@ package main
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/mattn/go-sqlite3"
"log" "log"
"github.com/mattn/go-sqlite3"
) )
func main() { func main() {
@ -29,8 +30,8 @@ func main() {
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var id, full_name, description, html_url string var id, fullName, description, htmlURL string
rows.Scan(&id, &full_name, &description, &html_url) rows.Scan(&id, &fullName, &description, &htmlURL)
fmt.Printf("%s: %s\n\t%s\n\t%s\n\n", id, full_name, description, html_url) fmt.Printf("%s: %s\n\t%s\n\t%s\n\n", id, fullName, description, htmlURL)
} }
} }

View File

@ -52,9 +52,16 @@ func main() {
for rows.Next() { for rows.Next() {
var id int var id int
var name string var name string
rows.Scan(&id, &name) err = rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
fmt.Println(id, name) fmt.Println(id, name)
} }
err = rows.Err()
if err != nil {
log.Fatal(err)
}
stmt, err = db.Prepare("select name from foo where id = ?") stmt, err = db.Prepare("select name from foo where id = ?")
if err != nil { if err != nil {
@ -86,7 +93,14 @@ func main() {
for rows.Next() { for rows.Next() {
var id int var id int
var name string var name string
rows.Scan(&id, &name) err = rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
fmt.Println(id, name) fmt.Println(id, name)
} }
err = rows.Err()
if err != nil {
log.Fatal(err)
}
} }

264
_example/trace/main.go Normal file
View File

@ -0,0 +1,264 @@
package main
import (
"database/sql"
"fmt"
"log"
"os"
sqlite3 "github.com/mattn/go-sqlite3"
)
func traceCallback(info sqlite3.TraceInfo) int {
// Not very readable but may be useful; uncomment next line in case of doubt:
//fmt.Printf("Trace: %#v\n", info)
var dbErrText string
if info.DBError.Code != 0 || info.DBError.ExtendedCode != 0 {
dbErrText = fmt.Sprintf("; DB error: %#v", info.DBError)
} else {
dbErrText = "."
}
// Show the Statement-or-Trigger text in curly braces ('{', '}')
// since from the *paired* ASCII characters they are
// the least used in SQL syntax, therefore better visual delimiters.
// Maybe show 'ExpandedSQL' the same way as 'StmtOrTrigger'.
//
// A known use of curly braces (outside strings) is
// for ODBC escape sequences. Not likely to appear here.
//
// Template languages, etc. don't matter, we should see their *result*
// at *this* level.
// Strange curly braces in SQL code that reached the database driver
// suggest that there is a bug in the application.
// The braces are likely to be either template syntax or
// a programming language's string interpolation syntax.
var expandedText string
if info.ExpandedSQL != "" {
if info.ExpandedSQL == info.StmtOrTrigger {
expandedText = " = exp"
} else {
expandedText = fmt.Sprintf(" expanded {%q}", info.ExpandedSQL)
}
} else {
expandedText = ""
}
// SQLite docs as of September 6, 2016: Tracing and Profiling Functions
// https://www.sqlite.org/c3ref/profile.html
//
// The profile callback time is in units of nanoseconds, however
// the current implementation is only capable of millisecond resolution
// so the six least significant digits in the time are meaningless.
// Future versions of SQLite might provide greater resolution on the profiler callback.
var runTimeText string
if info.RunTimeNanosec == 0 {
if info.EventCode == sqlite3.TraceProfile {
//runTimeText = "; no time" // seems confusing
runTimeText = "; time 0" // no measurement unit
} else {
//runTimeText = "; no time" // seems useless and confusing
}
} else {
const nanosPerMillisec = 1000000
if info.RunTimeNanosec%nanosPerMillisec == 0 {
runTimeText = fmt.Sprintf("; time %d ms", info.RunTimeNanosec/nanosPerMillisec)
} else {
// unexpected: better than millisecond resolution
runTimeText = fmt.Sprintf("; time %d ns!!!", info.RunTimeNanosec)
}
}
var modeText string
if info.AutoCommit {
modeText = "-AC-"
} else {
modeText = "+Tx+"
}
fmt.Printf("Trace: ev %d %s conn 0x%x, stmt 0x%x {%q}%s%s%s\n",
info.EventCode, modeText, info.ConnHandle, info.StmtHandle,
info.StmtOrTrigger, expandedText,
runTimeText,
dbErrText)
return 0
}
func main() {
eventMask := sqlite3.TraceStmt | sqlite3.TraceProfile | sqlite3.TraceRow | sqlite3.TraceClose
sql.Register("sqlite3_tracing",
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
err := conn.SetTrace(&sqlite3.TraceConfig{
Callback: traceCallback,
EventMask: uint(eventMask),
WantExpandedSQL: true,
})
return err
},
})
os.Exit(dbMain())
}
// Harder to do DB work in main().
// It's better with a separate function because
// 'defer' and 'os.Exit' don't go well together.
//
// DO NOT use 'log.Fatal...' below: remember that it's equivalent to
// Print() followed by a call to os.Exit(1) --- and
// we want to avoid Exit() so 'defer' can do cleanup.
// Use 'log.Panic...' instead.
func dbMain() int {
db, err := sql.Open("sqlite3_tracing", ":memory:")
if err != nil {
fmt.Printf("Failed to open database: %#+v\n", err)
return 1
}
defer db.Close()
err = db.Ping()
if err != nil {
log.Panic(err)
}
dbSetup(db)
dbDoInsert(db)
dbDoInsertPrepared(db)
dbDoSelect(db)
dbDoSelectPrepared(db)
return 0
}
// 'DDL' stands for "Data Definition Language":
// Note: "INTEGER PRIMARY KEY NOT NULL AUTOINCREMENT" causes the error
// 'near "AUTOINCREMENT": syntax error'; without "NOT NULL" it works.
const tableDDL = `CREATE TABLE t1 (
id INTEGER PRIMARY KEY AUTOINCREMENT,
note VARCHAR NOT NULL
)`
// 'DML' stands for "Data Manipulation Language":
const insertDML = "INSERT INTO t1 (note) VALUES (?)"
const selectDML = "SELECT id, note FROM t1 WHERE note LIKE ?"
const textPrefix = "bla-1234567890-"
const noteTextPattern = "%Prep%"
const nGenRows = 4 // Number of Rows to Generate (for *each* approach tested)
func dbSetup(db *sql.DB) {
var err error
_, err = db.Exec("DROP TABLE IF EXISTS t1")
if err != nil {
log.Panic(err)
}
_, err = db.Exec(tableDDL)
if err != nil {
log.Panic(err)
}
}
func dbDoInsert(db *sql.DB) {
const Descr = "DB-Exec"
for i := 0; i < nGenRows; i++ {
result, err := db.Exec(insertDML, textPrefix+Descr)
if err != nil {
log.Panic(err)
}
resultDoCheck(result, Descr, i)
}
}
func dbDoInsertPrepared(db *sql.DB) {
const Descr = "DB-Prepare"
stmt, err := db.Prepare(insertDML)
if err != nil {
log.Panic(err)
}
defer stmt.Close()
for i := 0; i < nGenRows; i++ {
result, err := stmt.Exec(textPrefix + Descr)
if err != nil {
log.Panic(err)
}
resultDoCheck(result, Descr, i)
}
}
func resultDoCheck(result sql.Result, callerDescr string, callIndex int) {
lastID, err := result.LastInsertId()
if err != nil {
log.Panic(err)
}
nAffected, err := result.RowsAffected()
if err != nil {
log.Panic(err)
}
log.Printf("Exec result for %s (%d): ID = %d, affected = %d\n", callerDescr, callIndex, lastID, nAffected)
}
func dbDoSelect(db *sql.DB) {
const Descr = "DB-Query"
rows, err := db.Query(selectDML, noteTextPattern)
if err != nil {
log.Panic(err)
}
defer rows.Close()
rowsDoFetch(rows, Descr)
}
func dbDoSelectPrepared(db *sql.DB) {
const Descr = "DB-Prepare"
stmt, err := db.Prepare(selectDML)
if err != nil {
log.Panic(err)
}
defer stmt.Close()
rows, err := stmt.Query(noteTextPattern)
if err != nil {
log.Panic(err)
}
defer rows.Close()
rowsDoFetch(rows, Descr)
}
func rowsDoFetch(rows *sql.Rows, callerDescr string) {
var nRows int
var id int64
var note string
for rows.Next() {
err := rows.Scan(&id, &note)
if err != nil {
log.Panic(err)
}
log.Printf("Row for %s (%d): id=%d, note=%q\n",
callerDescr, nRows, id, note)
nRows++
}
if err := rows.Err(); err != nil {
log.Panic(err)
}
log.Printf("Total %d rows for %s.\n", nRows, callerDescr)
}

38
_example/vtable/main.go Normal file
View File

@ -0,0 +1,38 @@
package main
import (
"database/sql"
"fmt"
"log"
"github.com/mattn/go-sqlite3"
)
func main() {
sql.Register("sqlite3_with_extensions", &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
return conn.CreateModule("github", &githubModule{})
},
})
db, err := sql.Open("sqlite3_with_extensions", ":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec("create virtual table repo using github(id, full_name, description, html_url)")
if err != nil {
log.Fatal(err)
}
rows, err := db.Query("select id, full_name, description, html_url from repo")
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var id, fullName, description, htmlURL string
rows.Scan(&id, &fullName, &description, &htmlURL)
fmt.Printf("%s: %s\n\t%s\n\t%s\n\n", id, fullName, description, htmlURL)
}
}

111
_example/vtable/vtable.go Normal file
View File

@ -0,0 +1,111 @@
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"github.com/mattn/go-sqlite3"
)
type githubRepo struct {
ID int `json:"id"`
FullName string `json:"full_name"`
Description string `json:"description"`
HTMLURL string `json:"html_url"`
}
type githubModule struct {
}
func (m *githubModule) Create(c *sqlite3.SQLiteConn, args []string) (sqlite3.VTab, error) {
err := c.DeclareVTab(fmt.Sprintf(`
CREATE TABLE %s (
id INT,
full_name TEXT,
description TEXT,
html_url TEXT
)`, args[0]))
if err != nil {
return nil, err
}
return &ghRepoTable{}, nil
}
func (m *githubModule) Connect(c *sqlite3.SQLiteConn, args []string) (sqlite3.VTab, error) {
return m.Create(c, args)
}
func (m *githubModule) DestroyModule() {}
type ghRepoTable struct {
repos []githubRepo
}
func (v *ghRepoTable) Open() (sqlite3.VTabCursor, error) {
resp, err := http.Get("https://api.github.com/repositories")
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var repos []githubRepo
if err := json.Unmarshal(body, &repos); err != nil {
return nil, err
}
return &ghRepoCursor{0, repos}, nil
}
func (v *ghRepoTable) BestIndex(cst []sqlite3.InfoConstraint, ob []sqlite3.InfoOrderBy) (*sqlite3.IndexResult, error) {
return &sqlite3.IndexResult{}, nil
}
func (v *ghRepoTable) Disconnect() error { return nil }
func (v *ghRepoTable) Destroy() error { return nil }
type ghRepoCursor struct {
index int
repos []githubRepo
}
func (vc *ghRepoCursor) Column(c *sqlite3.SQLiteContext, col int) error {
switch col {
case 0:
c.ResultInt(vc.repos[vc.index].ID)
case 1:
c.ResultText(vc.repos[vc.index].FullName)
case 2:
c.ResultText(vc.repos[vc.index].Description)
case 3:
c.ResultText(vc.repos[vc.index].HTMLURL)
}
return nil
}
func (vc *ghRepoCursor) Filter(idxNum int, idxStr string, vals []interface{}) error {
vc.index = 0
return nil
}
func (vc *ghRepoCursor) Next() error {
vc.index++
return nil
}
func (vc *ghRepoCursor) EOF() bool {
return vc.index >= len(vc.repos)
}
func (vc *ghRepoCursor) Rowid() (int64, error) {
return int64(vc.index), nil
}
func (vc *ghRepoCursor) Close() error {
return nil
}

View File

@ -6,7 +6,11 @@
package sqlite3 package sqlite3
/* /*
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h> #include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
#include <stdlib.h> #include <stdlib.h>
*/ */
import "C" import "C"
@ -15,10 +19,12 @@ import (
"unsafe" "unsafe"
) )
// SQLiteBackup implement interface of Backup.
type SQLiteBackup struct { type SQLiteBackup struct {
b *C.sqlite3_backup b *C.sqlite3_backup
} }
// Backup make backup from src to dest.
func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*SQLiteBackup, error) { func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*SQLiteBackup, error) {
destptr := C.CString(dest) destptr := C.CString(dest)
defer C.free(unsafe.Pointer(destptr)) defer C.free(unsafe.Pointer(destptr))
@ -33,10 +39,10 @@ func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*SQLiteB
return nil, c.lastError() return nil, c.lastError()
} }
// Backs up for one step. Calls the underlying `sqlite3_backup_step` function. // Step to backs up for one step. Calls the underlying `sqlite3_backup_step`
// This function returns a boolean indicating if the backup is done and // function. This function returns a boolean indicating if the backup is done
// an error signalling any other error. Done is returned if the underlying C // and an error signalling any other error. Done is returned if the underlying
// function returns SQLITE_DONE (Code 101) // C function returns SQLITE_DONE (Code 101)
func (b *SQLiteBackup) Step(p int) (bool, error) { func (b *SQLiteBackup) Step(p int) (bool, error) {
ret := C.sqlite3_backup_step(b.b, C.int(p)) ret := C.sqlite3_backup_step(b.b, C.int(p))
if ret == C.SQLITE_DONE { if ret == C.SQLITE_DONE {
@ -47,24 +53,33 @@ func (b *SQLiteBackup) Step(p int) (bool, error) {
return false, nil return false, nil
} }
// Remaining return whether have the rest for backup.
func (b *SQLiteBackup) Remaining() int { func (b *SQLiteBackup) Remaining() int {
return int(C.sqlite3_backup_remaining(b.b)) return int(C.sqlite3_backup_remaining(b.b))
} }
// PageCount return count of pages.
func (b *SQLiteBackup) PageCount() int { func (b *SQLiteBackup) PageCount() int {
return int(C.sqlite3_backup_pagecount(b.b)) return int(C.sqlite3_backup_pagecount(b.b))
} }
// Finish close backup.
func (b *SQLiteBackup) Finish() error { func (b *SQLiteBackup) Finish() error {
return b.Close() return b.Close()
} }
// Close close backup.
func (b *SQLiteBackup) Close() error { func (b *SQLiteBackup) Close() error {
ret := C.sqlite3_backup_finish(b.b) ret := C.sqlite3_backup_finish(b.b)
// sqlite3_backup_finish() never fails, it just returns the
// error code from previous operations, so clean up before
// checking and returning an error
b.b = nil
runtime.SetFinalizer(b, nil)
if ret != 0 { if ret != 0 {
return Error{Code: ErrNo(ret)} return Error{Code: ErrNo(ret)}
} }
b.b = nil
runtime.SetFinalizer(b, nil)
return nil return nil
} }

290
backup_test.go Normal file
View File

@ -0,0 +1,290 @@
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
import (
"database/sql"
"fmt"
"os"
"testing"
"time"
)
// The number of rows of test data to create in the source database.
// Can be used to control how many pages are available to be backed up.
const testRowCount = 100
// The maximum number of seconds after which the page-by-page backup is considered to have taken too long.
const usePagePerStepsTimeoutSeconds = 30
// Test the backup functionality.
func testBackup(t *testing.T, testRowCount int, usePerPageSteps bool) {
// This function will be called multiple times.
// It uses sql.Register(), which requires the name parameter value to be unique.
// There does not currently appear to be a way to unregister a registered driver, however.
// So generate a database driver name that will likely be unique.
var driverName = fmt.Sprintf("sqlite3_testBackup_%v_%v_%v", testRowCount, usePerPageSteps, time.Now().UnixNano())
// The driver's connection will be needed in order to perform the backup.
driverConns := []*SQLiteConn{}
sql.Register(driverName, &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
driverConns = append(driverConns, conn)
return nil
},
})
// Connect to the source database.
srcTempFilename := TempFilename(t)
defer os.Remove(srcTempFilename)
srcDb, err := sql.Open(driverName, srcTempFilename)
if err != nil {
t.Fatal("Failed to open the source database:", err)
}
defer srcDb.Close()
err = srcDb.Ping()
if err != nil {
t.Fatal("Failed to connect to the source database:", err)
}
// Connect to the destination database.
destTempFilename := TempFilename(t)
defer os.Remove(destTempFilename)
destDb, err := sql.Open(driverName, destTempFilename)
if err != nil {
t.Fatal("Failed to open the destination database:", err)
}
defer destDb.Close()
err = destDb.Ping()
if err != nil {
t.Fatal("Failed to connect to the destination database:", err)
}
// Check the driver connections.
if len(driverConns) != 2 {
t.Fatalf("Expected 2 driver connections, but found %v.", len(driverConns))
}
srcDbDriverConn := driverConns[0]
if srcDbDriverConn == nil {
t.Fatal("The source database driver connection is nil.")
}
destDbDriverConn := driverConns[1]
if destDbDriverConn == nil {
t.Fatal("The destination database driver connection is nil.")
}
// Generate some test data for the given ID.
var generateTestData = func(id int) string {
return fmt.Sprintf("test-%v", id)
}
// Populate the source database with a test table containing some test data.
tx, err := srcDb.Begin()
if err != nil {
t.Fatal("Failed to begin a transaction when populating the source database:", err)
}
_, err = srcDb.Exec("CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)")
if err != nil {
tx.Rollback()
t.Fatal("Failed to create the source database \"test\" table:", err)
}
for id := 0; id < testRowCount; id++ {
_, err = srcDb.Exec("INSERT INTO test (id, value) VALUES (?, ?)", id, generateTestData(id))
if err != nil {
tx.Rollback()
t.Fatal("Failed to insert a row into the source database \"test\" table:", err)
}
}
err = tx.Commit()
if err != nil {
t.Fatal("Failed to populate the source database:", err)
}
// Confirm that the destination database is initially empty.
var destTableCount int
err = destDb.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table'").Scan(&destTableCount)
if err != nil {
t.Fatal("Failed to check the destination table count:", err)
}
if destTableCount != 0 {
t.Fatalf("The destination database is not empty; %v table(s) found.", destTableCount)
}
// Prepare to perform the backup.
backup, err := destDbDriverConn.Backup("main", srcDbDriverConn, "main")
if err != nil {
t.Fatal("Failed to initialize the backup:", err)
}
// Allow the initial page count and remaining values to be retrieved.
// According to <https://www.sqlite.org/c3ref/backup_finish.html>, the page count and remaining values are "... only updated by sqlite3_backup_step()."
isDone, err := backup.Step(0)
if err != nil {
t.Fatal("Unable to perform an initial 0-page backup step:", err)
}
if isDone {
t.Fatal("Backup is unexpectedly done.")
}
// Check that the page count and remaining values are reasonable.
initialPageCount := backup.PageCount()
if initialPageCount <= 0 {
t.Fatalf("Unexpected initial page count value: %v", initialPageCount)
}
initialRemaining := backup.Remaining()
if initialRemaining <= 0 {
t.Fatalf("Unexpected initial remaining value: %v", initialRemaining)
}
if initialRemaining != initialPageCount {
t.Fatalf("Initial remaining value differs from the initial page count value; remaining: %v; page count: %v", initialRemaining, initialPageCount)
}
// Perform the backup.
if usePerPageSteps {
var startTime = time.Now().Unix()
// Test backing-up using a page-by-page approach.
var latestRemaining = initialRemaining
for {
// Perform the backup step.
isDone, err = backup.Step(1)
if err != nil {
t.Fatal("Failed to perform a backup step:", err)
}
// The page count should remain unchanged from its initial value.
currentPageCount := backup.PageCount()
if currentPageCount != initialPageCount {
t.Fatalf("Current page count differs from the initial page count; initial page count: %v; current page count: %v", initialPageCount, currentPageCount)
}
// There should now be one less page remaining.
currentRemaining := backup.Remaining()
expectedRemaining := latestRemaining - 1
if currentRemaining != expectedRemaining {
t.Fatalf("Unexpected remaining value; expected remaining value: %v; actual remaining value: %v", expectedRemaining, currentRemaining)
}
latestRemaining = currentRemaining
if isDone {
break
}
// Limit the runtime of the backup attempt.
if (time.Now().Unix() - startTime) > usePagePerStepsTimeoutSeconds {
t.Fatal("Backup is taking longer than expected.")
}
}
} else {
// Test the copying of all remaining pages.
isDone, err = backup.Step(-1)
if err != nil {
t.Fatal("Failed to perform a backup step:", err)
}
if !isDone {
t.Fatal("Backup is unexpectedly not done.")
}
}
// Check that the page count and remaining values are reasonable.
finalPageCount := backup.PageCount()
if finalPageCount != initialPageCount {
t.Fatalf("Final page count differs from the initial page count; initial page count: %v; final page count: %v", initialPageCount, finalPageCount)
}
finalRemaining := backup.Remaining()
if finalRemaining != 0 {
t.Fatalf("Unexpected remaining value: %v", finalRemaining)
}
// Finish the backup.
err = backup.Finish()
if err != nil {
t.Fatal("Failed to finish backup:", err)
}
// Confirm that the "test" table now exists in the destination database.
var doesTestTableExist bool
err = destDb.QueryRow("SELECT EXISTS (SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'test' LIMIT 1) AS test_table_exists").Scan(&doesTestTableExist)
if err != nil {
t.Fatal("Failed to check if the \"test\" table exists in the destination database:", err)
}
if !doesTestTableExist {
t.Fatal("The \"test\" table could not be found in the destination database.")
}
// Confirm that the number of rows in the destination database's "test" table matches that of the source table.
var actualTestTableRowCount int
err = destDb.QueryRow("SELECT COUNT(*) FROM test").Scan(&actualTestTableRowCount)
if err != nil {
t.Fatal("Failed to determine the rowcount of the \"test\" table in the destination database:", err)
}
if testRowCount != actualTestTableRowCount {
t.Fatalf("Unexpected destination \"test\" table row count; expected: %v; found: %v", testRowCount, actualTestTableRowCount)
}
// Check each of the rows in the destination database.
for id := 0; id < testRowCount; id++ {
var checkedValue string
err = destDb.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&checkedValue)
if err != nil {
t.Fatal("Failed to query the \"test\" table in the destination database:", err)
}
var expectedValue = generateTestData(id)
if checkedValue != expectedValue {
t.Fatalf("Unexpected value in the \"test\" table in the destination database; expected value: %v; actual value: %v", expectedValue, checkedValue)
}
}
}
func TestBackupStepByStep(t *testing.T) {
testBackup(t, testRowCount, true)
}
func TestBackupAllRemainingPages(t *testing.T) {
testBackup(t, testRowCount, false)
}
// Test the error reporting when preparing to perform a backup.
func TestBackupError(t *testing.T) {
const driverName = "sqlite3_TestBackupError"
// The driver's connection will be needed in order to perform the backup.
var dbDriverConn *SQLiteConn
sql.Register(driverName, &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
dbDriverConn = conn
return nil
},
})
// Connect to the database.
dbTempFilename := TempFilename(t)
defer os.Remove(dbTempFilename)
db, err := sql.Open(driverName, dbTempFilename)
if err != nil {
t.Fatal("Failed to open the database:", err)
}
defer db.Close()
db.Ping()
// Need the driver connection in order to perform the backup.
if dbDriverConn == nil {
t.Fatal("Failed to get the driver connection.")
}
// Prepare to perform the backup.
// Intentionally using the same connection for both the source and destination databases, to trigger an error result.
backup, err := dbDriverConn.Backup("main", dbDriverConn, "main")
if err == nil {
t.Fatal("Failed to get the expected error result.")
}
const expectedError = "source and destination must be distinct"
if err.Error() != expectedError {
t.Fatalf("Unexpected error message; expected value: \"%v\"; actual value: \"%v\"", expectedError, err.Error())
}
if backup != nil {
t.Fatal("Failed to get the expected nil backup result.")
}
}

View File

@ -11,7 +11,11 @@ package sqlite3
// code for SQLite custom functions is in here. // code for SQLite custom functions is in here.
/* /*
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h> #include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
#include <stdlib.h> #include <stdlib.h>
void _sqlite3_result_text(sqlite3_context* ctx, const char* s); void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
@ -36,8 +40,8 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value
} }
//export stepTrampoline //export stepTrampoline
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc] args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo) ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
ai.Step(ctx, args) ai.Step(ctx, args)
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

20
doc.go
View File

@ -1,7 +1,7 @@
/* /*
Package sqlite3 provides interface to SQLite3 databases. Package sqlite3 provides interface to SQLite3 databases.
This works as driver for database/sql. This works as a driver for database/sql.
Installation Installation
@ -9,7 +9,7 @@ Installation
Supported Types Supported Types
Currently, go-sqlite3 support following data types. Currently, go-sqlite3 supports the following data types.
+------------------------------+ +------------------------------+
|go | sqlite3 | |go | sqlite3 |
@ -26,8 +26,8 @@ Currently, go-sqlite3 support following data types.
SQLite3 Extension SQLite3 Extension
You can write your own extension module for sqlite3. For example, below is a You can write your own extension module for sqlite3. For example, below is an
extension for Regexp matcher operation. extension for a Regexp matcher operation.
#include <pcre.h> #include <pcre.h>
#include <string.h> #include <string.h>
@ -63,8 +63,8 @@ extension for Regexp matcher operation.
(void*)db, regexp_func, NULL, NULL); (void*)db, regexp_func, NULL, NULL);
} }
It need to build as so/dll shared library. And you need to register It needs to be built as a so/dll shared library. And you need to register
extension module like below. the extension module like below.
sql.Register("sqlite3_with_extensions", sql.Register("sqlite3_with_extensions",
&sqlite3.SQLiteDriver{ &sqlite3.SQLiteDriver{
@ -79,9 +79,9 @@ Then, you can use this extension.
Connection Hook Connection Hook
You can hook and inject your codes when connection established. database/sql You can hook and inject your code when the connection is established. database/sql
doesn't provide the way to get native go-sqlite3 interfaces. So if you want, doesn't provide a way to get native go-sqlite3 interfaces. So if you want,
you need to hook ConnectHook and get the SQLiteConn. you need to set ConnectHook and get the SQLiteConn.
sql.Register("sqlite3_with_hook_example", sql.Register("sqlite3_with_hook_example",
&sqlite3.SQLiteDriver{ &sqlite3.SQLiteDriver{
@ -102,7 +102,7 @@ call RegisterFunction from ConnectHook.
sql.Register("sqlite3_with_go_func", sql.Register("sqlite3_with_go_func",
&sqlite3.SQLiteDriver{ &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error { ConnectHook: func(conn *sqlite3.SQLiteConn) error {
return conn.RegisterFunc("regex", regex, true) return conn.RegisterFunc("regexp", regex, true)
}, },
}) })

View File

@ -7,12 +7,16 @@ package sqlite3
import "C" import "C"
// ErrNo inherit errno.
type ErrNo int type ErrNo int
// ErrNoMask is mask code.
const ErrNoMask C.int = 0xff const ErrNoMask C.int = 0xff
// ErrNoExtended is extended errno.
type ErrNoExtended int type ErrNoExtended int
// Error implement sqlite error code.
type Error struct { type Error struct {
Code ErrNo /* The error code returned by SQLite */ Code ErrNo /* The error code returned by SQLite */
ExtendedCode ErrNoExtended /* The extended error code returned by SQLite */ ExtendedCode ErrNoExtended /* The extended error code returned by SQLite */
@ -52,14 +56,17 @@ var (
ErrWarning = ErrNo(28) /* Warnings from sqlite3_log() */ ErrWarning = ErrNo(28) /* Warnings from sqlite3_log() */
) )
// Error return error message from errno.
func (err ErrNo) Error() string { func (err ErrNo) Error() string {
return Error{Code: err}.Error() return Error{Code: err}.Error()
} }
// Extend return extended errno.
func (err ErrNo) Extend(by int) ErrNoExtended { func (err ErrNo) Extend(by int) ErrNoExtended {
return ErrNoExtended(int(err) | (by << 8)) return ErrNoExtended(int(err) | (by << 8))
} }
// Error return error message that is extended code.
func (err ErrNoExtended) Error() string { func (err ErrNoExtended) Error() string {
return Error{Code: ErrNo(C.int(err) & ErrNoMask), ExtendedCode: err}.Error() return Error{Code: ErrNo(C.int(err) & ErrNoMask), ExtendedCode: err}.Error()
} }
@ -121,7 +128,7 @@ var (
ErrConstraintTrigger = ErrConstraint.Extend(7) ErrConstraintTrigger = ErrConstraint.Extend(7)
ErrConstraintUnique = ErrConstraint.Extend(8) ErrConstraintUnique = ErrConstraint.Extend(8)
ErrConstraintVTab = ErrConstraint.Extend(9) ErrConstraintVTab = ErrConstraint.Extend(9)
ErrConstraintRowId = ErrConstraint.Extend(10) ErrConstraintRowID = ErrConstraint.Extend(10)
ErrNoticeRecoverWAL = ErrNotice.Extend(1) ErrNoticeRecoverWAL = ErrNotice.Extend(1)
ErrNoticeRecoverRollback = ErrNotice.Extend(2) ErrNoticeRecoverRollback = ErrNotice.Extend(2)
ErrWarningAutoIndex = ErrWarning.Extend(1) ErrWarningAutoIndex = ErrWarning.Extend(1)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,14 @@ package sqlite3
#cgo CFLAGS: -std=gnu99 #cgo CFLAGS: -std=gnu99
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE #cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61 #cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
#cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC
#cgo CFLAGS: -Wno-deprecated-declarations
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h> #include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
@ -25,6 +32,10 @@ package sqlite3
# define SQLITE_OPEN_FULLMUTEX 0 # define SQLITE_OPEN_FULLMUTEX 0
#endif #endif
#ifndef SQLITE_DETERMINISTIC
# define SQLITE_DETERMINISTIC 0
#endif
static int static int
_sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) { _sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) {
#ifdef SQLITE_OPEN_URI #ifdef SQLITE_OPEN_URI
@ -89,8 +100,6 @@ int _sqlite3_create_function(
} }
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void doneTrampoline(sqlite3_context*);
*/ */
import "C" import "C"
import ( import (
@ -106,12 +115,14 @@ import (
"strings" "strings"
"time" "time"
"unsafe" "unsafe"
"golang.org/x/net/context"
) )
// Timestamp formats understood by both this module and SQLite. // SQLiteTimestampFormats is timestamp formats understood by both this module
// The first format in the slice will be used when saving time values // and SQLite. The first format in the slice will be used when saving time
// into the database. When parsing a string from a timestamp or // values into the database. When parsing a string from a timestamp or datetime
// datetime column, the formats are tried in order. // column, the formats are tried in order.
var SQLiteTimestampFormats = []string{ var SQLiteTimestampFormats = []string{
// By default, store timestamps with whatever timezone they come with. // By default, store timestamps with whatever timezone they come with.
// When parsed, they will be returned with the same timezone. // When parsed, they will be returned with the same timezone.
@ -130,21 +141,21 @@ func init() {
sql.Register("sqlite3", &SQLiteDriver{}) sql.Register("sqlite3", &SQLiteDriver{})
} }
// Return SQLite library Version information. // Version returns SQLite library version information.
func Version() (libVersion string, libVersionNumber int, sourceId string) { func Version() (libVersion string, libVersionNumber int, sourceID string) {
libVersion = C.GoString(C.sqlite3_libversion()) libVersion = C.GoString(C.sqlite3_libversion())
libVersionNumber = int(C.sqlite3_libversion_number()) libVersionNumber = int(C.sqlite3_libversion_number())
sourceId = C.GoString(C.sqlite3_sourceid()) sourceID = C.GoString(C.sqlite3_sourceid())
return libVersion, libVersionNumber, sourceId return libVersion, libVersionNumber, sourceID
} }
// Driver struct. // SQLiteDriver implement sql.Driver.
type SQLiteDriver struct { type SQLiteDriver struct {
Extensions []string Extensions []string
ConnectHook func(*SQLiteConn) error ConnectHook func(*SQLiteConn) error
} }
// Conn struct. // SQLiteConn implement sql.Conn.
type SQLiteConn struct { type SQLiteConn struct {
db *C.sqlite3 db *C.sqlite3
loc *time.Location loc *time.Location
@ -153,35 +164,34 @@ type SQLiteConn struct {
aggregators []*aggInfo aggregators []*aggInfo
} }
// Tx struct. // SQLiteTx implemen sql.Tx.
type SQLiteTx struct { type SQLiteTx struct {
c *SQLiteConn c *SQLiteConn
} }
// Stmt struct. // SQLiteStmt implement sql.Stmt.
type SQLiteStmt struct { type SQLiteStmt struct {
c *SQLiteConn c *SQLiteConn
s *C.sqlite3_stmt s *C.sqlite3_stmt
nv int
nn []string
t string t string
closed bool closed bool
cls bool cls bool
} }
// Result struct. // SQLiteResult implement sql.Result.
type SQLiteResult struct { type SQLiteResult struct {
id int64 id int64
changes int64 changes int64
} }
// Rows struct. // SQLiteRows implement sql.Rows.
type SQLiteRows struct { type SQLiteRows struct {
s *SQLiteStmt s *SQLiteStmt
nc int nc int
cols []string cols []string
decltype []string decltype []string
cls bool cls bool
done chan struct{}
} }
type functionInfo struct { type functionInfo struct {
@ -287,13 +297,19 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
// Commit transaction. // Commit transaction.
func (tx *SQLiteTx) Commit() error { func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec("COMMIT") _, err := tx.c.exec(context.Background(), "COMMIT", nil)
if err != nil && err.(Error).Code == C.SQLITE_BUSY {
// sqlite3 will leave the transaction open in this scenario.
// However, database/sql considers the transaction complete once we
// return from Commit() - we must clean up to honour its semantics.
tx.c.exec(context.Background(), "ROLLBACK", nil)
}
return err return err
} }
// Rollback transaction. // Rollback transaction.
func (tx *SQLiteTx) Rollback() error { func (tx *SQLiteTx) Rollback() error {
_, err := tx.c.exec("ROLLBACK") _, err := tx.c.exec(context.Background(), "ROLLBACK", nil)
return err return err
} }
@ -367,136 +383,15 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
if pure { if pure {
opts |= C.SQLITE_DETERMINISTIC opts |= C.SQLITE_DETERMINISTIC
} }
rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(newHandle(c, &fi)), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil) rv := sqlite3CreateFunction(c.db, cname, C.int(numArgs), C.int(opts), newHandle(c, &fi), C.callbackTrampoline, nil, nil)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return c.lastError() return c.lastError()
} }
return nil return nil
} }
// RegisterAggregator makes a Go type available as a SQLite aggregation function. func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTextRep C.int, pApp uintptr, xFunc unsafe.Pointer, xStep unsafe.Pointer, xFinal unsafe.Pointer) C.int {
// return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(pApp), (*[0]byte)(unsafe.Pointer(xFunc)), (*[0]byte)(unsafe.Pointer(xStep)), (*[0]byte)(unsafe.Pointer(xFinal)))
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
//
// The constructor function and the Step/Done methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
var ai aggInfo
ai.constructor = reflect.ValueOf(impl)
t := ai.constructor.Type()
if t.Kind() != reflect.Func {
return errors.New("non-function passed to RegisterAggregator")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
if t.NumIn() != 0 {
return errors.New("SQLite aggregator constructors must not have arguments")
}
agg := t.Out(0)
switch agg.Kind() {
case reflect.Ptr, reflect.Interface:
default:
return errors.New("SQlite aggregator constructor must return a pointer object")
}
stepFn, found := agg.MethodByName("Step")
if !found {
return errors.New("SQlite aggregator doesn't have a Step() function")
}
step := stepFn.Type
if step.NumOut() != 0 && step.NumOut() != 1 {
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
}
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQlite aggregator Step() return value must be error")
}
stepNArgs := step.NumIn()
start := 0
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
stepNArgs--
start++
}
if step.IsVariadic() {
stepNArgs--
}
for i := start; i < start+stepNArgs; i++ {
conv, err := callbackArg(step.In(i))
if err != nil {
return err
}
ai.stepArgConverters = append(ai.stepArgConverters, conv)
}
if step.IsVariadic() {
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
if err != nil {
return err
}
ai.stepVariadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
stepNArgs = -1
}
doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQlite aggregator doesn't have a Done() function")
}
done := doneFn.Type
doneNArgs := done.NumIn()
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
doneNArgs--
}
if doneNArgs != 0 {
return errors.New("SQlite aggregator Done() function must have no arguments")
}
if done.NumOut() != 1 && done.NumOut() != 2 {
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
}
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator Done() function must be error")
}
conv, err := callbackRet(done.Out(0))
if err != nil {
return err
}
ai.doneRetConverter = conv
ai.active = make(map[int64]reflect.Value)
ai.next = 1
// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
} }
// AutoCommit return which currently auto commit or not. // AutoCommit return which currently auto commit or not.
@ -504,22 +399,38 @@ func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0 return int(C.sqlite3_get_autocommit(c.db)) != 0
} }
func (c *SQLiteConn) lastError() Error { func (c *SQLiteConn) lastError() error {
return lastError(c.db)
}
func lastError(db *C.sqlite3) error {
rv := C.sqlite3_errcode(db)
if rv == C.SQLITE_OK {
return nil
}
return Error{ return Error{
Code: ErrNo(C.sqlite3_errcode(c.db)), Code: ErrNo(rv),
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)), ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(db)),
err: C.GoString(C.sqlite3_errmsg(c.db)), err: C.GoString(C.sqlite3_errmsg(db)),
} }
} }
// Implements Execer // Exec implements Execer.
func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if len(args) == 0 { list := make([]namedValue, len(args))
return c.exec(query) for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
} }
return c.exec(context.Background(), query, list)
}
func (c *SQLiteConn) exec(ctx context.Context, query string, args []namedValue) (driver.Result, error) {
start := 0
for { for {
s, err := c.Prepare(query) s, err := c.prepare(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -527,14 +438,19 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
if s.(*SQLiteStmt).s != nil { if s.(*SQLiteStmt).s != nil {
na := s.NumInput() na := s.NumInput()
if len(args) < na { if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args)) s.Close()
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
} }
res, err = s.Exec(args[:na]) for i := 0; i < na; i++ {
args[i].Ordinal -= start
}
res, err = s.(*SQLiteStmt).exec(ctx, args[:na])
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.Close()
return nil, err return nil, err
} }
args = args[na:] args = args[na:]
start += na
} }
tail := s.(*SQLiteStmt).t tail := s.(*SQLiteStmt).t
s.Close() s.Close()
@ -545,24 +461,46 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
} }
} }
// Implements Queryer type namedValue struct {
Name string
Ordinal int
Value driver.Value
}
// Query implements Queryer.
func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return c.query(context.Background(), query, list)
}
func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
start := 0
for { for {
s, err := c.Prepare(query) s, err := c.prepare(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.(*SQLiteStmt).cls = true s.(*SQLiteStmt).cls = true
na := s.NumInput() na := s.NumInput()
if len(args) < na { if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args)) return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
} }
rows, err := s.Query(args[:na]) for i := 0; i < na; i++ {
args[i].Ordinal -= start
}
rows, err := s.(*SQLiteStmt).query(ctx, args[:na])
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.Close()
return nil, err return rows, err
} }
args = args[na:] args = args[na:]
start += na
tail := s.(*SQLiteStmt).t tail := s.(*SQLiteStmt).t
if tail == "" { if tail == "" {
return rows, nil return rows, nil
@ -573,21 +511,13 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
} }
} }
func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd))
var rowid, changes C.longlong
rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes)
if rv != C.SQLITE_OK {
return nil, c.lastError()
}
return &SQLiteResult{int64(rowid), int64(changes)}, nil
}
// Begin transaction. // Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) { func (c *SQLiteConn) Begin() (driver.Tx, error) {
if _, err := c.exec(c.txlock); err != nil { return c.begin(context.Background())
}
func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) {
if _, err := c.exec(ctx, c.txlock, nil); err != nil {
return nil, err return nil, err
} }
return &SQLiteTx{c}, nil return &SQLiteTx{c}, nil
@ -598,7 +528,7 @@ func errorString(err Error) string {
} }
// Open database and return a new connection. // Open database and return a new connection.
// You can specify DSN string with URI filename. // You can specify a DSN string using a URI as the filename.
// test.db // test.db
// file:test.db?cache=shared&mode=memory // file:test.db?cache=shared&mode=memory
// :memory: // :memory:
@ -611,6 +541,8 @@ func errorString(err Error) string {
// _txlock=XXX // _txlock=XXX
// Specify locking behavior for transactions. XXX can be "immediate", // Specify locking behavior for transactions. XXX can be "immediate",
// "deferred", "exclusive". // "deferred", "exclusive".
// _foreign_keys=X
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 { if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation") return nil, errors.New("sqlite library was not compiled for thread-safe operation")
@ -618,7 +550,8 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
var loc *time.Location var loc *time.Location
txlock := "BEGIN" txlock := "BEGIN"
busy_timeout := 5000 busyTimeout := 5000
foreignKeys := -1
pos := strings.IndexRune(dsn, '?') pos := strings.IndexRune(dsn, '?')
if pos >= 1 { if pos >= 1 {
params, err := url.ParseQuery(dsn[pos+1:]) params, err := url.ParseQuery(dsn[pos+1:])
@ -644,7 +577,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("Invalid _busy_timeout: %v: %v", val, err) return nil, fmt.Errorf("Invalid _busy_timeout: %v: %v", val, err)
} }
busy_timeout = int(iv) busyTimeout = int(iv)
} }
// _txlock // _txlock
@ -661,6 +594,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
} }
} }
// _foreign_keys
if val := params.Get("_foreign_keys"); val != "" {
switch val {
case "1":
foreignKeys = 1
case "0":
foreignKeys = 0
default:
return nil, fmt.Errorf("Invalid _foreign_keys: %v", val)
}
}
if !strings.HasPrefix(dsn, "file:") { if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos] dsn = dsn[:pos]
} }
@ -681,21 +626,45 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New("sqlite succeeded without returning a database") return nil, errors.New("sqlite succeeded without returning a database")
} }
rv = C.sqlite3_busy_timeout(db, C.int(busy_timeout)) rv = C.sqlite3_busy_timeout(db, C.int(busyTimeout))
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
C.sqlite3_close_v2(db)
return nil, Error{Code: ErrNo(rv)} return nil, Error{Code: ErrNo(rv)}
} }
exec := func(s string) error {
cs := C.CString(s)
rv := C.sqlite3_exec(db, cs, nil, nil, nil)
C.free(unsafe.Pointer(cs))
if rv != C.SQLITE_OK {
return lastError(db)
}
return nil
}
if foreignKeys == 0 {
if err := exec("PRAGMA foreign_keys = OFF;"); err != nil {
C.sqlite3_close_v2(db)
return nil, err
}
} else if foreignKeys == 1 {
if err := exec("PRAGMA foreign_keys = ON;"); err != nil {
C.sqlite3_close_v2(db)
return nil, err
}
}
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock} conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
if len(d.Extensions) > 0 { if len(d.Extensions) > 0 {
if err := conn.loadExtensions(d.Extensions); err != nil { if err := conn.loadExtensions(d.Extensions); err != nil {
conn.Close()
return nil, err return nil, err
} }
} }
if d.ConnectHook != nil { if d.ConnectHook != nil {
if err := d.ConnectHook(conn); err != nil { if err := d.ConnectHook(conn); err != nil {
conn.Close()
return nil, err return nil, err
} }
} }
@ -705,18 +674,22 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
// Close the connection. // Close the connection.
func (c *SQLiteConn) Close() error { func (c *SQLiteConn) Close() error {
deleteHandles(c)
rv := C.sqlite3_close_v2(c.db) rv := C.sqlite3_close_v2(c.db)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return c.lastError() return c.lastError()
} }
deleteHandles(c)
c.db = nil c.db = nil
runtime.SetFinalizer(c, nil) runtime.SetFinalizer(c, nil)
return nil return nil
} }
// Prepare query string. Return a new statement. // Prepare the query string. Return a new statement.
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
return c.prepare(context.Background(), query)
}
func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, error) {
pquery := C.CString(query) pquery := C.CString(query)
defer C.free(unsafe.Pointer(pquery)) defer C.free(unsafe.Pointer(pquery))
var s *C.sqlite3_stmt var s *C.sqlite3_stmt
@ -729,15 +702,7 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
if tail != nil && *tail != '\000' { if tail != nil && *tail != '\000' {
t = strings.TrimSpace(C.GoString(tail)) t = strings.TrimSpace(C.GoString(tail))
} }
nv := int(C.sqlite3_bind_parameter_count(s)) ss := &SQLiteStmt{c: c, s: s, t: t}
var nn []string
for i := 0; i < nv; i++ {
pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))
if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 {
nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))))
}
}
ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close) runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
return ss, nil return ss, nil
} }
@ -759,9 +724,9 @@ func (s *SQLiteStmt) Close() error {
return nil return nil
} }
// Return a number of parameters. // NumInput return a number of parameters.
func (s *SQLiteStmt) NumInput() int { func (s *SQLiteStmt) NumInput() int {
return s.nv return int(C.sqlite3_bind_parameter_count(s.s))
} }
type bindArg struct { type bindArg struct {
@ -769,37 +734,30 @@ type bindArg struct {
v driver.Value v driver.Value
} }
func (s *SQLiteStmt) bind(args []driver.Value) error { var placeHolder byte = 0
func (s *SQLiteStmt) bind(args []namedValue) error {
rv := C.sqlite3_reset(s.s) rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return s.c.lastError() return s.c.lastError()
} }
var vargs []bindArg for i, v := range args {
narg := len(args) if v.Name != "" {
vargs = make([]bindArg, narg) cname := C.CString(":" + v.Name)
if len(s.nn) > 0 { args[i].Ordinal = int(C.sqlite3_bind_parameter_index(s.s, cname))
for i, v := range s.nn { C.free(unsafe.Pointer(cname))
if pi, err := strconv.Atoi(v[1:]); err == nil {
vargs[i] = bindArg{pi, args[i]}
}
}
} else {
for i, v := range args {
vargs[i] = bindArg{i + 1, v}
} }
} }
for _, varg := range vargs { for _, arg := range args {
n := C.int(varg.n) n := C.int(arg.Ordinal)
v := varg.v switch v := arg.Value.(type) {
switch v := v.(type) {
case nil: case nil:
rv = C.sqlite3_bind_null(s.s, n) rv = C.sqlite3_bind_null(s.s, n)
case string: case string:
if len(v) == 0 { if len(v) == 0 {
b := []byte{0} rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder)), C.int(0))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
} else { } else {
b := []byte(v) b := []byte(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
@ -815,11 +773,13 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
case float64: case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v)) rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte: case []byte:
var ptr *byte
if len(v) == 0 { if len(v) == 0 {
rv = C._sqlite3_bind_blob(s.s, n, nil, 0) ptr = &placeHolder
} else { } else {
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v))) ptr = &v[0]
} }
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(ptr), C.int(len(v)))
case time.Time: case time.Time:
b := []byte(v.Format(SQLiteTimestampFormats[0])) b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
@ -833,29 +793,85 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
// Query the statement with arguments. Return records. // Query the statement with arguments. Return records.
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return s.query(context.Background(), list)
}
func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows, error) {
if err := s.bind(args); err != nil { if err := s.bind(args); err != nil {
return nil, err return nil, err
} }
return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil, s.cls}, nil
rows := &SQLiteRows{
s: s,
nc: int(C.sqlite3_column_count(s.s)),
cols: nil,
decltype: nil,
cls: s.cls,
done: make(chan struct{}),
}
go func(db *C.sqlite3) {
select {
case <-ctx.Done():
select {
case <-rows.done:
default:
C.sqlite3_interrupt(db)
rows.Close()
}
case <-rows.done:
}
}(s.c.db)
return rows, nil
} }
// Return last inserted ID. // LastInsertId teturn last inserted ID.
func (r *SQLiteResult) LastInsertId() (int64, error) { func (r *SQLiteResult) LastInsertId() (int64, error) {
return r.id, nil return r.id, nil
} }
// Return how many rows affected. // RowsAffected return how many rows affected.
func (r *SQLiteResult) RowsAffected() (int64, error) { func (r *SQLiteResult) RowsAffected() (int64, error) {
return r.changes, nil return r.changes, nil
} }
// Execute the statement with arguments. Return result object. // Exec execute the statement with arguments. Return result object.
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return s.exec(context.Background(), list)
}
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
if err := s.bind(args); err != nil { if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s) C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s) C.sqlite3_clear_bindings(s.s)
return nil, err return nil, err
} }
done := make(chan struct{})
defer close(done)
go func(db *C.sqlite3) {
select {
case <-ctx.Done():
C.sqlite3_interrupt(db)
case <-done:
}
}(s.c.db)
var rowid, changes C.longlong var rowid, changes C.longlong
rv := C._sqlite3_step(s.s, &rowid, &changes) rv := C._sqlite3_step(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
@ -864,7 +880,8 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
C.sqlite3_clear_bindings(s.s) C.sqlite3_clear_bindings(s.s)
return nil, err return nil, err
} }
return &SQLiteResult{int64(rowid), int64(changes)}, nil
return &SQLiteResult{id: int64(rowid), changes: int64(changes)}, nil
} }
// Close the rows. // Close the rows.
@ -872,6 +889,9 @@ func (rc *SQLiteRows) Close() error {
if rc.s.closed { if rc.s.closed {
return nil return nil
} }
if rc.done != nil {
close(rc.done)
}
if rc.cls { if rc.cls {
return rc.s.Close() return rc.s.Close()
} }
@ -882,7 +902,7 @@ func (rc *SQLiteRows) Close() error {
return nil return nil
} }
// Return column names. // Columns return column names.
func (rc *SQLiteRows) Columns() []string { func (rc *SQLiteRows) Columns() []string {
if rc.nc != len(rc.cols) { if rc.nc != len(rc.cols) {
rc.cols = make([]string, rc.nc) rc.cols = make([]string, rc.nc)
@ -893,7 +913,18 @@ func (rc *SQLiteRows) Columns() []string {
return rc.cols return rc.cols
} }
// Move cursor to next. // DeclTypes return column types.
func (rc *SQLiteRows) DeclTypes() []string {
if rc.decltype == nil {
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
}
}
return rc.decltype
}
// Next move cursor to next.
func (rc *SQLiteRows) Next(dest []driver.Value) error { func (rc *SQLiteRows) Next(dest []driver.Value) error {
rv := C.sqlite3_step(rc.s.s) rv := C.sqlite3_step(rc.s.s)
if rv == C.SQLITE_DONE { if rv == C.SQLITE_DONE {
@ -907,12 +938,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
return nil return nil
} }
if rc.decltype == nil { rc.DeclTypes()
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
}
}
for i := range dest { for i := range dest {
switch C.sqlite3_column_type(rc.s.s, C.int(i)) { switch C.sqlite3_column_type(rc.s.s, C.int(i)) {

103
sqlite3_context.go Normal file
View File

@ -0,0 +1,103 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
/*
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
#include <stdlib.h>
// These wrappers are necessary because SQLITE_TRANSIENT
// is a pointer constant, and cgo doesn't translate them correctly.
static inline void my_result_text(sqlite3_context *ctx, char *p, int np) {
sqlite3_result_text(ctx, p, np, SQLITE_TRANSIENT);
}
static inline void my_result_blob(sqlite3_context *ctx, void *p, int np) {
sqlite3_result_blob(ctx, p, np, SQLITE_TRANSIENT);
}
*/
import "C"
import (
"math"
"reflect"
"unsafe"
)
const i64 = unsafe.Sizeof(int(0)) > 4
// SQLiteContext behave sqlite3_context
type SQLiteContext C.sqlite3_context
// ResultBool sets the result of an SQL function.
func (c *SQLiteContext) ResultBool(b bool) {
if b {
c.ResultInt(1)
} else {
c.ResultInt(0)
}
}
// ResultBlob sets the result of an SQL function.
// See: sqlite3_result_blob, http://sqlite.org/c3ref/result_blob.html
func (c *SQLiteContext) ResultBlob(b []byte) {
if i64 && len(b) > math.MaxInt32 {
C.sqlite3_result_error_toobig((*C.sqlite3_context)(c))
return
}
var p *byte
if len(b) > 0 {
p = &b[0]
}
C.my_result_blob((*C.sqlite3_context)(c), unsafe.Pointer(p), C.int(len(b)))
}
// ResultDouble sets the result of an SQL function.
// See: sqlite3_result_double, http://sqlite.org/c3ref/result_blob.html
func (c *SQLiteContext) ResultDouble(d float64) {
C.sqlite3_result_double((*C.sqlite3_context)(c), C.double(d))
}
// ResultInt sets the result of an SQL function.
// See: sqlite3_result_int, http://sqlite.org/c3ref/result_blob.html
func (c *SQLiteContext) ResultInt(i int) {
if i64 && (i > math.MaxInt32 || i < math.MinInt32) {
C.sqlite3_result_int64((*C.sqlite3_context)(c), C.sqlite3_int64(i))
} else {
C.sqlite3_result_int((*C.sqlite3_context)(c), C.int(i))
}
}
// ResultInt64 sets the result of an SQL function.
// See: sqlite3_result_int64, http://sqlite.org/c3ref/result_blob.html
func (c *SQLiteContext) ResultInt64(i int64) {
C.sqlite3_result_int64((*C.sqlite3_context)(c), C.sqlite3_int64(i))
}
// ResultNull sets the result of an SQL function.
// See: sqlite3_result_null, http://sqlite.org/c3ref/result_blob.html
func (c *SQLiteContext) ResultNull() {
C.sqlite3_result_null((*C.sqlite3_context)(c))
}
// ResultText sets the result of an SQL function.
// See: sqlite3_result_text, http://sqlite.org/c3ref/result_blob.html
func (c *SQLiteContext) ResultText(s string) {
h := (*reflect.StringHeader)(unsafe.Pointer(&s))
cs, l := (*C.char)(unsafe.Pointer(h.Data)), C.int(h.Len)
C.my_result_text((*C.sqlite3_context)(c), cs, l)
}
// ResultZeroblob sets the result of an SQL function.
// See: sqlite3_result_zeroblob, http://sqlite.org/c3ref/result_blob.html
func (c *SQLiteContext) ResultZeroblob(n int) {
C.sqlite3_result_zeroblob((*C.sqlite3_context)(c), C.int(n))
}

View File

@ -93,7 +93,10 @@ func TestFTS4(t *testing.T) {
_, err = db.Exec("DROP TABLE foo") _, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts4(tokenize=unicode61, id INTEGER PRIMARY KEY, value TEXT)") _, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts4(tokenize=unicode61, id INTEGER PRIMARY KEY, value TEXT)")
if err != nil { switch {
case err != nil && err.Error() == "unknown tokenizer: unicode61":
t.Skip("FTS4 not supported")
case err != nil:
t.Fatal("Failed to create table:", err) t.Fatal("Failed to create table:", err)
} }

13
sqlite3_fts5.go Normal file
View File

@ -0,0 +1,13 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build fts5
package sqlite3
/*
#cgo CFLAGS: -DSQLITE_ENABLE_FTS5
#cgo LDFLAGS: -lm
*/
import "C"

69
sqlite3_go18.go Normal file
View File

@ -0,0 +1,69 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build go1.8
package sqlite3
import (
"database/sql/driver"
"errors"
"context"
)
// Ping implement Pinger.
func (c *SQLiteConn) Ping(ctx context.Context) error {
if c.db == nil {
return errors.New("Connection was closed")
}
return nil
}
// QueryContext implement QueryerContext.
func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return c.query(ctx, query, list)
}
// ExecContext implement ExecerContext.
func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return c.exec(ctx, query, list)
}
// PrepareContext implement ConnPrepareContext.
func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return c.prepare(ctx, query)
}
// BeginTx implement ConnBeginTx.
func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return c.begin(ctx)
}
// QueryContext implement QueryerContext.
func (s *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return s.query(ctx, list)
}
// ExecContext implement ExecerContext.
func (s *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return s.exec(ctx, list)
}

50
sqlite3_go18_test.go Normal file
View File

@ -0,0 +1,50 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build go1.8
package sqlite3
import (
"database/sql"
"os"
"testing"
)
func TestNamedParams(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name text, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
_, err = db.Exec(`insert into foo(id, name, extra) values(:id, :name, :name)`, sql.Named("name", "foo"), sql.Named("id", 1))
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
row := db.QueryRow(`select id, extra from foo where id = :id and extra = :extra`, sql.Named("id", 1), sql.Named("extra", "foo"))
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}

12
sqlite3_json1.go Normal file
View File

@ -0,0 +1,12 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build json1
package sqlite3
/*
#cgo CFLAGS: -DSQLITE_ENABLE_JSON1
*/
import "C"

View File

@ -8,6 +8,7 @@ package sqlite3
/* /*
#cgo CFLAGS: -DUSE_LIBSQLITE3 #cgo CFLAGS: -DUSE_LIBSQLITE3
#cgo LDFLAGS: -lsqlite3 #cgo linux LDFLAGS: -lsqlite3
#cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3
*/ */
import "C" import "C"

View File

@ -7,7 +7,11 @@
package sqlite3 package sqlite3
/* /*
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h> #include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
#include <stdlib.h> #include <stdlib.h>
*/ */
import "C" import "C"
@ -27,6 +31,7 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error {
defer C.free(unsafe.Pointer(cext)) defer C.free(unsafe.Pointer(cext))
rv = C.sqlite3_load_extension(c.db, cext, nil, nil) rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
C.sqlite3_enable_load_extension(c.db, 0)
return errors.New(C.GoString(C.sqlite3_errmsg(c.db))) return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
} }
} }
@ -37,3 +42,28 @@ func (c *SQLiteConn) loadExtensions(extensions []string) error {
} }
return nil return nil
} }
// LoadExtension load the sqlite3 extension.
func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
rv := C.sqlite3_enable_load_extension(c.db, 1)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
clib := C.CString(lib)
defer C.free(unsafe.Pointer(clib))
centry := C.CString(entry)
defer C.free(unsafe.Pointer(centry))
rv = C.sqlite3_load_extension(c.db, clib, centry, nil)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
rv = C.sqlite3_enable_load_extension(c.db, 0)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
return nil
}

View File

@ -17,3 +17,7 @@ import (
func (c *SQLiteConn) loadExtensions(extensions []string) error { func (c *SQLiteConn) loadExtensions(extensions []string) error {
return errors.New("Extensions have been disabled for static builds") return errors.New("Extensions have been disabled for static builds")
} }
func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
return errors.New("Extensions have been disabled for static builds")
}

View File

@ -107,6 +107,35 @@ func TestReadonly(t *testing.T) {
} }
} }
func TestForeignKeys(t *testing.T) {
cases := map[string]bool{
"?_foreign_keys=1": true,
"?_foreign_keys=0": false,
}
for option, want := range cases {
fname := TempFilename(t)
uri := "file:" + fname + option
db, err := sql.Open("sqlite3", uri)
if err != nil {
os.Remove(fname)
t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
continue
}
var enabled bool
err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled)
db.Close()
os.Remove(fname)
if err != nil {
t.Errorf("query foreign_keys for %s: %v", uri, err)
continue
}
if enabled != want {
t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want)
continue
}
}
}
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
tempFilename := TempFilename(t) tempFilename := TempFilename(t)
defer os.Remove(tempFilename) defer os.Remove(tempFilename)
@ -168,7 +197,7 @@ func TestInsert(t *testing.T) {
var result int var result int
rows.Scan(&result) rows.Scan(&result)
if result != 123 { if result != 123 {
t.Errorf("Fetched %q; expected %q", 123, result) t.Errorf("Expected %d for fetched result, but %d:", 123, result)
} }
} }
@ -207,12 +236,12 @@ func TestUpdate(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("Failed to update record:", err) t.Fatal("Failed to update record:", err)
} }
lastId, err := res.LastInsertId() lastID, err := res.LastInsertId()
if err != nil { if err != nil {
t.Fatal("Failed to get LastInsertId:", err) t.Fatal("Failed to get LastInsertId:", err)
} }
if expected != lastId { if expected != lastID {
t.Errorf("Expected %q for last Id, but %q:", expected, lastId) t.Errorf("Expected %q for last Id, but %q:", expected, lastID)
} }
affected, _ = res.RowsAffected() affected, _ = res.RowsAffected()
if err != nil { if err != nil {
@ -233,7 +262,7 @@ func TestUpdate(t *testing.T) {
var result int var result int
rows.Scan(&result) rows.Scan(&result)
if result != 234 { if result != 234 {
t.Errorf("Fetched %q; expected %q", 234, result) t.Errorf("Expected %d for fetched result, but %d:", 234, result)
} }
} }
@ -272,12 +301,12 @@ func TestDelete(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("Failed to delete record:", err) t.Fatal("Failed to delete record:", err)
} }
lastId, err := res.LastInsertId() lastID, err := res.LastInsertId()
if err != nil { if err != nil {
t.Fatal("Failed to get LastInsertId:", err) t.Fatal("Failed to get LastInsertId:", err)
} }
if expected != lastId { if expected != lastID {
t.Errorf("Expected %q for last Id, but %q:", expected, lastId) t.Errorf("Expected %q for last Id, but %q:", expected, lastID)
} }
affected, err = res.RowsAffected() affected, err = res.RowsAffected()
if err != nil { if err != nil {
@ -993,42 +1022,6 @@ func TestVersion(t *testing.T) {
} }
} }
func TestNumberNamedParams(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name text, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
_, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, "foo")
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, "foo")
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}
func TestStringContainingZero(t *testing.T) { func TestStringContainingZero(t *testing.T) {
tempFilename := TempFilename(t) tempFilename := TempFilename(t)
defer os.Remove(tempFilename) defer os.Remove(tempFilename)
@ -1106,12 +1099,12 @@ func TestDateTimeNow(t *testing.T) {
} }
func TestFunctionRegistration(t *testing.T) { func TestFunctionRegistration(t *testing.T) {
addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) } addi8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) }
addi_64 := func(a, b int64) int64 { return a + b } addi64 := func(a, b int64) int64 { return a + b }
addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) } addu8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) }
addu_64 := func(a, b uint64) uint64 { return a + b } addu64 := func(a, b uint64) uint64 { return a + b }
addiu := func(a int, b uint) int64 { return int64(a) + int64(b) } addiu := func(a int, b uint) int64 { return int64(a) + int64(b) }
addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b } addf32_64 := func(a float32, b float64) float64 { return float64(a) + b }
not := func(a bool) bool { return !a } not := func(a bool) bool { return !a }
regex := func(re, s string) (bool, error) { regex := func(re, s string) (bool, error) {
return regexp.MatchString(re, s) return regexp.MatchString(re, s)
@ -1143,22 +1136,22 @@ func TestFunctionRegistration(t *testing.T) {
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error { ConnectHook: func(conn *SQLiteConn) error {
if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil { if err := conn.RegisterFunc("addi8_16_32", addi8_16_32, true); err != nil {
return err return err
} }
if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil { if err := conn.RegisterFunc("addi64", addi64, true); err != nil {
return err return err
} }
if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil { if err := conn.RegisterFunc("addu8_16_32", addu8_16_32, true); err != nil {
return err return err
} }
if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil { if err := conn.RegisterFunc("addu64", addu64, true); err != nil {
return err return err
} }
if err := conn.RegisterFunc("addiu", addiu, true); err != nil { if err := conn.RegisterFunc("addiu", addiu, true); err != nil {
return err return err
} }
if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil { if err := conn.RegisterFunc("addf32_64", addf32_64, true); err != nil {
return err return err
} }
if err := conn.RegisterFunc("not", not, true); err != nil { if err := conn.RegisterFunc("not", not, true); err != nil {
@ -1189,12 +1182,12 @@ func TestFunctionRegistration(t *testing.T) {
query string query string
expected interface{} expected interface{}
}{ }{
{"SELECT addi_8_16_32(1,2)", int32(3)}, {"SELECT addi8_16_32(1,2)", int32(3)},
{"SELECT addi_64(1,2)", int64(3)}, {"SELECT addi64(1,2)", int64(3)},
{"SELECT addu_8_16_32(1,2)", uint32(3)}, {"SELECT addu8_16_32(1,2)", uint32(3)},
{"SELECT addu_64(1,2)", uint64(3)}, {"SELECT addu64(1,2)", uint64(3)},
{"SELECT addiu(1,2)", int64(3)}, {"SELECT addiu(1,2)", int64(3)},
{"SELECT addf_32_64(1.5,1.5)", float64(3)}, {"SELECT addf32_64(1.5,1.5)", float64(3)},
{"SELECT not(1)", false}, {"SELECT not(1)", false},
{"SELECT not(0)", true}, {"SELECT not(0)", true},
{`SELECT regex("^foo.*", "foobar")`, true}, {`SELECT regex("^foo.*", "foobar")`, true},
@ -1220,62 +1213,54 @@ func TestFunctionRegistration(t *testing.T) {
} }
} }
type sumAggregator int64 func TestDeclTypes(t *testing.T) {
func (s *sumAggregator) Step(x int64) { d := SQLiteDriver{}
*s += sumAggregator(x)
}
func (s *sumAggregator) Done() int64 { conn, err := d.Open(":memory:")
return int64(*s)
}
func TestAggregatorRegistration(t *testing.T) {
customSum := func() *sumAggregator {
var ret sumAggregator
return &ret
}
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to begin transaction:", err)
} }
defer db.Close() defer conn.Close()
_, err = db.Exec("create table foo (department integer, profits integer)") sqlite3conn := conn.(*SQLiteConn)
_, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text)", nil)
if err != nil { if err != nil {
t.Fatal("Failed to create table:", err) t.Fatal("Failed to create table:", err)
} }
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") _, err = sqlite3conn.Exec("insert into foo(name) values(\"bar\")", nil)
if err != nil { if err != nil {
t.Fatal("Failed to insert records:", err) t.Fatal("Failed to insert:", err)
} }
tests := []struct { rs, err := sqlite3conn.Query("select * from foo", nil)
dept, sum int64 if err != nil {
}{ t.Fatal("Failed to select:", err)
{1, 30},
{2, 42},
} }
defer rs.Close()
for _, test := range tests { declTypes := rs.(*SQLiteRows).DeclTypes()
var ret int64
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) if !reflect.DeepEqual(declTypes, []string{"integer", "text"}) {
if err != nil { t.Fatal("Unexpected declTypes:", declTypes)
t.Fatal("Query failed:", err) }
} }
if ret != test.sum {
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) func TestPinger(t *testing.T) {
} db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
err = db.Ping()
if err != nil {
t.Fatal(err)
}
db.Close()
err = db.Ping()
if err == nil {
t.Fatal("Should be closed")
} }
} }
@ -1283,14 +1268,14 @@ var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) { func BenchmarkCustomFunctions(b *testing.B) {
customFunctionOnce.Do(func() { customFunctionOnce.Do(func() {
custom_add := func(a, b int64) int64 { customAdd := func(a, b int64) int64 {
return a + b return a + b
} }
sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{ sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error { ConnectHook: func(conn *SQLiteConn) error {
// Impure function to force sqlite to reexecute it each time. // Impure function to force sqlite to reexecute it each time.
if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil { if err := conn.RegisterFunc("custom_add", customAdd, false); err != nil {
return err return err
} }
return nil return nil

View File

@ -11,14 +11,17 @@ import (
"time" "time"
) )
// Dialect is a type of dialect of databases.
type Dialect int type Dialect int
// Dialects for databases.
const ( const (
SQLITE Dialect = iota SQLITE Dialect = iota // SQLITE mean SQLite3 dialect
POSTGRESQL POSTGRESQL // POSTGRESQL mean PostgreSQL dialect
MYSQL MYSQL // MYSQL mean MySQL dialect
) )
// DB provide context for the tests
type DB struct { type DB struct {
*testing.T *testing.T
*sql.DB *sql.DB
@ -32,19 +35,19 @@ var db *DB
var testTables = []string{"foo", "bar", "t", "bench"} var testTables = []string{"foo", "bar", "t", "bench"}
var tests = []testing.InternalTest{ var tests = []testing.InternalTest{
{"TestBlobs", TestBlobs}, {Name: "TestBlobs", F: TestBlobs},
{"TestManyQueryRow", TestManyQueryRow}, {Name: "TestManyQueryRow", F: TestManyQueryRow},
{"TestTxQuery", TestTxQuery}, {Name: "TestTxQuery", F: TestTxQuery},
{"TestPreparedStmt", TestPreparedStmt}, {Name: "TestPreparedStmt", F: TestPreparedStmt},
} }
var benchmarks = []testing.InternalBenchmark{ var benchmarks = []testing.InternalBenchmark{
{"BenchmarkExec", BenchmarkExec}, {Name: "BenchmarkExec", F: BenchmarkExec},
{"BenchmarkQuery", BenchmarkQuery}, {Name: "BenchmarkQuery", F: BenchmarkQuery},
{"BenchmarkParams", BenchmarkParams}, {Name: "BenchmarkParams", F: BenchmarkParams},
{"BenchmarkStmt", BenchmarkStmt}, {Name: "BenchmarkStmt", F: BenchmarkStmt},
{"BenchmarkRows", BenchmarkRows}, {Name: "BenchmarkRows", F: BenchmarkRows},
{"BenchmarkStmtRows", BenchmarkStmtRows}, {Name: "BenchmarkStmtRows", F: BenchmarkStmtRows},
} }
// RunTests runs the SQL test suite // RunTests runs the SQL test suite
@ -78,7 +81,7 @@ func (db *DB) tearDown() {
case MYSQL, POSTGRESQL: case MYSQL, POSTGRESQL:
db.mustExec("drop table if exists " + tbl) db.mustExec("drop table if exists " + tbl)
default: default:
db.Fatal("unkown dialect") db.Fatal("unknown dialect")
} }
} }
} }
@ -106,7 +109,7 @@ func (db *DB) blobType(size int) string {
case MYSQL: case MYSQL:
return fmt.Sprintf("VARBINARY(%d)", size) return fmt.Sprintf("VARBINARY(%d)", size)
} }
panic("unkown dialect") panic("unknown dialect")
} }
func (db *DB) serialPK() string { func (db *DB) serialPK() string {
@ -118,7 +121,7 @@ func (db *DB) serialPK() string {
case MYSQL: case MYSQL:
return "integer primary key auto_increment" return "integer primary key auto_increment"
} }
panic("unkown dialect") panic("unknown dialect")
} }
func (db *DB) now() string { func (db *DB) now() string {
@ -130,7 +133,7 @@ func (db *DB) now() string {
case MYSQL: case MYSQL:
return "now()" return "now()"
} }
panic("unkown dialect") panic("unknown dialect")
} }
func makeBench() { func makeBench() {
@ -149,6 +152,7 @@ func makeBench() {
} }
} }
// TestResult is test for result
func TestResult(t *testing.T) { func TestResult(t *testing.T) {
db.tearDown() db.tearDown()
db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))") db.mustExec("create temporary table test (id " + db.serialPK() + ", name varchar(10))")
@ -175,6 +179,7 @@ func TestResult(t *testing.T) {
} }
} }
// TestBlobs is test for blobs
func TestBlobs(t *testing.T) { func TestBlobs(t *testing.T) {
db.tearDown() db.tearDown()
var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} var blob = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
@ -201,6 +206,7 @@ func TestBlobs(t *testing.T) {
} }
} }
// TestManyQueryRow is test for many query row
func TestManyQueryRow(t *testing.T) { func TestManyQueryRow(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Log("skipping in short mode") t.Log("skipping in short mode")
@ -218,6 +224,7 @@ func TestManyQueryRow(t *testing.T) {
} }
} }
// TestTxQuery is test for transactional query
func TestTxQuery(t *testing.T) { func TestTxQuery(t *testing.T) {
db.tearDown() db.tearDown()
tx, err := db.Begin() tx, err := db.Begin()
@ -256,6 +263,7 @@ func TestTxQuery(t *testing.T) {
} }
} }
// TestPreparedStmt is test for prepared statement
func TestPreparedStmt(t *testing.T) { func TestPreparedStmt(t *testing.T) {
db.tearDown() db.tearDown()
db.mustExec("CREATE TABLE t (count INT)") db.mustExec("CREATE TABLE t (count INT)")
@ -301,6 +309,7 @@ func TestPreparedStmt(t *testing.T) {
// test -bench but calling Benchmark() from a benchmark test // test -bench but calling Benchmark() from a benchmark test
// currently hangs go. // currently hangs go.
// BenchmarkExec is benchmark for exec
func BenchmarkExec(b *testing.B) { func BenchmarkExec(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if _, err := db.Exec("select 1"); err != nil { if _, err := db.Exec("select 1"); err != nil {
@ -309,6 +318,7 @@ func BenchmarkExec(b *testing.B) {
} }
} }
// BenchmarkQuery is benchmark for query
func BenchmarkQuery(b *testing.B) { func BenchmarkQuery(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var n sql.NullString var n sql.NullString
@ -322,6 +332,7 @@ func BenchmarkQuery(b *testing.B) {
} }
} }
// BenchmarkParams is benchmark for params
func BenchmarkParams(b *testing.B) { func BenchmarkParams(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
var n sql.NullString var n sql.NullString
@ -335,6 +346,7 @@ func BenchmarkParams(b *testing.B) {
} }
} }
// BenchmarkStmt is benchmark for statement
func BenchmarkStmt(b *testing.B) { func BenchmarkStmt(b *testing.B) {
st, err := db.Prepare("select ?, ?, ?, ?") st, err := db.Prepare("select ?, ?, ?, ?")
if err != nil { if err != nil {
@ -354,6 +366,7 @@ func BenchmarkStmt(b *testing.B) {
} }
} }
// BenchmarkRows is benchmark for rows
func BenchmarkRows(b *testing.B) { func BenchmarkRows(b *testing.B) {
db.once.Do(makeBench) db.once.Do(makeBench)
@ -378,6 +391,7 @@ func BenchmarkRows(b *testing.B) {
} }
} }
// BenchmarkStmtRows is benchmark for statement rows
func BenchmarkStmtRows(b *testing.B) { func BenchmarkStmtRows(b *testing.B) {
db.once.Do(makeBench) db.once.Do(makeBench)

414
sqlite3_trace.go Normal file
View File

@ -0,0 +1,414 @@
// Copyright (C) 2016 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build trace
package sqlite3
/*
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
#include <stdlib.h>
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void doneTrampoline(sqlite3_context*);
int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
*/
import "C"
import (
"errors"
"fmt"
"reflect"
"strings"
"sync"
"unsafe"
)
// Trace... constants identify the possible events causing callback invocation.
// Values are same as the corresponding SQLite Trace Event Codes.
const (
TraceStmt = C.SQLITE_TRACE_STMT
TraceProfile = C.SQLITE_TRACE_PROFILE
TraceRow = C.SQLITE_TRACE_ROW
TraceClose = C.SQLITE_TRACE_CLOSE
)
type TraceInfo struct {
// Pack together the shorter fields, to keep the struct smaller.
// On a 64-bit machine there would be padding
// between EventCode and ConnHandle; having AutoCommit here is "free":
EventCode uint32
AutoCommit bool
ConnHandle uintptr
// Usually filled, unless EventCode = TraceClose = SQLITE_TRACE_CLOSE:
// identifier for a prepared statement:
StmtHandle uintptr
// Two strings filled when EventCode = TraceStmt = SQLITE_TRACE_STMT:
// (1) either the unexpanded SQL text of the prepared statement, or
// an SQL comment that indicates the invocation of a trigger;
// (2) expanded SQL, if requested and if (1) is not an SQL comment.
StmtOrTrigger string
ExpandedSQL string // only if requested (TraceConfig.WantExpandedSQL = true)
// filled when EventCode = TraceProfile = SQLITE_TRACE_PROFILE:
// estimated number of nanoseconds that the prepared statement took to run:
RunTimeNanosec int64
DBError Error
}
// TraceUserCallback gives the signature for a trace function
// provided by the user (Go application programmer).
// SQLite 3.14 documentation (as of September 2, 2016)
// for SQL Trace Hook = sqlite3_trace_v2():
// The integer return value from the callback is currently ignored,
// though this may change in future releases. Callback implementations
// should return zero to ensure future compatibility.
type TraceUserCallback func(TraceInfo) int
type TraceConfig struct {
Callback TraceUserCallback
EventMask C.uint
WantExpandedSQL bool
}
func fillDBError(dbErr *Error, db *C.sqlite3) {
// See SQLiteConn.lastError(), in file 'sqlite3.go' at the time of writing (Sept 5, 2016)
dbErr.Code = ErrNo(C.sqlite3_errcode(db))
dbErr.ExtendedCode = ErrNoExtended(C.sqlite3_extended_errcode(db))
dbErr.err = C.GoString(C.sqlite3_errmsg(db))
}
func fillExpandedSQL(info *TraceInfo, db *C.sqlite3, pStmt unsafe.Pointer) {
if pStmt == nil {
panic("No SQLite statement pointer in P arg of trace_v2 callback")
}
expSQLiteCStr := C.sqlite3_expanded_sql((*C.sqlite3_stmt)(pStmt))
if expSQLiteCStr == nil {
fillDBError(&info.DBError, db)
return
}
info.ExpandedSQL = C.GoString(expSQLiteCStr)
}
//export traceCallbackTrampoline
func traceCallbackTrampoline(
traceEventCode C.uint,
// Parameter named 'C' in SQLite docs = Context given at registration:
ctx unsafe.Pointer,
// Parameter named 'P' in SQLite docs (Primary event data?):
p unsafe.Pointer,
// Parameter named 'X' in SQLite docs (eXtra event data?):
xValue unsafe.Pointer) C.int {
if ctx == nil {
panic(fmt.Sprintf("No context (ev 0x%x)", traceEventCode))
}
contextDB := (*C.sqlite3)(ctx)
connHandle := uintptr(ctx)
var traceConf TraceConfig
var found bool
if traceEventCode == TraceClose {
// clean up traceMap: 'pop' means get and delete
traceConf, found = popTraceMapping(connHandle)
} else {
traceConf, found = lookupTraceMapping(connHandle)
}
if !found {
panic(fmt.Sprintf("Mapping not found for handle 0x%x (ev 0x%x)",
connHandle, traceEventCode))
}
var info TraceInfo
info.EventCode = uint32(traceEventCode)
info.AutoCommit = (int(C.sqlite3_get_autocommit(contextDB)) != 0)
info.ConnHandle = connHandle
switch traceEventCode {
case TraceStmt:
info.StmtHandle = uintptr(p)
var xStr string
if xValue != nil {
xStr = C.GoString((*C.char)(xValue))
}
info.StmtOrTrigger = xStr
if !strings.HasPrefix(xStr, "--") {
// Not SQL comment, therefore the current event
// is not related to a trigger.
// The user might want to receive the expanded SQL;
// let's check:
if traceConf.WantExpandedSQL {
fillExpandedSQL(&info, contextDB, p)
}
}
case TraceProfile:
info.StmtHandle = uintptr(p)
if xValue == nil {
panic("NULL pointer in X arg of trace_v2 callback for SQLITE_TRACE_PROFILE event")
}
info.RunTimeNanosec = *(*int64)(xValue)
// sample the error //TODO: is it safe? is it useful?
fillDBError(&info.DBError, contextDB)
case TraceRow:
info.StmtHandle = uintptr(p)
case TraceClose:
handle := uintptr(p)
if handle != info.ConnHandle {
panic(fmt.Sprintf("Different conn handle 0x%x (expected 0x%x) in SQLITE_TRACE_CLOSE event.",
handle, info.ConnHandle))
}
default:
// Pass unsupported events to the user callback (if configured);
// let the user callback decide whether to panic or ignore them.
}
// Do not execute user callback when the event was not requested by user!
// Remember that the Close event is always selected when
// registering this callback trampoline with SQLite --- for cleanup.
// In the future there may be more events forced to "selected" in SQLite
// for the driver's needs.
if traceConf.EventMask&traceEventCode == 0 {
return 0
}
r := 0
if traceConf.Callback != nil {
r = traceConf.Callback(info)
}
return C.int(r)
}
type traceMapEntry struct {
config TraceConfig
}
var traceMapLock sync.Mutex
var traceMap = make(map[uintptr]traceMapEntry)
func addTraceMapping(connHandle uintptr, traceConf TraceConfig) {
traceMapLock.Lock()
defer traceMapLock.Unlock()
oldEntryCopy, found := traceMap[connHandle]
if found {
panic(fmt.Sprintf("Adding trace config %v: handle 0x%x already registered (%v).",
traceConf, connHandle, oldEntryCopy.config))
}
traceMap[connHandle] = traceMapEntry{config: traceConf}
fmt.Printf("Added trace config %v: handle 0x%x.\n", traceConf, connHandle)
}
func lookupTraceMapping(connHandle uintptr) (TraceConfig, bool) {
traceMapLock.Lock()
defer traceMapLock.Unlock()
entryCopy, found := traceMap[connHandle]
return entryCopy.config, found
}
// 'pop' = get and delete from map before returning the value to the caller
func popTraceMapping(connHandle uintptr) (TraceConfig, bool) {
traceMapLock.Lock()
defer traceMapLock.Unlock()
entryCopy, found := traceMap[connHandle]
if found {
delete(traceMap, connHandle)
fmt.Printf("Pop handle 0x%x: deleted trace config %v.\n", connHandle, entryCopy.config)
}
return entryCopy.config, found
}
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
//
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
//
// The constructor function and the Step/Done methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
var ai aggInfo
ai.constructor = reflect.ValueOf(impl)
t := ai.constructor.Type()
if t.Kind() != reflect.Func {
return errors.New("non-function passed to RegisterAggregator")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
if t.NumIn() != 0 {
return errors.New("SQLite aggregator constructors must not have arguments")
}
agg := t.Out(0)
switch agg.Kind() {
case reflect.Ptr, reflect.Interface:
default:
return errors.New("SQlite aggregator constructor must return a pointer object")
}
stepFn, found := agg.MethodByName("Step")
if !found {
return errors.New("SQlite aggregator doesn't have a Step() function")
}
step := stepFn.Type
if step.NumOut() != 0 && step.NumOut() != 1 {
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
}
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQlite aggregator Step() return value must be error")
}
stepNArgs := step.NumIn()
start := 0
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
stepNArgs--
start++
}
if step.IsVariadic() {
stepNArgs--
}
for i := start; i < start+stepNArgs; i++ {
conv, err := callbackArg(step.In(i))
if err != nil {
return err
}
ai.stepArgConverters = append(ai.stepArgConverters, conv)
}
if step.IsVariadic() {
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
if err != nil {
return err
}
ai.stepVariadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
stepNArgs = -1
}
doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQlite aggregator doesn't have a Done() function")
}
done := doneFn.Type
doneNArgs := done.NumIn()
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
doneNArgs--
}
if doneNArgs != 0 {
return errors.New("SQlite aggregator Done() function must have no arguments")
}
if done.NumOut() != 1 && done.NumOut() != 2 {
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
}
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator Done() function must be error")
}
conv, err := callbackRet(done.Out(0))
if err != nil {
return err
}
ai.doneRetConverter = conv
ai.active = make(map[int64]reflect.Value)
ai.next = 1
// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// SetTrace installs or removes the trace callback for the given database connection.
// It's not named 'RegisterTrace' because only one callback can be kept and called.
// Calling SetTrace a second time on same database connection
// overrides (cancels) any prior callback and all its settings:
// event mask, etc.
func (c *SQLiteConn) SetTrace(requested *TraceConfig) error {
connHandle := uintptr(unsafe.Pointer(c.db))
_, _ = popTraceMapping(connHandle)
if requested == nil {
// The traceMap entry was deleted already by popTraceMapping():
// can disable all events now, no need to watch for TraceClose.
err := c.setSQLiteTrace(0)
return err
}
reqCopy := *requested
// Disable potentially expensive operations
// if their result will not be used. We are doing this
// just in case the caller provided nonsensical input.
if reqCopy.EventMask&TraceStmt == 0 {
reqCopy.WantExpandedSQL = false
}
addTraceMapping(connHandle, reqCopy)
// The callback trampoline function does cleanup on Close event,
// regardless of the presence or absence of the user callback.
// Therefore it needs the Close event to be selected:
actualEventMask := uint(reqCopy.EventMask | TraceClose)
err := c.setSQLiteTrace(actualEventMask)
return err
}
func (c *SQLiteConn) setSQLiteTrace(sqliteEventMask uint) error {
rv := C.sqlite3_trace_v2(c.db,
C.uint(sqliteEventMask),
(*[0]byte)(unsafe.Pointer(C.traceCallbackTrampoline)),
unsafe.Pointer(c.db)) // Fourth arg is same as first: we are
// passing the database connection handle as callback context.
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}

72
sqlite3_trace_test.go Normal file
View File

@ -0,0 +1,72 @@
// Copyright (C) 2016 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build trace
package sqlite3
import (
"database/sql"
"testing"
)
type sumAggregator int64
func (s *sumAggregator) Step(x int64) {
*s += sumAggregator(x)
}
func (s *sumAggregator) Done() int64 {
return int64(*s)
}
func TestAggregatorRegistration(t *testing.T) {
customSum := func() *sumAggregator {
var ret sumAggregator
return &ret
}
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec("create table foo (department integer, profits integer)")
if err != nil {
// trace feature is not implemented
t.Skip("Failed to create table:", err)
}
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
if err != nil {
t.Fatal("Failed to insert records:", err)
}
tests := []struct {
dept, sum int64
}{
{1, 30},
{2, 42},
}
for _, test := range tests {
var ret int64
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
if err != nil {
t.Fatal("Query failed:", err)
}
if ret != test.sum {
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
}
}
}

57
sqlite3_type.go Normal file
View File

@ -0,0 +1,57 @@
package sqlite3
/*
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
*/
import "C"
import (
"reflect"
"time"
)
// ColumnTypeDatabaseTypeName implement RowsColumnTypeDatabaseTypeName.
func (rc *SQLiteRows) ColumnTypeDatabaseTypeName(i int) string {
return C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))
}
/*
func (rc *SQLiteRows) ColumnTypeLength(index int) (length int64, ok bool) {
return 0, false
}
func (rc *SQLiteRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
return 0, 0, false
}
*/
// ColumnTypeNullable implement RowsColumnTypeNullable.
func (rc *SQLiteRows) ColumnTypeNullable(i int) (nullable, ok bool) {
return true, true
}
// ColumnTypeScanType implement RowsColumnTypeScanType.
func (rc *SQLiteRows) ColumnTypeScanType(i int) reflect.Type {
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
case C.SQLITE_INTEGER:
switch C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))) {
case "timestamp", "datetime", "date":
return reflect.TypeOf(time.Time{})
case "boolean":
return reflect.TypeOf(false)
}
return reflect.TypeOf(int64(0))
case C.SQLITE_FLOAT:
return reflect.TypeOf(float64(0))
case C.SQLITE_BLOB:
return reflect.SliceOf(reflect.TypeOf(byte(0)))
case C.SQLITE_NULL:
return reflect.TypeOf(nil)
case C.SQLITE_TEXT:
return reflect.TypeOf("")
}
return reflect.SliceOf(reflect.TypeOf(byte(0)))
}

646
sqlite3_vtable.go Normal file
View File

@ -0,0 +1,646 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build vtable
package sqlite3
/*
#cgo CFLAGS: -std=gnu99
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
#cgo CFLAGS: -DSQLITE_ENABLE_COLUMN_METADATA=1
#cgo CFLAGS: -Wno-deprecated-declarations
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
#include <stdlib.h>
#include <stdint.h>
#include <memory.h>
static inline char *_sqlite3_mprintf(char *zFormat, char *arg) {
return sqlite3_mprintf(zFormat, arg);
}
typedef struct goVTab goVTab;
struct goVTab {
sqlite3_vtab base;
void *vTab;
};
uintptr_t goMInit(void *db, void *pAux, int argc, char **argv, char **pzErr, int isCreate);
static int cXInit(sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr, int isCreate) {
void *vTab = (void *)goMInit(db, pAux, argc, (char**)argv, pzErr, isCreate);
if (!vTab || *pzErr) {
return SQLITE_ERROR;
}
goVTab *pvTab = (goVTab *)sqlite3_malloc(sizeof(goVTab));
if (!pvTab) {
*pzErr = sqlite3_mprintf("%s", "Out of memory");
return SQLITE_NOMEM;
}
memset(pvTab, 0, sizeof(goVTab));
pvTab->vTab = vTab;
*ppVTab = (sqlite3_vtab *)pvTab;
*pzErr = 0;
return SQLITE_OK;
}
static inline int cXCreate(sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr) {
return cXInit(db, pAux, argc, argv, ppVTab, pzErr, 1);
}
static inline int cXConnect(sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr) {
return cXInit(db, pAux, argc, argv, ppVTab, pzErr, 0);
}
char* goVBestIndex(void *pVTab, void *icp);
static inline int cXBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *info) {
char *pzErr = goVBestIndex(((goVTab*)pVTab)->vTab, info);
if (pzErr) {
if (pVTab->zErrMsg)
sqlite3_free(pVTab->zErrMsg);
pVTab->zErrMsg = pzErr;
return SQLITE_ERROR;
}
return SQLITE_OK;
}
char* goVRelease(void *pVTab, int isDestroy);
static int cXRelease(sqlite3_vtab *pVTab, int isDestroy) {
char *pzErr = goVRelease(((goVTab*)pVTab)->vTab, isDestroy);
if (pzErr) {
if (pVTab->zErrMsg)
sqlite3_free(pVTab->zErrMsg);
pVTab->zErrMsg = pzErr;
return SQLITE_ERROR;
}
if (pVTab->zErrMsg)
sqlite3_free(pVTab->zErrMsg);
sqlite3_free(pVTab);
return SQLITE_OK;
}
static inline int cXDisconnect(sqlite3_vtab *pVTab) {
return cXRelease(pVTab, 0);
}
static inline int cXDestroy(sqlite3_vtab *pVTab) {
return cXRelease(pVTab, 1);
}
typedef struct goVTabCursor goVTabCursor;
struct goVTabCursor {
sqlite3_vtab_cursor base;
void *vTabCursor;
};
uintptr_t goVOpen(void *pVTab, char **pzErr);
static int cXOpen(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor) {
void *vTabCursor = (void *)goVOpen(((goVTab*)pVTab)->vTab, &(pVTab->zErrMsg));
goVTabCursor *pCursor = (goVTabCursor *)sqlite3_malloc(sizeof(goVTabCursor));
if (!pCursor) {
return SQLITE_NOMEM;
}
memset(pCursor, 0, sizeof(goVTabCursor));
pCursor->vTabCursor = vTabCursor;
*ppCursor = (sqlite3_vtab_cursor *)pCursor;
return SQLITE_OK;
}
static int setErrMsg(sqlite3_vtab_cursor *pCursor, char *pzErr) {
if (pCursor->pVtab->zErrMsg)
sqlite3_free(pCursor->pVtab->zErrMsg);
pCursor->pVtab->zErrMsg = pzErr;
return SQLITE_ERROR;
}
char* goVClose(void *pCursor);
static int cXClose(sqlite3_vtab_cursor *pCursor) {
char *pzErr = goVClose(((goVTabCursor*)pCursor)->vTabCursor);
if (pzErr) {
return setErrMsg(pCursor, pzErr);
}
sqlite3_free(pCursor);
return SQLITE_OK;
}
char* goVFilter(void *pCursor, int idxNum, char* idxName, int argc, sqlite3_value **argv);
static int cXFilter(sqlite3_vtab_cursor *pCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) {
char *pzErr = goVFilter(((goVTabCursor*)pCursor)->vTabCursor, idxNum, (char*)idxStr, argc, argv);
if (pzErr) {
return setErrMsg(pCursor, pzErr);
}
return SQLITE_OK;
}
char* goVNext(void *pCursor);
static int cXNext(sqlite3_vtab_cursor *pCursor) {
char *pzErr = goVNext(((goVTabCursor*)pCursor)->vTabCursor);
if (pzErr) {
return setErrMsg(pCursor, pzErr);
}
return SQLITE_OK;
}
int goVEof(void *pCursor);
static inline int cXEof(sqlite3_vtab_cursor *pCursor) {
return goVEof(((goVTabCursor*)pCursor)->vTabCursor);
}
char* goVColumn(void *pCursor, void *cp, int col);
static int cXColumn(sqlite3_vtab_cursor *pCursor, sqlite3_context *ctx, int i) {
char *pzErr = goVColumn(((goVTabCursor*)pCursor)->vTabCursor, ctx, i);
if (pzErr) {
return setErrMsg(pCursor, pzErr);
}
return SQLITE_OK;
}
char* goVRowid(void *pCursor, sqlite3_int64 *pRowid);
static int cXRowid(sqlite3_vtab_cursor *pCursor, sqlite3_int64 *pRowid) {
char *pzErr = goVRowid(((goVTabCursor*)pCursor)->vTabCursor, pRowid);
if (pzErr) {
return setErrMsg(pCursor, pzErr);
}
return SQLITE_OK;
}
char* goVUpdate(void *pVTab, int argc, sqlite3_value **argv, sqlite3_int64 *pRowid);
static int cXUpdate(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, sqlite3_int64 *pRowid) {
char *pzErr = goVUpdate(((goVTab*)pVTab)->vTab, argc, argv, pRowid);
if (pzErr) {
if (pVTab->zErrMsg)
sqlite3_free(pVTab->zErrMsg);
pVTab->zErrMsg = pzErr;
return SQLITE_ERROR;
}
return SQLITE_OK;
}
static sqlite3_module goModule = {
0, // iVersion
cXCreate, // xCreate - create a table
cXConnect, // xConnect - connect to an existing table
cXBestIndex, // xBestIndex - Determine search strategy
cXDisconnect, // xDisconnect - Disconnect from a table
cXDestroy, // xDestroy - Drop a table
cXOpen, // xOpen - open a cursor
cXClose, // xClose - close a cursor
cXFilter, // xFilter - configure scan constraints
cXNext, // xNext - advance a cursor
cXEof, // xEof
cXColumn, // xColumn - read data
cXRowid, // xRowid - read data
cXUpdate, // xUpdate - write data
// Not implemented
0, // xBegin - begin transaction
0, // xSync - sync transaction
0, // xCommit - commit transaction
0, // xRollback - rollback transaction
0, // xFindFunction - function overloading
0, // xRename - rename the table
0, // xSavepoint
0, // xRelease
0 // xRollbackTo
};
void goMDestroy(void*);
static int _sqlite3_create_module(sqlite3 *db, const char *zName, uintptr_t pClientData) {
return sqlite3_create_module_v2(db, zName, &goModule, (void*) pClientData, goMDestroy);
}
*/
import "C"
import (
"fmt"
"math"
"reflect"
"unsafe"
)
type sqliteModule struct {
c *SQLiteConn
name string
module Module
}
type sqliteVTab struct {
module *sqliteModule
vTab VTab
}
type sqliteVTabCursor struct {
vTab *sqliteVTab
vTabCursor VTabCursor
}
// Op is type of operations.
type Op uint8
// Op mean identity of operations.
const (
OpEQ Op = 2
OpGT = 4
OpLE = 8
OpLT = 16
OpGE = 32
OpMATCH = 64
OpLIKE = 65 /* 3.10.0 and later only */
OpGLOB = 66 /* 3.10.0 and later only */
OpREGEXP = 67 /* 3.10.0 and later only */
OpScanUnique = 1 /* Scan visits at most 1 row */
)
// InfoConstraint give information of constraint.
type InfoConstraint struct {
Column int
Op Op
Usable bool
}
// InfoOrderBy give information of order-by.
type InfoOrderBy struct {
Column int
Desc bool
}
func constraints(info *C.sqlite3_index_info) []InfoConstraint {
l := info.nConstraint
slice := (*[1 << 30]C.struct_sqlite3_index_constraint)(unsafe.Pointer(info.aConstraint))[:l:l]
cst := make([]InfoConstraint, 0, l)
for _, c := range slice {
var usable bool
if c.usable > 0 {
usable = true
}
cst = append(cst, InfoConstraint{
Column: int(c.iColumn),
Op: Op(c.op),
Usable: usable,
})
}
return cst
}
func orderBys(info *C.sqlite3_index_info) []InfoOrderBy {
l := info.nOrderBy
slice := (*[1 << 30]C.struct_sqlite3_index_orderby)(unsafe.Pointer(info.aOrderBy))[:l:l]
ob := make([]InfoOrderBy, 0, l)
for _, c := range slice {
var desc bool
if c.desc > 0 {
desc = true
}
ob = append(ob, InfoOrderBy{
Column: int(c.iColumn),
Desc: desc,
})
}
return ob
}
// IndexResult is a Go struct representation of what eventually ends up in the
// output fields for `sqlite3_index_info`
// See: https://www.sqlite.org/c3ref/index_info.html
type IndexResult struct {
Used []bool // aConstraintUsage
IdxNum int
IdxStr string
AlreadyOrdered bool // orderByConsumed
EstimatedCost float64
EstimatedRows float64
}
// mPrintf is a utility wrapper around sqlite3_mprintf
func mPrintf(format, arg string) *C.char {
cf := C.CString(format)
defer C.free(unsafe.Pointer(cf))
ca := C.CString(arg)
defer C.free(unsafe.Pointer(ca))
return C._sqlite3_mprintf(cf, ca)
}
//export goMInit
func goMInit(db, pClientData unsafe.Pointer, argc C.int, argv **C.char, pzErr **C.char, isCreate C.int) C.uintptr_t {
m := lookupHandle(uintptr(pClientData)).(*sqliteModule)
if m.c.db != (*C.sqlite3)(db) {
*pzErr = mPrintf("%s", "Inconsistent db handles")
return 0
}
args := make([]string, argc)
var A []*C.char
slice := reflect.SliceHeader{Data: uintptr(unsafe.Pointer(argv)), Len: int(argc), Cap: int(argc)}
a := reflect.NewAt(reflect.TypeOf(A), unsafe.Pointer(&slice)).Elem().Interface()
for i, s := range a.([]*C.char) {
args[i] = C.GoString(s)
}
var vTab VTab
var err error
if isCreate == 1 {
vTab, err = m.module.Create(m.c, args)
} else {
vTab, err = m.module.Connect(m.c, args)
}
if err != nil {
*pzErr = mPrintf("%s", err.Error())
return 0
}
vt := sqliteVTab{m, vTab}
*pzErr = nil
return C.uintptr_t(newHandle(m.c, &vt))
}
//export goVRelease
func goVRelease(pVTab unsafe.Pointer, isDestroy C.int) *C.char {
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
var err error
if isDestroy == 1 {
err = vt.vTab.Destroy()
} else {
err = vt.vTab.Disconnect()
}
if err != nil {
return mPrintf("%s", err.Error())
}
return nil
}
//export goVOpen
func goVOpen(pVTab unsafe.Pointer, pzErr **C.char) C.uintptr_t {
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
vTabCursor, err := vt.vTab.Open()
if err != nil {
*pzErr = mPrintf("%s", err.Error())
return 0
}
vtc := sqliteVTabCursor{vt, vTabCursor}
*pzErr = nil
return C.uintptr_t(newHandle(vt.module.c, &vtc))
}
//export goVBestIndex
func goVBestIndex(pVTab unsafe.Pointer, icp unsafe.Pointer) *C.char {
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
info := (*C.sqlite3_index_info)(icp)
csts := constraints(info)
res, err := vt.vTab.BestIndex(csts, orderBys(info))
if err != nil {
return mPrintf("%s", err.Error())
}
if len(res.Used) != len(csts) {
return mPrintf("Result.Used != expected value", "")
}
// Get a pointer to constraint_usage struct so we can update in place.
l := info.nConstraint
s := (*[1 << 30]C.struct_sqlite3_index_constraint_usage)(unsafe.Pointer(info.aConstraintUsage))[:l:l]
index := 1
for i := C.int(0); i < info.nConstraint; i++ {
if res.Used[i] {
s[i].argvIndex = C.int(index)
s[i].omit = C.uchar(1)
index++
}
}
info.idxNum = C.int(res.IdxNum)
idxStr := C.CString(res.IdxStr)
defer C.free(unsafe.Pointer(idxStr))
info.idxStr = idxStr
info.needToFreeIdxStr = C.int(0)
if res.AlreadyOrdered {
info.orderByConsumed = C.int(1)
}
info.estimatedCost = C.double(res.EstimatedCost)
info.estimatedRows = C.sqlite3_int64(res.EstimatedRows)
return nil
}
//export goVClose
func goVClose(pCursor unsafe.Pointer) *C.char {
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
err := vtc.vTabCursor.Close()
if err != nil {
return mPrintf("%s", err.Error())
}
return nil
}
//export goMDestroy
func goMDestroy(pClientData unsafe.Pointer) {
m := lookupHandle(uintptr(pClientData)).(*sqliteModule)
m.module.DestroyModule()
}
//export goVFilter
func goVFilter(pCursor unsafe.Pointer, idxNum C.int, idxName *C.char, argc C.int, argv **C.sqlite3_value) *C.char {
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
vals := make([]interface{}, 0, argc)
for _, v := range args {
conv, err := callbackArgGeneric(v)
if err != nil {
return mPrintf("%s", err.Error())
}
vals = append(vals, conv.Interface())
}
err := vtc.vTabCursor.Filter(int(idxNum), C.GoString(idxName), vals)
if err != nil {
return mPrintf("%s", err.Error())
}
return nil
}
//export goVNext
func goVNext(pCursor unsafe.Pointer) *C.char {
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
err := vtc.vTabCursor.Next()
if err != nil {
return mPrintf("%s", err.Error())
}
return nil
}
//export goVEof
func goVEof(pCursor unsafe.Pointer) C.int {
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
err := vtc.vTabCursor.EOF()
if err {
return 1
}
return 0
}
//export goVColumn
func goVColumn(pCursor, cp unsafe.Pointer, col C.int) *C.char {
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
c := (*SQLiteContext)(cp)
err := vtc.vTabCursor.Column(c, int(col))
if err != nil {
return mPrintf("%s", err.Error())
}
return nil
}
//export goVRowid
func goVRowid(pCursor unsafe.Pointer, pRowid *C.sqlite3_int64) *C.char {
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
rowid, err := vtc.vTabCursor.Rowid()
if err != nil {
return mPrintf("%s", err.Error())
}
*pRowid = C.sqlite3_int64(rowid)
return nil
}
//export goVUpdate
func goVUpdate(pVTab unsafe.Pointer, argc C.int, argv **C.sqlite3_value, pRowid *C.sqlite3_int64) *C.char {
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
var tname string
if n, ok := vt.vTab.(interface {
TableName() string
}); ok {
tname = n.TableName() + " "
}
err := fmt.Errorf("virtual %s table %sis read-only", vt.module.name, tname)
if v, ok := vt.vTab.(VTabUpdater); ok {
// convert argv
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
vals := make([]interface{}, 0, argc)
for _, v := range args {
conv, err := callbackArgGeneric(v)
if err != nil {
return mPrintf("%s", err.Error())
}
// work around for SQLITE_NULL
x := conv.Interface()
if z, ok := x.([]byte); ok && z == nil {
x = nil
}
vals = append(vals, x)
}
switch {
case argc == 1:
err = v.Delete(vals[0])
case argc > 1 && vals[0] == nil:
var id int64
id, err = v.Insert(vals[1], vals[2:])
if err == nil {
*pRowid = C.sqlite3_int64(id)
}
case argc > 1:
err = v.Update(vals[1], vals[2:])
}
}
if err != nil {
return mPrintf("%s", err.Error())
}
return nil
}
// Module is a "virtual table module", it defines the implementation of a
// virtual tables. See: http://sqlite.org/c3ref/module.html
type Module interface {
// http://sqlite.org/vtab.html#xcreate
Create(c *SQLiteConn, args []string) (VTab, error)
// http://sqlite.org/vtab.html#xconnect
Connect(c *SQLiteConn, args []string) (VTab, error)
// http://sqlite.org/c3ref/create_module.html
DestroyModule()
}
// VTab describes a particular instance of the virtual table.
// See: http://sqlite.org/c3ref/vtab.html
type VTab interface {
// http://sqlite.org/vtab.html#xbestindex
BestIndex([]InfoConstraint, []InfoOrderBy) (*IndexResult, error)
// http://sqlite.org/vtab.html#xdisconnect
Disconnect() error
// http://sqlite.org/vtab.html#sqlite3_module.xDestroy
Destroy() error
// http://sqlite.org/vtab.html#xopen
Open() (VTabCursor, error)
}
// VTabUpdater is a type that allows a VTab to be inserted, updated, or
// deleted.
// See: https://sqlite.org/vtab.html#xupdate
type VTabUpdater interface {
Delete(interface{}) error
Insert(interface{}, []interface{}) (int64, error)
Update(interface{}, []interface{}) error
}
// VTabCursor describes cursors that point into the virtual table and are used
// to loop through the virtual table. See: http://sqlite.org/c3ref/vtab_cursor.html
type VTabCursor interface {
// http://sqlite.org/vtab.html#xclose
Close() error
// http://sqlite.org/vtab.html#xfilter
Filter(idxNum int, idxStr string, vals []interface{}) error
// http://sqlite.org/vtab.html#xnext
Next() error
// http://sqlite.org/vtab.html#xeof
EOF() bool
// http://sqlite.org/vtab.html#xcolumn
Column(c *SQLiteContext, col int) error
// http://sqlite.org/vtab.html#xrowid
Rowid() (int64, error)
}
// DeclareVTab declares the Schema of a virtual table.
// See: http://sqlite.org/c3ref/declare_vtab.html
func (c *SQLiteConn) DeclareVTab(sql string) error {
zSQL := C.CString(sql)
defer C.free(unsafe.Pointer(zSQL))
rv := C.sqlite3_declare_vtab(c.db, zSQL)
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// CreateModule registers a virtual table implementation.
// See: http://sqlite.org/c3ref/create_module.html
func (c *SQLiteConn) CreateModule(moduleName string, module Module) error {
mname := C.CString(moduleName)
defer C.free(unsafe.Pointer(mname))
udm := sqliteModule{c, moduleName, module}
rv := C._sqlite3_create_module(c.db, mname, C.uintptr_t(newHandle(c, &udm)))
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}

485
sqlite3_vtable_test.go Normal file
View File

@ -0,0 +1,485 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build vtable
package sqlite3
import (
"database/sql"
"errors"
"fmt"
"os"
"reflect"
"strings"
"testing"
)
type testModule struct {
t *testing.T
intarray []int
}
type testVTab struct {
intarray []int
}
type testVTabCursor struct {
vTab *testVTab
index int
}
func (m testModule) Create(c *SQLiteConn, args []string) (VTab, error) {
if len(args) != 6 {
m.t.Fatal("six arguments expected")
}
if args[0] != "test" {
m.t.Fatal("module name")
}
if args[1] != "main" {
m.t.Fatal("db name")
}
if args[2] != "vtab" {
m.t.Fatal("table name")
}
if args[3] != "'1'" {
m.t.Fatal("first arg")
}
if args[4] != "2" {
m.t.Fatal("second arg")
}
if args[5] != "three" {
m.t.Fatal("third argsecond arg")
}
err := c.DeclareVTab("CREATE TABLE x(test TEXT)")
if err != nil {
return nil, err
}
return &testVTab{m.intarray}, nil
}
func (m testModule) Connect(c *SQLiteConn, args []string) (VTab, error) {
return m.Create(c, args)
}
func (m testModule) DestroyModule() {}
func (v *testVTab) BestIndex(cst []InfoConstraint, ob []InfoOrderBy) (*IndexResult, error) {
used := make([]bool, 0, len(cst))
for range cst {
used = append(used, false)
}
return &IndexResult{
Used: used,
IdxNum: 0,
IdxStr: "test-index",
AlreadyOrdered: true,
EstimatedCost: 100,
EstimatedRows: 200,
}, nil
}
func (v *testVTab) Disconnect() error {
return nil
}
func (v *testVTab) Destroy() error {
return nil
}
func (v *testVTab) Open() (VTabCursor, error) {
return &testVTabCursor{v, 0}, nil
}
func (vc *testVTabCursor) Close() error {
return nil
}
func (vc *testVTabCursor) Filter(idxNum int, idxStr string, vals []interface{}) error {
vc.index = 0
return nil
}
func (vc *testVTabCursor) Next() error {
vc.index++
return nil
}
func (vc *testVTabCursor) EOF() bool {
return vc.index >= len(vc.vTab.intarray)
}
func (vc *testVTabCursor) Column(c *SQLiteContext, col int) error {
if col != 0 {
return fmt.Errorf("column index out of bounds: %d", col)
}
c.ResultInt(vc.vTab.intarray[vc.index])
return nil
}
func (vc *testVTabCursor) Rowid() (int64, error) {
return int64(vc.index), nil
}
func TestCreateModule(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
intarray := []int{1, 2, 3}
sql.Register("sqlite3_TestCreateModule", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
return conn.CreateModule("test", testModule{t, intarray})
},
})
db, err := sql.Open("sqlite3_TestCreateModule", tempFilename)
if err != nil {
t.Fatalf("could not open db: %v", err)
}
_, err = db.Exec("CREATE VIRTUAL TABLE vtab USING test('1', 2, three)")
if err != nil {
t.Fatalf("could not create vtable: %v", err)
}
var i, value int
rows, err := db.Query("SELECT rowid, * FROM vtab WHERE test = '3'")
if err != nil {
t.Fatalf("couldn't select from virtual table: %v", err)
}
for rows.Next() {
rows.Scan(&i, &value)
if intarray[i] != value {
t.Fatalf("want %v but %v", intarray[i], value)
}
}
_, err = db.Exec("DROP TABLE vtab")
if err != nil {
t.Fatalf("couldn't drop virtual table: %v", err)
}
}
func TestVUpdate(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
// create module
updateMod := &vtabUpdateModule{t, make(map[string]*vtabUpdateTable)}
// register module
sql.Register("sqlite3_TestVUpdate", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
return conn.CreateModule("updatetest", updateMod)
},
})
// connect
db, err := sql.Open("sqlite3_TestVUpdate", tempFilename)
if err != nil {
t.Fatalf("could not open db: %v", err)
}
// create test table
_, err = db.Exec(`CREATE VIRTUAL TABLE vt USING updatetest(f1 integer, f2 text, f3 text)`)
if err != nil {
t.Fatalf("could not create updatetest vtable vt, got: %v", err)
}
// check that table is defined properly
if len(updateMod.tables) != 1 {
t.Fatalf("expected exactly 1 table to exist, got: %d", len(updateMod.tables))
}
if _, ok := updateMod.tables["vt"]; !ok {
t.Fatalf("expected table `vt` to exist in tables")
}
// check nothing in updatetest
rows, err := db.Query(`select * from vt`)
if err != nil {
t.Fatalf("could not query vt, got: %v", err)
}
i, err := getRowCount(rows)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if i != 0 {
t.Fatalf("expected no rows in vt, got: %d", i)
}
_, err = db.Exec(`delete from vt where f1 = 'yes'`)
if err != nil {
t.Fatalf("expected error on delete, got nil")
}
// test bad column name
_, err = db.Exec(`insert into vt (f4) values('a')`)
if err == nil {
t.Fatalf("expected error on insert, got nil")
}
// insert to vt
res, err := db.Exec(`insert into vt (f1, f2, f3) values (115, 'b', 'c'), (116, 'd', 'e')`)
if err != nil {
t.Fatalf("expected no error on insert, got: %v", err)
}
n, err := res.RowsAffected()
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if n != 2 {
t.Fatalf("expected 1 row affected, got: %d", n)
}
// check vt table
vt := updateMod.tables["vt"]
if len(vt.data) != 2 {
t.Fatalf("expected table vt to have exactly 2 rows, got: %d", len(vt.data))
}
if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
}
if !reflect.DeepEqual(vt.data[1], []interface{}{int64(116), "d", "e"}) {
t.Fatalf("expected table vt entry 1 to be [116 d e], instead: %v", vt.data[1])
}
// query vt
var f1 int
var f2, f3 string
err = db.QueryRow(`select * from vt where f1 = 115`).Scan(&f1, &f2, &f3)
if err != nil {
t.Fatalf("expected no error on vt query, got: %v", err)
}
// check column values
if f1 != 115 || f2 != "b" || f3 != "c" {
t.Errorf("expected f1==115, f2==b, f3==c, got: %d, %q, %q", f1, f2, f3)
}
// update vt
res, err = db.Exec(`update vt set f1=117, f2='f' where f3='e'`)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
n, err = res.RowsAffected()
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if n != 1 {
t.Fatalf("expected exactly one row updated, got: %d", n)
}
// check vt table
if len(vt.data) != 2 {
t.Fatalf("expected table vt to have exactly 2 rows, got: %d", len(vt.data))
}
if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
}
if !reflect.DeepEqual(vt.data[1], []interface{}{int64(117), "f", "e"}) {
t.Fatalf("expected table vt entry 1 to be [117 f e], instead: %v", vt.data[1])
}
// delete from vt
res, err = db.Exec(`delete from vt where f1 = 117`)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
n, err = res.RowsAffected()
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if n != 1 {
t.Fatalf("expected exactly one row deleted, got: %d", n)
}
// check vt table
if len(vt.data) != 1 {
t.Fatalf("expected table vt to have exactly 1 row, got: %d", len(vt.data))
}
if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
}
// check updatetest has 1 result
rows, err = db.Query(`select * from vt`)
if err != nil {
t.Fatalf("could not query vt, got: %v", err)
}
i, err = getRowCount(rows)
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if i != 1 {
t.Fatalf("expected 1 row in vt, got: %d", i)
}
}
func getRowCount(rows *sql.Rows) (int, error) {
var i int
for rows.Next() {
i++
}
return i, nil
}
type vtabUpdateModule struct {
t *testing.T
tables map[string]*vtabUpdateTable
}
func (m *vtabUpdateModule) Create(c *SQLiteConn, args []string) (VTab, error) {
if len(args) < 2 {
return nil, errors.New("must declare at least one column")
}
// get database name, table name, and column declarations ...
dbname, tname, decls := args[1], args[2], args[3:]
// extract column names + types from parameters declarations
cols, typs := make([]string, len(decls)), make([]string, len(decls))
for i := 0; i < len(decls); i++ {
n, typ := decls[i], ""
if j := strings.IndexAny(n, " \t\n"); j != -1 {
typ, n = strings.TrimSpace(n[j+1:]), n[:j]
}
cols[i], typs[i] = n, typ
}
// declare table
err := c.DeclareVTab(fmt.Sprintf(`CREATE TABLE "%s"."%s" (%s)`, dbname, tname, strings.Join(decls, ",")))
if err != nil {
return nil, err
}
// create table
vtab := &vtabUpdateTable{m.t, dbname, tname, cols, typs, make([][]interface{}, 0)}
m.tables[tname] = vtab
return vtab, nil
}
func (m *vtabUpdateModule) Connect(c *SQLiteConn, args []string) (VTab, error) {
return m.Create(c, args)
}
func (m *vtabUpdateModule) DestroyModule() {}
type vtabUpdateTable struct {
t *testing.T
db string
name string
cols []string
typs []string
data [][]interface{}
}
func (t *vtabUpdateTable) Open() (VTabCursor, error) {
return &vtabUpdateCursor{t, 0}, nil
}
func (t *vtabUpdateTable) BestIndex(cst []InfoConstraint, ob []InfoOrderBy) (*IndexResult, error) {
return &IndexResult{Used: make([]bool, len(cst))}, nil
}
func (t *vtabUpdateTable) Disconnect() error {
return nil
}
func (t *vtabUpdateTable) Destroy() error {
return nil
}
func (t *vtabUpdateTable) Insert(id interface{}, vals []interface{}) (int64, error) {
var i int64
if id == nil {
i, t.data = int64(len(t.data)), append(t.data, vals)
return i, nil
}
var ok bool
i, ok = id.(int64)
if !ok {
return 0, fmt.Errorf("id is invalid type: %T", id)
}
t.data[i] = vals
return i, nil
}
func (t *vtabUpdateTable) Update(id interface{}, vals []interface{}) error {
i, ok := id.(int64)
if !ok {
return fmt.Errorf("id is invalid type: %T", id)
}
if int(i) >= len(t.data) || i < 0 {
return fmt.Errorf("invalid row id %d", i)
}
t.data[int(i)] = vals
return nil
}
func (t *vtabUpdateTable) Delete(id interface{}) error {
i, ok := id.(int64)
if !ok {
return fmt.Errorf("id is invalid type: %T", id)
}
if int(i) >= len(t.data) || i < 0 {
return fmt.Errorf("invalid row id %d", i)
}
t.data = append(t.data[:i], t.data[i+1:]...)
return nil
}
type vtabUpdateCursor struct {
t *vtabUpdateTable
i int
}
func (c *vtabUpdateCursor) Column(ctxt *SQLiteContext, col int) error {
switch x := c.t.data[c.i][col].(type) {
case []byte:
ctxt.ResultBlob(x)
case bool:
ctxt.ResultBool(x)
case float64:
ctxt.ResultDouble(x)
case int:
ctxt.ResultInt(x)
case int64:
ctxt.ResultInt64(x)
case nil:
ctxt.ResultNull()
case string:
ctxt.ResultText(x)
default:
ctxt.ResultText(fmt.Sprintf("%v", x))
}
return nil
}
func (c *vtabUpdateCursor) Filter(ixNum int, ixName string, vals []interface{}) error {
return nil
}
func (c *vtabUpdateCursor) Next() error {
c.i++
return nil
}
func (c *vtabUpdateCursor) EOF() bool {
return c.i >= len(c.t.data)
}
func (c *vtabUpdateCursor) Rowid() (int64, error) {
return int64(c.i), nil
}
func (c *vtabUpdateCursor) Close() error {
return nil
}

View File

@ -1,3 +1,4 @@
#ifndef USE_LIBSQLITE3
/* /*
** 2006 June 7 ** 2006 June 7
** **
@ -15,11 +16,9 @@
** as extensions by SQLite should #include this file instead of ** as extensions by SQLite should #include this file instead of
** sqlite3.h. ** sqlite3.h.
*/ */
#ifndef _SQLITE3EXT_H_ #ifndef SQLITE3EXT_H
#define _SQLITE3EXT_H_ #define SQLITE3EXT_H
#include "sqlite3-binding.h" #include "sqlite3.h"
typedef struct sqlite3_api_routines sqlite3_api_routines;
/* /*
** The following structure holds pointers to all of the SQLite API ** The following structure holds pointers to all of the SQLite API
@ -279,8 +278,23 @@ struct sqlite3_api_routines {
int (*status64)(int,sqlite3_int64*,sqlite3_int64*,int); int (*status64)(int,sqlite3_int64*,sqlite3_int64*,int);
int (*strlike)(const char*,const char*,unsigned int); int (*strlike)(const char*,const char*,unsigned int);
int (*db_cacheflush)(sqlite3*); int (*db_cacheflush)(sqlite3*);
/* Version 3.12.0 and later */
int (*system_errno)(sqlite3*);
/* Version 3.14.0 and later */
int (*trace_v2)(sqlite3*,unsigned,int(*)(unsigned,void*,void*,void*),void*);
char *(*expanded_sql)(sqlite3_stmt*);
}; };
/*
** This is the function signature used for all extension entry points. It
** is also defined in the file "loadext.c".
*/
typedef int (*sqlite3_loadext_entry)(
sqlite3 *db, /* Handle to the database. */
char **pzErrMsg, /* Used to set error string on failure. */
const sqlite3_api_routines *pThunk /* Extension API function pointers. */
);
/* /*
** The following macros redefine the API routines so that they are ** The following macros redefine the API routines so that they are
** redirected through the global sqlite3_api structure. ** redirected through the global sqlite3_api structure.
@ -522,6 +536,11 @@ struct sqlite3_api_routines {
#define sqlite3_status64 sqlite3_api->status64 #define sqlite3_status64 sqlite3_api->status64
#define sqlite3_strlike sqlite3_api->strlike #define sqlite3_strlike sqlite3_api->strlike
#define sqlite3_db_cacheflush sqlite3_api->db_cacheflush #define sqlite3_db_cacheflush sqlite3_api->db_cacheflush
/* Version 3.12.0 and later */
#define sqlite3_system_errno sqlite3_api->system_errno
/* Version 3.14.0 and later */
#define sqlite3_trace_v2 sqlite3_api->trace_v2
#define sqlite3_expanded_sql sqlite3_api->expanded_sql
#endif /* !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION) */ #endif /* !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION) */
#if !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION) #if !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION)
@ -539,4 +558,8 @@ struct sqlite3_api_routines {
# define SQLITE_EXTENSION_INIT3 /*no-op*/ # define SQLITE_EXTENSION_INIT3 /*no-op*/
#endif #endif
#endif /* _SQLITE3EXT_H_ */ #endif /* SQLITE3EXT_H */
#else // USE_LIBSQLITE3
// If users really want to link against the system sqlite3 we
// need to make this file a noop.
#endif

99
tool/upgrade.go Normal file
View File

@ -0,0 +1,99 @@
// +build ignore
package main
import (
"archive/zip"
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path"
"path/filepath"
"strings"
"github.com/PuerkitoBio/goquery"
)
func main() {
site := "https://www.sqlite.org/download.html"
fmt.Printf("scraping %v\n", site)
doc, err := goquery.NewDocument(site)
if err != nil {
log.Fatal(err)
}
var url string
doc.Find("a").Each(func(_ int, s *goquery.Selection) {
if url == "" && strings.HasPrefix(s.Text(), "sqlite-amalgamation-") {
url = "https://www.sqlite.org/2017/" + s.Text()
}
})
if url == "" {
return
}
fmt.Printf("downloading %v\n", url)
resp, err := http.Get(url)
if err != nil {
log.Fatal(err)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
resp.Body.Close()
log.Fatal(err)
}
fmt.Printf("extracting %v\n", path.Base(url))
r, err := zip.NewReader(bytes.NewReader(b), resp.ContentLength)
if err != nil {
resp.Body.Close()
log.Fatal(err)
}
resp.Body.Close()
for _, zf := range r.File {
var f *os.File
switch path.Base(zf.Name) {
case "sqlite3.c":
f, err = os.Create("sqlite3-binding.c")
case "sqlite3.h":
f, err = os.Create("sqlite3-binding.h")
case "sqlite3ext.h":
f, err = os.Create("sqlite3ext.h")
default:
continue
}
if err != nil {
log.Fatal(err)
}
zr, err := zf.Open()
if err != nil {
log.Fatal(err)
}
_, err = io.WriteString(f, "#ifndef USE_LIBSQLITE3\n")
if err != nil {
zr.Close()
f.Close()
log.Fatal(err)
}
_, err = io.Copy(f, zr)
if err != nil {
zr.Close()
f.Close()
log.Fatal(err)
}
_, err = io.WriteString(f, "#else // USE_LIBSQLITE3\n // If users really want to link against the system sqlite3 we\n// need to make this file a noop.\n #endif")
if err != nil {
zr.Close()
f.Close()
log.Fatal(err)
}
zr.Close()
f.Close()
fmt.Printf("extracted %v\n", filepath.Base(f.Name()))
}
}