forked from mirror/go-sqlite3
Support returning any from callbacks (#1046)
Support returning any from callbacks
This commit is contained in:
parent
b819467576
commit
3ccccfb4c9
19
callback.go
19
callback.go
|
@ -353,6 +353,20 @@ func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||
if v.IsNil() {
|
||||
C.sqlite3_result_null(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
cb, err := callbackRet(v.Elem().Type())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return cb(ctx, v.Elem())
|
||||
}
|
||||
|
||||
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
|
||||
switch typ.Kind() {
|
||||
case reflect.Interface:
|
||||
|
@ -360,6 +374,11 @@ func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
|
|||
if typ.Implements(errorInterface) {
|
||||
return callbackRetNil, nil
|
||||
}
|
||||
|
||||
if typ.NumMethod() == 0 {
|
||||
return callbackRetGeneric, nil
|
||||
}
|
||||
|
||||
fallthrough
|
||||
case reflect.Slice:
|
||||
if typ.Elem().Kind() != reflect.Uint8 {
|
||||
|
|
|
@ -102,3 +102,15 @@ func TestCallbackConverters(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbackReturnAny(t *testing.T) {
|
||||
udf := func() interface{} {
|
||||
return 1
|
||||
}
|
||||
|
||||
typ := reflect.TypeOf(udf)
|
||||
_, err := callbackRet(typ.Out(0))
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid callback for any return type, got: %s", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1449,6 +1449,63 @@ func TestAggregatorRegistration(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type mode struct {
|
||||
counts map[interface{}]int
|
||||
top interface{}
|
||||
topCount int
|
||||
}
|
||||
|
||||
func newMode() *mode {
|
||||
return &mode{
|
||||
counts: map[interface{}]int{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mode) Step(x interface{}) {
|
||||
m.counts[x]++
|
||||
c := m.counts[x]
|
||||
if c > m.topCount {
|
||||
m.top = x
|
||||
m.topCount = c
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mode) Done() interface{} {
|
||||
return m.top
|
||||
}
|
||||
|
||||
func TestAggregatorRegistration_GenericReturn(t *testing.T) {
|
||||
sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
return conn.RegisterAggregator("mode", newMode, true)
|
||||
},
|
||||
})
|
||||
db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec("create table foo (department integer, profits integer)")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create table:", err)
|
||||
}
|
||||
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to insert records:", err)
|
||||
}
|
||||
|
||||
var mode int
|
||||
err = db.QueryRow("select mode(profits) from foo").Scan(&mode)
|
||||
if err != nil {
|
||||
t.Fatal("MODE query error:", err)
|
||||
}
|
||||
|
||||
if mode != 20 {
|
||||
t.Fatal("Got incorrect mode. Wanted 20, got: ", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func rot13(r rune) rune {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
|
|
Loading…
Reference in New Issue