Merge pull request #1 from mattn/master

Bring master up-to-date
This commit is contained in:
Philip O'Toole 2016-02-23 01:18:14 -05:00
commit 3e97a4ca68
29 changed files with 57596 additions and 16192 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
*.db
*.exe
*.dll

View File

@ -4,6 +4,6 @@ go:
before_install: before_install:
- go get github.com/axw/gocov/gocov - go get github.com/axw/gocov/gocov
- go get github.com/mattn/goveralls - go get github.com/mattn/goveralls
- go get code.google.com/p/go.tools/cmd/cover - go get golang.org/x/tools/cmd/cover
script: script:
- $HOME/gopath/bin/goveralls -repotoken 3qJVUE0iQwqnCbmNcDsjYu1nh4J4KIFXx - $HOME/gopath/bin/goveralls -repotoken 3qJVUE0iQwqnCbmNcDsjYu1nh4J4KIFXx

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Yasuhiro Matsumoto
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,8 +1,9 @@
go-sqlite3 go-sqlite3
========== ==========
[![Build Status](https://travis-ci.org/mattn/go-sqlite3.png?branch=master)](https://travis-ci.org/mattn/go-sqlite3) [![Build Status](https://travis-ci.org/mattn/go-sqlite3.svg?branch=master)](https://travis-ci.org/mattn/go-sqlite3)
[![Coverage Status](https://coveralls.io/repos/mattn/go-sqlite3/badge.png?branch=master)](https://coveralls.io/r/mattn/go-sqlite3?branch=master) [![Coverage Status](https://coveralls.io/repos/mattn/go-sqlite3/badge.svg?branch=master)](https://coveralls.io/r/mattn/go-sqlite3?branch=master)
[![GoDoc](https://godoc.org/github.com/mattn/go-sqlite3?status.svg)](http://godoc.org/github.com/mattn/go-sqlite3)
Description Description
----------- -----------
@ -16,6 +17,10 @@ This package can be installed with the go get command:
go get github.com/mattn/go-sqlite3 go get github.com/mattn/go-sqlite3
_go-sqlite3_ is *cgo* package.
If you want to build your app using go-sqlite3, you need gcc.
However, if you install _go-sqlite3_ with `go install github.com/mattn/go-sqlite3`, you don't need gcc to build your app anymore.
Documentation Documentation
------------- -------------
@ -26,6 +31,14 @@ Examples can be found under the `./_example` directory
FAQ FAQ
--- ---
* Want to build go-sqlite3 with libsqlite3 on my linux.
Use `go build --tags "libsqlite3 linux"`
* Want to build go-sqlite3 with icu extension.
Use `go build --tags "icu"`
* Can't build go-sqlite3 on windows 64bit. * Can't build go-sqlite3 on windows 64bit.
> Probably, you are using go 1.0, go1.0 has a problem when it comes to compiling/linking on windows 64bit. > Probably, you are using go 1.0, go1.0 has a problem when it comes to compiling/linking on windows 64bit.
@ -36,7 +49,27 @@ FAQ
> You can pass some arguments into the connection string, for example, a URI. > You can pass some arguments into the connection string, for example, a URI.
> See: https://github.com/mattn/go-sqlite3/issues/39 > See: https://github.com/mattn/go-sqlite3/issues/39
* Do you want cross compiling? mingw on Linux or Mac?
> See: https://github.com/mattn/go-sqlite3/issues/106
> See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html
* Want to get time.Time with current locale
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
License License
------- -------
MIT: http://mattn.mit-license.org/2012 MIT: http://mattn.mit-license.org/2012
sqlite3-binding.c, sqlite3-binding.h, sqlite3ext.h
The -binding suffix was added to avoid build failures under gccgo.
In this repository, those files are amalgamation code that copied from SQLite3. The license of those codes are depend on the license of SQLite3.
Author
------
Yasuhiro Matsumoto (a.k.a mattn)

View File

@ -0,0 +1,133 @@
package main
import (
"database/sql"
"fmt"
"log"
"math"
"math/rand"
sqlite "github.com/mattn/go-sqlite3"
)
// Computes x^y
func pow(x, y int64) int64 {
return int64(math.Pow(float64(x), float64(y)))
}
// Computes the bitwise exclusive-or of all its arguments
func xor(xs ...int64) int64 {
var ret int64
for _, x := range xs {
ret ^= x
}
return ret
}
// Returns a random number. It's actually deterministic here because
// we don't seed the RNG, but it's an example of a non-pure function
// from SQLite's POV.
func getrand() int64 {
return rand.Int63()
}
// Computes the standard deviation of a GROUPed BY set of values
type stddev struct {
xs []int64
// Running average calculation
sum int64
n int64
}
func newStddev() *stddev { return &stddev{} }
func (s *stddev) Step(x int64) {
s.xs = append(s.xs, x)
s.sum += x
s.n++
}
func (s *stddev) Done() float64 {
mean := float64(s.sum) / float64(s.n)
var sqDiff []float64
for _, x := range s.xs {
sqDiff = append(sqDiff, math.Pow(float64(x)-mean, 2))
}
var dev float64
for _, x := range sqDiff {
dev += x
}
dev /= float64(len(sqDiff))
return math.Sqrt(dev)
}
func main() {
sql.Register("sqlite3_custom", &sqlite.SQLiteDriver{
ConnectHook: func(conn *sqlite.SQLiteConn) error {
if err := conn.RegisterFunc("pow", pow, true); err != nil {
return err
}
if err := conn.RegisterFunc("xor", xor, true); err != nil {
return err
}
if err := conn.RegisterFunc("rand", getrand, false); err != nil {
return err
}
if err := conn.RegisterAggregator("stddev", newStddev, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_custom", ":memory:")
if err != nil {
log.Fatal("Failed to open database:", err)
}
defer db.Close()
var i int64
err = db.QueryRow("SELECT pow(2,3)").Scan(&i)
if err != nil {
log.Fatal("POW query error:", err)
}
fmt.Println("pow(2,3) =", i) // 8
err = db.QueryRow("SELECT xor(1,2,3,4,5,6)").Scan(&i)
if err != nil {
log.Fatal("XOR query error:", err)
}
fmt.Println("xor(1,2,3,4,5) =", i) // 7
err = db.QueryRow("SELECT rand()").Scan(&i)
if err != nil {
log.Fatal("RAND query error:", err)
}
fmt.Println("rand() =", i) // pseudorandom
_, err = db.Exec("create table foo (department integer, profits integer)")
if err != nil {
log.Fatal("Failed to create table:", err)
}
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115)")
if err != nil {
log.Fatal("Failed to insert records:", err)
}
rows, err := db.Query("select department, stddev(profits) from foo group by department")
if err != nil {
log.Fatal("STDDEV query error:", err)
}
defer rows.Close()
for rows.Next() {
var dept int64
var dev float64
if err := rows.Scan(&dept, &dev); err != nil {
log.Fatal(err)
}
fmt.Printf("dept=%d stddev=%f\n", dept, dev)
}
if err := rows.Err(); err != nil {
log.Fatal(err)
}
}

View File

@ -10,12 +10,12 @@ import (
func main() { func main() {
sqlite3conn := []*sqlite3.SQLiteConn{} sqlite3conn := []*sqlite3.SQLiteConn{}
sql.Register("sqlite3_with_hook_example", sql.Register("sqlite3_with_hook_example",
&sqlite3.SQLiteDriver{ &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error { ConnectHook: func(conn *sqlite3.SQLiteConn) error {
sqlite3conn = append(sqlite3conn, conn) sqlite3conn = append(sqlite3conn, conn)
return nil return nil
}, },
}) })
os.Remove("./foo.db") os.Remove("./foo.db")
os.Remove("./bar.db") os.Remove("./bar.db")
@ -54,7 +54,7 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
bk.Step(-1) _, err = bk.Step(-1)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -1,6 +1,6 @@
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <sqlite3.h> #include <sqlite3-binding.h>
#include <sqlite3ext.h> #include <sqlite3ext.h>
#include <curl/curl.h> #include <curl/curl.h>
#include "picojson.h" #include "picojson.h"

View File

@ -55,7 +55,6 @@ func main() {
rows.Scan(&id, &name) rows.Scan(&id, &name)
fmt.Println(id, name) fmt.Println(id, name)
} }
rows.Close()
stmt, err = db.Prepare("select name from foo where id = ?") stmt, err = db.Prepare("select name from foo where id = ?")
if err != nil { if err != nil {

View File

@ -1,26 +1,34 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3 package sqlite3
/* /*
#include <sqlite3.h> #include <sqlite3-binding.h>
#include <stdlib.h> #include <stdlib.h>
*/ */
import "C" import "C"
import ( import (
"runtime"
"unsafe" "unsafe"
) )
type Backup struct { type SQLiteBackup struct {
b *C.sqlite3_backup b *C.sqlite3_backup
} }
func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*Backup, error) { func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*SQLiteBackup, error) {
destptr := C.CString(dest) destptr := C.CString(dest)
defer C.free(unsafe.Pointer(destptr)) defer C.free(unsafe.Pointer(destptr))
srcptr := C.CString(src) srcptr := C.CString(src)
defer C.free(unsafe.Pointer(srcptr)) defer C.free(unsafe.Pointer(srcptr))
if b := C.sqlite3_backup_init(c.db, destptr, conn.db, srcptr); b != nil { if b := C.sqlite3_backup_init(c.db, destptr, conn.db, srcptr); b != nil {
return &Backup{b: b}, nil bb := &SQLiteBackup{b: b}
runtime.SetFinalizer(bb, (*SQLiteBackup).Finish)
return bb, nil
} }
return nil, c.lastError() return nil, c.lastError()
} }
@ -29,28 +37,34 @@ func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*Backup,
// This function returns a boolean indicating if the backup is done and // This function returns a boolean indicating if the backup is done and
// an error signalling any other error. Done is returned if the underlying C // an error signalling any other error. Done is returned if the underlying C
// function returns SQLITE_DONE (Code 101) // function returns SQLITE_DONE (Code 101)
func (b *Backup) Step(p int) (bool, error) { func (b *SQLiteBackup) Step(p int) (bool, error) {
ret := C.sqlite3_backup_step(b.b, C.int(p)) ret := C.sqlite3_backup_step(b.b, C.int(p))
if ret == 101 { if ret == C.SQLITE_DONE {
return true, nil return true, nil
} else if ret != 0 { } else if ret != 0 && ret != C.SQLITE_LOCKED && ret != C.SQLITE_BUSY {
return false, Error{Code: ErrNo(ret)} return false, Error{Code: ErrNo(ret)}
} }
return false, nil return false, nil
} }
func (b *Backup) Remaining() int { func (b *SQLiteBackup) Remaining() int {
return int(C.sqlite3_backup_remaining(b.b)) return int(C.sqlite3_backup_remaining(b.b))
} }
func (b *Backup) PageCount() int { func (b *SQLiteBackup) PageCount() int {
return int(C.sqlite3_backup_pagecount(b.b)) return int(C.sqlite3_backup_pagecount(b.b))
} }
func (b *Backup) Finish() error { func (b *SQLiteBackup) Finish() error {
return b.Close()
}
func (b *SQLiteBackup) Close() error {
ret := C.sqlite3_backup_finish(b.b) ret := C.sqlite3_backup_finish(b.b)
if ret != 0 { if ret != 0 {
return Error{Code: ErrNo(ret)} return Error{Code: ErrNo(ret)}
} }
b.b = nil
runtime.SetFinalizer(b, nil)
return nil return nil
} }

336
callback.go Normal file
View File

@ -0,0 +1,336 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
// You can't export a Go function to C and have definitions in the C
// preamble in the same file, so we have to have callbackTrampoline in
// its own file. Because we need a separate file anyway, the support
// code for SQLite custom functions is in here.
/*
#include <sqlite3-binding.h>
#include <stdlib.h>
void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
*/
import "C"
import (
"errors"
"fmt"
"math"
"reflect"
"sync"
"unsafe"
)
//export callbackTrampoline
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo)
fi.Call(ctx, args)
}
//export stepTrampoline
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
ai.Step(ctx, args)
}
//export doneTrampoline
func doneTrampoline(ctx *C.sqlite3_context) {
handle := uintptr(C.sqlite3_user_data(ctx))
ai := lookupHandle(handle).(*aggInfo)
ai.Done(ctx)
}
// Use handles to avoid passing Go pointers to C.
type handleVal struct {
db *SQLiteConn
val interface{}
}
var handleLock sync.Mutex
var handleVals = make(map[uintptr]handleVal)
var handleIndex uintptr = 100
func newHandle(db *SQLiteConn, v interface{}) uintptr {
handleLock.Lock()
defer handleLock.Unlock()
i := handleIndex
handleIndex++
handleVals[i] = handleVal{db, v}
return i
}
func lookupHandle(handle uintptr) interface{} {
handleLock.Lock()
defer handleLock.Unlock()
r, ok := handleVals[handle]
if !ok {
if handle >= 100 && handle < handleIndex {
panic("deleted handle")
} else {
panic("invalid handle")
}
}
return r.val
}
func deleteHandles(db *SQLiteConn) {
handleLock.Lock()
defer handleLock.Unlock()
for handle, val := range handleVals {
if val.db == db {
delete(handleVals, handle)
}
}
}
// This is only here so that tests can refer to it.
type callbackArgRaw C.sqlite3_value
type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
type callbackArgCast struct {
f callbackArgConverter
typ reflect.Type
}
func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
val, err := c.f(v)
if err != nil {
return reflect.Value{}, err
}
if !val.Type().ConvertibleTo(c.typ) {
return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
}
return val.Convert(c.typ), nil
}
func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
}
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
}
func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
}
i := int64(C.sqlite3_value_int64(v))
val := false
if i != 0 {
val = true
}
return reflect.ValueOf(val), nil
}
func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
}
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
}
func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_BLOB:
l := C.sqlite3_value_bytes(v)
p := C.sqlite3_value_blob(v)
return reflect.ValueOf(C.GoBytes(p, l)), nil
case C.SQLITE_TEXT:
l := C.sqlite3_value_bytes(v)
c := unsafe.Pointer(C.sqlite3_value_text(v))
return reflect.ValueOf(C.GoBytes(c, l)), nil
default:
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
}
}
func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_BLOB:
l := C.sqlite3_value_bytes(v)
p := (*C.char)(C.sqlite3_value_blob(v))
return reflect.ValueOf(C.GoStringN(p, l)), nil
case C.SQLITE_TEXT:
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
return reflect.ValueOf(C.GoString(c)), nil
default:
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
}
}
func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_INTEGER:
return callbackArgInt64(v)
case C.SQLITE_FLOAT:
return callbackArgFloat64(v)
case C.SQLITE_TEXT:
return callbackArgString(v)
case C.SQLITE_BLOB:
return callbackArgBytes(v)
case C.SQLITE_NULL:
// Interpret NULL as a nil byte slice.
var ret []byte
return reflect.ValueOf(ret), nil
default:
panic("unreachable")
}
}
func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
switch typ.Kind() {
case reflect.Interface:
if typ.NumMethod() != 0 {
return nil, errors.New("the only supported interface type is interface{}")
}
return callbackArgGeneric, nil
case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 {
return nil, errors.New("the only supported slice type is []byte")
}
return callbackArgBytes, nil
case reflect.String:
return callbackArgString, nil
case reflect.Bool:
return callbackArgBool, nil
case reflect.Int64:
return callbackArgInt64, nil
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
c := callbackArgCast{callbackArgInt64, typ}
return c.Run, nil
case reflect.Float64:
return callbackArgFloat64, nil
case reflect.Float32:
c := callbackArgCast{callbackArgFloat64, typ}
return c.Run, nil
default:
return nil, fmt.Errorf("don't know how to convert to %s", typ)
}
}
func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
var args []reflect.Value
if len(argv) < len(converters) {
return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
}
for i, arg := range argv[:len(converters)] {
v, err := converters[i](arg)
if err != nil {
return nil, err
}
args = append(args, v)
}
if variadic != nil {
for _, arg := range argv[len(converters):] {
v, err := variadic(arg)
if err != nil {
return nil, err
}
args = append(args, v)
}
}
return args, nil
}
type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
switch v.Type().Kind() {
case reflect.Int64:
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
v = v.Convert(reflect.TypeOf(int64(0)))
case reflect.Bool:
b := v.Interface().(bool)
if b {
v = reflect.ValueOf(int64(1))
} else {
v = reflect.ValueOf(int64(0))
}
default:
return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
}
C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
return nil
}
func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
switch v.Type().Kind() {
case reflect.Float64:
case reflect.Float32:
v = v.Convert(reflect.TypeOf(float64(0)))
default:
return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
}
C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
return nil
}
func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("cannot convert %s to BLOB", v.Type())
}
i := v.Interface()
if i == nil || len(i.([]byte)) == 0 {
C.sqlite3_result_null(ctx)
} else {
bs := i.([]byte)
C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
}
return nil
}
func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
if v.Type().Kind() != reflect.String {
return fmt.Errorf("cannot convert %s to TEXT", v.Type())
}
C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
return nil
}
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
switch typ.Kind() {
case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 {
return nil, errors.New("the only supported slice type is []byte")
}
return callbackRetBlob, nil
case reflect.String:
return callbackRetText, nil
case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
return callbackRetInteger, nil
case reflect.Float32, reflect.Float64:
return callbackRetFloat, nil
default:
return nil, fmt.Errorf("don't know how to convert to %s", typ)
}
}
func callbackError(ctx *C.sqlite3_context, err error) {
cstr := C.CString(err.Error())
defer C.free(unsafe.Pointer(cstr))
C.sqlite3_result_error(ctx, cstr, -1)
}
// Test support code. Tests are not allowed to import "C", so we can't
// declare any functions that use C.sqlite3_value.
func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
return func(*C.sqlite3_value) (reflect.Value, error) {
return v, err
}
}

97
callback_test.go Normal file
View File

@ -0,0 +1,97 @@
package sqlite3
import (
"errors"
"math"
"reflect"
"testing"
)
func TestCallbackArgCast(t *testing.T) {
intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil)
floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil)
errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test"))
tests := []struct {
f callbackArgConverter
o reflect.Value
}{
{intConv, reflect.ValueOf(int8(-1))},
{intConv, reflect.ValueOf(int16(-1))},
{intConv, reflect.ValueOf(int32(-1))},
{intConv, reflect.ValueOf(uint8(math.MaxUint8))},
{intConv, reflect.ValueOf(uint16(math.MaxUint16))},
{intConv, reflect.ValueOf(uint32(math.MaxUint32))},
// Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1
{intConv, reflect.ValueOf(uint64(math.MaxInt64))},
{floatConv, reflect.ValueOf(float32(math.Inf(1)))},
}
for _, test := range tests {
conv := callbackArgCast{test.f, test.o.Type()}
val, err := conv.Run(nil)
if err != nil {
t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err)
} else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) {
t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface())
}
}
conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))}
_, err := conv.Run(nil)
if err == nil {
t.Errorf("Expected error during callbackArgCast, but got none")
}
}
func TestCallbackConverters(t *testing.T) {
tests := []struct {
v interface{}
err bool
}{
// Unfortunately, we can't tell which converter was returned,
// but we can at least check which types can be converted.
{[]byte{0}, false},
{"text", false},
{true, false},
{int8(0), false},
{int16(0), false},
{int32(0), false},
{int64(0), false},
{uint8(0), false},
{uint16(0), false},
{uint32(0), false},
{uint64(0), false},
{int(0), false},
{uint(0), false},
{float64(0), false},
{float32(0), false},
{func() {}, true},
{complex64(complex(0, 0)), true},
{complex128(complex(0, 0)), true},
{struct{}{}, true},
{map[string]string{}, true},
{[]string{}, true},
{(*int8)(nil), true},
{make(chan int), true},
}
for _, test := range tests {
_, err := callbackArg(reflect.TypeOf(test.v))
if test.err && err == nil {
t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
} else if !test.err && err != nil {
t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
}
}
for _, test := range tests {
_, err := callbackRet(reflect.TypeOf(test.v))
if test.err && err == nil {
t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
} else if !test.err && err != nil {
t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,7 @@
*/ */
#ifndef _SQLITE3EXT_H_ #ifndef _SQLITE3EXT_H_
#define _SQLITE3EXT_H_ #define _SQLITE3EXT_H_
#include "sqlite3.h" #include "sqlite3-binding.h"
typedef struct sqlite3_api_routines sqlite3_api_routines; typedef struct sqlite3_api_routines sqlite3_api_routines;
@ -28,7 +28,7 @@ typedef struct sqlite3_api_routines sqlite3_api_routines;
** WARNING: In order to maintain backwards compatibility, add new ** WARNING: In order to maintain backwards compatibility, add new
** interfaces to the end of this structure only. If you insert new ** interfaces to the end of this structure only. If you insert new
** interfaces in the middle of this structure, then older different ** interfaces in the middle of this structure, then older different
** versions of SQLite will not be able to load each others' shared ** versions of SQLite will not be able to load each other's shared
** libraries! ** libraries!
*/ */
struct sqlite3_api_routines { struct sqlite3_api_routines {
@ -250,11 +250,40 @@ struct sqlite3_api_routines {
const char *(*uri_parameter)(const char*,const char*); const char *(*uri_parameter)(const char*,const char*);
char *(*vsnprintf)(int,char*,const char*,va_list); char *(*vsnprintf)(int,char*,const char*,va_list);
int (*wal_checkpoint_v2)(sqlite3*,const char*,int,int*,int*); int (*wal_checkpoint_v2)(sqlite3*,const char*,int,int*,int*);
/* Version 3.8.7 and later */
int (*auto_extension)(void(*)(void));
int (*bind_blob64)(sqlite3_stmt*,int,const void*,sqlite3_uint64,
void(*)(void*));
int (*bind_text64)(sqlite3_stmt*,int,const char*,sqlite3_uint64,
void(*)(void*),unsigned char);
int (*cancel_auto_extension)(void(*)(void));
int (*load_extension)(sqlite3*,const char*,const char*,char**);
void *(*malloc64)(sqlite3_uint64);
sqlite3_uint64 (*msize)(void*);
void *(*realloc64)(void*,sqlite3_uint64);
void (*reset_auto_extension)(void);
void (*result_blob64)(sqlite3_context*,const void*,sqlite3_uint64,
void(*)(void*));
void (*result_text64)(sqlite3_context*,const char*,sqlite3_uint64,
void(*)(void*), unsigned char);
int (*strglob)(const char*,const char*);
/* Version 3.8.11 and later */
sqlite3_value *(*value_dup)(const sqlite3_value*);
void (*value_free)(sqlite3_value*);
int (*result_zeroblob64)(sqlite3_context*,sqlite3_uint64);
int (*bind_zeroblob64)(sqlite3_stmt*, int, sqlite3_uint64);
/* Version 3.9.0 and later */
unsigned int (*value_subtype)(sqlite3_value*);
void (*result_subtype)(sqlite3_context*,unsigned int);
/* Version 3.10.0 and later */
int (*status64)(int,sqlite3_int64*,sqlite3_int64*,int);
int (*strlike)(const char*,const char*,unsigned int);
int (*db_cacheflush)(sqlite3*);
}; };
/* /*
** The following macros redefine the API routines so that they are ** The following macros redefine the API routines so that they are
** redirected throught the global sqlite3_api structure. ** redirected through the global sqlite3_api structure.
** **
** This header file is also used by the loadext.c source file ** This header file is also used by the loadext.c source file
** (part of the main SQLite library - not an extension) so that ** (part of the main SQLite library - not an extension) so that
@ -263,7 +292,7 @@ struct sqlite3_api_routines {
** the API. So the redefinition macros are only valid if the ** the API. So the redefinition macros are only valid if the
** SQLITE_CORE macros is undefined. ** SQLITE_CORE macros is undefined.
*/ */
#ifndef SQLITE_CORE #if !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION)
#define sqlite3_aggregate_context sqlite3_api->aggregate_context #define sqlite3_aggregate_context sqlite3_api->aggregate_context
#ifndef SQLITE_OMIT_DEPRECATED #ifndef SQLITE_OMIT_DEPRECATED
#define sqlite3_aggregate_count sqlite3_api->aggregate_count #define sqlite3_aggregate_count sqlite3_api->aggregate_count
@ -390,6 +419,7 @@ struct sqlite3_api_routines {
#define sqlite3_value_text16le sqlite3_api->value_text16le #define sqlite3_value_text16le sqlite3_api->value_text16le
#define sqlite3_value_type sqlite3_api->value_type #define sqlite3_value_type sqlite3_api->value_type
#define sqlite3_vmprintf sqlite3_api->vmprintf #define sqlite3_vmprintf sqlite3_api->vmprintf
#define sqlite3_vsnprintf sqlite3_api->vsnprintf
#define sqlite3_overload_function sqlite3_api->overload_function #define sqlite3_overload_function sqlite3_api->overload_function
#define sqlite3_prepare_v2 sqlite3_api->prepare_v2 #define sqlite3_prepare_v2 sqlite3_api->prepare_v2
#define sqlite3_prepare16_v2 sqlite3_api->prepare16_v2 #define sqlite3_prepare16_v2 sqlite3_api->prepare16_v2
@ -467,9 +497,34 @@ struct sqlite3_api_routines {
#define sqlite3_uri_parameter sqlite3_api->uri_parameter #define sqlite3_uri_parameter sqlite3_api->uri_parameter
#define sqlite3_uri_vsnprintf sqlite3_api->vsnprintf #define sqlite3_uri_vsnprintf sqlite3_api->vsnprintf
#define sqlite3_wal_checkpoint_v2 sqlite3_api->wal_checkpoint_v2 #define sqlite3_wal_checkpoint_v2 sqlite3_api->wal_checkpoint_v2
#endif /* SQLITE_CORE */ /* Version 3.8.7 and later */
#define sqlite3_auto_extension sqlite3_api->auto_extension
#define sqlite3_bind_blob64 sqlite3_api->bind_blob64
#define sqlite3_bind_text64 sqlite3_api->bind_text64
#define sqlite3_cancel_auto_extension sqlite3_api->cancel_auto_extension
#define sqlite3_load_extension sqlite3_api->load_extension
#define sqlite3_malloc64 sqlite3_api->malloc64
#define sqlite3_msize sqlite3_api->msize
#define sqlite3_realloc64 sqlite3_api->realloc64
#define sqlite3_reset_auto_extension sqlite3_api->reset_auto_extension
#define sqlite3_result_blob64 sqlite3_api->result_blob64
#define sqlite3_result_text64 sqlite3_api->result_text64
#define sqlite3_strglob sqlite3_api->strglob
/* Version 3.8.11 and later */
#define sqlite3_value_dup sqlite3_api->value_dup
#define sqlite3_value_free sqlite3_api->value_free
#define sqlite3_result_zeroblob64 sqlite3_api->result_zeroblob64
#define sqlite3_bind_zeroblob64 sqlite3_api->bind_zeroblob64
/* Version 3.9.0 and later */
#define sqlite3_value_subtype sqlite3_api->value_subtype
#define sqlite3_result_subtype sqlite3_api->result_subtype
/* Version 3.10.0 and later */
#define sqlite3_status64 sqlite3_api->status64
#define sqlite3_strlike sqlite3_api->strlike
#define sqlite3_db_cacheflush sqlite3_api->db_cacheflush
#endif /* !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION) */
#ifndef SQLITE_CORE #if !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION)
/* This case when the file really is being compiled as a loadable /* This case when the file really is being compiled as a loadable
** extension */ ** extension */
# define SQLITE_EXTENSION_INIT1 const sqlite3_api_routines *sqlite3_api=0; # define SQLITE_EXTENSION_INIT1 const sqlite3_api_routines *sqlite3_api=0;

17
doc.go
View File

@ -91,5 +91,22 @@ you need to hook ConnectHook and get the SQLiteConn.
}, },
}) })
Go SQlite3 Extensions
If you want to register Go functions as SQLite extension functions,
call RegisterFunction from ConnectHook.
regex = func(re, s string) (bool, error) {
return regexp.MatchString(re, s)
}
sql.Register("sqlite3_with_go_func",
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
return conn.RegisterFunc("regex", regex, true)
},
})
See the documentation of RegisterFunc for more details.
*/ */
package sqlite3 package sqlite3

View File

@ -1,3 +1,8 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3 package sqlite3
import "C" import "C"

View File

@ -1,3 +1,8 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3 package sqlite3
import ( import (
@ -226,6 +231,12 @@ func TestExtendedErrorCodes_Unique(t *testing.T) {
t.Errorf("Wrong extended error code: %d != %d", t.Errorf("Wrong extended error code: %d != %d",
sqliteErr.ExtendedCode, ErrConstraintUnique) sqliteErr.ExtendedCode, ErrConstraintUnique)
} }
extended := sqliteErr.Code.Extend(3).Error()
expected := "constraint failed"
if extended != expected {
t.Errorf("Wrong basic error code: %q != %q",
extended, expected)
}
} }
} }

4
sqlite3-binding.c Normal file
View File

@ -0,0 +1,4 @@
#ifndef USE_LIBSQLITE3
# include "code/sqlite3-binding.c"
#endif

5
sqlite3-binding.h Normal file
View File

@ -0,0 +1,5 @@
#ifndef USE_LIBSQLITE3
#include "code/sqlite3-binding.h"
#else
#include <sqlite3.h>
#endif

View File

@ -1,7 +1,15 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3 package sqlite3
/* /*
#include <sqlite3.h> #cgo CFLAGS: -std=gnu99
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
#include <sqlite3-binding.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
@ -39,23 +47,62 @@ _sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
#include <stdio.h> #include <stdio.h>
#include <stdint.h> #include <stdint.h>
static long static int
_sqlite3_last_insert_rowid(sqlite3* db) { _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* changes)
return (long) sqlite3_last_insert_rowid(db); {
int rv = sqlite3_exec(db, pcmd, 0, 0, 0);
*rowid = (long long) sqlite3_last_insert_rowid(db);
*changes = (long long) sqlite3_changes(db);
return rv;
} }
static long static int
_sqlite3_changes(sqlite3* db) { _sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes)
return (long) sqlite3_changes(db); {
int rv = sqlite3_step(stmt);
sqlite3* db = sqlite3_db_handle(stmt);
*rowid = (long long) sqlite3_last_insert_rowid(db);
*changes = (long long) sqlite3_changes(db);
return rv;
} }
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
sqlite3_result_text(ctx, s, -1, &free);
}
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) {
sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT);
}
int _sqlite3_create_function(
sqlite3 *db,
const char *zFunctionName,
int nArg,
int eTextRep,
uintptr_t pApp,
void (*xFunc)(sqlite3_context*,int,sqlite3_value**),
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
void (*xFinal)(sqlite3_context*)
) {
return sqlite3_create_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xFunc, xStep, xFinal);
}
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void doneTrampoline(sqlite3_context*);
*/ */
import "C" import "C"
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt"
"io" "io"
"net/url"
"reflect"
"runtime"
"strconv"
"strings" "strings"
"time" "time"
"unsafe" "unsafe"
@ -66,6 +113,10 @@ import (
// into the database. When parsing a string from a timestamp or // into the database. When parsing a string from a timestamp or
// datetime column, the formats are tried in order. // datetime column, the formats are tried in order.
var SQLiteTimestampFormats = []string{ var SQLiteTimestampFormats = []string{
// By default, store timestamps with whatever timezone they come with.
// When parsed, they will be returned with the same timezone.
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02T15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999", "2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999", "2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05", "2006-01-02 15:04:05",
@ -73,13 +124,20 @@ var SQLiteTimestampFormats = []string{
"2006-01-02 15:04", "2006-01-02 15:04",
"2006-01-02T15:04", "2006-01-02T15:04",
"2006-01-02", "2006-01-02",
"2006-01-02 15:04:05-07:00",
} }
func init() { func init() {
sql.Register("sqlite3", &SQLiteDriver{}) sql.Register("sqlite3", &SQLiteDriver{})
} }
// Return SQLite library Version information.
func Version() (libVersion string, libVersionNumber int, sourceId string) {
libVersion = C.GoString(C.sqlite3_libversion())
libVersionNumber = int(C.sqlite3_libversion_number())
sourceId = C.GoString(C.sqlite3_sourceid())
return libVersion, libVersionNumber, sourceId
}
// Driver struct. // Driver struct.
type SQLiteDriver struct { type SQLiteDriver struct {
Extensions []string Extensions []string
@ -88,7 +146,11 @@ type SQLiteDriver struct {
// Conn struct. // Conn struct.
type SQLiteConn struct { type SQLiteConn struct {
db *C.sqlite3 db *C.sqlite3
loc *time.Location
txlock string
funcs []*functionInfo
aggregators []*aggInfo
} }
// Tx struct. // Tx struct.
@ -100,6 +162,8 @@ type SQLiteTx struct {
type SQLiteStmt struct { type SQLiteStmt struct {
c *SQLiteConn c *SQLiteConn
s *C.sqlite3_stmt s *C.sqlite3_stmt
nv int
nn []string
t string t string
closed bool closed bool
cls bool cls bool
@ -120,6 +184,107 @@ type SQLiteRows struct {
cls bool cls bool
} }
type functionInfo struct {
f reflect.Value
argConverters []callbackArgConverter
variadicConverter callbackArgConverter
retConverter callbackRetConverter
}
func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
args, err := callbackConvertArgs(argv, fi.argConverters, fi.variadicConverter)
if err != nil {
callbackError(ctx, err)
return
}
ret := fi.f.Call(args)
if len(ret) == 2 && ret[1].Interface() != nil {
callbackError(ctx, ret[1].Interface().(error))
return
}
err = fi.retConverter(ctx, ret[0])
if err != nil {
callbackError(ctx, err)
return
}
}
type aggInfo struct {
constructor reflect.Value
// Active aggregator objects for aggregations in flight. The
// aggregators are indexed by a counter stored in the aggregation
// user data space provided by sqlite.
active map[int64]reflect.Value
next int64
stepArgConverters []callbackArgConverter
stepVariadicConverter callbackArgConverter
doneRetConverter callbackRetConverter
}
func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) {
aggIdx := (*int64)(C.sqlite3_aggregate_context(ctx, C.int(8)))
if *aggIdx == 0 {
*aggIdx = ai.next
ret := ai.constructor.Call(nil)
if len(ret) == 2 && ret[1].Interface() != nil {
return 0, reflect.Value{}, ret[1].Interface().(error)
}
if ret[0].IsNil() {
return 0, reflect.Value{}, errors.New("aggregator constructor returned nil state")
}
ai.next++
ai.active[*aggIdx] = ret[0]
}
return *aggIdx, ai.active[*aggIdx], nil
}
func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
_, agg, err := ai.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}
args, err := callbackConvertArgs(argv, ai.stepArgConverters, ai.stepVariadicConverter)
if err != nil {
callbackError(ctx, err)
return
}
ret := agg.MethodByName("Step").Call(args)
if len(ret) == 1 && ret[0].Interface() != nil {
callbackError(ctx, ret[0].Interface().(error))
return
}
}
func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
idx, agg, err := ai.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}
defer func() { delete(ai.active, idx) }()
ret := agg.MethodByName("Done").Call(nil)
if len(ret) == 2 && ret[1].Interface() != nil {
callbackError(ctx, ret[1].Interface().(error))
return
}
err = ai.doneRetConverter(ctx, ret[0])
if err != nil {
callbackError(ctx, err)
return
}
}
// Commit transaction. // Commit transaction.
func (tx *SQLiteTx) Commit() error { func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec("COMMIT") _, err := tx.c.exec("COMMIT")
@ -132,6 +297,208 @@ func (tx *SQLiteTx) Rollback() error {
return err return err
} }
// RegisterFunc makes a Go function available as a SQLite function.
//
// The Go function can have arguments of the following types: any
// numeric type except complex, bool, []byte, string and
// interface{}. interface{} arguments are given the direct translation
// of the SQLite data type: int64 for INTEGER, float64 for FLOAT,
// []byte for BLOB, string for TEXT.
//
// The function can additionally be variadic, as long as the type of
// the variadic argument is one of the above.
//
// If pure is true. SQLite will assume that the function's return
// value depends only on its inputs, and make more aggressive
// optimizations in its queries.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error {
var fi functionInfo
fi.f = reflect.ValueOf(impl)
t := fi.f.Type()
if t.Kind() != reflect.Func {
return errors.New("Non-function passed to RegisterFunc")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite functions must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
numArgs := t.NumIn()
if t.IsVariadic() {
numArgs--
}
for i := 0; i < numArgs; i++ {
conv, err := callbackArg(t.In(i))
if err != nil {
return err
}
fi.argConverters = append(fi.argConverters, conv)
}
if t.IsVariadic() {
conv, err := callbackArg(t.In(numArgs).Elem())
if err != nil {
return err
}
fi.variadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
numArgs = -1
}
conv, err := callbackRet(t.Out(0))
if err != nil {
return err
}
fi.retConverter = conv
// fi must outlast the database connection, or we'll have dangling pointers.
c.funcs = append(c.funcs, &fi)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(newHandle(c, &fi)), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
//
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
//
// The constructor function and the Step/Done methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
var ai aggInfo
ai.constructor = reflect.ValueOf(impl)
t := ai.constructor.Type()
if t.Kind() != reflect.Func {
return errors.New("non-function passed to RegisterAggregator")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
if t.NumIn() != 0 {
return errors.New("SQLite aggregator constructors must not have arguments")
}
agg := t.Out(0)
switch agg.Kind() {
case reflect.Ptr, reflect.Interface:
default:
return errors.New("SQlite aggregator constructor must return a pointer object")
}
stepFn, found := agg.MethodByName("Step")
if !found {
return errors.New("SQlite aggregator doesn't have a Step() function")
}
step := stepFn.Type
if step.NumOut() != 0 && step.NumOut() != 1 {
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
}
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQlite aggregator Step() return value must be error")
}
stepNArgs := step.NumIn()
start := 0
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
stepNArgs--
start++
}
if step.IsVariadic() {
stepNArgs--
}
for i := start; i < start+stepNArgs; i++ {
conv, err := callbackArg(step.In(i))
if err != nil {
return err
}
ai.stepArgConverters = append(ai.stepArgConverters, conv)
}
if step.IsVariadic() {
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
if err != nil {
return err
}
ai.stepVariadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
stepNArgs = -1
}
doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQlite aggregator doesn't have a Done() function")
}
done := doneFn.Type
doneNArgs := done.NumIn()
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
doneNArgs--
}
if doneNArgs != 0 {
return errors.New("SQlite aggregator Done() function must have no arguments")
}
if done.NumOut() != 1 && done.NumOut() != 2 {
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
}
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator Done() function must be error")
}
conv, err := callbackRet(done.Out(0))
if err != nil {
return err
}
ai.doneRetConverter = conv
ai.active = make(map[int64]reflect.Value)
ai.next = 1
// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// AutoCommit return which currently auto commit or not. // AutoCommit return which currently auto commit or not.
func (c *SQLiteConn) AutoCommit() bool { func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0 return int(C.sqlite3_get_autocommit(c.db)) != 0
@ -159,6 +526,9 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
var res driver.Result var res driver.Result
if s.(*SQLiteStmt).s != nil { if s.(*SQLiteStmt).s != nil {
na := s.NumInput() na := s.NumInput()
if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
}
res, err = s.Exec(args[:na]) res, err = s.Exec(args[:na])
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.Close()
@ -184,6 +554,9 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
} }
s.(*SQLiteStmt).cls = true s.(*SQLiteStmt).cls = true
na := s.NumInput() na := s.NumInput()
if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
}
rows, err := s.Query(args[:na]) rows, err := s.Query(args[:na])
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.Close()
@ -194,6 +567,7 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
if tail == "" { if tail == "" {
return rows, nil return rows, nil
} }
rows.Close()
s.Close() s.Close()
query = tail query = tail
} }
@ -202,19 +576,18 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
func (c *SQLiteConn) exec(cmd string) (driver.Result, error) { func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
pcmd := C.CString(cmd) pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd)) defer C.free(unsafe.Pointer(pcmd))
rv := C.sqlite3_exec(c.db, pcmd, nil, nil, nil)
var rowid, changes C.longlong
rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return nil, c.lastError() return nil, c.lastError()
} }
return &SQLiteResult{ return &SQLiteResult{int64(rowid), int64(changes)}, nil
int64(C._sqlite3_last_insert_rowid(c.db)),
int64(C._sqlite3_changes(c.db)),
}, nil
} }
// Begin transaction. // Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) { func (c *SQLiteConn) Begin() (driver.Tx, error) {
if _, err := c.exec("BEGIN"); err != nil { if _, err := c.exec(c.txlock); err != nil {
return nil, err return nil, err
} }
return &SQLiteTx{c}, nil return &SQLiteTx{c}, nil
@ -230,11 +603,69 @@ func errorString(err Error) string {
// file:test.db?cache=shared&mode=memory // file:test.db?cache=shared&mode=memory
// :memory: // :memory:
// file::memory: // file::memory:
// go-sqlite3 adds the following query parameters to those used by SQLite:
// _loc=XXX
// Specify location of time format. It's possible to specify "auto".
// _busy_timeout=XXX
// Specify value for sqlite3_busy_timeout.
// _txlock=XXX
// Specify locking behavior for transactions. XXX can be "immediate",
// "deferred", "exclusive".
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 { if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation") return nil, errors.New("sqlite library was not compiled for thread-safe operation")
} }
var loc *time.Location
txlock := "BEGIN"
busy_timeout := 5000
pos := strings.IndexRune(dsn, '?')
if pos >= 1 {
params, err := url.ParseQuery(dsn[pos+1:])
if err != nil {
return nil, err
}
// _loc
if val := params.Get("_loc"); val != "" {
if val == "auto" {
loc = time.Local
} else {
loc, err = time.LoadLocation(val)
if err != nil {
return nil, fmt.Errorf("Invalid _loc: %v: %v", val, err)
}
}
}
// _busy_timeout
if val := params.Get("_busy_timeout"); val != "" {
iv, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("Invalid _busy_timeout: %v: %v", val, err)
}
busy_timeout = int(iv)
}
// _txlock
if val := params.Get("_txlock"); val != "" {
switch val {
case "immediate":
txlock = "BEGIN IMMEDIATE"
case "exclusive":
txlock = "BEGIN EXCLUSIVE"
case "deferred":
txlock = "BEGIN"
default:
return nil, fmt.Errorf("Invalid _txlock: %v", val)
}
}
if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos]
}
}
var db *C.sqlite3 var db *C.sqlite3
name := C.CString(dsn) name := C.CString(dsn)
defer C.free(unsafe.Pointer(name)) defer C.free(unsafe.Pointer(name))
@ -250,38 +681,17 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, errors.New("sqlite succeeded without returning a database") return nil, errors.New("sqlite succeeded without returning a database")
} }
rv = C.sqlite3_busy_timeout(db, 5000) rv = C.sqlite3_busy_timeout(db, C.int(busy_timeout))
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return nil, Error{Code: ErrNo(rv)} return nil, Error{Code: ErrNo(rv)}
} }
conn := &SQLiteConn{db} conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
if len(d.Extensions) > 0 { if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1) if err := conn.loadExtensions(d.Extensions); err != nil {
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
stmt, err := conn.Prepare("SELECT load_extension(?);")
if err != nil {
return nil, err return nil, err
} }
for _, extension := range d.Extensions {
if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
return nil, err
}
}
if err = stmt.Close(); err != nil {
return nil, err
}
rv = C.sqlite3_enable_load_extension(db, 0)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
} }
if d.ConnectHook != nil { if d.ConnectHook != nil {
@ -289,17 +699,19 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, err return nil, err
} }
} }
runtime.SetFinalizer(conn, (*SQLiteConn).Close)
return conn, nil return conn, nil
} }
// Close the connection. // Close the connection.
func (c *SQLiteConn) Close() error { func (c *SQLiteConn) Close() error {
deleteHandles(c)
rv := C.sqlite3_close_v2(c.db) rv := C.sqlite3_close_v2(c.db)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return c.lastError() return c.lastError()
} }
c.db = nil c.db = nil
runtime.SetFinalizer(c, nil)
return nil return nil
} }
@ -314,10 +726,20 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
return nil, c.lastError() return nil, c.lastError()
} }
var t string var t string
if tail != nil && C.strlen(tail) > 0 { if tail != nil && *tail != '\000' {
t = strings.TrimSpace(C.GoString(tail)) t = strings.TrimSpace(C.GoString(tail))
} }
return &SQLiteStmt{c: c, s: s, t: t}, nil nv := int(C.sqlite3_bind_parameter_count(s))
var nn []string
for i := 0; i < nv; i++ {
pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))
if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 {
nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))))
}
}
ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
return ss, nil
} }
// Close the statement. // Close the statement.
@ -333,12 +755,18 @@ func (s *SQLiteStmt) Close() error {
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return s.c.lastError() return s.c.lastError()
} }
runtime.SetFinalizer(s, nil)
return nil return nil
} }
// Return a number of parameters. // Return a number of parameters.
func (s *SQLiteStmt) NumInput() int { func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s)) return s.nv
}
type bindArg struct {
n int
v driver.Value
} }
func (s *SQLiteStmt) bind(args []driver.Value) error { func (s *SQLiteStmt) bind(args []driver.Value) error {
@ -347,8 +775,24 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
return s.c.lastError() return s.c.lastError()
} }
for i, v := range args { var vargs []bindArg
n := C.int(i + 1) narg := len(args)
vargs = make([]bindArg, narg)
if len(s.nn) > 0 {
for i, v := range s.nn {
if pi, err := strconv.Atoi(v[1:]); err == nil {
vargs[i] = bindArg{pi, args[i]}
}
}
} else {
for i, v := range args {
vargs[i] = bindArg{i + 1, v}
}
}
for _, varg := range vargs {
n := C.int(varg.n)
v := varg.v
switch v := v.(type) { switch v := v.(type) {
case nil: case nil:
rv = C.sqlite3_bind_null(s.s, n) rv = C.sqlite3_bind_null(s.s, n)
@ -371,13 +815,13 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
case float64: case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v)) rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte: case []byte:
var p *byte if len(v) == 0 {
if len(v) > 0 { rv = C._sqlite3_bind_blob(s.s, n, nil, 0)
p = &v[0] } else {
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v)))
} }
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
case time.Time: case time.Time:
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
} }
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
@ -408,18 +852,19 @@ func (r *SQLiteResult) RowsAffected() (int64, error) {
// Execute the statement with arguments. Return result object. // Execute the statement with arguments. Return result object.
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
if err := s.bind(args); err != nil { if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err return nil, err
} }
rv := C.sqlite3_step(s.s) var rowid, changes C.longlong
rv := C._sqlite3_step(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return nil, s.c.lastError() err := s.c.lastError()
C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s)
return nil, err
} }
return &SQLiteResult{int64(rowid), int64(changes)}, nil
res := &SQLiteResult{
int64(C._sqlite3_last_insert_rowid(s.c.db)),
int64(C._sqlite3_changes(s.c.db)),
}
return res, nil
} }
// Close the rows. // Close the rows.
@ -474,8 +919,20 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
case C.SQLITE_INTEGER: case C.SQLITE_INTEGER:
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i))) val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
switch rc.decltype[i] { switch rc.decltype[i] {
case "timestamp", "datetime": case "timestamp", "datetime", "date":
dest[i] = time.Unix(val, 0) var t time.Time
// Assume a millisecond unix timestamp if it's 13 digits -- too
// large to be a reasonable timestamp in seconds.
if val > 1e12 || val < -1e12 {
val *= int64(time.Millisecond) // convert ms to nsec
} else {
val *= int64(time.Second) // convert sec to nsec
}
t = time.Unix(0, val).UTC()
if rc.s.c.loc != nil {
t = t.In(rc.s.c.loc)
}
dest[i] = t
case "boolean": case "boolean":
dest[i] = val > 0 dest[i] = val > 0
default: default:
@ -502,19 +959,29 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
dest[i] = nil dest[i] = nil
case C.SQLITE_TEXT: case C.SQLITE_TEXT:
var err error var err error
s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i))))) var timeVal time.Time
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
s := C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))), C.int(n))
switch rc.decltype[i] { switch rc.decltype[i] {
case "timestamp", "datetime": case "timestamp", "datetime", "date":
var t time.Time
s = strings.TrimSuffix(s, "Z")
for _, format := range SQLiteTimestampFormats { for _, format := range SQLiteTimestampFormats {
if dest[i], err = time.Parse(format, s); err == nil { if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
t = timeVal
break break
} }
} }
if err != nil { if err != nil {
// The column is a time value, so return the zero time on parse failure. // The column is a time value, so return the zero time on parse failure.
dest[i] = time.Time{} t = time.Time{}
} }
if rc.s.c.loc != nil {
t = t.In(rc.s.c.loc)
}
dest[i] = t
default: default:
dest[i] = []byte(s) dest[i] = []byte(s)
} }

127
sqlite3_fts3_test.go Normal file
View File

@ -0,0 +1,127 @@
// Copyright (C) 2015 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
import (
"database/sql"
"os"
"testing"
)
func TestFTS3(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts3(id INTEGER PRIMARY KEY, value TEXT)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
_, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 1, `今日の 晩御飯は 天麩羅よ`)
if err != nil {
t.Fatal("Failed to insert value:", err)
}
_, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 2, `今日は いい 天気だ`)
if err != nil {
t.Fatal("Failed to insert value:", err)
}
rows, err := db.Query("SELECT id, value FROM foo WHERE value MATCH '今日* 天*'")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
for rows.Next() {
var id int
var value string
if err := rows.Scan(&id, &value); err != nil {
t.Error("Unable to scan results:", err)
continue
}
if id == 1 && value != `今日の 晩御飯は 天麩羅よ` {
t.Error("Value for id 1 should be `今日の 晩御飯は 天麩羅よ`, but:", value)
} else if id == 2 && value != `今日は いい 天気だ` {
t.Error("Value for id 2 should be `今日は いい 天気だ`, but:", value)
}
}
rows, err = db.Query("SELECT value FROM foo WHERE value MATCH '今日* 天麩羅*'")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
var value string
if !rows.Next() {
t.Fatal("Result should be only one")
}
if err := rows.Scan(&value); err != nil {
t.Fatal("Unable to scan results:", err)
}
if value != `今日の 晩御飯は 天麩羅よ` {
t.Fatal("Value should be `今日の 晩御飯は 天麩羅よ`, but:", value)
}
if rows.Next() {
t.Fatal("Result should be only one")
}
}
func TestFTS4(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts4(tokenize=unicode61, id INTEGER PRIMARY KEY, value TEXT)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
_, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 1, `février`)
if err != nil {
t.Fatal("Failed to insert value:", err)
}
rows, err := db.Query("SELECT value FROM foo WHERE value MATCH 'fevrier'")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
var value string
if !rows.Next() {
t.Fatal("Result should be only one")
}
if err := rows.Scan(&value); err != nil {
t.Fatal("Unable to scan results:", err)
}
if value != `février` {
t.Fatal("Value should be `février`, but:", value)
}
if rows.Next() {
t.Fatal("Result should be only one")
}
}

13
sqlite3_icu.go Normal file
View File

@ -0,0 +1,13 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build icu
package sqlite3
/*
#cgo LDFLAGS: -licuuc -licui18n
#cgo CFLAGS: -DSQLITE_ENABLE_ICU
*/
import "C"

13
sqlite3_libsqlite3.go Normal file
View File

@ -0,0 +1,13 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build libsqlite3
package sqlite3
/*
#cgo CFLAGS: -DUSE_LIBSQLITE3
#cgo LDFLAGS: -lsqlite3
*/
import "C"

39
sqlite3_load_extension.go Normal file
View File

@ -0,0 +1,39 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build !sqlite_omit_load_extension
package sqlite3
/*
#include <sqlite3-binding.h>
#include <stdlib.h>
*/
import "C"
import (
"errors"
"unsafe"
)
func (c *SQLiteConn) loadExtensions(extensions []string) error {
rv := C.sqlite3_enable_load_extension(c.db, 1)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
for _, extension := range extensions {
cext := C.CString(extension)
defer C.free(unsafe.Pointer(cext))
rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
}
rv = C.sqlite3_enable_load_extension(c.db, 0)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
return nil
}

View File

@ -0,0 +1,19 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build sqlite_omit_load_extension
package sqlite3
/*
#cgo CFLAGS: -DSQLITE_OMIT_LOAD_EXTENSION
*/
import "C"
import (
"errors"
)
func (c *SQLiteConn) loadExtensions(extensions []string) error {
return errors.New("Extensions have been disabled for static builds")
}

View File

@ -1,3 +1,7 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build !windows // +build !windows
package sqlite3 package sqlite3
@ -5,6 +9,5 @@ package sqlite3
/* /*
#cgo CFLAGS: -I. #cgo CFLAGS: -I.
#cgo linux LDFLAGS: -ldl #cgo linux LDFLAGS: -ldl
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE
*/ */
import "C" import "C"

View File

@ -1,28 +1,49 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3 package sqlite3
import ( import (
"crypto/rand"
"database/sql" "database/sql"
"encoding/hex" "database/sql/driver"
"errors"
"fmt"
"io/ioutil"
"net/url"
"os" "os"
"path/filepath" "reflect"
"regexp"
"strings"
"sync"
"testing" "testing"
"time" "time"
"./sqlite3_test" "github.com/mattn/go-sqlite3/sqlite3_test"
) )
func TempFilename() string { func TempFilename(t *testing.T) string {
randBytes := make([]byte, 16) f, err := ioutil.TempFile("", "go-sqlite3-test-")
rand.Read(randBytes) if err != nil {
return filepath.Join(os.TempDir(), "foo"+hex.EncodeToString(randBytes)+".db") t.Fatal(err)
}
f.Close()
return f.Name()
} }
func TestOpen(t *testing.T) { func doTestOpen(t *testing.T, option string) (string, error) {
tempFilename := TempFilename() var url string
db, err := sql.Open("sqlite3", tempFilename) tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
if option != "" {
url = tempFilename + option
} else {
url = tempFilename
}
db, err := sql.Open("sqlite3", url)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) return "Failed to open database:", err
} }
defer os.Remove(tempFilename) defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
@ -30,21 +51,69 @@ func TestOpen(t *testing.T) {
_, err = db.Exec("drop table foo") _, err = db.Exec("drop table foo")
_, err = db.Exec("create table foo (id integer)") _, err = db.Exec("create table foo (id integer)")
if err != nil { if err != nil {
t.Fatal("Failed to create table:", err) return "Failed to create table:", err
} }
if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() { if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() {
t.Error("Failed to create ./foo.db") return "Failed to create ./foo.db", nil
}
return "", nil
}
func TestOpen(t *testing.T) {
cases := map[string]bool{
"": true,
"?_txlock=immediate": true,
"?_txlock=deferred": true,
"?_txlock=exclusive": true,
"?_txlock=bogus": false,
}
for option, expectedPass := range cases {
result, err := doTestOpen(t, option)
if result == "" {
if !expectedPass {
errmsg := fmt.Sprintf("_txlock error not caught at dbOpen with option: %s", option)
t.Fatal(errmsg)
}
} else if expectedPass {
if err == nil {
t.Fatal(result)
} else {
t.Fatal(result, err)
}
}
}
}
func TestReadonly(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db1, err := sql.Open("sqlite3", "file:"+tempFilename)
if err != nil {
t.Fatal(err)
}
db1.Exec("CREATE TABLE test (x int, y float)")
db2, err := sql.Open("sqlite3", "file:"+tempFilename+"?mode=ro")
if err != nil {
t.Fatal(err)
}
_ = db2
_, err = db2.Exec("INSERT INTO test VALUES (1, 3.14)")
if err == nil {
t.Fatal("didn't expect INSERT into read-only database to work")
} }
} }
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
_, err = db.Exec("drop table foo") _, err = db.Exec("drop table foo")
_, err = db.Exec("create table foo (id integer)") _, err = db.Exec("create table foo (id integer)")
@ -65,12 +134,12 @@ func TestClose(t *testing.T) {
} }
func TestInsert(t *testing.T) { func TestInsert(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("drop table foo") _, err = db.Exec("drop table foo")
@ -104,12 +173,12 @@ func TestInsert(t *testing.T) {
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("drop table foo") _, err = db.Exec("drop table foo")
@ -169,12 +238,12 @@ func TestUpdate(t *testing.T) {
} }
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("drop table foo") _, err = db.Exec("drop table foo")
@ -230,12 +299,12 @@ func TestDelete(t *testing.T) {
} }
func TestBooleanRoundtrip(t *testing.T) { func TestBooleanRoundtrip(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("DROP TABLE foo") _, err = db.Exec("DROP TABLE foo")
@ -278,13 +347,15 @@ func TestBooleanRoundtrip(t *testing.T) {
} }
} }
func timezone(t time.Time) string { return t.Format("-07:00") }
func TestTimestamp(t *testing.T) { func TestTimestamp(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("DROP TABLE foo") _, err = db.Exec("DROP TABLE foo")
@ -296,6 +367,7 @@ func TestTimestamp(t *testing.T) {
timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC)
timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC)
timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC)
tzTest := time.FixedZone("TEST", -9*3600-13*60)
tests := []struct { tests := []struct {
value interface{} value interface{}
expected time.Time expected time.Time
@ -303,8 +375,9 @@ func TestTimestamp(t *testing.T) {
{"nonsense", time.Time{}}, {"nonsense", time.Time{}},
{"0000-00-00 00:00:00", time.Time{}}, {"0000-00-00 00:00:00", time.Time{}},
{timestamp1, timestamp1}, {timestamp1, timestamp1},
{timestamp1.Unix(), timestamp1}, {timestamp2.Unix(), timestamp2.Truncate(time.Second)},
{timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1}, {timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)},
{timestamp1.In(tzTest), timestamp1.In(tzTest)},
{timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1}, {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1},
{timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1}, {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1},
{timestamp1.Format("2006-01-02 15:04:05"), timestamp1}, {timestamp1.Format("2006-01-02 15:04:05"), timestamp1},
@ -312,6 +385,7 @@ func TestTimestamp(t *testing.T) {
{timestamp2, timestamp2}, {timestamp2, timestamp2},
{"2006-01-02 15:04:05.123456789", timestamp2}, {"2006-01-02 15:04:05.123456789", timestamp2},
{"2006-01-02T15:04:05.123456789", timestamp2}, {"2006-01-02T15:04:05.123456789", timestamp2},
{"2006-01-02T05:51:05.123456789-09:13", timestamp2.In(tzTest)},
{"2012-11-04", timestamp3}, {"2012-11-04", timestamp3},
{"2012-11-04 00:00", timestamp3}, {"2012-11-04 00:00", timestamp3},
{"2012-11-04 00:00:00", timestamp3}, {"2012-11-04 00:00:00", timestamp3},
@ -319,6 +393,14 @@ func TestTimestamp(t *testing.T) {
{"2012-11-04T00:00", timestamp3}, {"2012-11-04T00:00", timestamp3},
{"2012-11-04T00:00:00", timestamp3}, {"2012-11-04T00:00:00", timestamp3},
{"2012-11-04T00:00:00.000", timestamp3}, {"2012-11-04T00:00:00.000", timestamp3},
{"2006-01-02T15:04:05.123456789Z", timestamp2},
{"2012-11-04Z", timestamp3},
{"2012-11-04 00:00Z", timestamp3},
{"2012-11-04 00:00:00Z", timestamp3},
{"2012-11-04 00:00:00.000Z", timestamp3},
{"2012-11-04T00:00Z", timestamp3},
{"2012-11-04T00:00:00Z", timestamp3},
{"2012-11-04T00:00:00.000Z", timestamp3},
} }
for i := range tests { for i := range tests {
_, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value)
@ -353,6 +435,14 @@ func TestTimestamp(t *testing.T) {
if !tests[id].expected.Equal(dt) { if !tests[id].expected.Equal(dt) {
t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt)
} }
if timezone(tests[id].expected) != timezone(ts) {
t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value,
timezone(tests[id].expected), timezone(ts))
}
if timezone(tests[id].expected) != timezone(dt) {
t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value,
timezone(tests[id].expected), timezone(dt))
}
} }
if seen != len(tests) { if seen != len(tests) {
@ -361,13 +451,13 @@ func TestTimestamp(t *testing.T) {
} }
func TestBoolean(t *testing.T) { func TestBoolean(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)") _, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)")
@ -453,13 +543,12 @@ func TestBoolean(t *testing.T) {
} }
func TestFloat32(t *testing.T) { func TestFloat32(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("CREATE TABLE foo(id INTEGER)") _, err = db.Exec("CREATE TABLE foo(id INTEGER)")
@ -491,13 +580,12 @@ func TestFloat32(t *testing.T) {
} }
func TestNull(t *testing.T) { func TestNull(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
rows, err := db.Query("SELECT 3.141592") rows, err := db.Query("SELECT 3.141592")
@ -523,13 +611,12 @@ func TestNull(t *testing.T) {
} }
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("CREATE TABLE foo(id INTEGER)") _, err = db.Exec("CREATE TABLE foo(id INTEGER)")
@ -583,14 +670,14 @@ func TestTransaction(t *testing.T) {
} }
func TestWAL(t *testing.T) { func TestWAL(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
if _, err = db.Exec("PRAGMA journal_mode=WAL;"); err != nil { if _, err = db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
t.Fatal("Failed to Exec PRAGMA journal_mode:", err) t.Fatal("Failed to Exec PRAGMA journal_mode:", err)
} }
@ -628,8 +715,106 @@ func TestWAL(t *testing.T) {
} }
} }
func TestTimezoneConversion(t *testing.T) {
zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
for _, tz := range zones {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz))
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE TABLE foo(id INTEGER, ts TIMESTAMP, dt DATETIME)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
loc, err := time.LoadLocation(tz)
if err != nil {
t.Fatal("Failed to load location:", err)
}
timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC)
timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC)
timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC)
tests := []struct {
value interface{}
expected time.Time
}{
{"nonsense", time.Time{}.In(loc)},
{"0000-00-00 00:00:00", time.Time{}.In(loc)},
{timestamp1, timestamp1.In(loc)},
{timestamp1.Unix(), timestamp1.In(loc)},
{timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02 15:04:05"), timestamp1.In(loc)},
{timestamp1.Format("2006-01-02T15:04:05"), timestamp1.In(loc)},
{timestamp2, timestamp2.In(loc)},
{"2006-01-02 15:04:05.123456789", timestamp2.In(loc)},
{"2006-01-02T15:04:05.123456789", timestamp2.In(loc)},
{"2012-11-04", timestamp3.In(loc)},
{"2012-11-04 00:00", timestamp3.In(loc)},
{"2012-11-04 00:00:00", timestamp3.In(loc)},
{"2012-11-04 00:00:00.000", timestamp3.In(loc)},
{"2012-11-04T00:00", timestamp3.In(loc)},
{"2012-11-04T00:00:00", timestamp3.In(loc)},
{"2012-11-04T00:00:00.000", timestamp3.In(loc)},
}
for i := range tests {
_, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value)
if err != nil {
t.Fatal("Failed to insert timestamp:", err)
}
}
rows, err := db.Query("SELECT id, ts, dt FROM foo ORDER BY id ASC")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
seen := 0
for rows.Next() {
var id int
var ts, dt time.Time
if err := rows.Scan(&id, &ts, &dt); err != nil {
t.Error("Unable to scan results:", err)
continue
}
if id < 0 || id >= len(tests) {
t.Error("Bad row id: ", id)
continue
}
seen++
if !tests[id].expected.Equal(ts) {
t.Errorf("Timestamp value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, ts)
}
if !tests[id].expected.Equal(dt) {
t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt)
}
if tests[id].expected.Location().String() != ts.Location().String() {
t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), ts.Location().String())
}
if tests[id].expected.Location().String() != dt.Location().String() {
t.Errorf("Location for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected.Location().String(), dt.Location().String())
}
}
if seen != len(tests) {
t.Errorf("Expected to see %d rows", len(tests))
}
}
}
func TestSuite(t *testing.T) { func TestSuite(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:") tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -641,12 +826,12 @@ func TestSuite(t *testing.T) {
// TODO: Execer & Queryer currently disabled // TODO: Execer & Queryer currently disabled
// https://github.com/mattn/go-sqlite3/issues/82 // https://github.com/mattn/go-sqlite3/issues/82
func TestExecer(t *testing.T) { func TestExecer(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec(` _, err = db.Exec(`
@ -661,12 +846,12 @@ func TestExecer(t *testing.T) {
} }
func TestQueryer(t *testing.T) { func TestQueryer(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec(` _, err = db.Exec(`
@ -702,7 +887,8 @@ func TestQueryer(t *testing.T) {
} }
func TestStress(t *testing.T) { func TestStress(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
@ -737,3 +923,393 @@ func TestStress(t *testing.T) {
db.Close() db.Close()
} }
} }
func TestDateTimeLocal(t *testing.T) {
zone := "Asia/Tokyo"
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone)
if err != nil {
t.Fatal("Failed to open database:", err)
}
db.Exec("CREATE TABLE foo (dt datetime);")
db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');")
row := db.QueryRow("select * from foo")
var d time.Time
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.Hour() == 15 || !strings.Contains(d.String(), "JST") {
t.Fatal("Result should have timezone", d)
}
db.Close()
db, err = sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
row = db.QueryRow("select * from foo")
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.UTC().Hour() != 15 || !strings.Contains(d.String(), "UTC") {
t.Fatalf("Result should not have timezone %v %v", zone, d.String())
}
_, err = db.Exec("DELETE FROM foo")
if err != nil {
t.Fatal("Failed to delete table:", err)
}
dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST")
if err != nil {
t.Fatal("Failed to parse datetime:", err)
}
db.Exec("INSERT INTO foo VALUES(?);", dt)
db.Close()
db, err = sql.Open("sqlite3", tempFilename+"?_loc="+zone)
if err != nil {
t.Fatal("Failed to open database:", err)
}
row = db.QueryRow("select * from foo")
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.Hour() != 15 || !strings.Contains(d.String(), "JST") {
t.Fatalf("Result should have timezone %v %v", zone, d.String())
}
}
func TestVersion(t *testing.T) {
s, n, id := Version()
if s == "" || n == 0 || id == "" {
t.Errorf("Version failed %q, %d, %q\n", s, n, id)
}
}
func TestNumberNamedParams(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name text, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
_, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, "foo")
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, "foo")
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}
func TestStringContainingZero(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
const text = "foo\x00bar"
_, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, text)
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, text)
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != text {
t.Error("Failed to db.QueryRow: not matched results")
}
}
const CurrentTimeStamp = "2006-01-02 15:04:05"
type TimeStamp struct{ *time.Time }
func (t TimeStamp) Scan(value interface{}) error {
var err error
switch v := value.(type) {
case string:
*t.Time, err = time.Parse(CurrentTimeStamp, v)
case []byte:
*t.Time, err = time.Parse(CurrentTimeStamp, string(v))
default:
err = errors.New("invalid type for current_timestamp")
}
return err
}
func (t TimeStamp) Value() (driver.Value, error) {
return t.Time.Format(CurrentTimeStamp), nil
}
func TestDateTimeNow(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
var d time.Time
err = db.QueryRow("SELECT datetime('now')").Scan(TimeStamp{&d})
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
}
func TestFunctionRegistration(t *testing.T) {
addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) }
addi_64 := func(a, b int64) int64 { return a + b }
addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) }
addu_64 := func(a, b uint64) uint64 { return a + b }
addiu := func(a int, b uint) int64 { return int64(a) + int64(b) }
addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b }
not := func(a bool) bool { return !a }
regex := func(re, s string) (bool, error) {
return regexp.MatchString(re, s)
}
generic := func(a interface{}) int64 {
switch a.(type) {
case int64:
return 1
case float64:
return 2
case []byte:
return 3
case string:
return 4
default:
panic("unreachable")
}
}
variadic := func(a, b int64, c ...int64) int64 {
ret := a + b
for _, d := range c {
ret += d
}
return ret
}
variadicGeneric := func(a ...interface{}) int64 {
return int64(len(a))
}
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil {
return err
}
if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil {
return err
}
if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil {
return err
}
if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil {
return err
}
if err := conn.RegisterFunc("addiu", addiu, true); err != nil {
return err
}
if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil {
return err
}
if err := conn.RegisterFunc("not", not, true); err != nil {
return err
}
if err := conn.RegisterFunc("regex", regex, true); err != nil {
return err
}
if err := conn.RegisterFunc("generic", generic, true); err != nil {
return err
}
if err := conn.RegisterFunc("variadic", variadic, true); err != nil {
return err
}
if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:")
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
ops := []struct {
query string
expected interface{}
}{
{"SELECT addi_8_16_32(1,2)", int32(3)},
{"SELECT addi_64(1,2)", int64(3)},
{"SELECT addu_8_16_32(1,2)", uint32(3)},
{"SELECT addu_64(1,2)", uint64(3)},
{"SELECT addiu(1,2)", int64(3)},
{"SELECT addf_32_64(1.5,1.5)", float64(3)},
{"SELECT not(1)", false},
{"SELECT not(0)", true},
{`SELECT regex("^foo.*", "foobar")`, true},
{`SELECT regex("^foo.*", "barfoobar")`, false},
{"SELECT generic(1)", int64(1)},
{"SELECT generic(1.1)", int64(2)},
{`SELECT generic(NULL)`, int64(3)},
{`SELECT generic("foo")`, int64(4)},
{"SELECT variadic(1,2)", int64(3)},
{"SELECT variadic(1,2,3,4)", int64(10)},
{"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)},
{`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)},
}
for _, op := range ops {
ret := reflect.New(reflect.TypeOf(op.expected))
err = db.QueryRow(op.query).Scan(ret.Interface())
if err != nil {
t.Errorf("Query %q failed: %s", op.query, err)
} else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) {
t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected)
}
}
}
type sumAggregator int64
func (s *sumAggregator) Step(x int64) {
*s += sumAggregator(x)
}
func (s *sumAggregator) Done() int64 {
return int64(*s)
}
func TestAggregatorRegistration(t *testing.T) {
customSum := func() *sumAggregator {
var ret sumAggregator
return &ret
}
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec("create table foo (department integer, profits integer)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
if err != nil {
t.Fatal("Failed to insert records:", err)
}
tests := []struct {
dept, sum int64
}{
{1, 30},
{2, 42},
}
for _, test := range tests {
var ret int64
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
if err != nil {
t.Fatal("Query failed:", err)
}
if ret != test.sum {
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
}
}
}
var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) {
customFunctionOnce.Do(func() {
custom_add := func(a, b int64) int64 {
return a + b
}
sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
// Impure function to force sqlite to reexecute it each time.
if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil {
return err
}
return nil
},
})
})
db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:")
if err != nil {
b.Fatal("Failed to open database:", err)
}
defer db.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
var i int64
err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i)
if err != nil {
b.Fatal("Failed to run custom add:", err)
}
}
}

View File

@ -275,12 +275,11 @@ func TestPreparedStmt(t *testing.T) {
} }
const nRuns = 10 const nRuns = 10
ch := make(chan bool) var wg sync.WaitGroup
for i := 0; i < nRuns; i++ { for i := 0; i < nRuns; i++ {
wg.Add(1)
go func() { go func() {
defer func() { defer wg.Done()
ch <- true
}()
for j := 0; j < 10; j++ { for j := 0; j < 10; j++ {
count := 0 count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
@ -294,9 +293,7 @@ func TestPreparedStmt(t *testing.T) {
} }
}() }()
} }
for i := 0; i < nRuns; i++ { wg.Wait()
<-ch
}
} }
// Benchmarks need to use panic() since b.Error errors are lost when // Benchmarks need to use panic() since b.Error errors are lost when
@ -318,7 +315,7 @@ func BenchmarkQuery(b *testing.B) {
var i int var i int
var f float64 var f float64
var s string var s string
// var t time.Time // var t time.Time
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
panic(err) panic(err)
} }
@ -331,7 +328,7 @@ func BenchmarkParams(b *testing.B) {
var i int var i int
var f float64 var f float64
var s string var s string
// var t time.Time // var t time.Time
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err) panic(err)
} }
@ -350,7 +347,7 @@ func BenchmarkStmt(b *testing.B) {
var i int var i int
var f float64 var f float64
var s string var s string
// var t time.Time // var t time.Time
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err) panic(err)
} }

View File

@ -1,8 +1,14 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build windows
package sqlite3 package sqlite3
/* /*
#cgo CFLAGS: -I. -fno-stack-check -fno-stack-protector -mno-stack-arg-probe #cgo CFLAGS: -I. -fno-stack-check -fno-stack-protector -mno-stack-arg-probe
#cgo LDFLAGS: -lmingwex -lmingw32 -lmsvcr100 #cgo windows,386 CFLAGS: -D_USE_32BIT_TIME_T
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE #cgo LDFLAGS: -lmingwex -lmingw32
*/ */
import "C" import "C"