forked from mirror/go-sqlite3
Support returning any from callbacks (#1046)
Support returning any from callbacks
This commit is contained in:
parent
b819467576
commit
3ccccfb4c9
19
callback.go
19
callback.go
|
@ -353,6 +353,20 @@ func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||||
|
if v.IsNil() {
|
||||||
|
C.sqlite3_result_null(ctx)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cb, err := callbackRet(v.Elem().Type())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cb(ctx, v.Elem())
|
||||||
|
}
|
||||||
|
|
||||||
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
|
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
|
||||||
switch typ.Kind() {
|
switch typ.Kind() {
|
||||||
case reflect.Interface:
|
case reflect.Interface:
|
||||||
|
@ -360,6 +374,11 @@ func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
|
||||||
if typ.Implements(errorInterface) {
|
if typ.Implements(errorInterface) {
|
||||||
return callbackRetNil, nil
|
return callbackRetNil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if typ.NumMethod() == 0 {
|
||||||
|
return callbackRetGeneric, nil
|
||||||
|
}
|
||||||
|
|
||||||
fallthrough
|
fallthrough
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
if typ.Elem().Kind() != reflect.Uint8 {
|
if typ.Elem().Kind() != reflect.Uint8 {
|
||||||
|
|
|
@ -102,3 +102,15 @@ func TestCallbackConverters(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCallbackReturnAny(t *testing.T) {
|
||||||
|
udf := func() interface{} {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := reflect.TypeOf(udf)
|
||||||
|
_, err := callbackRet(typ.Out(0))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected valid callback for any return type, got: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1449,6 +1449,63 @@ func TestAggregatorRegistration(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mode struct {
|
||||||
|
counts map[interface{}]int
|
||||||
|
top interface{}
|
||||||
|
topCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMode() *mode {
|
||||||
|
return &mode{
|
||||||
|
counts: map[interface{}]int{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mode) Step(x interface{}) {
|
||||||
|
m.counts[x]++
|
||||||
|
c := m.counts[x]
|
||||||
|
if c > m.topCount {
|
||||||
|
m.top = x
|
||||||
|
m.topCount = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mode) Done() interface{} {
|
||||||
|
return m.top
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAggregatorRegistration_GenericReturn(t *testing.T) {
|
||||||
|
sql.Register("sqlite3_AggregatorRegistration_GenericReturn", &SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
return conn.RegisterAggregator("mode", newMode, true)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
db, err := sql.Open("sqlite3_AggregatorRegistration_GenericReturn", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err = db.Exec("create table foo (department integer, profits integer)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to create table:", err)
|
||||||
|
}
|
||||||
|
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115), (2, 20)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to insert records:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var mode int
|
||||||
|
err = db.QueryRow("select mode(profits) from foo").Scan(&mode)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("MODE query error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mode != 20 {
|
||||||
|
t.Fatal("Got incorrect mode. Wanted 20, got: ", mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func rot13(r rune) rune {
|
func rot13(r rune) rune {
|
||||||
switch {
|
switch {
|
||||||
case r >= 'A' && r <= 'Z':
|
case r >= 'A' && r <= 'Z':
|
||||||
|
|
Loading…
Reference in New Issue