From cf8fa0af80e0d227c79ef2b4635e8d0d77432275 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Thu, 20 Aug 2015 23:08:48 -0700 Subject: [PATCH] Implement support for passing Go functions as custom functions to SQLite. Fixes #226. --- callback.go | 20 +++++ doc.go | 23 ++++- sqlite3.go | 191 ++++++++++++++++++++++++++++++++++++++++ sqlite3_test.go | 108 +++++++++++++++++++++++ sqlite3_test/sqltest.go | 6 +- 5 files changed, 342 insertions(+), 6 deletions(-) create mode 100644 callback.go diff --git a/callback.go b/callback.go new file mode 100644 index 0000000..938d7fe --- /dev/null +++ b/callback.go @@ -0,0 +1,20 @@ +// Copyright (C) 2014 Yasuhiro Matsumoto . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package sqlite3 + +/* +#include +*/ +import "C" + +import "unsafe" + +//export callbackTrampoline +func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { + args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc] + fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) + fi.Call(ctx, args) +} diff --git a/doc.go b/doc.go index 51364c3..a45d852 100644 --- a/doc.go +++ b/doc.go @@ -33,7 +33,7 @@ extension for Regexp matcher operation. #include #include #include - + SQLITE_EXTENSION_INIT1 static void regexp_func(sqlite3_context *context, int argc, sqlite3_value **argv) { if (argc >= 2) { @@ -44,7 +44,7 @@ extension for Regexp matcher operation. int vec[500]; int n, rc; pcre* re = pcre_compile(pattern, 0, &errstr, &erroff, NULL); - rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500); + rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500); if (rc <= 0) { sqlite3_result_error(context, errstr, 0); return; @@ -52,7 +52,7 @@ extension for Regexp matcher operation. sqlite3_result_int(context, 1); } } - + #ifdef _WIN32 __declspec(dllexport) #endif @@ -91,5 +91,22 @@ you need to hook ConnectHook and get the SQLiteConn. }, }) +Go SQlite3 Extensions + +If you want to register Go functions as SQLite extension functions, +call RegisterFunction from ConnectHook. + + regex = func(re, s string) (bool, error) { + return regexp.MatchString(re, s) + } + sql.Register("sqlite3_with_go_func", + &sqlite3.SQLiteDriver{ + ConnectHook: func(conn *sqlite3.SQLiteConn) error { + return conn.RegisterFunc("regex", regex, true) + }, + }) + +See the documentation of RegisterFunc for more details. + */ package sqlite3 diff --git a/sqlite3.go b/sqlite3.go index d57d9fb..f995589 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -66,6 +66,15 @@ _sqlite3_step(sqlite3_stmt* stmt, long long* rowid, long long* changes) return rv; } +void _sqlite3_result_text(sqlite3_context* ctx, const char* s) { + sqlite3_result_text(ctx, s, -1, &free); +} + +void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l) { + sqlite3_result_blob(ctx, b, l, SQLITE_TRANSIENT); +} + +void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); */ import "C" import ( @@ -75,6 +84,7 @@ import ( "fmt" "io" "net/url" + "reflect" "runtime" "strconv" "strings" @@ -120,6 +130,7 @@ type SQLiteConn struct { db *C.sqlite3 loc *time.Location txlock string + funcs []*functionInfo } // Tx struct. @@ -153,6 +164,89 @@ type SQLiteRows struct { cls bool } +type functionInfo struct { + f reflect.Value + argConverters []func(*C.sqlite3_value) (reflect.Value, error) +} + +func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { + cstr := C.CString(err.Error()) + defer C.free(unsafe.Pointer(cstr)) + C.sqlite3_result_error(ctx, cstr, -1) +} + +func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { + var args []reflect.Value + for i, arg := range argv { + v, err := fi.argConverters[i](arg) + if err != nil { + fi.error(ctx, err) + return + } + args = append(args, v) + } + + ret := fi.f.Call(args) + + if len(ret) == 2 && ret[1].Interface() != nil { + fi.error(ctx, ret[1].Interface().(error)) + return + } + + res := ret[0].Interface() + // Normalize ret to one of the types sqlite knows. + switch r := res.(type) { + case int64, float64, []byte, string: + // Already the right type + case bool: + if r { + res = int64(1) + } else { + res = int64(0) + } + case int: + res = int64(r) + case uint: + res = int64(r) + case uint8: + res = int64(r) + case uint16: + res = int64(r) + case uint32: + res = int64(r) + case uint64: + res = int64(r) + case int8: + res = int64(r) + case int16: + res = int64(r) + case int32: + res = int64(r) + case float32: + res = float64(r) + default: + fi.error(ctx, errors.New("cannot convert returned type to sqlite type")) + return + } + + switch r := res.(type) { + case int64: + C.sqlite3_result_int64(ctx, C.sqlite3_int64(r)) + case float64: + C.sqlite3_result_double(ctx, C.double(r)) + case []byte: + if len(r) == 0 { + C.sqlite3_result_null(ctx) + } else { + C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r))) + } + case string: + C._sqlite3_result_text(ctx, C.CString(r)) + default: + panic("unreachable") + } +} + // Commit transaction. func (tx *SQLiteTx) Commit() error { _, err := tx.c.exec("COMMIT") @@ -165,6 +259,103 @@ func (tx *SQLiteTx) Rollback() error { return err } +// RegisterFunc makes a Go function available as a SQLite function. +// +// The function must accept only arguments of type int64, float64, +// []byte or string, and return one value of any numeric type except +// complex, bool, []byte or string. Optionally, an error can be +// provided as a second return value. +// +// If pure is true. SQLite will assume that the function's return +// value depends only on its inputs, and make more aggressive +// optimizations in its queries. +func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) error { + var fi functionInfo + fi.f = reflect.ValueOf(impl) + t := fi.f.Type() + if t.Kind() != reflect.Func { + return errors.New("Non-function passed to RegisterFunc") + } + if t.IsVariadic() { + return errors.New("Variadic SQLite functions are not supported") + } + if t.NumOut() != 1 && t.NumOut() != 2 { + return errors.New("SQLite functions must return 1 or 2 values") + } + if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("Second return value of SQLite function must be error") + } + + for i := 0; i < t.NumIn(); i++ { + arg := t.In(i) + var conv func(*C.sqlite3_value) (reflect.Value, error) + switch arg.Kind() { + case reflect.Int64: + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name) + } + return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil + } + case reflect.Float64: + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name) + } + return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil + } + case reflect.Slice: + if arg.Elem().Kind() != reflect.Uint8 { + return errors.New("The only supported slice type is []byte") + } + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := C.sqlite3_value_blob(v) + return reflect.ValueOf(C.GoBytes(p, l)), nil + case C.SQLITE_TEXT: + l := C.sqlite3_value_bytes(v) + c := unsafe.Pointer(C.sqlite3_value_text(v)) + return reflect.ValueOf(C.GoBytes(c, l)), nil + default: + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) + } + } + case reflect.String: + conv = func(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := (*C.char)(C.sqlite3_value_blob(v)) + return reflect.ValueOf(C.GoStringN(p, l)), nil + case C.SQLITE_TEXT: + c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) + return reflect.ValueOf(C.GoString(c)), nil + default: + return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) + } + } + } + fi.argConverters = append(fi.argConverters, conv) + } + + // fi must outlast the database connection, or we'll have dangling pointers. + c.funcs = append(c.funcs, &fi) + + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + opts := C.SQLITE_UTF8 + if pure { + opts |= C.SQLITE_DETERMINISTIC + } + rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) + if rv != C.SQLITE_OK { + return c.lastError() + } + return nil +} + // AutoCommit return which currently auto commit or not. func (c *SQLiteConn) AutoCommit() bool { return int(C.sqlite3_get_autocommit(c.db)) != 0 diff --git a/sqlite3_test.go b/sqlite3_test.go index 423f30e..a58e373 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -15,7 +15,9 @@ import ( "net/url" "os" "path/filepath" + "regexp" "strings" + "sync" "testing" "time" @@ -1056,3 +1058,109 @@ func TestDateTimeNow(t *testing.T) { t.Fatal("Failed to scan datetime:", err) } } + +func TestFunctionRegistration(t *testing.T) { + custom_add := func(a, b int64) (int64, error) { + return a + b, nil + } + custom_regex := func(s, re string) bool { + matched, err := regexp.MatchString(re, s) + if err != nil { + // We should really return the error here, but this + // function is also testing single return value functions. + panic("Bad regexp") + } + return matched + } + + sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil { + return err + } + if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_FunctionRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + additions := []struct { + a, b, c int64 + }{ + {1, 1, 2}, + {1, 3, 4}, + {1, -1, 0}, + } + + for _, add := range additions { + var i int64 + err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i) + if err != nil { + t.Fatal("Failed to call custom_add:", err) + } + if i != add.c { + t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c) + } + } + + regexes := []struct { + re, in string + out bool + }{ + {".*", "foo", true}, + {"^foo.*", "foobar", true}, + {"^foo.*", "barfoo", false}, + } + + for _, re := range regexes { + var b bool + err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b) + if err != nil { + t.Fatal("Failed to call regexp:", err) + } + if b != re.out { + t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out) + } + } +} + +var customFunctionOnce sync.Once + +func BenchmarkCustomFunctions(b *testing.B) { + customFunctionOnce.Do(func() { + custom_add := func(a, b int64) (int64, error) { + return a + b, nil + } + + sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + // Impure function to force sqlite to reexecute it each time. + if err := conn.RegisterFunc("custom_add", custom_add, false); err != nil { + return err + } + return nil + }, + }) + }) + + db, err := sql.Open("sqlite3_BenchmarkCustomFunctions", ":memory:") + if err != nil { + b.Fatal("Failed to open database:", err) + } + defer db.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var i int64 + err = db.QueryRow("SELECT custom_add(1,2)").Scan(&i) + if err != nil { + b.Fatal("Failed to run custom add:", err) + } + } +} diff --git a/sqlite3_test/sqltest.go b/sqlite3_test/sqltest.go index fc82782..782e15f 100644 --- a/sqlite3_test/sqltest.go +++ b/sqlite3_test/sqltest.go @@ -318,7 +318,7 @@ func BenchmarkQuery(b *testing.B) { var i int var f float64 var s string -// var t time.Time + // var t time.Time if err := db.QueryRow("select null, 1, 1.1, 'foo'").Scan(&n, &i, &f, &s); err != nil { panic(err) } @@ -331,7 +331,7 @@ func BenchmarkParams(b *testing.B) { var i int var f float64 var s string -// var t time.Time + // var t time.Time if err := db.QueryRow("select ?, ?, ?, ?", nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { panic(err) } @@ -350,7 +350,7 @@ func BenchmarkStmt(b *testing.B) { var i int var f float64 var s string -// var t time.Time + // var t time.Time if err := st.QueryRow(nil, 1, 1.1, "foo").Scan(&n, &i, &f, &s); err != nil { panic(err) }