Implement support for passing Go functions as custom functions to SQLite.

Fixes #226.
This commit is contained in:
David Anderson 2015-08-20 23:08:48 -07:00
parent 8897bf1452
commit cf8fa0af80
5 changed files with 342 additions and 6 deletions

20
callback.go Normal file
View File

@ -0,0 +1,20 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
package sqlite3
/*
#include <sqlite3-binding.h>
*/
import "C"
import "unsafe"
//export callbackTrampoline
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
fi.Call(ctx, args)
}

23
doc.go
View File

@ -33,7 +33,7 @@ extension for Regexp matcher operation.
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
#include <sqlite3ext.h> #include <sqlite3ext.h>
SQLITE_EXTENSION_INIT1 SQLITE_EXTENSION_INIT1
static void regexp_func(sqlite3_context *context, int argc, sqlite3_value **argv) { static void regexp_func(sqlite3_context *context, int argc, sqlite3_value **argv) {
if (argc >= 2) { if (argc >= 2) {
@ -44,7 +44,7 @@ extension for Regexp matcher operation.
int vec[500]; int vec[500];
int n, rc; int n, rc;
pcre* re = pcre_compile(pattern, 0, &errstr, &erroff, NULL); pcre* re = pcre_compile(pattern, 0, &errstr, &erroff, NULL);
rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500); rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500);
if (rc <= 0) { if (rc <= 0) {
sqlite3_result_error(context, errstr, 0); sqlite3_result_error(context, errstr, 0);
return; return;
@ -52,7 +52,7 @@ extension for Regexp matcher operation.
sqlite3_result_int(context, 1); sqlite3_result_int(context, 1);
} }
} }
#ifdef _WIN32 #ifdef _WIN32
__declspec(dllexport) __declspec(dllexport)
#endif #endif
@ -91,5 +91,22 @@ you need to hook ConnectHook and get the SQLiteConn.
}, },
}) })
Go SQlite3 Extensions
If you want to register Go functions as SQLite extension functions,
call RegisterFunction from ConnectHook.
regex = func(re, s string) (bool, error) {
return regexp.MatchString(re, s)
}
sql.Register("sqlite3_with_go_func",
&sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
return conn.RegisterFunc("regex", regex, true)
},
})
See the documentation of RegisterFunc for more details.
*/ */
package sqlite3 package sqlite3

View File

@ -66,6 +66,15 @@ _sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes)
return rv; return rv;
} }
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
sqlite3_result_text(ctx, s, -1, &free);
}
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) {
sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT);
}
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
*/ */
import "C" import "C"
import ( import (
@ -75,6 +84,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/url" "net/url"
"reflect"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -120,6 +130,7 @@ type SQLiteConn struct {
db *C.sqlite3 db *C.sqlite3
loc *time.Location loc *time.Location
txlock string txlock string
funcs []*functionInfo
} }
// Tx struct. // Tx struct.
@ -153,6 +164,89 @@ type SQLiteRows struct {
cls bool cls bool
} }
type functionInfo struct {
f reflect.Value
argConverters []func(*C.sqlite3_value) (reflect.Value, error)
}
func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {
cstr := C.CString(err.Error())
defer C.free(unsafe.Pointer(cstr))
C.sqlite3_result_error(ctx, cstr, -1)
}
func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
var args []reflect.Value
for i, arg := range argv {
v, err := fi.argConverters[i](arg)
if err != nil {
fi.error(ctx, err)
return
}
args = append(args, v)
}
ret := fi.f.Call(args)
if len(ret) == 2 && ret[1].Interface() != nil {
fi.error(ctx, ret[1].Interface().(error))
return
}
res := ret[0].Interface()
// Normalize ret to one of the types sqlite knows.
switch r := res.(type) {
case int64, float64, []byte, string:
// Already the right type
case bool:
if r {
res = int64(1)
} else {
res = int64(0)
}
case int:
res = int64(r)
case uint:
res = int64(r)
case uint8:
res = int64(r)
case uint16:
res = int64(r)
case uint32:
res = int64(r)
case uint64:
res = int64(r)
case int8:
res = int64(r)
case int16:
res = int64(r)
case int32:
res = int64(r)
case float32:
res = float64(r)
default:
fi.error(ctx, errors.New("cannot convert returned type to sqlite type"))
return
}
switch r := res.(type) {
case int64:
C.sqlite3_result_int64(ctx, C.sqlite3_int64(r))
case float64:
C.sqlite3_result_double(ctx, C.double(r))
case []byte:
if len(r) == 0 {
C.sqlite3_result_null(ctx)
} else {
C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r)))
}
case string:
C._sqlite3_result_text(ctx, C.CString(r))
default:
panic("unreachable")
}
}
// Commit transaction. // Commit transaction.
func (tx *SQLiteTx) Commit() error { func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec("COMMIT") _, err := tx.c.exec("COMMIT")
@ -165,6 +259,103 @@ func (tx *SQLiteTx) Rollback() error {
return err return err
} }
// RegisterFunc makes a Go function available as a SQLite function.
//
// The function must accept only arguments of type int64, float64,
// []byte or string, and return one value of any numeric type except
// complex, bool, []byte or string. Optionally, an error can be
// provided as a second return value.
//
// If pure is true. SQLite will assume that the function's return
// value depends only on its inputs, and make more aggressive
// optimizations in its queries.
func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error {
var fi functionInfo
fi.f = reflect.ValueOf(impl)
t := fi.f.Type()
if t.Kind() != reflect.Func {
return errors.New("Non-function passed to RegisterFunc")
}
if t.IsVariadic() {
return errors.New("Variadic SQLite functions are not supported")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite functions must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
for i := 0; i < t.NumIn(); i++ {
arg := t.In(i)
var conv func(*C.sqlite3_value) (reflect.Value, error)
switch arg.Kind() {
case reflect.Int64:
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name)
}
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
}
case reflect.Float64:
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name)
}
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
}
case reflect.Slice:
if arg.Elem().Kind() != reflect.Uint8 {
return errors.New("The only supported slice type is []byte")
}
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_BLOB:
l := C.sqlite3_value_bytes(v)
p := C.sqlite3_value_blob(v)
return reflect.ValueOf(C.GoBytes(p, l)), nil
case C.SQLITE_TEXT:
l := C.sqlite3_value_bytes(v)
c := unsafe.Pointer(C.sqlite3_value_text(v))
return reflect.ValueOf(C.GoBytes(c, l)), nil
default:
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name)
}
}
case reflect.String:
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_BLOB:
l := C.sqlite3_value_bytes(v)
p := (*C.char)(C.sqlite3_value_blob(v))
return reflect.ValueOf(C.GoStringN(p, l)), nil
case C.SQLITE_TEXT:
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
return reflect.ValueOf(C.GoString(c)), nil
default:
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name)
}
}
}
fi.argConverters = append(fi.argConverters, conv)
}
// fi must outlast the database connection, or we'll have dangling pointers.
c.funcs = append(c.funcs, &fi)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil)
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// AutoCommit return which currently auto commit or not. // AutoCommit return which currently auto commit or not.
func (c *SQLiteConn) AutoCommit() bool { func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0 return int(C.sqlite3_get_autocommit(c.db)) != 0

View File

@ -15,7 +15,9 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -1056,3 +1058,109 @@ func TestDateTimeNow(t *testing.T) {
t.Fatal("Failed to scan datetime:", err) t.Fatal("Failed to scan datetime:", err)
} }
} }
func TestFunctionRegistration(t *testing.T) {
custom_add := func(a, b int64) (int64, error) {
return a + b, nil
}
custom_regex := func(s, re string) bool {
matched, err := regexp.MatchString(re, s)
if err != nil {
// We should really return the error here, but this
// function is also testing single return value functions.
panic("Bad regexp")
}
return matched
}
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil {
return err
}
if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:")
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
additions := []struct {
a, b, c int64
}{
{1, 1, 2},
{1, 3, 4},
{1, -1, 0},
}
for _, add := range additions {
var i int64
err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i)
if err != nil {
t.Fatal("Failed to call custom_add:", err)
}
if i != add.c {
t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c)
}
}
regexes := []struct {
re, in string
out bool
}{
{".*", "foo", true},
{"^foo.*", "foobar", true},
{"^foo.*", "barfoo", false},
}
for _, re := range regexes {
var b bool
err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b)
if err != nil {
t.Fatal("Failed to call regexp:", err)
}
if b != re.out {
t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out)
}
}
}
var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) {
customFunctionOnce.Do(func() {
custom_add := func(a, b int64) (int64, error) {
return a + b, nil
}
sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
// Impure function to force sqlite to reexecute it each time.
if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil {
return err
}
return nil
},
})
})
db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:")
if err != nil {
b.Fatal("Failed to open database:", err)
}
defer db.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
var i int64
err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i)
if err != nil {
b.Fatal("Failed to run custom add:", err)
}
}
}

View File

@ -318,7 +318,7 @@ func BenchmarkQuery(b *testing.B) {
var i int var i int
var f float64 var f float64
var s string var s string
// var t time.Time // var t time.Time
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
panic(err) panic(err)
} }
@ -331,7 +331,7 @@ func BenchmarkParams(b *testing.B) {
var i int var i int
var f float64 var f float64
var s string var s string
// var t time.Time // var t time.Time
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err) panic(err)
} }
@ -350,7 +350,7 @@ func BenchmarkStmt(b *testing.B) {
var i int var i int
var f float64 var f float64
var s string var s string
// var t time.Time // var t time.Time
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
panic(err) panic(err)
} }