Update sqlcipher to 3.11.0

This commit is contained in:
xeodou 2016-07-07 15:43:10 +08:00
commit 959ff350e1
24 changed files with 54647 additions and 17316 deletions

View File

@ -1,9 +1,13 @@
language: go language: go
sudo: required
dist: trusty
go: go:
- 1.5
- 1.6
- tip - tip
before_install: before_install:
- go get github.com/axw/gocov/gocov
- go get github.com/mattn/goveralls - go get github.com/mattn/goveralls
- go get golang.org/x/tools/cmd/cover - go get golang.org/x/tools/cmd/cover
script: script:
- $HOME/gopath/bin/goveralls -repotoken 3qJVUE0iQwqnCbmNcDsjYu1nh4J4KIFXx - $HOME/gopath/bin/goveralls -repotoken 3qJVUE0iQwqnCbmNcDsjYu1nh4J4KIFXx
- go test -v . -tags "libsqlite3"

View File

@ -1,6 +1,6 @@
The MIT License (MIT) The MIT License (MIT)
Copyright (c) 2014 Yasuhiro Matsumoto Copyright (c) 2014 Xeodou Li
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@ -4,6 +4,7 @@ go-sqlcipher
SQLCipher driver conforming to the built-in database/sql interface and using the latest sqlite3 code. SQLCipher driver conforming to the built-in database/sql interface and using the latest sqlite3 code.
which is which is
`3.8.8.3 2015-02-25 13:29:11 9d6c1880fb75660bbabd693175579529785f8a6b` `3.8.8.3 2015-02-25 13:29:11 9d6c1880fb75660bbabd693175579529785f8a6b`
@ -39,6 +40,10 @@ This package can be installed with the go get command:
go get github.com/xeodou/go-sqlcipher go get github.com/xeodou/go-sqlcipher
_go-sqlcipher_ is *cgo* package.
If you want to build your app using go-sqlcipher, you need gcc.
However, if you install _go-sqlcipher_ with `go install github.com/xeodou/go-sqlcipher`, you don't need gcc to build your app anymore.
Documentation Documentation
------------- -------------
@ -50,10 +55,24 @@ FAQ
--- ---
The golang code is copy from [go-sqlite3](https://github.com/mattn/go-sqlite3) The golang code is copy from [go-sqlite3](https://github.com/mattn/go-sqlite3)
If you have some issue, you can maybe you can find from https://github.com/mattn/go-sqlite3/issues If you have some issue, maybe you can find from https://github.com/mattn/go-sqlite3/issues
Here is some help from go-sqlite3 project. Here is some help from go-sqlite3 project.
* Want to build go-sqlite3 with libsqlite3 on my linux.
Use `go build --tags "libsqlite3 linux"`
* Want to build go-sqlite3 with libsqlite3 on OS X.
Install sqlite3 from homebrew: `brew install sqlite3`
Use `go build --tags "libsqlite3 darwin"`
* Want to build go-sqlite3 with icu extension.
Use `go build --tags "icu"`
* Can't build go-sqlite3 on windows 64bit. * Can't build go-sqlite3 on windows 64bit.
> Probably, you are using go 1.0, go1.0 has a problem when it comes to compiling/linking on windows 64bit. > Probably, you are using go 1.0, go1.0 has a problem when it comes to compiling/linking on windows 64bit.
@ -64,7 +83,7 @@ Here is some help from go-sqlite3 project.
> You can pass some arguments into the connection string, for example, a URI. > You can pass some arguments into the connection string, for example, a URI.
> See: https://github.com/mattn/go-sqlite3/issues/39 > See: https://github.com/mattn/go-sqlite3/issues/39
* Do you want cross compiling? mingw on Linux or Mac? * Do you want to cross compile? mingw on Linux or Mac?
> See: https://github.com/mattn/go-sqlite3/issues/106 > See: https://github.com/mattn/go-sqlite3/issues/106
> See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html > See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html
@ -88,6 +107,11 @@ The -binding suffix was added to avoid build failures under gccgo.
In this repository, those files are amalgamation code that copied from SQLCipher. The license of those codes are depend on the license of SQLCipher. In this repository, those files are amalgamation code that copied from SQLCipher. The license of those codes are depend on the license of SQLCipher.
In this repository, those files are an amalgamation of code that was copied from SQLite3. The license of that code is the same as the license of SQLite3.
Original repository https://github.com/mattn/go-sqlite3 is under MIT.
Author Author
------ ------

View File

@ -0,0 +1,133 @@
package main
import (
"database/sql"
"fmt"
"log"
"math"
"math/rand"
sqlite "github.com/mattn/go-sqlite3"
)
// Computes x^y
func pow(x, y int64) int64 {
return int64(math.Pow(float64(x), float64(y)))
}
// Computes the bitwise exclusive-or of all its arguments
func xor(xs ...int64) int64 {
var ret int64
for _, x := range xs {
ret ^= x
}
return ret
}
// Returns a random number. It's actually deterministic here because
// we don't seed the RNG, but it's an example of a non-pure function
// from SQLite's POV.
func getrand() int64 {
return rand.Int63()
}
// Computes the standard deviation of a GROUPed BY set of values
type stddev struct {
xs []int64
// Running average calculation
sum int64
n int64
}
func newStddev() *stddev { return &stddev{} }
func (s *stddev) Step(x int64) {
s.xs = append(s.xs, x)
s.sum += x
s.n++
}
func (s *stddev) Done() float64 {
mean := float64(s.sum) / float64(s.n)
var sqDiff []float64
for _, x := range s.xs {
sqDiff = append(sqDiff, math.Pow(float64(x)-mean, 2))
}
var dev float64
for _, x := range sqDiff {
dev += x
}
dev /= float64(len(sqDiff))
return math.Sqrt(dev)
}
func main() {
sql.Register("sqlite3_custom", &sqlite.SQLiteDriver{
ConnectHook: func(conn *sqlite.SQLiteConn) error {
if err := conn.RegisterFunc("pow", pow, true); err != nil {
return err
}
if err := conn.RegisterFunc("xor", xor, true); err != nil {
return err
}
if err := conn.RegisterFunc("rand", getrand, false); err != nil {
return err
}
if err := conn.RegisterAggregator("stddev", newStddev, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_custom", ":memory:")
if err != nil {
log.Fatal("Failed to open database:", err)
}
defer db.Close()
var i int64
err = db.QueryRow("SELECT pow(2,3)").Scan(&i)
if err != nil {
log.Fatal("POW query error:", err)
}
fmt.Println("pow(2,3) =", i) // 8
err = db.QueryRow("SELECT xor(1,2,3,4,5,6)").Scan(&i)
if err != nil {
log.Fatal("XOR query error:", err)
}
fmt.Println("xor(1,2,3,4,5) =", i) // 7
err = db.QueryRow("SELECT rand()").Scan(&i)
if err != nil {
log.Fatal("RAND query error:", err)
}
fmt.Println("rand() =", i) // pseudorandom
_, err = db.Exec("create table foo (department integer, profits integer)")
if err != nil {
log.Fatal("Failed to create table:", err)
}
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115)")
if err != nil {
log.Fatal("Failed to insert records:", err)
}
rows, err := db.Query("select department, stddev(profits) from foo group by department")
if err != nil {
log.Fatal("STDDEV query error:", err)
}
defer rows.Close()
for rows.Next() {
var dept int64
var dev float64
if err := rows.Scan(&dept, &dev); err != nil {
log.Fatal(err)
}
fmt.Printf("dept=%d stddev=%f\n", dept, dev)
}
if err := rows.Err(); err != nil {
log.Fatal(err)
}
}

View File

@ -54,7 +54,7 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
bk.Step(-1) _, err = bk.Step(-1)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -52,10 +52,16 @@ func main() {
for rows.Next() { for rows.Next() {
var id int var id int
var name string var name string
rows.Scan(&id, &name) err = rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
fmt.Println(id, name) fmt.Println(id, name)
} }
rows.Close() err = rows.Err()
if err != nil {
log.Fatal(err)
}
stmt, err = db.Prepare("select name from foo where id = ?") stmt, err = db.Prepare("select name from foo where id = ?")
if err != nil { if err != nil {
@ -87,7 +93,14 @@ func main() {
for rows.Next() { for rows.Next() {
var id int var id int
var name string var name string
rows.Scan(&id, &name) err = rows.Scan(&id, &name)
if err != nil {
log.Fatal(err)
}
fmt.Println(id, name) fmt.Println(id, name)
} }
err = rows.Err()
if err != nil {
log.Fatal(err)
}
} }

View File

@ -6,7 +6,11 @@
package sqlite3 package sqlite3
/* /*
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h> #include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
#include <stdlib.h> #include <stdlib.h>
*/ */
import "C" import "C"

336
callback.go Normal file
View File

@ -0,0 +1,336 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
// You can't export a Go function to C and have definitions in the C
// preamble in the same file, so we have to have callbackTrampoline in
// its own file. Because we need a separate file anyway, the support
// code for SQLite custom functions is in here.
/*
#include <sqlite3-binding.h>
#include <stdlib.h>
void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
*/
import "C"
import (
"errors"
"fmt"
"math"
"reflect"
"sync"
"unsafe"
)
//export callbackTrampoline
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo)
fi.Call(ctx, args)
}
//export stepTrampoline
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
ai.Step(ctx, args)
}
//export doneTrampoline
func doneTrampoline(ctx *C.sqlite3_context) {
handle := uintptr(C.sqlite3_user_data(ctx))
ai := lookupHandle(handle).(*aggInfo)
ai.Done(ctx)
}
// Use handles to avoid passing Go pointers to C.
type handleVal struct {
db *SQLiteConn
val interface{}
}
var handleLock sync.Mutex
var handleVals = make(map[uintptr]handleVal)
var handleIndex uintptr = 100
func newHandle(db *SQLiteConn, v interface{}) uintptr {
handleLock.Lock()
defer handleLock.Unlock()
i := handleIndex
handleIndex++
handleVals[i] = handleVal{db, v}
return i
}
func lookupHandle(handle uintptr) interface{} {
handleLock.Lock()
defer handleLock.Unlock()
r, ok := handleVals[handle]
if !ok {
if handle >= 100 && handle < handleIndex {
panic("deleted handle")
} else {
panic("invalid handle")
}
}
return r.val
}
func deleteHandles(db *SQLiteConn) {
handleLock.Lock()
defer handleLock.Unlock()
for handle, val := range handleVals {
if val.db == db {
delete(handleVals, handle)
}
}
}
// This is only here so that tests can refer to it.
type callbackArgRaw C.sqlite3_value
type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
type callbackArgCast struct {
f callbackArgConverter
typ reflect.Type
}
func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
val, err := c.f(v)
if err != nil {
return reflect.Value{}, err
}
if !val.Type().ConvertibleTo(c.typ) {
return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
}
return val.Convert(c.typ), nil
}
func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
}
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
}
func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
}
i := int64(C.sqlite3_value_int64(v))
val := false
if i != 0 {
val = true
}
return reflect.ValueOf(val), nil
}
func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
}
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
}
func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_BLOB:
l := C.sqlite3_value_bytes(v)
p := C.sqlite3_value_blob(v)
return reflect.ValueOf(C.GoBytes(p, l)), nil
case C.SQLITE_TEXT:
l := C.sqlite3_value_bytes(v)
c := unsafe.Pointer(C.sqlite3_value_text(v))
return reflect.ValueOf(C.GoBytes(c, l)), nil
default:
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
}
}
func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_BLOB:
l := C.sqlite3_value_bytes(v)
p := (*C.char)(C.sqlite3_value_blob(v))
return reflect.ValueOf(C.GoStringN(p, l)), nil
case C.SQLITE_TEXT:
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
return reflect.ValueOf(C.GoString(c)), nil
default:
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
}
}
func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_INTEGER:
return callbackArgInt64(v)
case C.SQLITE_FLOAT:
return callbackArgFloat64(v)
case C.SQLITE_TEXT:
return callbackArgString(v)
case C.SQLITE_BLOB:
return callbackArgBytes(v)
case C.SQLITE_NULL:
// Interpret NULL as a nil byte slice.
var ret []byte
return reflect.ValueOf(ret), nil
default:
panic("unreachable")
}
}
func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
switch typ.Kind() {
case reflect.Interface:
if typ.NumMethod() != 0 {
return nil, errors.New("the only supported interface type is interface{}")
}
return callbackArgGeneric, nil
case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 {
return nil, errors.New("the only supported slice type is []byte")
}
return callbackArgBytes, nil
case reflect.String:
return callbackArgString, nil
case reflect.Bool:
return callbackArgBool, nil
case reflect.Int64:
return callbackArgInt64, nil
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
c := callbackArgCast{callbackArgInt64, typ}
return c.Run, nil
case reflect.Float64:
return callbackArgFloat64, nil
case reflect.Float32:
c := callbackArgCast{callbackArgFloat64, typ}
return c.Run, nil
default:
return nil, fmt.Errorf("don't know how to convert to %s", typ)
}
}
func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
var args []reflect.Value
if len(argv) < len(converters) {
return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
}
for i, arg := range argv[:len(converters)] {
v, err := converters[i](arg)
if err != nil {
return nil, err
}
args = append(args, v)
}
if variadic != nil {
for _, arg := range argv[len(converters):] {
v, err := variadic(arg)
if err != nil {
return nil, err
}
args = append(args, v)
}
}
return args, nil
}
type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
switch v.Type().Kind() {
case reflect.Int64:
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
v = v.Convert(reflect.TypeOf(int64(0)))
case reflect.Bool:
b := v.Interface().(bool)
if b {
v = reflect.ValueOf(int64(1))
} else {
v = reflect.ValueOf(int64(0))
}
default:
return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
}
C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
return nil
}
func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
switch v.Type().Kind() {
case reflect.Float64:
case reflect.Float32:
v = v.Convert(reflect.TypeOf(float64(0)))
default:
return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
}
C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
return nil
}
func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("cannot convert %s to BLOB", v.Type())
}
i := v.Interface()
if i == nil || len(i.([]byte)) == 0 {
C.sqlite3_result_null(ctx)
} else {
bs := i.([]byte)
C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
}
return nil
}
func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
if v.Type().Kind() != reflect.String {
return fmt.Errorf("cannot convert %s to TEXT", v.Type())
}
C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
return nil
}
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
switch typ.Kind() {
case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 {
return nil, errors.New("the only supported slice type is []byte")
}
return callbackRetBlob, nil
case reflect.String:
return callbackRetText, nil
case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
return callbackRetInteger, nil
case reflect.Float32, reflect.Float64:
return callbackRetFloat, nil
default:
return nil, fmt.Errorf("don't know how to convert to %s", typ)
}
}
func callbackError(ctx *C.sqlite3_context, err error) {
cstr := C.CString(err.Error())
defer C.free(unsafe.Pointer(cstr))
C.sqlite3_result_error(ctx, cstr, -1)
}
// Test support code. Tests are not allowed to import "C", so we can't
// declare any functions that use C.sqlite3_value.
func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
return func(*C.sqlite3_value) (reflect.Value, error) {
return v, err
}
}

97
callback_test.go Normal file
View File

@ -0,0 +1,97 @@
package sqlite3
import (
"errors"
"math"
"reflect"
"testing"
)
func TestCallbackArgCast(t *testing.T) {
intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil)
floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil)
errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test"))
tests := []struct {
f callbackArgConverter
o reflect.Value
}{
{intConv, reflect.ValueOf(int8(-1))},
{intConv, reflect.ValueOf(int16(-1))},
{intConv, reflect.ValueOf(int32(-1))},
{intConv, reflect.ValueOf(uint8(math.MaxUint8))},
{intConv, reflect.ValueOf(uint16(math.MaxUint16))},
{intConv, reflect.ValueOf(uint32(math.MaxUint32))},
// Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1
{intConv, reflect.ValueOf(uint64(math.MaxInt64))},
{floatConv, reflect.ValueOf(float32(math.Inf(1)))},
}
for _, test := range tests {
conv := callbackArgCast{test.f, test.o.Type()}
val, err := conv.Run(nil)
if err != nil {
t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err)
} else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) {
t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface())
}
}
conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))}
_, err := conv.Run(nil)
if err == nil {
t.Errorf("Expected error during callbackArgCast, but got none")
}
}
func TestCallbackConverters(t *testing.T) {
tests := []struct {
v interface{}
err bool
}{
// Unfortunately, we can't tell which converter was returned,
// but we can at least check which types can be converted.
{[]byte{0}, false},
{"text", false},
{true, false},
{int8(0), false},
{int16(0), false},
{int32(0), false},
{int64(0), false},
{uint8(0), false},
{uint16(0), false},
{uint32(0), false},
{uint64(0), false},
{int(0), false},
{uint(0), false},
{float64(0), false},
{float32(0), false},
{func() {}, true},
{complex64(complex(0, 0)), true},
{complex128(complex(0, 0)), true},
{struct{}{}, true},
{map[string]string{}, true},
{[]string{}, true},
{(*int8)(nil), true},
{make(chan int), true},
}
for _, test := range tests {
_, err := callbackArg(reflect.TypeOf(test.v))
if test.err && err == nil {
t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
} else if !test.err && err != nil {
t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
}
}
for _, test := range tests {
_, err := callbackRet(reflect.TypeOf(test.v))
if test.err && err == nil {
t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
} else if !test.err && err != nil {
t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
}
}
}

35
doc.go
View File

@ -1,7 +1,7 @@
/* /*
Package sqlite3 provides interface to SQLite3 databases. Package sqlite3 provides interface to SQLite3 databases.
This works as driver for database/sql. This works as a driver for database/sql.
Installation Installation
@ -9,7 +9,7 @@ Installation
Supported Types Supported Types
Currently, go-sqlite3 support following data types. Currently, go-sqlite3 supports the following data types.
+------------------------------+ +------------------------------+
|go | sqlite3 | |go | sqlite3 |
@ -26,8 +26,8 @@ Currently, go-sqlite3 support following data types.
SQLite3 Extension SQLite3 Extension
You can write your own extension module for sqlite3. For example, below is a You can write your own extension module for sqlite3. For example, below is an
extension for Regexp matcher operation. extension for a Regexp matcher operation.
#include <pcre.h> #include <pcre.h>
#include <string.h> #include <string.h>
@ -63,8 +63,8 @@ extension for Regexp matcher operation.
(void*)db, regexp_func, NULL, NULL); (void*)db, regexp_func, NULL, NULL);
} }
It need to build as so/dll shared library. And you need to register It needs to be built as a so/dll shared library. And you need to register
extension module like below. the extension module like below.
sql.Register("sqlite3_with_extensions", sql.Register("sqlite3_with_extensions",
&sqlite3.SQLiteDriver{ &sqlite3.SQLiteDriver{
@ -79,9 +79,9 @@ Then, you can use this extension.
Connection Hook Connection Hook
You can hook and inject your codes when connection established. database/sql You can hook and inject your code when the connection is established. database/sql
doesn't provide the way to get native go-sqlite3 interfaces. So if you want, doesn't provide a way to get native go-sqlite3 interfaces. So if you want,
you need to hook ConnectHook and get the SQLiteConn. you need to set ConnectHook and get the SQLiteConn.
sql.Register("sqlite3_with_hook_example", sql.Register("sqlite3_with_hook_example",
&sqlite3.SQLiteDriver{ &sqlite3.SQLiteDriver{
@ -91,5 +91,22 @@ you need to hook ConnectHook and get the SQLiteConn.
}, },
}) })
Go SQlite3 Extensions
If you want to register Go functions as SQLite extension functions,
call RegisterFunction from ConnectHook.
regex = func(re, s string) (bool, error) {
return regexp.MatchString(re, s)
}
sql.Register("sqlite3_with_go_func",
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
return conn.RegisterFunc("regexp", regex, true)
},
})
See the documentation of RegisterFunc for more details.
*/ */
package sqlite3 package sqlite3

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,12 @@
// //
// Use of this source code is governed by an MIT-style // Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//
// Copyright (C) 2014 Xeodou <xeodou@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
//
package sqlite3 package sqlite3
@ -9,9 +15,13 @@ package sqlite3
#cgo CFLAGS: -std=gnu99 #cgo CFLAGS: -std=gnu99
#cgo CFLAGS: -DSQLITE_HAS_CODEC #cgo CFLAGS: -DSQLITE_HAS_CODEC
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE #cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS #cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
#cgo LDFLAGS: -lcrypto #cgo LDFLAGS: -lcrypto
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h> #include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
@ -27,6 +37,10 @@ package sqlite3
# define SQLITE_OPEN_FULLMUTEX 0 # define SQLITE_OPEN_FULLMUTEX 0
#endif #endif
#ifndef SQLITE_DETERMINISTIC
# define SQLITE_DETERMINISTIC 0
#endif
static int static int
_sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) { _sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) {
#ifdef SQLITE_OPEN_URI #ifdef SQLITE_OPEN_URI
@ -50,24 +64,49 @@ _sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
#include <stdint.h> #include <stdint.h>
static int static int
_sqlite3_exec(sqlite3* db, const char* pcmd, long* rowid, long* changes) _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* changes)
{ {
int rv = sqlite3_exec(db, pcmd, 0, 0, 0); int rv = sqlite3_exec(db, pcmd, 0, 0, 0);
*rowid = (long) sqlite3_last_insert_rowid(db); *rowid = (long long) sqlite3_last_insert_rowid(db);
*changes = (long) sqlite3_changes(db); *changes = (long long) sqlite3_changes(db);
return rv; return rv;
} }
static int static int
_sqlite3_step(sqlite3_stmt* stmt, long* rowid, long* changes) _sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes)
{ {
int rv = sqlite3_step(stmt); int rv = sqlite3_step(stmt);
sqlite3* db = sqlite3_db_handle(stmt); sqlite3* db = sqlite3_db_handle(stmt);
*rowid = (long) sqlite3_last_insert_rowid(db); *rowid = (long long) sqlite3_last_insert_rowid(db);
*changes = (long) sqlite3_changes(db); *changes = (long long) sqlite3_changes(db);
return rv; return rv;
} }
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
sqlite3_result_text(ctx, s, -1, &free);
}
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) {
sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT);
}
int _sqlite3_create_function(
sqlite3 *db,
const char *zFunctionName,
int nArg,
int eTextRep,
uintptr_t pApp,
void (*xFunc)(sqlite3_context*,int,sqlite3_value**),
void (*xStep)(sqlite3_context*,int,sqlite3_value**),
void (*xFinal)(sqlite3_context*)
) {
return sqlite3_create_function(db, zFunctionName, nArg, eTextRep, (void*) pApp, xFunc, xStep, xFinal);
}
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void doneTrampoline(sqlite3_context*);
*/ */
import "C" import "C"
import ( import (
@ -77,6 +116,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/url" "net/url"
"reflect"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -89,6 +129,10 @@ import (
// into the database. When parsing a string from a timestamp or // into the database. When parsing a string from a timestamp or
// datetime column, the formats are tried in order. // datetime column, the formats are tried in order.
var SQLiteTimestampFormats = []string{ var SQLiteTimestampFormats = []string{
// By default, store timestamps with whatever timezone they come with.
// When parsed, they will be returned with the same timezone.
"2006-01-02 15:04:05.999999999-07:00",
"2006-01-02T15:04:05.999999999-07:00",
"2006-01-02 15:04:05.999999999", "2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999", "2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05", "2006-01-02 15:04:05",
@ -96,14 +140,13 @@ var SQLiteTimestampFormats = []string{
"2006-01-02 15:04", "2006-01-02 15:04",
"2006-01-02T15:04", "2006-01-02T15:04",
"2006-01-02", "2006-01-02",
"2006-01-02 15:04:05-07:00",
} }
func init() { func init() {
sql.Register("sqlite3", &SQLiteDriver{}) sql.Register("sqlite3", &SQLiteDriver{})
} }
// Return SQLite library Version information. // Version returns SQLite library version information.
func Version() (libVersion string, libVersionNumber int, sourceId string) { func Version() (libVersion string, libVersionNumber int, sourceId string) {
libVersion = C.GoString(C.sqlite3_libversion()) libVersion = C.GoString(C.sqlite3_libversion())
libVersionNumber = int(C.sqlite3_libversion_number()) libVersionNumber = int(C.sqlite3_libversion_number())
@ -122,6 +165,8 @@ type SQLiteConn struct {
db *C.sqlite3 db *C.sqlite3
loc *time.Location loc *time.Location
txlock string txlock string
funcs []*functionInfo
aggregators []*aggInfo
} }
// Tx struct. // Tx struct.
@ -155,6 +200,107 @@ type SQLiteRows struct {
cls bool cls bool
} }
type functionInfo struct {
f reflect.Value
argConverters []callbackArgConverter
variadicConverter callbackArgConverter
retConverter callbackRetConverter
}
func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
args, err := callbackConvertArgs(argv, fi.argConverters, fi.variadicConverter)
if err != nil {
callbackError(ctx, err)
return
}
ret := fi.f.Call(args)
if len(ret) == 2 && ret[1].Interface() != nil {
callbackError(ctx, ret[1].Interface().(error))
return
}
err = fi.retConverter(ctx, ret[0])
if err != nil {
callbackError(ctx, err)
return
}
}
type aggInfo struct {
constructor reflect.Value
// Active aggregator objects for aggregations in flight. The
// aggregators are indexed by a counter stored in the aggregation
// user data space provided by sqlite.
active map[int64]reflect.Value
next int64
stepArgConverters []callbackArgConverter
stepVariadicConverter callbackArgConverter
doneRetConverter callbackRetConverter
}
func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) {
aggIdx := (*int64)(C.sqlite3_aggregate_context(ctx, C.int(8)))
if *aggIdx == 0 {
*aggIdx = ai.next
ret := ai.constructor.Call(nil)
if len(ret) == 2 && ret[1].Interface() != nil {
return 0, reflect.Value{}, ret[1].Interface().(error)
}
if ret[0].IsNil() {
return 0, reflect.Value{}, errors.New("aggregator constructor returned nil state")
}
ai.next++
ai.active[*aggIdx] = ret[0]
}
return *aggIdx, ai.active[*aggIdx], nil
}
func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
_, agg, err := ai.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}
args, err := callbackConvertArgs(argv, ai.stepArgConverters, ai.stepVariadicConverter)
if err != nil {
callbackError(ctx, err)
return
}
ret := agg.MethodByName("Step").Call(args)
if len(ret) == 1 && ret[0].Interface() != nil {
callbackError(ctx, ret[0].Interface().(error))
return
}
}
func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
idx, agg, err := ai.agg(ctx)
if err != nil {
callbackError(ctx, err)
return
}
defer func() { delete(ai.active, idx) }()
ret := agg.MethodByName("Done").Call(nil)
if len(ret) == 2 && ret[1].Interface() != nil {
callbackError(ctx, ret[1].Interface().(error))
return
}
err = ai.doneRetConverter(ctx, ret[0])
if err != nil {
callbackError(ctx, err)
return
}
}
// Commit transaction. // Commit transaction.
func (tx *SQLiteTx) Commit() error { func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec("COMMIT") _, err := tx.c.exec("COMMIT")
@ -167,6 +313,208 @@ func (tx *SQLiteTx) Rollback() error {
return err return err
} }
// RegisterFunc makes a Go function available as a SQLite function.
//
// The Go function can have arguments of the following types: any
// numeric type except complex, bool, []byte, string and
// interface{}. interface{} arguments are given the direct translation
// of the SQLite data type: int64 for INTEGER, float64 for FLOAT,
// []byte for BLOB, string for TEXT.
//
// The function can additionally be variadic, as long as the type of
// the variadic argument is one of the above.
//
// If pure is true. SQLite will assume that the function's return
// value depends only on its inputs, and make more aggressive
// optimizations in its queries.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error {
var fi functionInfo
fi.f = reflect.ValueOf(impl)
t := fi.f.Type()
if t.Kind() != reflect.Func {
return errors.New("Non-function passed to RegisterFunc")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite functions must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
numArgs := t.NumIn()
if t.IsVariadic() {
numArgs--
}
for i := 0; i < numArgs; i++ {
conv, err := callbackArg(t.In(i))
if err != nil {
return err
}
fi.argConverters = append(fi.argConverters, conv)
}
if t.IsVariadic() {
conv, err := callbackArg(t.In(numArgs).Elem())
if err != nil {
return err
}
fi.variadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
numArgs = -1
}
conv, err := callbackRet(t.Out(0))
if err != nil {
return err
}
fi.retConverter = conv
// fi must outlast the database connection, or we'll have dangling pointers.
c.funcs = append(c.funcs, &fi)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(newHandle(c, &fi)), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
//
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
//
// The constructor function and the Step/Done methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
var ai aggInfo
ai.constructor = reflect.ValueOf(impl)
t := ai.constructor.Type()
if t.Kind() != reflect.Func {
return errors.New("non-function passed to RegisterAggregator")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
if t.NumIn() != 0 {
return errors.New("SQLite aggregator constructors must not have arguments")
}
agg := t.Out(0)
switch agg.Kind() {
case reflect.Ptr, reflect.Interface:
default:
return errors.New("SQlite aggregator constructor must return a pointer object")
}
stepFn, found := agg.MethodByName("Step")
if !found {
return errors.New("SQlite aggregator doesn't have a Step() function")
}
step := stepFn.Type
if step.NumOut() != 0 && step.NumOut() != 1 {
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
}
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQlite aggregator Step() return value must be error")
}
stepNArgs := step.NumIn()
start := 0
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
stepNArgs--
start++
}
if step.IsVariadic() {
stepNArgs--
}
for i := start; i < start+stepNArgs; i++ {
conv, err := callbackArg(step.In(i))
if err != nil {
return err
}
ai.stepArgConverters = append(ai.stepArgConverters, conv)
}
if step.IsVariadic() {
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
if err != nil {
return err
}
ai.stepVariadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
stepNArgs = -1
}
doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQlite aggregator doesn't have a Done() function")
}
done := doneFn.Type
doneNArgs := done.NumIn()
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
doneNArgs--
}
if doneNArgs != 0 {
return errors.New("SQlite aggregator Done() function must have no arguments")
}
if done.NumOut() != 1 && done.NumOut() != 2 {
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
}
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator Done() function must be error")
}
conv, err := callbackRet(done.Out(0))
if err != nil {
return err
}
ai.doneRetConverter = conv
ai.active = make(map[int64]reflect.Value)
ai.next = 1
// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// AutoCommit return which currently auto commit or not. // AutoCommit return which currently auto commit or not.
func (c *SQLiteConn) AutoCommit() bool { func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0 return int(C.sqlite3_get_autocommit(c.db)) != 0
@ -245,7 +593,7 @@ func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
pcmd := C.CString(cmd) pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd)) defer C.free(unsafe.Pointer(pcmd))
var rowid, changes C.long var rowid, changes C.longlong
rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes) rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return nil, c.lastError() return nil, c.lastError()
@ -266,12 +614,12 @@ func errorString(err Error) string {
} }
// Open database and return a new connection. // Open database and return a new connection.
// You can specify DSN string with URI filename. // You can specify a DSN string using a URI as the filename.
// test.db // test.db
// file:test.db?cache=shared&mode=memory // file:test.db?cache=shared&mode=memory
// :memory: // :memory:
// file::memory: // file::memory:
// go-sqlite handle especially query parameters. // go-sqlite3 adds the following query parameters to those used by SQLite:
// _loc=XXX // _loc=XXX
// Specify location of time format. It's possible to specify "auto". // Specify location of time format. It's possible to specify "auto".
// _busy_timeout=XXX // _busy_timeout=XXX
@ -357,23 +705,8 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock} conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
if len(d.Extensions) > 0 { if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1) if err := conn.loadExtensions(d.Extensions); err != nil {
if rv != C.SQLITE_OK { return nil, err
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
for _, extension := range d.Extensions {
cext := C.CString(extension)
defer C.free(unsafe.Pointer(cext))
rv = C.sqlite3_load_extension(db, cext, nil, nil)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
}
rv = C.sqlite3_enable_load_extension(db, 0)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
} }
} }
@ -388,6 +721,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
// Close the connection. // Close the connection.
func (c *SQLiteConn) Close() error { func (c *SQLiteConn) Close() error {
deleteHandles(c)
rv := C.sqlite3_close_v2(c.db) rv := C.sqlite3_close_v2(c.db)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return c.lastError() return c.lastError()
@ -397,7 +731,7 @@ func (c *SQLiteConn) Close() error {
return nil return nil
} }
// Prepare query string. Return a new statement. // Prepare the query string. Return a new statement.
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
pquery := C.CString(query) pquery := C.CString(query)
defer C.free(unsafe.Pointer(pquery)) defer C.free(unsafe.Pointer(pquery))
@ -497,13 +831,13 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
case float64: case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v)) rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte: case []byte:
var p *byte if len(v) == 0 {
if len(v) > 0 { rv = C._sqlite3_bind_blob(s.s, n, nil, 0)
p = &v[0] } else {
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v)))
} }
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
case time.Time: case time.Time:
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
} }
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
@ -538,7 +872,7 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
C.sqlite3_clear_bindings(s.s) C.sqlite3_clear_bindings(s.s)
return nil, err return nil, err
} }
var rowid, changes C.long var rowid, changes C.longlong
rv := C._sqlite3_step(s.s, &rowid, &changes) rv := C._sqlite3_step(s.s, &rowid, &changes)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
err := s.c.lastError() err := s.c.lastError()
@ -575,6 +909,17 @@ func (rc *SQLiteRows) Columns() []string {
return rc.cols return rc.cols
} }
// Return column types.
func (rc *SQLiteRows) DeclTypes() []string {
if rc.decltype == nil {
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
}
}
return rc.decltype
}
// Move cursor to next. // Move cursor to next.
func (rc *SQLiteRows) Next(dest []driver.Value) error { func (rc *SQLiteRows) Next(dest []driver.Value) error {
rv := C.sqlite3_step(rc.s.s) rv := C.sqlite3_step(rc.s.s)
@ -589,12 +934,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
return nil return nil
} }
if rc.decltype == nil { rc.DeclTypes()
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
}
}
for i := range dest { for i := range dest {
switch C.sqlite3_column_type(rc.s.s, C.int(i)) { switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
@ -602,18 +942,15 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i))) val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
switch rc.decltype[i] { switch rc.decltype[i] {
case "timestamp", "datetime", "date": case "timestamp", "datetime", "date":
unixTimestamp := strconv.FormatInt(val, 10)
var t time.Time var t time.Time
if len(unixTimestamp) == 13 { // Assume a millisecond unix timestamp if it's 13 digits -- too
duration, err := time.ParseDuration(unixTimestamp + "ms") // large to be a reasonable timestamp in seconds.
if err != nil { if val > 1e12 || val < -1e12 {
return fmt.Errorf("error parsing %s value %d, %s", rc.decltype[i], val, err) val *= int64(time.Millisecond) // convert ms to nsec
}
epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)
t = epoch.Add(duration)
} else { } else {
t = time.Unix(val, 0) val *= int64(time.Second) // convert sec to nsec
} }
t = time.Unix(0, val).UTC()
if rc.s.c.loc != nil { if rc.s.c.loc != nil {
t = t.In(rc.s.c.loc) t = t.In(rc.s.c.loc)
} }

View File

@ -12,12 +12,12 @@ import (
) )
func TestFTS3(t *testing.T) { func TestFTS3(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("DROP TABLE foo") _, err = db.Exec("DROP TABLE foo")
@ -81,3 +81,50 @@ func TestFTS3(t *testing.T) {
t.Fatal("Result should be only one") t.Fatal("Result should be only one")
} }
} }
func TestFTS4(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec("DROP TABLE foo")
_, err = db.Exec("CREATE VIRTUAL TABLE foo USING fts4(tokenize=unicode61, id INTEGER PRIMARY KEY, value TEXT)")
switch {
case err != nil && err.Error() == "unknown tokenizer: unicode61":
t.Skip("FTS4 not supported")
case err != nil:
t.Fatal("Failed to create table:", err)
}
_, err = db.Exec("INSERT INTO foo(id, value) VALUES(?, ?)", 1, `février`)
if err != nil {
t.Fatal("Failed to insert value:", err)
}
rows, err := db.Query("SELECT value FROM foo WHERE value MATCH 'fevrier'")
if err != nil {
t.Fatal("Unable to query foo table:", err)
}
defer rows.Close()
var value string
if !rows.Next() {
t.Fatal("Result should be only one")
}
if err := rows.Scan(&value); err != nil {
t.Fatal("Unable to scan results:", err)
}
if value != `février` {
t.Fatal("Value should be `février`, but:", value)
}
if rows.Next() {
t.Fatal("Result should be only one")
}
}

13
sqlite3_fts5.go Normal file
View File

@ -0,0 +1,13 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build fts5
package sqlite3
/*
#cgo CFLAGS: -DSQLITE_ENABLE_FTS5
#cgo LDFLAGS: -lm
*/
import "C"

13
sqlite3_icu.go Normal file
View File

@ -0,0 +1,13 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build icu
package sqlite3
/*
#cgo LDFLAGS: -licuuc -licui18n
#cgo CFLAGS: -DSQLITE_ENABLE_ICU
*/
import "C"

12
sqlite3_json1.go Normal file
View File

@ -0,0 +1,12 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build json1
package sqlite3
/*
#cgo CFLAGS: -DSQLITE_ENABLE_JSON1
*/
import "C"

14
sqlite3_libsqlite3.go Normal file
View File

@ -0,0 +1,14 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build libsqlite3
package sqlite3
/*
#cgo CFLAGS: -DUSE_LIBSQLITE3
#cgo linux LDFLAGS: -lsqlite3
#cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3
*/
import "C"

63
sqlite3_load_extension.go Normal file
View File

@ -0,0 +1,63 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build !sqlite_omit_load_extension
package sqlite3
/*
#include <sqlite3-binding.h>
#include <stdlib.h>
*/
import "C"
import (
"errors"
"unsafe"
)
func (c *SQLiteConn) loadExtensions(extensions []string) error {
rv := C.sqlite3_enable_load_extension(c.db, 1)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
for _, extension := range extensions {
cext := C.CString(extension)
defer C.free(unsafe.Pointer(cext))
rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
}
rv = C.sqlite3_enable_load_extension(c.db, 0)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
return nil
}
func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
rv := C.sqlite3_enable_load_extension(c.db, 1)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
clib := C.CString(lib)
defer C.free(unsafe.Pointer(clib))
centry := C.CString(entry)
defer C.free(unsafe.Pointer(centry))
rv = C.sqlite3_load_extension(c.db, clib, centry, nil)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
rv = C.sqlite3_enable_load_extension(c.db, 0)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
}
return nil
}

View File

@ -0,0 +1,23 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build sqlite_omit_load_extension
package sqlite3
/*
#cgo CFLAGS: -DSQLITE_OMIT_LOAD_EXTENSION
*/
import "C"
import (
"errors"
)
func (c *SQLiteConn) loadExtensions(extensions []string) error {
return errors.New("Extensions have been disabled for static builds")
}
func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
return errors.New("Extensions have been disabled for static builds")
}

View File

@ -6,31 +6,36 @@
package sqlite3 package sqlite3
import ( import (
"crypto/rand"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net/url" "net/url"
"os" "os"
"path/filepath" "reflect"
"regexp"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
"github.com/xeodou/go-sqlcipher/sqlite3_test" "github.com/xeodou/go-sqlcipher/sqlite3_test"
) )
func TempFilename() string { func TempFilename(t *testing.T) string {
randBytes := make([]byte, 16) f, err := ioutil.TempFile("", "go-sqlite3-test-")
rand.Read(randBytes) if err != nil {
return filepath.Join(os.TempDir(), "foo"+hex.EncodeToString(randBytes)+".db") t.Fatal(err)
}
f.Close()
return f.Name()
} }
func doTestOpen(t *testing.T, option string) (string, error) { func doTestOpen(t *testing.T, option string) (string, error) {
var url string var url string
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
if option != "" { if option != "" {
url = tempFilename + option url = tempFilename + option
} else { } else {
@ -81,13 +86,34 @@ func TestOpen(t *testing.T) {
} }
} }
func TestReadonly(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db1, err := sql.Open("sqlite3", "file:"+tempFilename)
if err != nil {
t.Fatal(err)
}
db1.Exec("CREATE TABLE test (x int, y float)")
db2, err := sql.Open("sqlite3", "file:"+tempFilename+"?mode=ro")
if err != nil {
t.Fatal(err)
}
_ = db2
_, err = db2.Exec("INSERT INTO test VALUES (1, 3.14)")
if err == nil {
t.Fatal("didn't expect INSERT into read-only database to work")
}
}
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
_, err = db.Exec("drop table foo") _, err = db.Exec("drop table foo")
_, err = db.Exec("create table foo (id integer)") _, err = db.Exec("create table foo (id integer)")
@ -108,12 +134,12 @@ func TestClose(t *testing.T) {
} }
func TestInsert(t *testing.T) { func TestInsert(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("drop table foo") _, err = db.Exec("drop table foo")
@ -142,17 +168,17 @@ func TestInsert(t *testing.T) {
var result int var result int
rows.Scan(&result) rows.Scan(&result)
if result != 123 { if result != 123 {
t.Errorf("Fetched %q; expected %q", 123, result) t.Errorf("Expected %d for fetched result, but %d:", 123, result)
} }
} }
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("drop table foo") _, err = db.Exec("drop table foo")
@ -207,17 +233,17 @@ func TestUpdate(t *testing.T) {
var result int var result int
rows.Scan(&result) rows.Scan(&result)
if result != 234 { if result != 234 {
t.Errorf("Fetched %q; expected %q", 234, result) t.Errorf("Expected %d for fetched result, but %d:", 234, result)
} }
} }
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("drop table foo") _, err = db.Exec("drop table foo")
@ -273,12 +299,12 @@ func TestDelete(t *testing.T) {
} }
func TestBooleanRoundtrip(t *testing.T) { func TestBooleanRoundtrip(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("DROP TABLE foo") _, err = db.Exec("DROP TABLE foo")
@ -321,13 +347,15 @@ func TestBooleanRoundtrip(t *testing.T) {
} }
} }
func timezone(t time.Time) string { return t.Format("-07:00") }
func TestTimestamp(t *testing.T) { func TestTimestamp(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("DROP TABLE foo") _, err = db.Exec("DROP TABLE foo")
@ -339,6 +367,7 @@ func TestTimestamp(t *testing.T) {
timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC)
timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC) timestamp2 := time.Date(2006, time.January, 2, 15, 4, 5, 123456789, time.UTC)
timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC) timestamp3 := time.Date(2012, time.November, 4, 0, 0, 0, 0, time.UTC)
tzTest := time.FixedZone("TEST", -9*3600-13*60)
tests := []struct { tests := []struct {
value interface{} value interface{}
expected time.Time expected time.Time
@ -346,9 +375,9 @@ func TestTimestamp(t *testing.T) {
{"nonsense", time.Time{}}, {"nonsense", time.Time{}},
{"0000-00-00 00:00:00", time.Time{}}, {"0000-00-00 00:00:00", time.Time{}},
{timestamp1, timestamp1}, {timestamp1, timestamp1},
{timestamp1.Unix(), timestamp1}, {timestamp2.Unix(), timestamp2.Truncate(time.Second)},
{timestamp1.UnixNano() / int64(time.Millisecond), timestamp1}, {timestamp2.UnixNano() / int64(time.Millisecond), timestamp2.Truncate(time.Millisecond)},
{timestamp1.In(time.FixedZone("TEST", -7*3600)), timestamp1}, {timestamp1.In(tzTest), timestamp1.In(tzTest)},
{timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1}, {timestamp1.Format("2006-01-02 15:04:05.000"), timestamp1},
{timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1}, {timestamp1.Format("2006-01-02T15:04:05.000"), timestamp1},
{timestamp1.Format("2006-01-02 15:04:05"), timestamp1}, {timestamp1.Format("2006-01-02 15:04:05"), timestamp1},
@ -356,6 +385,7 @@ func TestTimestamp(t *testing.T) {
{timestamp2, timestamp2}, {timestamp2, timestamp2},
{"2006-01-02 15:04:05.123456789", timestamp2}, {"2006-01-02 15:04:05.123456789", timestamp2},
{"2006-01-02T15:04:05.123456789", timestamp2}, {"2006-01-02T15:04:05.123456789", timestamp2},
{"2006-01-02T05:51:05.123456789-09:13", timestamp2.In(tzTest)},
{"2012-11-04", timestamp3}, {"2012-11-04", timestamp3},
{"2012-11-04 00:00", timestamp3}, {"2012-11-04 00:00", timestamp3},
{"2012-11-04 00:00:00", timestamp3}, {"2012-11-04 00:00:00", timestamp3},
@ -363,6 +393,14 @@ func TestTimestamp(t *testing.T) {
{"2012-11-04T00:00", timestamp3}, {"2012-11-04T00:00", timestamp3},
{"2012-11-04T00:00:00", timestamp3}, {"2012-11-04T00:00:00", timestamp3},
{"2012-11-04T00:00:00.000", timestamp3}, {"2012-11-04T00:00:00.000", timestamp3},
{"2006-01-02T15:04:05.123456789Z", timestamp2},
{"2012-11-04Z", timestamp3},
{"2012-11-04 00:00Z", timestamp3},
{"2012-11-04 00:00:00Z", timestamp3},
{"2012-11-04 00:00:00.000Z", timestamp3},
{"2012-11-04T00:00Z", timestamp3},
{"2012-11-04T00:00:00Z", timestamp3},
{"2012-11-04T00:00:00.000Z", timestamp3},
} }
for i := range tests { for i := range tests {
_, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value) _, err = db.Exec("INSERT INTO foo(id, ts, dt) VALUES(?, ?, ?)", i, tests[i].value, tests[i].value)
@ -397,6 +435,14 @@ func TestTimestamp(t *testing.T) {
if !tests[id].expected.Equal(dt) { if !tests[id].expected.Equal(dt) {
t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt) t.Errorf("Datetime value for id %v (%v) should be %v, not %v", id, tests[id].value, tests[id].expected, dt)
} }
if timezone(tests[id].expected) != timezone(ts) {
t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value,
timezone(tests[id].expected), timezone(ts))
}
if timezone(tests[id].expected) != timezone(dt) {
t.Errorf("Timezone for id %v (%v) should be %v, not %v", id, tests[id].value,
timezone(tests[id].expected), timezone(dt))
}
} }
if seen != len(tests) { if seen != len(tests) {
@ -405,13 +451,13 @@ func TestTimestamp(t *testing.T) {
} }
func TestBoolean(t *testing.T) { func TestBoolean(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)") _, err = db.Exec("CREATE TABLE foo(id INTEGER, fbool BOOLEAN)")
@ -497,13 +543,12 @@ func TestBoolean(t *testing.T) {
} }
func TestFloat32(t *testing.T) { func TestFloat32(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("CREATE TABLE foo(id INTEGER)") _, err = db.Exec("CREATE TABLE foo(id INTEGER)")
@ -535,13 +580,12 @@ func TestFloat32(t *testing.T) {
} }
func TestNull(t *testing.T) { func TestNull(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
rows, err := db.Query("SELECT 3.141592") rows, err := db.Query("SELECT 3.141592")
@ -567,13 +611,12 @@ func TestNull(t *testing.T) {
} }
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("CREATE TABLE foo(id INTEGER)") _, err = db.Exec("CREATE TABLE foo(id INTEGER)")
@ -627,14 +670,14 @@ func TestTransaction(t *testing.T) {
} }
func TestWAL(t *testing.T) { func TestWAL(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
if _, err = db.Exec("PRAGMA journal_mode=WAL;"); err != nil { if _, err = db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
t.Fatal("Failed to Exec PRAGMA journal_mode:", err) t.Fatal("Failed to Exec PRAGMA journal_mode:", err)
} }
@ -675,12 +718,12 @@ func TestWAL(t *testing.T) {
func TestTimezoneConversion(t *testing.T) { func TestTimezoneConversion(t *testing.T) {
zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} zones := []string{"UTC", "US/Central", "US/Pacific", "Local"}
for _, tz := range zones { for _, tz := range zones {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz)) db, err := sql.Open("sqlite3", tempFilename+"?_loc="+url.QueryEscape(tz))
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec("DROP TABLE foo") _, err = db.Exec("DROP TABLE foo")
@ -769,7 +812,9 @@ func TestTimezoneConversion(t *testing.T) {
} }
func TestSuite(t *testing.T) { func TestSuite(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:") tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -781,12 +826,12 @@ func TestSuite(t *testing.T) {
// TODO: Execer & Queryer currently disabled // TODO: Execer & Queryer currently disabled
// https://github.com/mattn/go-sqlite3/issues/82 // https://github.com/mattn/go-sqlite3/issues/82
func TestExecer(t *testing.T) { func TestExecer(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec(` _, err = db.Exec(`
@ -801,12 +846,12 @@ func TestExecer(t *testing.T) {
} }
func TestQueryer(t *testing.T) { func TestQueryer(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec(` _, err = db.Exec(`
@ -842,7 +887,8 @@ func TestQueryer(t *testing.T) {
} }
func TestStress(t *testing.T) { func TestStress(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
@ -880,7 +926,8 @@ func TestStress(t *testing.T) {
func TestDateTimeLocal(t *testing.T) { func TestDateTimeLocal(t *testing.T) {
zone := "Asia/Tokyo" zone := "Asia/Tokyo"
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone) db, err := sql.Open("sqlite3", tempFilename+"?_loc="+zone)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
@ -947,12 +994,12 @@ func TestVersion(t *testing.T) {
} }
func TestNumberNamedParams(t *testing.T) { func TestNumberNamedParams(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec(` _, err = db.Exec(`
@ -983,7 +1030,7 @@ func TestNumberNamedParams(t *testing.T) {
} }
func TestEncryptoDatabase(t *testing.T) { func TestEncryptoDatabase(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
@ -1062,12 +1109,12 @@ func TestEncryptoDatabase(t *testing.T) {
} }
func TestStringContainingZero(t *testing.T) { func TestStringContainingZero(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
_, err = db.Exec(` _, err = db.Exec(`
@ -1122,7 +1169,8 @@ func (t TimeStamp) Value() (driver.Value, error) {
} }
func TestDateTimeNow(t *testing.T) { func TestDateTimeNow(t *testing.T) {
tempFilename := TempFilename() tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename) db, err := sql.Open("sqlite3", tempFilename)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
@ -1135,3 +1183,247 @@ func TestDateTimeNow(t *testing.T) {
t.Fatal("Failed to scan datetime:", err) t.Fatal("Failed to scan datetime:", err)
} }
} }
func TestFunctionRegistration(t *testing.T) {
addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) }
addi_64 := func(a, b int64) int64 { return a + b }
addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) }
addu_64 := func(a, b uint64) uint64 { return a + b }
addiu := func(a int, b uint) int64 { return int64(a) + int64(b) }
addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b }
not := func(a bool) bool { return !a }
regex := func(re, s string) (bool, error) {
return regexp.MatchString(re, s)
}
generic := func(a interface{}) int64 {
switch a.(type) {
case int64:
return 1
case float64:
return 2
case []byte:
return 3
case string:
return 4
default:
panic("unreachable")
}
}
variadic := func(a, b int64, c ...int64) int64 {
ret := a + b
for _, d := range c {
ret += d
}
return ret
}
variadicGeneric := func(a ...interface{}) int64 {
return int64(len(a))
}
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil {
return err
}
if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil {
return err
}
if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil {
return err
}
if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil {
return err
}
if err := conn.RegisterFunc("addiu", addiu, true); err != nil {
return err
}
if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil {
return err
}
if err := conn.RegisterFunc("not", not, true); err != nil {
return err
}
if err := conn.RegisterFunc("regex", regex, true); err != nil {
return err
}
if err := conn.RegisterFunc("generic", generic, true); err != nil {
return err
}
if err := conn.RegisterFunc("variadic", variadic, true); err != nil {
return err
}
if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:")
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
ops := []struct {
query string
expected interface{}
}{
{"SELECT addi_8_16_32(1,2)", int32(3)},
{"SELECT addi_64(1,2)", int64(3)},
{"SELECT addu_8_16_32(1,2)", uint32(3)},
{"SELECT addu_64(1,2)", uint64(3)},
{"SELECT addiu(1,2)", int64(3)},
{"SELECT addf_32_64(1.5,1.5)", float64(3)},
{"SELECT not(1)", false},
{"SELECT not(0)", true},
{`SELECT regex("^foo.*", "foobar")`, true},
{`SELECT regex("^foo.*", "barfoobar")`, false},
{"SELECT generic(1)", int64(1)},
{"SELECT generic(1.1)", int64(2)},
{`SELECT generic(NULL)`, int64(3)},
{`SELECT generic("foo")`, int64(4)},
{"SELECT variadic(1,2)", int64(3)},
{"SELECT variadic(1,2,3,4)", int64(10)},
{"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)},
{`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)},
}
for _, op := range ops {
ret := reflect.New(reflect.TypeOf(op.expected))
err = db.QueryRow(op.query).Scan(ret.Interface())
if err != nil {
t.Errorf("Query %q failed: %s", op.query, err)
} else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) {
t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected)
}
}
}
type sumAggregator int64
func (s *sumAggregator) Step(x int64) {
*s += sumAggregator(x)
}
func (s *sumAggregator) Done() int64 {
return int64(*s)
}
func TestAggregatorRegistration(t *testing.T) {
customSum := func() *sumAggregator {
var ret sumAggregator
return &ret
}
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec("create table foo (department integer, profits integer)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
if err != nil {
t.Fatal("Failed to insert records:", err)
}
tests := []struct {
dept, sum int64
}{
{1, 30},
{2, 42},
}
for _, test := range tests {
var ret int64
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
if err != nil {
t.Fatal("Query failed:", err)
}
if ret != test.sum {
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
}
}
}
func TestDeclTypes(t *testing.T) {
d := SQLiteDriver{}
conn, err := d.Open(":memory:")
if err != nil {
t.Fatal("Failed to begin transaction:", err)
}
defer conn.Close()
sqlite3conn := conn.(*SQLiteConn)
_, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text)", nil)
if err != nil {
t.Fatal("Failed to create table:", err)
}
_, err = sqlite3conn.Exec("insert into foo(name) values(\"bar\")", nil)
if err != nil {
t.Fatal("Failed to insert:", err)
}
rs, err := sqlite3conn.Query("select * from foo", nil)
if err != nil {
t.Fatal("Failed to select:", err)
}
defer rs.Close()
declTypes := rs.(*SQLiteRows).DeclTypes()
if !reflect.DeepEqual(declTypes, []string{"integer", "text"}) {
t.Fatal("Unexpected declTypes:", declTypes)
}
}
var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) {
customFunctionOnce.Do(func() {
custom_add := func(a, b int64) int64 {
return a + b
}
sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
// Impure function to force sqlite to reexecute it each time.
if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil {
return err
}
return nil
},
})
})
db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:")
if err != nil {
b.Fatal("Failed to open database:", err)
}
defer db.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
var i int64
err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i)
if err != nil {
b.Fatal("Failed to run custom add:", err)
}
}
}

View File

@ -275,12 +275,11 @@ func TestPreparedStmt(t *testing.T) {
} }
const nRuns = 10 const nRuns = 10
ch := make(chan bool) var wg sync.WaitGroup
for i := 0; i < nRuns; i++ { for i := 0; i < nRuns; i++ {
wg.Add(1)
go func() { go func() {
defer func() { defer wg.Done()
ch <- true
}()
for j := 0; j < 10; j++ { for j := 0; j < 10; j++ {
count := 0 count := 0
if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows { if err := sel.QueryRow().Scan(&count); err != nil && err != sql.ErrNoRows {
@ -294,9 +293,7 @@ func TestPreparedStmt(t *testing.T) {
} }
}() }()
} }
for i := 0; i < nRuns; i++ { wg.Wait()
<-ch
}
} }
// Benchmarks need to use panic() since b.Error errors are lost when // Benchmarks need to use panic() since b.Error errors are lost when
@ -318,7 +315,7 @@ func BenchmarkQuery(b *testing.B) {
var i int var i int
var f float64 var f float64
var s string var s string
// var t time.Time // var t time.Time
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
panic(err) panic(err)
} }
@ -331,7 +328,7 @@ func BenchmarkParams(b *testing.B) {
var i int var i int
var f float64 var f float64
var s string var s string
// var t time.Time // var t time.Time
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err) panic(err)
} }
@ -350,7 +347,7 @@ func BenchmarkStmt(b *testing.B) {
var i int var i int
var f float64 var f float64
var s string var s string
// var t time.Time // var t time.Time
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err) panic(err)
} }

View File

@ -8,7 +8,7 @@ package sqlite3
/* /*
#cgo CFLAGS: -I. -fno-stack-check -fno-stack-protector -mno-stack-arg-probe #cgo CFLAGS: -I. -fno-stack-check -fno-stack-protector -mno-stack-arg-probe
#cgo windows,386 CFLAGS: -D_localtime32=localtime #cgo windows,386 CFLAGS: -D_USE_32BIT_TIME_T
#cgo LDFLAGS: -lmingwex -lmingw32 #cgo LDFLAGS: -lmingwex -lmingw32
*/ */
import "C" import "C"

View File

@ -267,6 +267,23 @@ struct sqlite3_api_routines {
void (*result_text64)(sqlite3_context*,const char*,sqlite3_uint64, void (*result_text64)(sqlite3_context*,const char*,sqlite3_uint64,
void(*)(void*), unsigned char); void(*)(void*), unsigned char);
int (*strglob)(const char*,const char*); int (*strglob)(const char*,const char*);
<<<<<<< HEAD
=======
/* Version 3.8.11 and later */
sqlite3_value *(*value_dup)(const sqlite3_value*);
void (*value_free)(sqlite3_value*);
int (*result_zeroblob64)(sqlite3_context*,sqlite3_uint64);
int (*bind_zeroblob64)(sqlite3_stmt*, int, sqlite3_uint64);
/* Version 3.9.0 and later */
unsigned int (*value_subtype)(sqlite3_value*);
void (*result_subtype)(sqlite3_context*,unsigned int);
/* Version 3.10.0 and later */
int (*status64)(int,sqlite3_int64*,sqlite3_int64*,int);
int (*strlike)(const char*,const char*,unsigned int);
int (*db_cacheflush)(sqlite3*);
/* Version 3.12.0 and later */
int (*system_errno)(sqlite3*);
>>>>>>> mattn/master
}; };
/* /*
@ -280,7 +297,7 @@ struct sqlite3_api_routines {
** the API. So the redefinition macros are only valid if the ** the API. So the redefinition macros are only valid if the
** SQLITE_CORE macros is undefined. ** SQLITE_CORE macros is undefined.
*/ */
#ifndef SQLITE_CORE #if !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION)
#define sqlite3_aggregate_context sqlite3_api->aggregate_context #define sqlite3_aggregate_context sqlite3_api->aggregate_context
#ifndef SQLITE_OMIT_DEPRECATED #ifndef SQLITE_OMIT_DEPRECATED
#define sqlite3_aggregate_count sqlite3_api->aggregate_count #define sqlite3_aggregate_count sqlite3_api->aggregate_count
@ -407,6 +424,7 @@ struct sqlite3_api_routines {
#define sqlite3_value_text16le sqlite3_api->value_text16le #define sqlite3_value_text16le sqlite3_api->value_text16le
#define sqlite3_value_type sqlite3_api->value_type #define sqlite3_value_type sqlite3_api->value_type
#define sqlite3_vmprintf sqlite3_api->vmprintf #define sqlite3_vmprintf sqlite3_api->vmprintf
#define sqlite3_vsnprintf sqlite3_api->vsnprintf
#define sqlite3_overload_function sqlite3_api->overload_function #define sqlite3_overload_function sqlite3_api->overload_function
#define sqlite3_prepare_v2 sqlite3_api->prepare_v2 #define sqlite3_prepare_v2 sqlite3_api->prepare_v2
#define sqlite3_prepare16_v2 sqlite3_api->prepare16_v2 #define sqlite3_prepare16_v2 sqlite3_api->prepare16_v2
@ -497,9 +515,29 @@ struct sqlite3_api_routines {
#define sqlite3_result_blob64 sqlite3_api->result_blob64 #define sqlite3_result_blob64 sqlite3_api->result_blob64
#define sqlite3_result_text64 sqlite3_api->result_text64 #define sqlite3_result_text64 sqlite3_api->result_text64
#define sqlite3_strglob sqlite3_api->strglob #define sqlite3_strglob sqlite3_api->strglob
<<<<<<< HEAD
#endif /* SQLITE_CORE */ #endif /* SQLITE_CORE */
#ifndef SQLITE_CORE #ifndef SQLITE_CORE
=======
/* Version 3.8.11 and later */
#define sqlite3_value_dup sqlite3_api->value_dup
#define sqlite3_value_free sqlite3_api->value_free
#define sqlite3_result_zeroblob64 sqlite3_api->result_zeroblob64
#define sqlite3_bind_zeroblob64 sqlite3_api->bind_zeroblob64
/* Version 3.9.0 and later */
#define sqlite3_value_subtype sqlite3_api->value_subtype
#define sqlite3_result_subtype sqlite3_api->result_subtype
/* Version 3.10.0 and later */
#define sqlite3_status64 sqlite3_api->status64
#define sqlite3_strlike sqlite3_api->strlike
#define sqlite3_db_cacheflush sqlite3_api->db_cacheflush
/* Version 3.12.0 and later */
#define sqlite3_system_errno sqlite3_api->system_errno
#endif /* !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION) */
#if !defined(SQLITE_CORE) && !defined(SQLITE_OMIT_LOAD_EXTENSION)
>>>>>>> mattn/master
/* This case when the file really is being compiled as a loadable /* This case when the file really is being compiled as a loadable
** extension */ ** extension */
# define SQLITE_EXTENSION_INIT1 const sqlite3_api_routines *sqlite3_api=0; # define SQLITE_EXTENSION_INIT1 const sqlite3_api_routines *sqlite3_api=0;