forked from mirror/go-sqlite3
Merge pull request #479 from kenshaw/move-registeraggregator
Move RegisterAggregator implementation
This commit is contained in:
commit
ed69081a91
127
sqlite3.go
127
sqlite3.go
|
@ -100,6 +100,8 @@ int _sqlite3_create_function(
|
||||||
}
|
}
|
||||||
|
|
||||||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||||
|
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||||
|
void doneTrampoline(sqlite3_context*);
|
||||||
|
|
||||||
int compareTrampoline(void*, int, char*, int, char*);
|
int compareTrampoline(void*, int, char*, int, char*);
|
||||||
int commitHookTrampoline(void*);
|
int commitHookTrampoline(void*);
|
||||||
|
@ -503,6 +505,131 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe
|
||||||
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(pApp), (*[0]byte)(unsafe.Pointer(xFunc)), (*[0]byte)(unsafe.Pointer(xStep)), (*[0]byte)(unsafe.Pointer(xFinal)))
|
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(pApp), (*[0]byte)(unsafe.Pointer(xFunc)), (*[0]byte)(unsafe.Pointer(xStep)), (*[0]byte)(unsafe.Pointer(xFinal)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
|
||||||
|
//
|
||||||
|
// Because aggregation is incremental, it's implemented in Go with a
|
||||||
|
// type that has 2 methods: func Step(values) accumulates one row of
|
||||||
|
// data into the accumulator, and func Done() ret finalizes and
|
||||||
|
// returns the aggregate value. "values" and "ret" may be any type
|
||||||
|
// supported by RegisterFunc.
|
||||||
|
//
|
||||||
|
// RegisterAggregator takes as implementation a constructor function
|
||||||
|
// that constructs an instance of the aggregator type each time an
|
||||||
|
// aggregation begins. The constructor must return a pointer to a
|
||||||
|
// type, or an interface that implements Step() and Done().
|
||||||
|
//
|
||||||
|
// The constructor function and the Step/Done methods may optionally
|
||||||
|
// return an error in addition to their other return values.
|
||||||
|
//
|
||||||
|
// See _example/go_custom_funcs for a detailed example.
|
||||||
|
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
|
||||||
|
var ai aggInfo
|
||||||
|
ai.constructor = reflect.ValueOf(impl)
|
||||||
|
t := ai.constructor.Type()
|
||||||
|
if t.Kind() != reflect.Func {
|
||||||
|
return errors.New("non-function passed to RegisterAggregator")
|
||||||
|
}
|
||||||
|
if t.NumOut() != 1 && t.NumOut() != 2 {
|
||||||
|
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
|
||||||
|
}
|
||||||
|
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||||
|
return errors.New("Second return value of SQLite function must be error")
|
||||||
|
}
|
||||||
|
if t.NumIn() != 0 {
|
||||||
|
return errors.New("SQLite aggregator constructors must not have arguments")
|
||||||
|
}
|
||||||
|
|
||||||
|
agg := t.Out(0)
|
||||||
|
switch agg.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Interface:
|
||||||
|
default:
|
||||||
|
return errors.New("SQlite aggregator constructor must return a pointer object")
|
||||||
|
}
|
||||||
|
stepFn, found := agg.MethodByName("Step")
|
||||||
|
if !found {
|
||||||
|
return errors.New("SQlite aggregator doesn't have a Step() function")
|
||||||
|
}
|
||||||
|
step := stepFn.Type
|
||||||
|
if step.NumOut() != 0 && step.NumOut() != 1 {
|
||||||
|
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
|
||||||
|
}
|
||||||
|
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||||
|
return errors.New("type of SQlite aggregator Step() return value must be error")
|
||||||
|
}
|
||||||
|
|
||||||
|
stepNArgs := step.NumIn()
|
||||||
|
start := 0
|
||||||
|
if agg.Kind() == reflect.Ptr {
|
||||||
|
// Skip over the method receiver
|
||||||
|
stepNArgs--
|
||||||
|
start++
|
||||||
|
}
|
||||||
|
if step.IsVariadic() {
|
||||||
|
stepNArgs--
|
||||||
|
}
|
||||||
|
for i := start; i < start+stepNArgs; i++ {
|
||||||
|
conv, err := callbackArg(step.In(i))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ai.stepArgConverters = append(ai.stepArgConverters, conv)
|
||||||
|
}
|
||||||
|
if step.IsVariadic() {
|
||||||
|
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ai.stepVariadicConverter = conv
|
||||||
|
// Pass -1 to sqlite so that it allows any number of
|
||||||
|
// arguments. The call helper verifies that the minimum number
|
||||||
|
// of arguments is present for variadic functions.
|
||||||
|
stepNArgs = -1
|
||||||
|
}
|
||||||
|
|
||||||
|
doneFn, found := agg.MethodByName("Done")
|
||||||
|
if !found {
|
||||||
|
return errors.New("SQlite aggregator doesn't have a Done() function")
|
||||||
|
}
|
||||||
|
done := doneFn.Type
|
||||||
|
doneNArgs := done.NumIn()
|
||||||
|
if agg.Kind() == reflect.Ptr {
|
||||||
|
// Skip over the method receiver
|
||||||
|
doneNArgs--
|
||||||
|
}
|
||||||
|
if doneNArgs != 0 {
|
||||||
|
return errors.New("SQlite aggregator Done() function must have no arguments")
|
||||||
|
}
|
||||||
|
if done.NumOut() != 1 && done.NumOut() != 2 {
|
||||||
|
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
|
||||||
|
}
|
||||||
|
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||||
|
return errors.New("second return value of SQLite aggregator Done() function must be error")
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := callbackRet(done.Out(0))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ai.doneRetConverter = conv
|
||||||
|
ai.active = make(map[int64]reflect.Value)
|
||||||
|
ai.next = 1
|
||||||
|
|
||||||
|
// ai must outlast the database connection, or we'll have dangling pointers.
|
||||||
|
c.aggregators = append(c.aggregators, &ai)
|
||||||
|
|
||||||
|
cname := C.CString(name)
|
||||||
|
defer C.free(unsafe.Pointer(cname))
|
||||||
|
opts := C.SQLITE_UTF8
|
||||||
|
if pure {
|
||||||
|
opts |= C.SQLITE_DETERMINISTIC
|
||||||
|
}
|
||||||
|
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
|
||||||
|
if rv != C.SQLITE_OK {
|
||||||
|
return c.lastError()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AutoCommit return which currently auto commit or not.
|
// AutoCommit return which currently auto commit or not.
|
||||||
func (c *SQLiteConn) AutoCommit() bool {
|
func (c *SQLiteConn) AutoCommit() bool {
|
||||||
return int(C.sqlite3_get_autocommit(c.db)) != 0
|
return int(C.sqlite3_get_autocommit(c.db)) != 0
|
||||||
|
|
|
@ -1232,6 +1232,66 @@ func TestFunctionRegistration(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type sumAggregator int64
|
||||||
|
|
||||||
|
func (s *sumAggregator) Step(x int64) {
|
||||||
|
*s += sumAggregator(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *sumAggregator) Done() int64 {
|
||||||
|
return int64(*s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAggregatorRegistration(t *testing.T) {
|
||||||
|
customSum := func() *sumAggregator {
|
||||||
|
var ret sumAggregator
|
||||||
|
return &ret
|
||||||
|
}
|
||||||
|
|
||||||
|
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err = db.Exec("create table foo (department integer, profits integer)")
|
||||||
|
if err != nil {
|
||||||
|
// trace feature is not implemented
|
||||||
|
t.Skip("Failed to create table:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to insert records:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
dept, sum int64
|
||||||
|
}{
|
||||||
|
{1, 30},
|
||||||
|
{2, 42},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
var ret int64
|
||||||
|
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Query failed:", err)
|
||||||
|
}
|
||||||
|
if ret != test.sum {
|
||||||
|
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func rot13(r rune) rune {
|
func rot13(r rune) rune {
|
||||||
switch {
|
switch {
|
||||||
case r >= 'A' && r <= 'Z':
|
case r >= 'A' && r <= 'Z':
|
||||||
|
|
129
sqlite3_trace.go
129
sqlite3_trace.go
|
@ -14,16 +14,12 @@ package sqlite3
|
||||||
#endif
|
#endif
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
|
||||||
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
|
|
||||||
void doneTrampoline(sqlite3_context*);
|
|
||||||
int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
|
int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
@ -239,131 +235,6 @@ func popTraceMapping(connHandle uintptr) (TraceConfig, bool) {
|
||||||
return entryCopy.config, found
|
return entryCopy.config, found
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
|
|
||||||
//
|
|
||||||
// Because aggregation is incremental, it's implemented in Go with a
|
|
||||||
// type that has 2 methods: func Step(values) accumulates one row of
|
|
||||||
// data into the accumulator, and func Done() ret finalizes and
|
|
||||||
// returns the aggregate value. "values" and "ret" may be any type
|
|
||||||
// supported by RegisterFunc.
|
|
||||||
//
|
|
||||||
// RegisterAggregator takes as implementation a constructor function
|
|
||||||
// that constructs an instance of the aggregator type each time an
|
|
||||||
// aggregation begins. The constructor must return a pointer to a
|
|
||||||
// type, or an interface that implements Step() and Done().
|
|
||||||
//
|
|
||||||
// The constructor function and the Step/Done methods may optionally
|
|
||||||
// return an error in addition to their other return values.
|
|
||||||
//
|
|
||||||
// See _example/go_custom_funcs for a detailed example.
|
|
||||||
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
|
|
||||||
var ai aggInfo
|
|
||||||
ai.constructor = reflect.ValueOf(impl)
|
|
||||||
t := ai.constructor.Type()
|
|
||||||
if t.Kind() != reflect.Func {
|
|
||||||
return errors.New("non-function passed to RegisterAggregator")
|
|
||||||
}
|
|
||||||
if t.NumOut() != 1 && t.NumOut() != 2 {
|
|
||||||
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
|
|
||||||
}
|
|
||||||
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
|
||||||
return errors.New("Second return value of SQLite function must be error")
|
|
||||||
}
|
|
||||||
if t.NumIn() != 0 {
|
|
||||||
return errors.New("SQLite aggregator constructors must not have arguments")
|
|
||||||
}
|
|
||||||
|
|
||||||
agg := t.Out(0)
|
|
||||||
switch agg.Kind() {
|
|
||||||
case reflect.Ptr, reflect.Interface:
|
|
||||||
default:
|
|
||||||
return errors.New("SQlite aggregator constructor must return a pointer object")
|
|
||||||
}
|
|
||||||
stepFn, found := agg.MethodByName("Step")
|
|
||||||
if !found {
|
|
||||||
return errors.New("SQlite aggregator doesn't have a Step() function")
|
|
||||||
}
|
|
||||||
step := stepFn.Type
|
|
||||||
if step.NumOut() != 0 && step.NumOut() != 1 {
|
|
||||||
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
|
|
||||||
}
|
|
||||||
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
|
||||||
return errors.New("type of SQlite aggregator Step() return value must be error")
|
|
||||||
}
|
|
||||||
|
|
||||||
stepNArgs := step.NumIn()
|
|
||||||
start := 0
|
|
||||||
if agg.Kind() == reflect.Ptr {
|
|
||||||
// Skip over the method receiver
|
|
||||||
stepNArgs--
|
|
||||||
start++
|
|
||||||
}
|
|
||||||
if step.IsVariadic() {
|
|
||||||
stepNArgs--
|
|
||||||
}
|
|
||||||
for i := start; i < start+stepNArgs; i++ {
|
|
||||||
conv, err := callbackArg(step.In(i))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ai.stepArgConverters = append(ai.stepArgConverters, conv)
|
|
||||||
}
|
|
||||||
if step.IsVariadic() {
|
|
||||||
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ai.stepVariadicConverter = conv
|
|
||||||
// Pass -1 to sqlite so that it allows any number of
|
|
||||||
// arguments. The call helper verifies that the minimum number
|
|
||||||
// of arguments is present for variadic functions.
|
|
||||||
stepNArgs = -1
|
|
||||||
}
|
|
||||||
|
|
||||||
doneFn, found := agg.MethodByName("Done")
|
|
||||||
if !found {
|
|
||||||
return errors.New("SQlite aggregator doesn't have a Done() function")
|
|
||||||
}
|
|
||||||
done := doneFn.Type
|
|
||||||
doneNArgs := done.NumIn()
|
|
||||||
if agg.Kind() == reflect.Ptr {
|
|
||||||
// Skip over the method receiver
|
|
||||||
doneNArgs--
|
|
||||||
}
|
|
||||||
if doneNArgs != 0 {
|
|
||||||
return errors.New("SQlite aggregator Done() function must have no arguments")
|
|
||||||
}
|
|
||||||
if done.NumOut() != 1 && done.NumOut() != 2 {
|
|
||||||
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
|
|
||||||
}
|
|
||||||
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
|
||||||
return errors.New("second return value of SQLite aggregator Done() function must be error")
|
|
||||||
}
|
|
||||||
|
|
||||||
conv, err := callbackRet(done.Out(0))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ai.doneRetConverter = conv
|
|
||||||
ai.active = make(map[int64]reflect.Value)
|
|
||||||
ai.next = 1
|
|
||||||
|
|
||||||
// ai must outlast the database connection, or we'll have dangling pointers.
|
|
||||||
c.aggregators = append(c.aggregators, &ai)
|
|
||||||
|
|
||||||
cname := C.CString(name)
|
|
||||||
defer C.free(unsafe.Pointer(cname))
|
|
||||||
opts := C.SQLITE_UTF8
|
|
||||||
if pure {
|
|
||||||
opts |= C.SQLITE_DETERMINISTIC
|
|
||||||
}
|
|
||||||
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
|
|
||||||
if rv != C.SQLITE_OK {
|
|
||||||
return c.lastError()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTrace installs or removes the trace callback for the given database connection.
|
// SetTrace installs or removes the trace callback for the given database connection.
|
||||||
// It's not named 'RegisterTrace' because only one callback can be kept and called.
|
// It's not named 'RegisterTrace' because only one callback can be kept and called.
|
||||||
// Calling SetTrace a second time on same database connection
|
// Calling SetTrace a second time on same database connection
|
||||||
|
|
|
@ -1,72 +0,0 @@
|
||||||
// Copyright (C) 2016 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
|
||||||
//
|
|
||||||
// Use of this source code is governed by an MIT-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
// +build trace
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type sumAggregator int64
|
|
||||||
|
|
||||||
func (s *sumAggregator) Step(x int64) {
|
|
||||||
*s += sumAggregator(x)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sumAggregator) Done() int64 {
|
|
||||||
return int64(*s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAggregatorRegistration(t *testing.T) {
|
|
||||||
customSum := func() *sumAggregator {
|
|
||||||
var ret sumAggregator
|
|
||||||
return &ret
|
|
||||||
}
|
|
||||||
|
|
||||||
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
|
|
||||||
ConnectHook: func(conn *SQLiteConn) error {
|
|
||||||
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
})
|
|
||||||
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Failed to open database:", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
_, err = db.Exec("create table foo (department integer, profits integer)")
|
|
||||||
if err != nil {
|
|
||||||
// trace feature is not implemented
|
|
||||||
t.Skip("Failed to create table:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Failed to insert records:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
dept, sum int64
|
|
||||||
}{
|
|
||||||
{1, 30},
|
|
||||||
{2, 42},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
var ret int64
|
|
||||||
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Query failed:", err)
|
|
||||||
}
|
|
||||||
if ret != test.sum {
|
|
||||||
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue