Move RegisterAggregator implementation

The SQLiteConn.RegisterAggregator implementation was defined in
sqlite3_trace.go file, which is guarded with a build constraint. This
change simply moves RegisterAggregator to the main sqlite3.go file,
and moves accompanying unit tests.

The rationale for this move is that it was not possible for downstream
using packages to use RegisterAggregator without also specifying (and
notifying the user) the 'trace' build tag.
This commit is contained in:
Kenneth Shaw 2017-11-05 09:18:06 +07:00
parent 615c193e01
commit 7174000f77
4 changed files with 187 additions and 201 deletions

View File

@ -100,6 +100,8 @@ int _sqlite3_create_function(
}
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void doneTrampoline(sqlite3_context*);
int compareTrampoline(void*, int, char*, int, char*);
int commitHookTrampoline(void*);
@ -477,6 +479,131 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(pApp), (*[0]byte)(unsafe.Pointer(xFunc)), (*[0]byte)(unsafe.Pointer(xStep)), (*[0]byte)(unsafe.Pointer(xFinal)))
}
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
//
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
//
// The constructor function and the Step/Done methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
var ai aggInfo
ai.constructor = reflect.ValueOf(impl)
t := ai.constructor.Type()
if t.Kind() != reflect.Func {
return errors.New("non-function passed to RegisterAggregator")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
if t.NumIn() != 0 {
return errors.New("SQLite aggregator constructors must not have arguments")
}
agg := t.Out(0)
switch agg.Kind() {
case reflect.Ptr, reflect.Interface:
default:
return errors.New("SQlite aggregator constructor must return a pointer object")
}
stepFn, found := agg.MethodByName("Step")
if !found {
return errors.New("SQlite aggregator doesn't have a Step() function")
}
step := stepFn.Type
if step.NumOut() != 0 && step.NumOut() != 1 {
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
}
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQlite aggregator Step() return value must be error")
}
stepNArgs := step.NumIn()
start := 0
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
stepNArgs--
start++
}
if step.IsVariadic() {
stepNArgs--
}
for i := start; i < start+stepNArgs; i++ {
conv, err := callbackArg(step.In(i))
if err != nil {
return err
}
ai.stepArgConverters = append(ai.stepArgConverters, conv)
}
if step.IsVariadic() {
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
if err != nil {
return err
}
ai.stepVariadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
stepNArgs = -1
}
doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQlite aggregator doesn't have a Done() function")
}
done := doneFn.Type
doneNArgs := done.NumIn()
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
doneNArgs--
}
if doneNArgs != 0 {
return errors.New("SQlite aggregator Done() function must have no arguments")
}
if done.NumOut() != 1 && done.NumOut() != 2 {
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
}
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator Done() function must be error")
}
conv, err := callbackRet(done.Out(0))
if err != nil {
return err
}
ai.doneRetConverter = conv
ai.active = make(map[int64]reflect.Value)
ai.next = 1
// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// AutoCommit return which currently auto commit or not.
func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0

View File

@ -1232,6 +1232,66 @@ func TestFunctionRegistration(t *testing.T) {
}
}
type sumAggregator int64
func (s *sumAggregator) Step(x int64) {
*s += sumAggregator(x)
}
func (s *sumAggregator) Done() int64 {
return int64(*s)
}
func TestAggregatorRegistration(t *testing.T) {
customSum := func() *sumAggregator {
var ret sumAggregator
return &ret
}
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec("create table foo (department integer, profits integer)")
if err != nil {
// trace feature is not implemented
t.Skip("Failed to create table:", err)
}
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
if err != nil {
t.Fatal("Failed to insert records:", err)
}
tests := []struct {
dept, sum int64
}{
{1, 30},
{2, 42},
}
for _, test := range tests {
var ret int64
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
if err != nil {
t.Fatal("Query failed:", err)
}
if ret != test.sum {
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
}
}
}
func rot13(r rune) rune {
switch {
case r >= 'A' && r <= 'Z':

View File

@ -14,16 +14,12 @@ package sqlite3
#endif
#include <stdlib.h>
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void doneTrampoline(sqlite3_context*);
int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
*/
import "C"
import (
"errors"
"fmt"
"reflect"
"strings"
"sync"
"unsafe"
@ -239,131 +235,6 @@ func popTraceMapping(connHandle uintptr) (TraceConfig, bool) {
return entryCopy.config, found
}
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
//
// Because aggregation is incremental, it's implemented in Go with a
// type that has 2 methods: func Step(values) accumulates one row of
// data into the accumulator, and func Done() ret finalizes and
// returns the aggregate value. "values" and "ret" may be any type
// supported by RegisterFunc.
//
// RegisterAggregator takes as implementation a constructor function
// that constructs an instance of the aggregator type each time an
// aggregation begins. The constructor must return a pointer to a
// type, or an interface that implements Step() and Done().
//
// The constructor function and the Step/Done methods may optionally
// return an error in addition to their other return values.
//
// See _example/go_custom_funcs for a detailed example.
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
var ai aggInfo
ai.constructor = reflect.ValueOf(impl)
t := ai.constructor.Type()
if t.Kind() != reflect.Func {
return errors.New("non-function passed to RegisterAggregator")
}
if t.NumOut() != 1 && t.NumOut() != 2 {
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
}
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("Second return value of SQLite function must be error")
}
if t.NumIn() != 0 {
return errors.New("SQLite aggregator constructors must not have arguments")
}
agg := t.Out(0)
switch agg.Kind() {
case reflect.Ptr, reflect.Interface:
default:
return errors.New("SQlite aggregator constructor must return a pointer object")
}
stepFn, found := agg.MethodByName("Step")
if !found {
return errors.New("SQlite aggregator doesn't have a Step() function")
}
step := stepFn.Type
if step.NumOut() != 0 && step.NumOut() != 1 {
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
}
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("type of SQlite aggregator Step() return value must be error")
}
stepNArgs := step.NumIn()
start := 0
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
stepNArgs--
start++
}
if step.IsVariadic() {
stepNArgs--
}
for i := start; i < start+stepNArgs; i++ {
conv, err := callbackArg(step.In(i))
if err != nil {
return err
}
ai.stepArgConverters = append(ai.stepArgConverters, conv)
}
if step.IsVariadic() {
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
if err != nil {
return err
}
ai.stepVariadicConverter = conv
// Pass -1 to sqlite so that it allows any number of
// arguments. The call helper verifies that the minimum number
// of arguments is present for variadic functions.
stepNArgs = -1
}
doneFn, found := agg.MethodByName("Done")
if !found {
return errors.New("SQlite aggregator doesn't have a Done() function")
}
done := doneFn.Type
doneNArgs := done.NumIn()
if agg.Kind() == reflect.Ptr {
// Skip over the method receiver
doneNArgs--
}
if doneNArgs != 0 {
return errors.New("SQlite aggregator Done() function must have no arguments")
}
if done.NumOut() != 1 && done.NumOut() != 2 {
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
}
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
return errors.New("second return value of SQLite aggregator Done() function must be error")
}
conv, err := callbackRet(done.Out(0))
if err != nil {
return err
}
ai.doneRetConverter = conv
ai.active = make(map[int64]reflect.Value)
ai.next = 1
// ai must outlast the database connection, or we'll have dangling pointers.
c.aggregators = append(c.aggregators, &ai)
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
opts := C.SQLITE_UTF8
if pure {
opts |= C.SQLITE_DETERMINISTIC
}
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
if rv != C.SQLITE_OK {
return c.lastError()
}
return nil
}
// SetTrace installs or removes the trace callback for the given database connection.
// It's not named 'RegisterTrace' because only one callback can be kept and called.
// Calling SetTrace a second time on same database connection

View File

@ -1,72 +0,0 @@
// Copyright (C) 2016 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build trace
package sqlite3
import (
"database/sql"
"testing"
)
type sumAggregator int64
func (s *sumAggregator) Step(x int64) {
*s += sumAggregator(x)
}
func (s *sumAggregator) Done() int64 {
return int64(*s)
}
func TestAggregatorRegistration(t *testing.T) {
customSum := func() *sumAggregator {
var ret sumAggregator
return &ret
}
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error {
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec("create table foo (department integer, profits integer)")
if err != nil {
// trace feature is not implemented
t.Skip("Failed to create table:", err)
}
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
if err != nil {
t.Fatal("Failed to insert records:", err)
}
tests := []struct {
dept, sum int64
}{
{1, 30},
{2, 42},
}
for _, test := range tests {
var ret int64
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
if err != nil {
t.Fatal("Query failed:", err)
}
if ret != test.sum {
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
}
}
}