From 7174000f779db8d51c898c082640cb2f475cd8d3 Mon Sep 17 00:00:00 2001 From: Kenneth Shaw Date: Sun, 5 Nov 2017 09:18:06 +0700 Subject: [PATCH] Move RegisterAggregator implementation The SQLiteConn.RegisterAggregator implementation was defined in sqlite3_trace.go file, which is guarded with a build constraint. This change simply moves RegisterAggregator to the main sqlite3.go file, and moves accompanying unit tests. The rationale for this move is that it was not possible for downstream using packages to use RegisterAggregator without also specifying (and notifying the user) the 'trace' build tag. --- sqlite3.go | 127 +++++++++++++++++++++++++++++++++++++++++ sqlite3_test.go | 60 ++++++++++++++++++++ sqlite3_trace.go | 129 ------------------------------------------ sqlite3_trace_test.go | 72 ----------------------- 4 files changed, 187 insertions(+), 201 deletions(-) delete mode 100644 sqlite3_trace_test.go diff --git a/sqlite3.go b/sqlite3.go index 1ff58c3..3b4c5fb 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -100,6 +100,8 @@ int _sqlite3_create_function( } void callbackTrampoline(sqlite3_context*, int, sqlite3_value**); +void stepTrampoline(sqlite3_context*, int, sqlite3_value**); +void doneTrampoline(sqlite3_context*); int compareTrampoline(void*, int, char*, int, char*); int commitHookTrampoline(void*); @@ -477,6 +479,131 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(pApp), (*[0]byte)(unsafe.Pointer(xFunc)), (*[0]byte)(unsafe.Pointer(xStep)), (*[0]byte)(unsafe.Pointer(xFinal))) } +// RegisterAggregator makes a Go type available as a SQLite aggregation function. +// +// Because aggregation is incremental, it's implemented in Go with a +// type that has 2 methods: func Step(values) accumulates one row of +// data into the accumulator, and func Done() ret finalizes and +// returns the aggregate value. "values" and "ret" may be any type +// supported by RegisterFunc. +// +// RegisterAggregator takes as implementation a constructor function +// that constructs an instance of the aggregator type each time an +// aggregation begins. The constructor must return a pointer to a +// type, or an interface that implements Step() and Done(). +// +// The constructor function and the Step/Done methods may optionally +// return an error in addition to their other return values. +// +// See _example/go_custom_funcs for a detailed example. +func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error { + var ai aggInfo + ai.constructor = reflect.ValueOf(impl) + t := ai.constructor.Type() + if t.Kind() != reflect.Func { + return errors.New("non-function passed to RegisterAggregator") + } + if t.NumOut() != 1 && t.NumOut() != 2 { + return errors.New("SQLite aggregator constructors must return 1 or 2 values") + } + if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("Second return value of SQLite function must be error") + } + if t.NumIn() != 0 { + return errors.New("SQLite aggregator constructors must not have arguments") + } + + agg := t.Out(0) + switch agg.Kind() { + case reflect.Ptr, reflect.Interface: + default: + return errors.New("SQlite aggregator constructor must return a pointer object") + } + stepFn, found := agg.MethodByName("Step") + if !found { + return errors.New("SQlite aggregator doesn't have a Step() function") + } + step := stepFn.Type + if step.NumOut() != 0 && step.NumOut() != 1 { + return errors.New("SQlite aggregator Step() function must return 0 or 1 values") + } + if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("type of SQlite aggregator Step() return value must be error") + } + + stepNArgs := step.NumIn() + start := 0 + if agg.Kind() == reflect.Ptr { + // Skip over the method receiver + stepNArgs-- + start++ + } + if step.IsVariadic() { + stepNArgs-- + } + for i := start; i < start+stepNArgs; i++ { + conv, err := callbackArg(step.In(i)) + if err != nil { + return err + } + ai.stepArgConverters = append(ai.stepArgConverters, conv) + } + if step.IsVariadic() { + conv, err := callbackArg(t.In(start + stepNArgs).Elem()) + if err != nil { + return err + } + ai.stepVariadicConverter = conv + // Pass -1 to sqlite so that it allows any number of + // arguments. The call helper verifies that the minimum number + // of arguments is present for variadic functions. + stepNArgs = -1 + } + + doneFn, found := agg.MethodByName("Done") + if !found { + return errors.New("SQlite aggregator doesn't have a Done() function") + } + done := doneFn.Type + doneNArgs := done.NumIn() + if agg.Kind() == reflect.Ptr { + // Skip over the method receiver + doneNArgs-- + } + if doneNArgs != 0 { + return errors.New("SQlite aggregator Done() function must have no arguments") + } + if done.NumOut() != 1 && done.NumOut() != 2 { + return errors.New("SQLite aggregator Done() function must return 1 or 2 values") + } + if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + return errors.New("second return value of SQLite aggregator Done() function must be error") + } + + conv, err := callbackRet(done.Out(0)) + if err != nil { + return err + } + ai.doneRetConverter = conv + ai.active = make(map[int64]reflect.Value) + ai.next = 1 + + // ai must outlast the database connection, or we'll have dangling pointers. + c.aggregators = append(c.aggregators, &ai) + + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + opts := C.SQLITE_UTF8 + if pure { + opts |= C.SQLITE_DETERMINISTIC + } + rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) + if rv != C.SQLITE_OK { + return c.lastError() + } + return nil +} + // AutoCommit return which currently auto commit or not. func (c *SQLiteConn) AutoCommit() bool { return int(C.sqlite3_get_autocommit(c.db)) != 0 diff --git a/sqlite3_test.go b/sqlite3_test.go index 9d4b373..84ecb5a 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1232,6 +1232,66 @@ func TestFunctionRegistration(t *testing.T) { } } +type sumAggregator int64 + +func (s *sumAggregator) Step(x int64) { + *s += sumAggregator(x) +} + +func (s *sumAggregator) Done() int64 { + return int64(*s) +} + +func TestAggregatorRegistration(t *testing.T) { + customSum := func() *sumAggregator { + var ret sumAggregator + return &ret + } + + sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + if err := conn.RegisterAggregator("customSum", customSum, true); err != nil { + return err + } + return nil + }, + }) + db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + _, err = db.Exec("create table foo (department integer, profits integer)") + if err != nil { + // trace feature is not implemented + t.Skip("Failed to create table:", err) + } + + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") + if err != nil { + t.Fatal("Failed to insert records:", err) + } + + tests := []struct { + dept, sum int64 + }{ + {1, 30}, + {2, 42}, + } + + for _, test := range tests { + var ret int64 + err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) + if err != nil { + t.Fatal("Query failed:", err) + } + if ret != test.sum { + t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) + } + } +} + func rot13(r rune) rune { switch { case r >= 'A' && r <= 'Z': diff --git a/sqlite3_trace.go b/sqlite3_trace.go index a75f52a..ece6035 100644 --- a/sqlite3_trace.go +++ b/sqlite3_trace.go @@ -14,16 +14,12 @@ package sqlite3 #endif #include -void stepTrampoline(sqlite3_context*, int, sqlite3_value**); -void doneTrampoline(sqlite3_context*); int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x); */ import "C" import ( - "errors" "fmt" - "reflect" "strings" "sync" "unsafe" @@ -239,131 +235,6 @@ func popTraceMapping(connHandle uintptr) (TraceConfig, bool) { return entryCopy.config, found } -// RegisterAggregator makes a Go type available as a SQLite aggregation function. -// -// Because aggregation is incremental, it's implemented in Go with a -// type that has 2 methods: func Step(values) accumulates one row of -// data into the accumulator, and func Done() ret finalizes and -// returns the aggregate value. "values" and "ret" may be any type -// supported by RegisterFunc. -// -// RegisterAggregator takes as implementation a constructor function -// that constructs an instance of the aggregator type each time an -// aggregation begins. The constructor must return a pointer to a -// type, or an interface that implements Step() and Done(). -// -// The constructor function and the Step/Done methods may optionally -// return an error in addition to their other return values. -// -// See _example/go_custom_funcs for a detailed example. -func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error { - var ai aggInfo - ai.constructor = reflect.ValueOf(impl) - t := ai.constructor.Type() - if t.Kind() != reflect.Func { - return errors.New("non-function passed to RegisterAggregator") - } - if t.NumOut() != 1 && t.NumOut() != 2 { - return errors.New("SQLite aggregator constructors must return 1 or 2 values") - } - if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { - return errors.New("Second return value of SQLite function must be error") - } - if t.NumIn() != 0 { - return errors.New("SQLite aggregator constructors must not have arguments") - } - - agg := t.Out(0) - switch agg.Kind() { - case reflect.Ptr, reflect.Interface: - default: - return errors.New("SQlite aggregator constructor must return a pointer object") - } - stepFn, found := agg.MethodByName("Step") - if !found { - return errors.New("SQlite aggregator doesn't have a Step() function") - } - step := stepFn.Type - if step.NumOut() != 0 && step.NumOut() != 1 { - return errors.New("SQlite aggregator Step() function must return 0 or 1 values") - } - if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { - return errors.New("type of SQlite aggregator Step() return value must be error") - } - - stepNArgs := step.NumIn() - start := 0 - if agg.Kind() == reflect.Ptr { - // Skip over the method receiver - stepNArgs-- - start++ - } - if step.IsVariadic() { - stepNArgs-- - } - for i := start; i < start+stepNArgs; i++ { - conv, err := callbackArg(step.In(i)) - if err != nil { - return err - } - ai.stepArgConverters = append(ai.stepArgConverters, conv) - } - if step.IsVariadic() { - conv, err := callbackArg(t.In(start + stepNArgs).Elem()) - if err != nil { - return err - } - ai.stepVariadicConverter = conv - // Pass -1 to sqlite so that it allows any number of - // arguments. The call helper verifies that the minimum number - // of arguments is present for variadic functions. - stepNArgs = -1 - } - - doneFn, found := agg.MethodByName("Done") - if !found { - return errors.New("SQlite aggregator doesn't have a Done() function") - } - done := doneFn.Type - doneNArgs := done.NumIn() - if agg.Kind() == reflect.Ptr { - // Skip over the method receiver - doneNArgs-- - } - if doneNArgs != 0 { - return errors.New("SQlite aggregator Done() function must have no arguments") - } - if done.NumOut() != 1 && done.NumOut() != 2 { - return errors.New("SQLite aggregator Done() function must return 1 or 2 values") - } - if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { - return errors.New("second return value of SQLite aggregator Done() function must be error") - } - - conv, err := callbackRet(done.Out(0)) - if err != nil { - return err - } - ai.doneRetConverter = conv - ai.active = make(map[int64]reflect.Value) - ai.next = 1 - - // ai must outlast the database connection, or we'll have dangling pointers. - c.aggregators = append(c.aggregators, &ai) - - cname := C.CString(name) - defer C.free(unsafe.Pointer(cname)) - opts := C.SQLITE_UTF8 - if pure { - opts |= C.SQLITE_DETERMINISTIC - } - rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline) - if rv != C.SQLITE_OK { - return c.lastError() - } - return nil -} - // SetTrace installs or removes the trace callback for the given database connection. // It's not named 'RegisterTrace' because only one callback can be kept and called. // Calling SetTrace a second time on same database connection diff --git a/sqlite3_trace_test.go b/sqlite3_trace_test.go deleted file mode 100644 index 57a6757..0000000 --- a/sqlite3_trace_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (C) 2016 Yasuhiro Matsumoto . -// -// Use of this source code is governed by an MIT-style -// license that can be found in the LICENSE file. -// +build trace - -package sqlite3 - -import ( - "database/sql" - "testing" -) - -type sumAggregator int64 - -func (s *sumAggregator) Step(x int64) { - *s += sumAggregator(x) -} - -func (s *sumAggregator) Done() int64 { - return int64(*s) -} - -func TestAggregatorRegistration(t *testing.T) { - customSum := func() *sumAggregator { - var ret sumAggregator - return &ret - } - - sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{ - ConnectHook: func(conn *SQLiteConn) error { - if err := conn.RegisterAggregator("customSum", customSum, true); err != nil { - return err - } - return nil - }, - }) - db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:") - if err != nil { - t.Fatal("Failed to open database:", err) - } - defer db.Close() - - _, err = db.Exec("create table foo (department integer, profits integer)") - if err != nil { - // trace feature is not implemented - t.Skip("Failed to create table:", err) - } - - _, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)") - if err != nil { - t.Fatal("Failed to insert records:", err) - } - - tests := []struct { - dept, sum int64 - }{ - {1, 30}, - {2, 42}, - } - - for _, test := range tests { - var ret int64 - err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret) - if err != nil { - t.Fatal("Query failed:", err) - } - if ret != test.sum { - t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum) - } - } -}