mirror of https://github.com/mattn/go-sqlite3.git
Implement support for passing Go functions as custom functions to SQLite.
Fixes #226.
This commit is contained in:
parent
8897bf1452
commit
cf8fa0af80
|
@ -0,0 +1,20 @@
|
||||||
|
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||||
|
//
|
||||||
|
// Use of this source code is governed by an MIT-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <sqlite3-binding.h>
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import "unsafe"
|
||||||
|
|
||||||
|
//export callbackTrampoline
|
||||||
|
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
|
||||||
|
args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||||
|
fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
|
||||||
|
fi.Call(ctx, args)
|
||||||
|
}
|
17
doc.go
17
doc.go
|
@ -91,5 +91,22 @@ you need to hook ConnectHook and get the SQLiteConn.
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
Go SQlite3 Extensions
|
||||||
|
|
||||||
|
If you want to register Go functions as SQLite extension functions,
|
||||||
|
call RegisterFunction from ConnectHook.
|
||||||
|
|
||||||
|
regex = func(re, s string) (bool, error) {
|
||||||
|
return regexp.MatchString(re, s)
|
||||||
|
}
|
||||||
|
sql.Register("sqlite3_with_go_func",
|
||||||
|
&sqlite3.SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||||
|
return conn.RegisterFunc("regex", regex, true)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
See the documentation of RegisterFunc for more details.
|
||||||
|
|
||||||
*/
|
*/
|
||||||
package sqlite3
|
package sqlite3
|
||||||
|
|
191
sqlite3.go
191
sqlite3.go
|
@ -66,6 +66,15 @@ _sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes)
|
||||||
return rv;
|
return rv;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
|
||||||
|
sqlite3_result_text(ctx, s, -1, &free);
|
||||||
|
}
|
||||||
|
|
||||||
|
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) {
|
||||||
|
sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
|
@ -75,6 +84,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -120,6 +130,7 @@ type SQLiteConn struct {
|
||||||
db *C.sqlite3
|
db *C.sqlite3
|
||||||
loc *time.Location
|
loc *time.Location
|
||||||
txlock string
|
txlock string
|
||||||
|
funcs []*functionInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tx struct.
|
// Tx struct.
|
||||||
|
@ -153,6 +164,89 @@ type SQLiteRows struct {
|
||||||
cls bool
|
cls bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type functionInfo struct {
|
||||||
|
f reflect.Value
|
||||||
|
argConverters []func(*C.sqlite3_value) (reflect.Value, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {
|
||||||
|
cstr := C.CString(err.Error())
|
||||||
|
defer C.free(unsafe.Pointer(cstr))
|
||||||
|
C.sqlite3_result_error(ctx, cstr, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
|
||||||
|
var args []reflect.Value
|
||||||
|
for i, arg := range argv {
|
||||||
|
v, err := fi.argConverters[i](arg)
|
||||||
|
if err != nil {
|
||||||
|
fi.error(ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
args = append(args, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := fi.f.Call(args)
|
||||||
|
|
||||||
|
if len(ret) == 2 && ret[1].Interface() != nil {
|
||||||
|
fi.error(ctx, ret[1].Interface().(error))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res := ret[0].Interface()
|
||||||
|
// Normalize ret to one of the types sqlite knows.
|
||||||
|
switch r := res.(type) {
|
||||||
|
case int64, float64, []byte, string:
|
||||||
|
// Already the right type
|
||||||
|
case bool:
|
||||||
|
if r {
|
||||||
|
res = int64(1)
|
||||||
|
} else {
|
||||||
|
res = int64(0)
|
||||||
|
}
|
||||||
|
case int:
|
||||||
|
res = int64(r)
|
||||||
|
case uint:
|
||||||
|
res = int64(r)
|
||||||
|
case uint8:
|
||||||
|
res = int64(r)
|
||||||
|
case uint16:
|
||||||
|
res = int64(r)
|
||||||
|
case uint32:
|
||||||
|
res = int64(r)
|
||||||
|
case uint64:
|
||||||
|
res = int64(r)
|
||||||
|
case int8:
|
||||||
|
res = int64(r)
|
||||||
|
case int16:
|
||||||
|
res = int64(r)
|
||||||
|
case int32:
|
||||||
|
res = int64(r)
|
||||||
|
case float32:
|
||||||
|
res = float64(r)
|
||||||
|
default:
|
||||||
|
fi.error(ctx, errors.New("cannot convert returned type to sqlite type"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r := res.(type) {
|
||||||
|
case int64:
|
||||||
|
C.sqlite3_result_int64(ctx, C.sqlite3_int64(r))
|
||||||
|
case float64:
|
||||||
|
C.sqlite3_result_double(ctx, C.double(r))
|
||||||
|
case []byte:
|
||||||
|
if len(r) == 0 {
|
||||||
|
C.sqlite3_result_null(ctx)
|
||||||
|
} else {
|
||||||
|
C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r)))
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
C._sqlite3_result_text(ctx, C.CString(r))
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Commit transaction.
|
// Commit transaction.
|
||||||
func (tx *SQLiteTx) Commit() error {
|
func (tx *SQLiteTx) Commit() error {
|
||||||
_, err := tx.c.exec("COMMIT")
|
_, err := tx.c.exec("COMMIT")
|
||||||
|
@ -165,6 +259,103 @@ func (tx *SQLiteTx) Rollback() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterFunc makes a Go function available as a SQLite function.
|
||||||
|
//
|
||||||
|
// The function must accept only arguments of type int64, float64,
|
||||||
|
// []byte or string, and return one value of any numeric type except
|
||||||
|
// complex, bool, []byte or string. Optionally, an error can be
|
||||||
|
// provided as a second return value.
|
||||||
|
//
|
||||||
|
// If pure is true. SQLite will assume that the function's return
|
||||||
|
// value depends only on its inputs, and make more aggressive
|
||||||
|
// optimizations in its queries.
|
||||||
|
func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error {
|
||||||
|
var fi functionInfo
|
||||||
|
fi.f = reflect.ValueOf(impl)
|
||||||
|
t := fi.f.Type()
|
||||||
|
if t.Kind() != reflect.Func {
|
||||||
|
return errors.New("Non-function passed to RegisterFunc")
|
||||||
|
}
|
||||||
|
if t.IsVariadic() {
|
||||||
|
return errors.New("Variadic SQLite functions are not supported")
|
||||||
|
}
|
||||||
|
if t.NumOut() != 1 && t.NumOut() != 2 {
|
||||||
|
return errors.New("SQLite functions must return 1 or 2 values")
|
||||||
|
}
|
||||||
|
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||||
|
return errors.New("Second return value of SQLite function must be error")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < t.NumIn(); i++ {
|
||||||
|
arg := t.In(i)
|
||||||
|
var conv func(*C.sqlite3_value) (reflect.Value, error)
|
||||||
|
switch arg.Kind() {
|
||||||
|
case reflect.Int64:
|
||||||
|
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
|
||||||
|
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name)
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
|
||||||
|
}
|
||||||
|
case reflect.Float64:
|
||||||
|
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
|
||||||
|
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name)
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
if arg.Elem().Kind() != reflect.Uint8 {
|
||||||
|
return errors.New("The only supported slice type is []byte")
|
||||||
|
}
|
||||||
|
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
switch C.sqlite3_value_type(v) {
|
||||||
|
case C.SQLITE_BLOB:
|
||||||
|
l := C.sqlite3_value_bytes(v)
|
||||||
|
p := C.sqlite3_value_blob(v)
|
||||||
|
return reflect.ValueOf(C.GoBytes(p, l)), nil
|
||||||
|
case C.SQLITE_TEXT:
|
||||||
|
l := C.sqlite3_value_bytes(v)
|
||||||
|
c := unsafe.Pointer(C.sqlite3_value_text(v))
|
||||||
|
return reflect.ValueOf(C.GoBytes(c, l)), nil
|
||||||
|
default:
|
||||||
|
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.String:
|
||||||
|
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
switch C.sqlite3_value_type(v) {
|
||||||
|
case C.SQLITE_BLOB:
|
||||||
|
l := C.sqlite3_value_bytes(v)
|
||||||
|
p := (*C.char)(C.sqlite3_value_blob(v))
|
||||||
|
return reflect.ValueOf(C.GoStringN(p, l)), nil
|
||||||
|
case C.SQLITE_TEXT:
|
||||||
|
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
|
||||||
|
return reflect.ValueOf(C.GoString(c)), nil
|
||||||
|
default:
|
||||||
|
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fi.argConverters = append(fi.argConverters, conv)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fi must outlast the database connection, or we'll have dangling pointers.
|
||||||
|
c.funcs = append(c.funcs, &fi)
|
||||||
|
|
||||||
|
cname := C.CString(name)
|
||||||
|
defer C.free(unsafe.Pointer(cname))
|
||||||
|
opts := C.SQLITE_UTF8
|
||||||
|
if pure {
|
||||||
|
opts |= C.SQLITE_DETERMINISTIC
|
||||||
|
}
|
||||||
|
rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil)
|
||||||
|
if rv != C.SQLITE_OK {
|
||||||
|
return c.lastError()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AutoCommit return which currently auto commit or not.
|
// AutoCommit return which currently auto commit or not.
|
||||||
func (c *SQLiteConn) AutoCommit() bool {
|
func (c *SQLiteConn) AutoCommit() bool {
|
||||||
return int(C.sqlite3_get_autocommit(c.db)) != 0
|
return int(C.sqlite3_get_autocommit(c.db)) != 0
|
||||||
|
|
108
sqlite3_test.go
108
sqlite3_test.go
|
@ -15,7 +15,9 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -1056,3 +1058,109 @@ func TestDateTimeNow(t *testing.T) {
|
||||||
t.Fatal("Failed to scan datetime:", err)
|
t.Fatal("Failed to scan datetime:", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFunctionRegistration(t *testing.T) {
|
||||||
|
custom_add := func(a, b int64) (int64, error) {
|
||||||
|
return a + b, nil
|
||||||
|
}
|
||||||
|
custom_regex := func(s, re string) bool {
|
||||||
|
matched, err := regexp.MatchString(re, s)
|
||||||
|
if err != nil {
|
||||||
|
// We should really return the error here, but this
|
||||||
|
// function is also testing single return value functions.
|
||||||
|
panic("Bad regexp")
|
||||||
|
}
|
||||||
|
return matched
|
||||||
|
}
|
||||||
|
|
||||||
|
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
additions := []struct {
|
||||||
|
a, b, c int64
|
||||||
|
}{
|
||||||
|
{1, 1, 2},
|
||||||
|
{1, 3, 4},
|
||||||
|
{1, -1, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, add := range additions {
|
||||||
|
var i int64
|
||||||
|
err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to call custom_add:", err)
|
||||||
|
}
|
||||||
|
if i != add.c {
|
||||||
|
t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
regexes := []struct {
|
||||||
|
re, in string
|
||||||
|
out bool
|
||||||
|
}{
|
||||||
|
{".*", "foo", true},
|
||||||
|
{"^foo.*", "foobar", true},
|
||||||
|
{"^foo.*", "barfoo", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, re := range regexes {
|
||||||
|
var b bool
|
||||||
|
err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to call regexp:", err)
|
||||||
|
}
|
||||||
|
if b != re.out {
|
||||||
|
t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var customFunctionOnce sync.Once
|
||||||
|
|
||||||
|
func BenchmarkCustomFunctions(b *testing.B) {
|
||||||
|
customFunctionOnce.Do(func() {
|
||||||
|
custom_add := func(a, b int64) (int64, error) {
|
||||||
|
return a + b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
// Impure function to force sqlite to reexecute it each time.
|
||||||
|
if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
var i int64
|
||||||
|
err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal("Failed to run custom add:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -318,7 +318,7 @@ func BenchmarkQuery(b *testing.B) {
|
||||||
var i int
|
var i int
|
||||||
var f float64
|
var f float64
|
||||||
var s string
|
var s string
|
||||||
// var t time.Time
|
// var t time.Time
|
||||||
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -331,7 +331,7 @@ func BenchmarkParams(b *testing.B) {
|
||||||
var i int
|
var i int
|
||||||
var f float64
|
var f float64
|
||||||
var s string
|
var s string
|
||||||
// var t time.Time
|
// var t time.Time
|
||||||
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -350,7 +350,7 @@ func BenchmarkStmt(b *testing.B) {
|
||||||
var i int
|
var i int
|
||||||
var f float64
|
var f float64
|
||||||
var s string
|
var s string
|
||||||
// var t time.Time
|
// var t time.Time
|
||||||
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue