mirror of https://github.com/mattn/go-sqlite3.git
Merge pull request #229 from danderson/master
Implement support for calling Go functions from SQLite
This commit is contained in:
commit
0bb7f1c676
Binary file not shown.
|
@ -0,0 +1,133 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
|
||||||
|
sqlite "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Computes x^y
|
||||||
|
func pow(x, y int64) int64 {
|
||||||
|
return int64(math.Pow(float64(x), float64(y)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes the bitwise exclusive-or of all its arguments
|
||||||
|
func xor(xs ...int64) int64 {
|
||||||
|
var ret int64
|
||||||
|
for _, x := range xs {
|
||||||
|
ret ^= x
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a random number. It's actually deterministic here because
|
||||||
|
// we don't seed the RNG, but it's an example of a non-pure function
|
||||||
|
// from SQLite's POV.
|
||||||
|
func getrand() int64 {
|
||||||
|
return rand.Int63()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes the standard deviation of a GROUPed BY set of values
|
||||||
|
type stddev struct {
|
||||||
|
xs []int64
|
||||||
|
// Running average calculation
|
||||||
|
sum int64
|
||||||
|
n int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStddev() *stddev { return &stddev{} }
|
||||||
|
|
||||||
|
func (s *stddev) Step(x int64) {
|
||||||
|
s.xs = append(s.xs, x)
|
||||||
|
s.sum += x
|
||||||
|
s.n++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stddev) Done() float64 {
|
||||||
|
mean := float64(s.sum) / float64(s.n)
|
||||||
|
var sqDiff []float64
|
||||||
|
for _, x := range s.xs {
|
||||||
|
sqDiff = append(sqDiff, math.Pow(float64(x)-mean, 2))
|
||||||
|
}
|
||||||
|
var dev float64
|
||||||
|
for _, x := range sqDiff {
|
||||||
|
dev += x
|
||||||
|
}
|
||||||
|
dev /= float64(len(sqDiff))
|
||||||
|
return math.Sqrt(dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
sql.Register("sqlite3_custom", &sqlite.SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *sqlite.SQLiteConn) error {
|
||||||
|
if err := conn.RegisterFunc("pow", pow, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("xor", xor, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("rand", getrand, false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterAggregator("stddev", newStddev, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3_custom", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var i int64
|
||||||
|
err = db.QueryRow("SELECT pow(2,3)").Scan(&i)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("POW query error:", err)
|
||||||
|
}
|
||||||
|
fmt.Println("pow(2,3) =", i) // 8
|
||||||
|
|
||||||
|
err = db.QueryRow("SELECT xor(1,2,3,4,5,6)").Scan(&i)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("XOR query error:", err)
|
||||||
|
}
|
||||||
|
fmt.Println("xor(1,2,3,4,5) =", i) // 7
|
||||||
|
|
||||||
|
err = db.QueryRow("SELECT rand()").Scan(&i)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("RAND query error:", err)
|
||||||
|
}
|
||||||
|
fmt.Println("rand() =", i) // pseudorandom
|
||||||
|
|
||||||
|
_, err = db.Exec("create table foo (department integer, profits integer)")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Failed to create table:", err)
|
||||||
|
}
|
||||||
|
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115)")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Failed to insert records:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query("select department, stddev(profits) from foo group by department")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("STDDEV query error:", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var dept int64
|
||||||
|
var dev float64
|
||||||
|
if err := rows.Scan(&dept, &dev); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("dept=%d stddev=%f\n", dept, dev)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,289 @@
|
||||||
|
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
// You can't export a Go function to C and have definitions in the C
|
||||||
|
// preamble in the same file, so we have to have callbackTrampoline in
|
||||||
|
// its own file. Because we need a separate file anyway, the support
|
||||||
|
// code for SQLite custom functions is in here.
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <sqlite3-binding.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
|
||||||
|
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
//export callbackTrampoline
|
||||||
|
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
|
||||||
|
args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||||
|
fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
|
||||||
|
fi.Call(ctx, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
//export stepTrampoline
|
||||||
|
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
|
||||||
|
args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||||
|
ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
|
||||||
|
ai.Step(ctx, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
//export doneTrampoline
|
||||||
|
func doneTrampoline(ctx *C.sqlite3_context) {
|
||||||
|
ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
|
||||||
|
ai.Done(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is only here so that tests can refer to it.
|
||||||
|
type callbackArgRaw C.sqlite3_value
|
||||||
|
|
||||||
|
type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
|
||||||
|
|
||||||
|
type callbackArgCast struct {
|
||||||
|
f callbackArgConverter
|
||||||
|
typ reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
val, err := c.f(v)
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, err
|
||||||
|
}
|
||||||
|
if !val.Type().ConvertibleTo(c.typ) {
|
||||||
|
return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
|
||||||
|
}
|
||||||
|
return val.Convert(c.typ), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
|
||||||
|
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
|
||||||
|
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
|
||||||
|
}
|
||||||
|
i := int64(C.sqlite3_value_int64(v))
|
||||||
|
val := false
|
||||||
|
if i != 0 {
|
||||||
|
val = true
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(val), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
|
||||||
|
return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
switch C.sqlite3_value_type(v) {
|
||||||
|
case C.SQLITE_BLOB:
|
||||||
|
l := C.sqlite3_value_bytes(v)
|
||||||
|
p := C.sqlite3_value_blob(v)
|
||||||
|
return reflect.ValueOf(C.GoBytes(p, l)), nil
|
||||||
|
case C.SQLITE_TEXT:
|
||||||
|
l := C.sqlite3_value_bytes(v)
|
||||||
|
c := unsafe.Pointer(C.sqlite3_value_text(v))
|
||||||
|
return reflect.ValueOf(C.GoBytes(c, l)), nil
|
||||||
|
default:
|
||||||
|
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
switch C.sqlite3_value_type(v) {
|
||||||
|
case C.SQLITE_BLOB:
|
||||||
|
l := C.sqlite3_value_bytes(v)
|
||||||
|
p := (*C.char)(C.sqlite3_value_blob(v))
|
||||||
|
return reflect.ValueOf(C.GoStringN(p, l)), nil
|
||||||
|
case C.SQLITE_TEXT:
|
||||||
|
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
|
||||||
|
return reflect.ValueOf(C.GoString(c)), nil
|
||||||
|
default:
|
||||||
|
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
switch C.sqlite3_value_type(v) {
|
||||||
|
case C.SQLITE_INTEGER:
|
||||||
|
return callbackArgInt64(v)
|
||||||
|
case C.SQLITE_FLOAT:
|
||||||
|
return callbackArgFloat64(v)
|
||||||
|
case C.SQLITE_TEXT:
|
||||||
|
return callbackArgString(v)
|
||||||
|
case C.SQLITE_BLOB:
|
||||||
|
return callbackArgBytes(v)
|
||||||
|
case C.SQLITE_NULL:
|
||||||
|
// Interpret NULL as a nil byte slice.
|
||||||
|
var ret []byte
|
||||||
|
return reflect.ValueOf(ret), nil
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
|
||||||
|
switch typ.Kind() {
|
||||||
|
case reflect.Interface:
|
||||||
|
if typ.NumMethod() != 0 {
|
||||||
|
return nil, errors.New("the only supported interface type is interface{}")
|
||||||
|
}
|
||||||
|
return callbackArgGeneric, nil
|
||||||
|
case reflect.Slice:
|
||||||
|
if typ.Elem().Kind() != reflect.Uint8 {
|
||||||
|
return nil, errors.New("the only supported slice type is []byte")
|
||||||
|
}
|
||||||
|
return callbackArgBytes, nil
|
||||||
|
case reflect.String:
|
||||||
|
return callbackArgString, nil
|
||||||
|
case reflect.Bool:
|
||||||
|
return callbackArgBool, nil
|
||||||
|
case reflect.Int64:
|
||||||
|
return callbackArgInt64, nil
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||||
|
c := callbackArgCast{callbackArgInt64, typ}
|
||||||
|
return c.Run, nil
|
||||||
|
case reflect.Float64:
|
||||||
|
return callbackArgFloat64, nil
|
||||||
|
case reflect.Float32:
|
||||||
|
c := callbackArgCast{callbackArgFloat64, typ}
|
||||||
|
return c.Run, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("don't know how to convert to %s", typ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
|
||||||
|
var args []reflect.Value
|
||||||
|
|
||||||
|
if len(argv) < len(converters) {
|
||||||
|
return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, arg := range argv[:len(converters)] {
|
||||||
|
v, err := converters[i](arg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
args = append(args, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if variadic != nil {
|
||||||
|
for _, arg := range argv[len(converters):] {
|
||||||
|
v, err := variadic(arg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
args = append(args, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
|
||||||
|
|
||||||
|
func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||||
|
switch v.Type().Kind() {
|
||||||
|
case reflect.Int64:
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||||
|
v = v.Convert(reflect.TypeOf(int64(0)))
|
||||||
|
case reflect.Bool:
|
||||||
|
b := v.Interface().(bool)
|
||||||
|
if b {
|
||||||
|
v = reflect.ValueOf(int64(1))
|
||||||
|
} else {
|
||||||
|
v = reflect.ValueOf(int64(0))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||||
|
switch v.Type().Kind() {
|
||||||
|
case reflect.Float64:
|
||||||
|
case reflect.Float32:
|
||||||
|
v = v.Convert(reflect.TypeOf(float64(0)))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||||
|
if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
|
||||||
|
return fmt.Errorf("cannot convert %s to BLOB", v.Type())
|
||||||
|
}
|
||||||
|
i := v.Interface()
|
||||||
|
if i == nil || len(i.([]byte)) == 0 {
|
||||||
|
C.sqlite3_result_null(ctx)
|
||||||
|
} else {
|
||||||
|
bs := i.([]byte)
|
||||||
|
C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||||
|
if v.Type().Kind() != reflect.String {
|
||||||
|
return fmt.Errorf("cannot convert %s to TEXT", v.Type())
|
||||||
|
}
|
||||||
|
C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
|
||||||
|
switch typ.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
if typ.Elem().Kind() != reflect.Uint8 {
|
||||||
|
return nil, errors.New("the only supported slice type is []byte")
|
||||||
|
}
|
||||||
|
return callbackRetBlob, nil
|
||||||
|
case reflect.String:
|
||||||
|
return callbackRetText, nil
|
||||||
|
case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||||
|
return callbackRetInteger, nil
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return callbackRetFloat, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("don't know how to convert to %s", typ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackError(ctx *C.sqlite3_context, err error) {
|
||||||
|
cstr := C.CString(err.Error())
|
||||||
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
|
C.sqlite3_result_error(ctx, cstr, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test support code. Tests are not allowed to import "C", so we can't
|
||||||
|
// declare any functions that use C.sqlite3_value.
|
||||||
|
func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
|
||||||
|
return func(*C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,97 @@
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCallbackArgCast(t *testing.T) {
|
||||||
|
intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil)
|
||||||
|
floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil)
|
||||||
|
errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test"))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
f callbackArgConverter
|
||||||
|
o reflect.Value
|
||||||
|
}{
|
||||||
|
{intConv, reflect.ValueOf(int8(-1))},
|
||||||
|
{intConv, reflect.ValueOf(int16(-1))},
|
||||||
|
{intConv, reflect.ValueOf(int32(-1))},
|
||||||
|
{intConv, reflect.ValueOf(uint8(math.MaxUint8))},
|
||||||
|
{intConv, reflect.ValueOf(uint16(math.MaxUint16))},
|
||||||
|
{intConv, reflect.ValueOf(uint32(math.MaxUint32))},
|
||||||
|
// Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1
|
||||||
|
{intConv, reflect.ValueOf(uint64(math.MaxInt64))},
|
||||||
|
{floatConv, reflect.ValueOf(float32(math.Inf(1)))},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
conv := callbackArgCast{test.f, test.o.Type()}
|
||||||
|
val, err := conv.Run(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err)
|
||||||
|
} else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) {
|
||||||
|
t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))}
|
||||||
|
_, err := conv.Run(nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error during callbackArgCast, but got none")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbackConverters(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
v interface{}
|
||||||
|
err bool
|
||||||
|
}{
|
||||||
|
// Unfortunately, we can't tell which converter was returned,
|
||||||
|
// but we can at least check which types can be converted.
|
||||||
|
{[]byte{0}, false},
|
||||||
|
{"text", false},
|
||||||
|
{true, false},
|
||||||
|
{int8(0), false},
|
||||||
|
{int16(0), false},
|
||||||
|
{int32(0), false},
|
||||||
|
{int64(0), false},
|
||||||
|
{uint8(0), false},
|
||||||
|
{uint16(0), false},
|
||||||
|
{uint32(0), false},
|
||||||
|
{uint64(0), false},
|
||||||
|
{int(0), false},
|
||||||
|
{uint(0), false},
|
||||||
|
{float64(0), false},
|
||||||
|
{float32(0), false},
|
||||||
|
|
||||||
|
{func() {}, true},
|
||||||
|
{complex64(complex(0, 0)), true},
|
||||||
|
{complex128(complex(0, 0)), true},
|
||||||
|
{struct{}{}, true},
|
||||||
|
{map[string]string{}, true},
|
||||||
|
{[]string{}, true},
|
||||||
|
{(*int8)(nil), true},
|
||||||
|
{make(chan int), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
_, err := callbackArg(reflect.TypeOf(test.v))
|
||||||
|
if test.err && err == nil {
|
||||||
|
t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
|
||||||
|
} else if !test.err && err != nil {
|
||||||
|
t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
_, err := callbackRet(reflect.TypeOf(test.v))
|
||||||
|
if test.err && err == nil {
|
||||||
|
t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
|
||||||
|
} else if !test.err && err != nil {
|
||||||
|
t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
17
doc.go
17
doc.go
|
@ -91,5 +91,22 @@ you need to hook ConnectHook and get the SQLiteConn.
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Go SQlite3 Extensions
|
||||||
|
|
||||||
|
If you want to register Go functions as SQLite extension functions,
|
||||||
|
call RegisterFunction from ConnectHook.
|
||||||
|
|
||||||
|
regex = func(re, s string) (bool, error) {
|
||||||
|
return regexp.MatchString(re, s)
|
||||||
|
}
|
||||||
|
sql.Register("sqlite3_with_go_func",
|
||||||
|
&sqlite3.SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||||
|
return conn.RegisterFunc("regex", regex, true)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
See the documentation of RegisterFunc for more details.
|
||||||
|
|
||||||
*/
|
*/
|
||||||
package sqlite3
|
package sqlite3
|
||||||
|
|
317
sqlite3.go
317
sqlite3.go
|
@ -66,6 +66,17 @@ _sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes)
|
||||||
return rv;
|
return rv;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
|
||||||
|
sqlite3_result_text(ctx, s, -1, &free);
|
||||||
|
}
|
||||||
|
|
||||||
|
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) {
|
||||||
|
sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||||
|
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||||
|
void doneTrampoline(sqlite3_context*);
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
|
@ -75,6 +86,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -120,6 +132,8 @@ type SQLiteConn struct {
|
||||||
db *C.sqlite3
|
db *C.sqlite3
|
||||||
loc *time.Location
|
loc *time.Location
|
||||||
txlock string
|
txlock string
|
||||||
|
funcs []*functionInfo
|
||||||
|
aggregators []*aggInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tx struct.
|
// Tx struct.
|
||||||
|
@ -153,6 +167,107 @@ type SQLiteRows struct {
|
||||||
cls bool
|
cls bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type functionInfo struct {
|
||||||
|
f reflect.Value
|
||||||
|
argConverters []callbackArgConverter
|
||||||
|
variadicConverter callbackArgConverter
|
||||||
|
retConverter callbackRetConverter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
|
||||||
|
args, err := callbackConvertArgs(argv, fi.argConverters, fi.variadicConverter)
|
||||||
|
if err != nil {
|
||||||
|
callbackError(ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := fi.f.Call(args)
|
||||||
|
|
||||||
|
if len(ret) == 2 && ret[1].Interface() != nil {
|
||||||
|
callbackError(ctx, ret[1].Interface().(error))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fi.retConverter(ctx, ret[0])
|
||||||
|
if err != nil {
|
||||||
|
callbackError(ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type aggInfo struct {
|
||||||
|
constructor reflect.Value
|
||||||
|
|
||||||
|
// Active aggregator objects for aggregations in flight. The
|
||||||
|
// aggregators are indexed by a counter stored in the aggregation
|
||||||
|
// user data space provided by sqlite.
|
||||||
|
active map[int64]reflect.Value
|
||||||
|
next int64
|
||||||
|
|
||||||
|
stepArgConverters []callbackArgConverter
|
||||||
|
stepVariadicConverter callbackArgConverter
|
||||||
|
|
||||||
|
doneRetConverter callbackRetConverter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ai *aggInfo) agg(ctx *C.sqlite3_context) (int64, reflect.Value, error) {
|
||||||
|
aggIdx := (*int64)(C.sqlite3_aggregate_context(ctx, C.int(8)))
|
||||||
|
if *aggIdx == 0 {
|
||||||
|
*aggIdx = ai.next
|
||||||
|
ret := ai.constructor.Call(nil)
|
||||||
|
if len(ret) == 2 && ret[1].Interface() != nil {
|
||||||
|
return 0, reflect.Value{}, ret[1].Interface().(error)
|
||||||
|
}
|
||||||
|
if ret[0].IsNil() {
|
||||||
|
return 0, reflect.Value{}, errors.New("aggregator constructor returned nil state")
|
||||||
|
}
|
||||||
|
ai.next++
|
||||||
|
ai.active[*aggIdx] = ret[0]
|
||||||
|
}
|
||||||
|
return *aggIdx, ai.active[*aggIdx], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ai *aggInfo) Step(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
|
||||||
|
_, agg, err := ai.agg(ctx)
|
||||||
|
if err != nil {
|
||||||
|
callbackError(ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
args, err := callbackConvertArgs(argv, ai.stepArgConverters, ai.stepVariadicConverter)
|
||||||
|
if err != nil {
|
||||||
|
callbackError(ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := agg.MethodByName("Step").Call(args)
|
||||||
|
if len(ret) == 1 && ret[0].Interface() != nil {
|
||||||
|
callbackError(ctx, ret[0].Interface().(error))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
|
||||||
|
idx, agg, err := ai.agg(ctx)
|
||||||
|
if err != nil {
|
||||||
|
callbackError(ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { delete(ai.active, idx) }()
|
||||||
|
|
||||||
|
ret := agg.MethodByName("Done").Call(nil)
|
||||||
|
if len(ret) == 2 && ret[1].Interface() != nil {
|
||||||
|
callbackError(ctx, ret[1].Interface().(error))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ai.doneRetConverter(ctx, ret[0])
|
||||||
|
if err != nil {
|
||||||
|
callbackError(ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Commit transaction.
|
// Commit transaction.
|
||||||
func (tx *SQLiteTx) Commit() error {
|
func (tx *SQLiteTx) Commit() error {
|
||||||
_, err := tx.c.exec("COMMIT")
|
_, err := tx.c.exec("COMMIT")
|
||||||
|
@ -165,6 +280,208 @@ func (tx *SQLiteTx) Rollback() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterFunc makes a Go function available as a SQLite function.
|
||||||
|
//
|
||||||
|
// The Go function can have arguments of the following types: any
|
||||||
|
// numeric type except complex, bool, []byte, string and
|
||||||
|
// interface{}. interface{} arguments are given the direct translation
|
||||||
|
// of the SQLite data type: int64 for INTEGER, float64 for FLOAT,
|
||||||
|
// []byte for BLOB, string for TEXT.
|
||||||
|
//
|
||||||
|
// The function can additionally be variadic, as long as the type of
|
||||||
|
// the variadic argument is one of the above.
|
||||||
|
//
|
||||||
|
// If pure is true. SQLite will assume that the function's return
|
||||||
|
// value depends only on its inputs, and make more aggressive
|
||||||
|
// optimizations in its queries.
|
||||||
|
//
|
||||||
|
// See _example/go_custom_funcs for a detailed example.
|
||||||
|
func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error {
|
||||||
|
var fi functionInfo
|
||||||
|
fi.f = reflect.ValueOf(impl)
|
||||||
|
t := fi.f.Type()
|
||||||
|
if t.Kind() != reflect.Func {
|
||||||
|
return errors.New("Non-function passed to RegisterFunc")
|
||||||
|
}
|
||||||
|
if t.NumOut() != 1 && t.NumOut() != 2 {
|
||||||
|
return errors.New("SQLite functions must return 1 or 2 values")
|
||||||
|
}
|
||||||
|
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||||
|
return errors.New("Second return value of SQLite function must be error")
|
||||||
|
}
|
||||||
|
|
||||||
|
numArgs := t.NumIn()
|
||||||
|
if t.IsVariadic() {
|
||||||
|
numArgs--
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numArgs; i++ {
|
||||||
|
conv, err := callbackArg(t.In(i))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fi.argConverters = append(fi.argConverters, conv)
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.IsVariadic() {
|
||||||
|
conv, err := callbackArg(t.In(numArgs).Elem())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fi.variadicConverter = conv
|
||||||
|
// Pass -1 to sqlite so that it allows any number of
|
||||||
|
// arguments. The call helper verifies that the minimum number
|
||||||
|
// of arguments is present for variadic functions.
|
||||||
|
numArgs = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := callbackRet(t.Out(0))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fi.retConverter = conv
|
||||||
|
|
||||||
|
// fi must outlast the database connection, or we'll have dangling pointers.
|
||||||
|
c.funcs = append(c.funcs, &fi)
|
||||||
|
|
||||||
|
cname := C.CString(name)
|
||||||
|
defer C.free(unsafe.Pointer(cname))
|
||||||
|
opts := C.SQLITE_UTF8
|
||||||
|
if pure {
|
||||||
|
opts |= C.SQLITE_DETERMINISTIC
|
||||||
|
}
|
||||||
|
rv := C.sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil)
|
||||||
|
if rv != C.SQLITE_OK {
|
||||||
|
return c.lastError()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
|
||||||
|
//
|
||||||
|
// Because aggregation is incremental, it's implemented in Go with a
|
||||||
|
// type that has 2 methods: func Step(values) accumulates one row of
|
||||||
|
// data into the accumulator, and func Done() ret finalizes and
|
||||||
|
// returns the aggregate value. "values" and "ret" may be any type
|
||||||
|
// supported by RegisterFunc.
|
||||||
|
//
|
||||||
|
// RegisterAggregator takes as implementation a constructor function
|
||||||
|
// that constructs an instance of the aggregator type each time an
|
||||||
|
// aggregation begins. The constructor must return a pointer to a
|
||||||
|
// type, or an interface that implements Step() and Done().
|
||||||
|
//
|
||||||
|
// The constructor function and the Step/Done methods may optionally
|
||||||
|
// return an error in addition to their other return values.
|
||||||
|
//
|
||||||
|
// See _example/go_custom_funcs for a detailed example.
|
||||||
|
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
|
||||||
|
var ai aggInfo
|
||||||
|
ai.constructor = reflect.ValueOf(impl)
|
||||||
|
t := ai.constructor.Type()
|
||||||
|
if t.Kind() != reflect.Func {
|
||||||
|
return errors.New("non-function passed to RegisterAggregator")
|
||||||
|
}
|
||||||
|
if t.NumOut() != 1 && t.NumOut() != 2 {
|
||||||
|
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
|
||||||
|
}
|
||||||
|
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||||
|
return errors.New("Second return value of SQLite function must be error")
|
||||||
|
}
|
||||||
|
if t.NumIn() != 0 {
|
||||||
|
return errors.New("SQLite aggregator constructors must not have arguments")
|
||||||
|
}
|
||||||
|
|
||||||
|
agg := t.Out(0)
|
||||||
|
switch agg.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Interface:
|
||||||
|
default:
|
||||||
|
return errors.New("SQlite aggregator constructor must return a pointer object")
|
||||||
|
}
|
||||||
|
stepFn, found := agg.MethodByName("Step")
|
||||||
|
if !found {
|
||||||
|
return errors.New("SQlite aggregator doesn't have a Step() function")
|
||||||
|
}
|
||||||
|
step := stepFn.Type
|
||||||
|
if step.NumOut() != 0 && step.NumOut() != 1 {
|
||||||
|
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
|
||||||
|
}
|
||||||
|
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||||
|
return errors.New("type of SQlite aggregator Step() return value must be error")
|
||||||
|
}
|
||||||
|
|
||||||
|
stepNArgs := step.NumIn()
|
||||||
|
start := 0
|
||||||
|
if agg.Kind() == reflect.Ptr {
|
||||||
|
// Skip over the method receiver
|
||||||
|
stepNArgs--
|
||||||
|
start++
|
||||||
|
}
|
||||||
|
if step.IsVariadic() {
|
||||||
|
stepNArgs--
|
||||||
|
}
|
||||||
|
for i := start; i < start+stepNArgs; i++ {
|
||||||
|
conv, err := callbackArg(step.In(i))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ai.stepArgConverters = append(ai.stepArgConverters, conv)
|
||||||
|
}
|
||||||
|
if step.IsVariadic() {
|
||||||
|
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ai.stepVariadicConverter = conv
|
||||||
|
// Pass -1 to sqlite so that it allows any number of
|
||||||
|
// arguments. The call helper verifies that the minimum number
|
||||||
|
// of arguments is present for variadic functions.
|
||||||
|
stepNArgs = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
doneFn, found := agg.MethodByName("Done")
|
||||||
|
if !found {
|
||||||
|
return errors.New("SQlite aggregator doesn't have a Done() function")
|
||||||
|
}
|
||||||
|
done := doneFn.Type
|
||||||
|
doneNArgs := done.NumIn()
|
||||||
|
if agg.Kind() == reflect.Ptr {
|
||||||
|
// Skip over the method receiver
|
||||||
|
doneNArgs--
|
||||||
|
}
|
||||||
|
if doneNArgs != 0 {
|
||||||
|
return errors.New("SQlite aggregator Done() function must have no arguments")
|
||||||
|
}
|
||||||
|
if done.NumOut() != 1 && done.NumOut() != 2 {
|
||||||
|
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
|
||||||
|
}
|
||||||
|
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||||
|
return errors.New("second return value of SQLite aggregator Done() function must be error")
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := callbackRet(done.Out(0))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ai.doneRetConverter = conv
|
||||||
|
ai.active = make(map[int64]reflect.Value)
|
||||||
|
ai.next = 1
|
||||||
|
|
||||||
|
// ai must outlast the database connection, or we'll have dangling pointers.
|
||||||
|
c.aggregators = append(c.aggregators, &ai)
|
||||||
|
|
||||||
|
cname := C.CString(name)
|
||||||
|
defer C.free(unsafe.Pointer(cname))
|
||||||
|
opts := C.SQLITE_UTF8
|
||||||
|
if pure {
|
||||||
|
opts |= C.SQLITE_DETERMINISTIC
|
||||||
|
}
|
||||||
|
rv := C.sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), unsafe.Pointer(&ai), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline)))
|
||||||
|
if rv != C.SQLITE_OK {
|
||||||
|
return c.lastError()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AutoCommit return which currently auto commit or not.
|
// AutoCommit return which currently auto commit or not.
|
||||||
func (c *SQLiteConn) AutoCommit() bool {
|
func (c *SQLiteConn) AutoCommit() bool {
|
||||||
return int(C.sqlite3_get_autocommit(c.db)) != 0
|
return int(C.sqlite3_get_autocommit(c.db)) != 0
|
||||||
|
|
212
sqlite3_test.go
212
sqlite3_test.go
|
@ -15,7 +15,10 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -1058,3 +1061,212 @@ func TestDateTimeNow(t *testing.T) {
|
||||||
t.Fatal("Failed to scan datetime:", err)
|
t.Fatal("Failed to scan datetime:", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFunctionRegistration(t *testing.T) {
|
||||||
|
addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) }
|
||||||
|
addi_64 := func(a, b int64) int64 { return a + b }
|
||||||
|
addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) }
|
||||||
|
addu_64 := func(a, b uint64) uint64 { return a + b }
|
||||||
|
addiu := func(a int, b uint) int64 { return int64(a) + int64(b) }
|
||||||
|
addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b }
|
||||||
|
not := func(a bool) bool { return !a }
|
||||||
|
regex := func(re, s string) (bool, error) {
|
||||||
|
return regexp.MatchString(re, s)
|
||||||
|
}
|
||||||
|
generic := func(a interface{}) int64 {
|
||||||
|
switch a.(type) {
|
||||||
|
case int64:
|
||||||
|
return 1
|
||||||
|
case float64:
|
||||||
|
return 2
|
||||||
|
case []byte:
|
||||||
|
return 3
|
||||||
|
case string:
|
||||||
|
return 4
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
variadic := func(a, b int64, c ...int64) int64 {
|
||||||
|
ret := a + b
|
||||||
|
for _, d := range c {
|
||||||
|
ret += d
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
variadicGeneric := func(a ...interface{}) int64 {
|
||||||
|
return int64(len(a))
|
||||||
|
}
|
||||||
|
|
||||||
|
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("addiu", addiu, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("not", not, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("regex", regex, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("generic", generic, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("variadic", variadic, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
ops := []struct {
|
||||||
|
query string
|
||||||
|
expected interface{}
|
||||||
|
}{
|
||||||
|
{"SELECT addi_8_16_32(1,2)", int32(3)},
|
||||||
|
{"SELECT addi_64(1,2)", int64(3)},
|
||||||
|
{"SELECT addu_8_16_32(1,2)", uint32(3)},
|
||||||
|
{"SELECT addu_64(1,2)", uint64(3)},
|
||||||
|
{"SELECT addiu(1,2)", int64(3)},
|
||||||
|
{"SELECT addf_32_64(1.5,1.5)", float64(3)},
|
||||||
|
{"SELECT not(1)", false},
|
||||||
|
{"SELECT not(0)", true},
|
||||||
|
{`SELECT regex("^foo.*", "foobar")`, true},
|
||||||
|
{`SELECT regex("^foo.*", "barfoobar")`, false},
|
||||||
|
{"SELECT generic(1)", int64(1)},
|
||||||
|
{"SELECT generic(1.1)", int64(2)},
|
||||||
|
{`SELECT generic(NULL)`, int64(3)},
|
||||||
|
{`SELECT generic("foo")`, int64(4)},
|
||||||
|
{"SELECT variadic(1,2)", int64(3)},
|
||||||
|
{"SELECT variadic(1,2,3,4)", int64(10)},
|
||||||
|
{"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)},
|
||||||
|
{`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, op := range ops {
|
||||||
|
ret := reflect.New(reflect.TypeOf(op.expected))
|
||||||
|
err = db.QueryRow(op.query).Scan(ret.Interface())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Query %q failed: %s", op.query, err)
|
||||||
|
} else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) {
|
||||||
|
t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sumAggregator int64
|
||||||
|
|
||||||
|
func (s *sumAggregator) Step(x int64) {
|
||||||
|
*s += sumAggregator(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sumAggregator) Done() int64 {
|
||||||
|
return int64(*s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAggregatorRegistration(t *testing.T) {
|
||||||
|
customSum := func() *sumAggregator {
|
||||||
|
var ret sumAggregator
|
||||||
|
return &ret
|
||||||
|
}
|
||||||
|
|
||||||
|
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err = db.Exec("create table foo (department integer, profits integer)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to create table:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to insert records:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
dept, sum int64
|
||||||
|
}{
|
||||||
|
{1, 30},
|
||||||
|
{2, 42},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
var ret int64
|
||||||
|
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Query failed:", err)
|
||||||
|
}
|
||||||
|
if ret != test.sum {
|
||||||
|
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var customFunctionOnce sync.Once
|
||||||
|
|
||||||
|
func BenchmarkCustomFunctions(b *testing.B) {
|
||||||
|
customFunctionOnce.Do(func() {
|
||||||
|
custom_add := func(a, b int64) int64 {
|
||||||
|
return a + b
|
||||||
|
}
|
||||||
|
|
||||||
|
sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
// Impure function to force sqlite to reexecute it each time.
|
||||||
|
if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
var i int64
|
||||||
|
err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal("Failed to run custom add:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -318,7 +318,7 @@ func BenchmarkQuery(b *testing.B) {
|
||||||
var i int
|
var i int
|
||||||
var f float64
|
var f float64
|
||||||
var s string
|
var s string
|
||||||
// var t time.Time
|
// var t time.Time
|
||||||
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -331,7 +331,7 @@ func BenchmarkParams(b *testing.B) {
|
||||||
var i int
|
var i int
|
||||||
var f float64
|
var f float64
|
||||||
var s string
|
var s string
|
||||||
// var t time.Time
|
// var t time.Time
|
||||||
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -350,7 +350,7 @@ func BenchmarkStmt(b *testing.B) {
|
||||||
var i int
|
var i int
|
||||||
var f float64
|
var f float64
|
||||||
var s string
|
var s string
|
||||||
// var t time.Time
|
// var t time.Time
|
||||||
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue