Add support for interface{} arguments in Go SQLite functions.

This enabled support for functions like Foo(a interface{}) and
Bar(a ...interface{}).
This commit is contained in:
David Anderson 2015-08-21 17:12:18 -07:00
parent 566f63a43a
commit b037a61690
3 changed files with 60 additions and 5 deletions

View File

@ -108,8 +108,32 @@ func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
} }
} }
func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
switch C.sqlite3_value_type(v) {
case C.SQLITE_INTEGER:
return callbackArgInt64(v)
case C.SQLITE_FLOAT:
return callbackArgFloat64(v)
case C.SQLITE_TEXT:
return callbackArgString(v)
case C.SQLITE_BLOB:
return callbackArgBytes(v)
case C.SQLITE_NULL:
// Interpret NULL as a nil byte slice.
var ret []byte
return reflect.ValueOf(ret), nil
default:
panic("unreachable")
}
}
func callbackArg(typ reflect.Type) (callbackArgConverter, error) { func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
switch typ.Kind() { switch typ.Kind() {
case reflect.Interface:
if typ.NumMethod() != 0 {
return nil, errors.New("the only supported interface type is interface{}")
}
return callbackArgGeneric, nil
case reflect.Slice: case reflect.Slice:
if typ.Elem().Kind() != reflect.Uint8 { if typ.Elem().Kind() != reflect.Uint8 {
return nil, errors.New("the only supported slice type is []byte") return nil, errors.New("the only supported slice type is []byte")

View File

@ -232,11 +232,14 @@ func (tx *SQLiteTx) Rollback() error {
// RegisterFunc makes a Go function available as a SQLite function. // RegisterFunc makes a Go function available as a SQLite function.
// //
// The function can accept arguments of any real numeric type // The Go function can have arguments of the following types: any
// (i.e. not complex), as well as []byte and string. It must return a // numeric type except complex, bool, []byte, string and
// value of one of those types, and optionally an error as a second // interface{}. interface{} arguments are given the direct translation
// value. Variadic functions are allowed, if the variadic argument is // of the SQLite data type: int64 for INTEGER, float64 for FLOAT,
// one of the allowed types. // []byte for BLOB, string for TEXT.
//
// The function can additionally be variadic, as long as the type of
// the variadic argument is one of the above.
// //
// If pure is true. SQLite will assume that the function's return // If pure is true. SQLite will assume that the function's return
// value depends only on its inputs, and make more aggressive // value depends only on its inputs, and make more aggressive

View File

@ -1071,6 +1071,20 @@ func TestFunctionRegistration(t *testing.T) {
regex := func(re, s string) (bool, error) { regex := func(re, s string) (bool, error) {
return regexp.MatchString(re, s) return regexp.MatchString(re, s)
} }
generic := func(a interface{}) int64 {
switch a.(type) {
case int64:
return 1
case float64:
return 2
case []byte:
return 3
case string:
return 4
default:
panic("unreachable")
}
}
variadic := func(a, b int64, c ...int64) int64 { variadic := func(a, b int64, c ...int64) int64 {
ret := a + b ret := a + b
for _, d := range c { for _, d := range c {
@ -1078,6 +1092,9 @@ func TestFunctionRegistration(t *testing.T) {
} }
return ret return ret
} }
variadicGeneric := func(a ...interface{}) int64 {
return int64(len(a))
}
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
ConnectHook: func(conn *SQLiteConn) error { ConnectHook: func(conn *SQLiteConn) error {
@ -1105,9 +1122,15 @@ func TestFunctionRegistration(t *testing.T) {
if err := conn.RegisterFunc("regex", regex, true); err != nil { if err := conn.RegisterFunc("regex", regex, true); err != nil {
return err return err
} }
if err := conn.RegisterFunc("generic", generic, true); err != nil {
return err
}
if err := conn.RegisterFunc("variadic", variadic, true); err != nil { if err := conn.RegisterFunc("variadic", variadic, true); err != nil {
return err return err
} }
if err := conn.RegisterFunc("variadicGeneric", variadicGeneric, true); err != nil {
return err
}
return nil return nil
}, },
}) })
@ -1131,9 +1154,14 @@ func TestFunctionRegistration(t *testing.T) {
{"SELECT not(0)", true}, {"SELECT not(0)", true},
{`SELECT regex("^foo.*", "foobar")`, true}, {`SELECT regex("^foo.*", "foobar")`, true},
{`SELECT regex("^foo.*", "barfoobar")`, false}, {`SELECT regex("^foo.*", "barfoobar")`, false},
{"SELECT generic(1)", int64(1)},
{"SELECT generic(1.1)", int64(2)},
{`SELECT generic(NULL)`, int64(3)},
{`SELECT generic("foo")`, int64(4)},
{"SELECT variadic(1,2)", int64(3)}, {"SELECT variadic(1,2)", int64(3)},
{"SELECT variadic(1,2,3,4)", int64(10)}, {"SELECT variadic(1,2,3,4)", int64(10)},
{"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)}, {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)},
{`SELECT variadicGeneric(1,"foo",2.3, NULL)`, int64(4)},
} }
for _, op := range ops { for _, op := range ops {