forked from mirror/go-sqlcipher
Move RegisterAggregator implementation
The SQLiteConn.RegisterAggregator implementation was defined in sqlite3_trace.go file, which is guarded with a build constraint. This change simply moves RegisterAggregator to the main sqlite3.go file, and moves accompanying unit tests. The rationale for this move is that it was not possible for downstream using packages to use RegisterAggregator without also specifying (and notifying the user) the 'trace' build tag.
This commit is contained in:
parent
615c193e01
commit
7174000f77
127
sqlite3.go
127
sqlite3.go
|
@ -100,6 +100,8 @@ int _sqlite3_create_function(
|
|||
}
|
||||
|
||||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||
void doneTrampoline(sqlite3_context*);
|
||||
|
||||
int compareTrampoline(void*, int, char*, int, char*);
|
||||
int commitHookTrampoline(void*);
|
||||
|
@ -477,6 +479,131 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe
|
|||
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(pApp), (*[0]byte)(unsafe.Pointer(xFunc)), (*[0]byte)(unsafe.Pointer(xStep)), (*[0]byte)(unsafe.Pointer(xFinal)))
|
||||
}
|
||||
|
||||
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
|
||||
//
|
||||
// Because aggregation is incremental, it's implemented in Go with a
|
||||
// type that has 2 methods: func Step(values) accumulates one row of
|
||||
// data into the accumulator, and func Done() ret finalizes and
|
||||
// returns the aggregate value. "values" and "ret" may be any type
|
||||
// supported by RegisterFunc.
|
||||
//
|
||||
// RegisterAggregator takes as implementation a constructor function
|
||||
// that constructs an instance of the aggregator type each time an
|
||||
// aggregation begins. The constructor must return a pointer to a
|
||||
// type, or an interface that implements Step() and Done().
|
||||
//
|
||||
// The constructor function and the Step/Done methods may optionally
|
||||
// return an error in addition to their other return values.
|
||||
//
|
||||
// See _example/go_custom_funcs for a detailed example.
|
||||
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
|
||||
var ai aggInfo
|
||||
ai.constructor = reflect.ValueOf(impl)
|
||||
t := ai.constructor.Type()
|
||||
if t.Kind() != reflect.Func {
|
||||
return errors.New("non-function passed to RegisterAggregator")
|
||||
}
|
||||
if t.NumOut() != 1 && t.NumOut() != 2 {
|
||||
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
|
||||
}
|
||||
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return errors.New("Second return value of SQLite function must be error")
|
||||
}
|
||||
if t.NumIn() != 0 {
|
||||
return errors.New("SQLite aggregator constructors must not have arguments")
|
||||
}
|
||||
|
||||
agg := t.Out(0)
|
||||
switch agg.Kind() {
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
default:
|
||||
return errors.New("SQlite aggregator constructor must return a pointer object")
|
||||
}
|
||||
stepFn, found := agg.MethodByName("Step")
|
||||
if !found {
|
||||
return errors.New("SQlite aggregator doesn't have a Step() function")
|
||||
}
|
||||
step := stepFn.Type
|
||||
if step.NumOut() != 0 && step.NumOut() != 1 {
|
||||
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
|
||||
}
|
||||
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return errors.New("type of SQlite aggregator Step() return value must be error")
|
||||
}
|
||||
|
||||
stepNArgs := step.NumIn()
|
||||
start := 0
|
||||
if agg.Kind() == reflect.Ptr {
|
||||
// Skip over the method receiver
|
||||
stepNArgs--
|
||||
start++
|
||||
}
|
||||
if step.IsVariadic() {
|
||||
stepNArgs--
|
||||
}
|
||||
for i := start; i < start+stepNArgs; i++ {
|
||||
conv, err := callbackArg(step.In(i))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ai.stepArgConverters = append(ai.stepArgConverters, conv)
|
||||
}
|
||||
if step.IsVariadic() {
|
||||
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ai.stepVariadicConverter = conv
|
||||
// Pass -1 to sqlite so that it allows any number of
|
||||
// arguments. The call helper verifies that the minimum number
|
||||
// of arguments is present for variadic functions.
|
||||
stepNArgs = -1
|
||||
}
|
||||
|
||||
doneFn, found := agg.MethodByName("Done")
|
||||
if !found {
|
||||
return errors.New("SQlite aggregator doesn't have a Done() function")
|
||||
}
|
||||
done := doneFn.Type
|
||||
doneNArgs := done.NumIn()
|
||||
if agg.Kind() == reflect.Ptr {
|
||||
// Skip over the method receiver
|
||||
doneNArgs--
|
||||
}
|
||||
if doneNArgs != 0 {
|
||||
return errors.New("SQlite aggregator Done() function must have no arguments")
|
||||
}
|
||||
if done.NumOut() != 1 && done.NumOut() != 2 {
|
||||
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
|
||||
}
|
||||
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return errors.New("second return value of SQLite aggregator Done() function must be error")
|
||||
}
|
||||
|
||||
conv, err := callbackRet(done.Out(0))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ai.doneRetConverter = conv
|
||||
ai.active = make(map[int64]reflect.Value)
|
||||
ai.next = 1
|
||||
|
||||
// ai must outlast the database connection, or we'll have dangling pointers.
|
||||
c.aggregators = append(c.aggregators, &ai)
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
opts := C.SQLITE_UTF8
|
||||
if pure {
|
||||
opts |= C.SQLITE_DETERMINISTIC
|
||||
}
|
||||
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AutoCommit return which currently auto commit or not.
|
||||
func (c *SQLiteConn) AutoCommit() bool {
|
||||
return int(C.sqlite3_get_autocommit(c.db)) != 0
|
||||
|
|
|
@ -1232,6 +1232,66 @@ func TestFunctionRegistration(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type sumAggregator int64
|
||||
|
||||
func (s *sumAggregator) Step(x int64) {
|
||||
*s += sumAggregator(x)
|
||||
}
|
||||
|
||||
func (s *sumAggregator) Done() int64 {
|
||||
return int64(*s)
|
||||
}
|
||||
|
||||
func TestAggregatorRegistration(t *testing.T) {
|
||||
customSum := func() *sumAggregator {
|
||||
var ret sumAggregator
|
||||
return &ret
|
||||
}
|
||||
|
||||
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec("create table foo (department integer, profits integer)")
|
||||
if err != nil {
|
||||
// trace feature is not implemented
|
||||
t.Skip("Failed to create table:", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to insert records:", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
dept, sum int64
|
||||
}{
|
||||
{1, 30},
|
||||
{2, 42},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
var ret int64
|
||||
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
|
||||
if err != nil {
|
||||
t.Fatal("Query failed:", err)
|
||||
}
|
||||
if ret != test.sum {
|
||||
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func rot13(r rune) rune {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
|
|
129
sqlite3_trace.go
129
sqlite3_trace.go
|
@ -14,16 +14,12 @@ package sqlite3
|
|||
#endif
|
||||
#include <stdlib.h>
|
||||
|
||||
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||
void doneTrampoline(sqlite3_context*);
|
||||
int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
@ -239,131 +235,6 @@ func popTraceMapping(connHandle uintptr) (TraceConfig, bool) {
|
|||
return entryCopy.config, found
|
||||
}
|
||||
|
||||
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
|
||||
//
|
||||
// Because aggregation is incremental, it's implemented in Go with a
|
||||
// type that has 2 methods: func Step(values) accumulates one row of
|
||||
// data into the accumulator, and func Done() ret finalizes and
|
||||
// returns the aggregate value. "values" and "ret" may be any type
|
||||
// supported by RegisterFunc.
|
||||
//
|
||||
// RegisterAggregator takes as implementation a constructor function
|
||||
// that constructs an instance of the aggregator type each time an
|
||||
// aggregation begins. The constructor must return a pointer to a
|
||||
// type, or an interface that implements Step() and Done().
|
||||
//
|
||||
// The constructor function and the Step/Done methods may optionally
|
||||
// return an error in addition to their other return values.
|
||||
//
|
||||
// See _example/go_custom_funcs for a detailed example.
|
||||
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
|
||||
var ai aggInfo
|
||||
ai.constructor = reflect.ValueOf(impl)
|
||||
t := ai.constructor.Type()
|
||||
if t.Kind() != reflect.Func {
|
||||
return errors.New("non-function passed to RegisterAggregator")
|
||||
}
|
||||
if t.NumOut() != 1 && t.NumOut() != 2 {
|
||||
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
|
||||
}
|
||||
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return errors.New("Second return value of SQLite function must be error")
|
||||
}
|
||||
if t.NumIn() != 0 {
|
||||
return errors.New("SQLite aggregator constructors must not have arguments")
|
||||
}
|
||||
|
||||
agg := t.Out(0)
|
||||
switch agg.Kind() {
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
default:
|
||||
return errors.New("SQlite aggregator constructor must return a pointer object")
|
||||
}
|
||||
stepFn, found := agg.MethodByName("Step")
|
||||
if !found {
|
||||
return errors.New("SQlite aggregator doesn't have a Step() function")
|
||||
}
|
||||
step := stepFn.Type
|
||||
if step.NumOut() != 0 && step.NumOut() != 1 {
|
||||
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
|
||||
}
|
||||
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return errors.New("type of SQlite aggregator Step() return value must be error")
|
||||
}
|
||||
|
||||
stepNArgs := step.NumIn()
|
||||
start := 0
|
||||
if agg.Kind() == reflect.Ptr {
|
||||
// Skip over the method receiver
|
||||
stepNArgs--
|
||||
start++
|
||||
}
|
||||
if step.IsVariadic() {
|
||||
stepNArgs--
|
||||
}
|
||||
for i := start; i < start+stepNArgs; i++ {
|
||||
conv, err := callbackArg(step.In(i))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ai.stepArgConverters = append(ai.stepArgConverters, conv)
|
||||
}
|
||||
if step.IsVariadic() {
|
||||
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ai.stepVariadicConverter = conv
|
||||
// Pass -1 to sqlite so that it allows any number of
|
||||
// arguments. The call helper verifies that the minimum number
|
||||
// of arguments is present for variadic functions.
|
||||
stepNArgs = -1
|
||||
}
|
||||
|
||||
doneFn, found := agg.MethodByName("Done")
|
||||
if !found {
|
||||
return errors.New("SQlite aggregator doesn't have a Done() function")
|
||||
}
|
||||
done := doneFn.Type
|
||||
doneNArgs := done.NumIn()
|
||||
if agg.Kind() == reflect.Ptr {
|
||||
// Skip over the method receiver
|
||||
doneNArgs--
|
||||
}
|
||||
if doneNArgs != 0 {
|
||||
return errors.New("SQlite aggregator Done() function must have no arguments")
|
||||
}
|
||||
if done.NumOut() != 1 && done.NumOut() != 2 {
|
||||
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
|
||||
}
|
||||
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
||||
return errors.New("second return value of SQLite aggregator Done() function must be error")
|
||||
}
|
||||
|
||||
conv, err := callbackRet(done.Out(0))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ai.doneRetConverter = conv
|
||||
ai.active = make(map[int64]reflect.Value)
|
||||
ai.next = 1
|
||||
|
||||
// ai must outlast the database connection, or we'll have dangling pointers.
|
||||
c.aggregators = append(c.aggregators, &ai)
|
||||
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
opts := C.SQLITE_UTF8
|
||||
if pure {
|
||||
opts |= C.SQLITE_DETERMINISTIC
|
||||
}
|
||||
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTrace installs or removes the trace callback for the given database connection.
|
||||
// It's not named 'RegisterTrace' because only one callback can be kept and called.
|
||||
// Calling SetTrace a second time on same database connection
|
||||
|
|
|
@ -1,72 +0,0 @@
|
|||
// Copyright (C) 2016 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// +build trace
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type sumAggregator int64
|
||||
|
||||
func (s *sumAggregator) Step(x int64) {
|
||||
*s += sumAggregator(x)
|
||||
}
|
||||
|
||||
func (s *sumAggregator) Done() int64 {
|
||||
return int64(*s)
|
||||
}
|
||||
|
||||
func TestAggregatorRegistration(t *testing.T) {
|
||||
customSum := func() *sumAggregator {
|
||||
var ret sumAggregator
|
||||
return &ret
|
||||
}
|
||||
|
||||
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec("create table foo (department integer, profits integer)")
|
||||
if err != nil {
|
||||
// trace feature is not implemented
|
||||
t.Skip("Failed to create table:", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to insert records:", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
dept, sum int64
|
||||
}{
|
||||
{1, 30},
|
||||
{2, 42},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
var ret int64
|
||||
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
|
||||
if err != nil {
|
||||
t.Fatal("Query failed:", err)
|
||||
}
|
||||
if ret != test.sum {
|
||||
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue