forked from mirror/go-sqlcipher
Implement support for variadic functions.
Currently, the variadic part must all be the same type, because there's no "generic" arg converter.
This commit is contained in:
parent
122ddb16de
commit
566f63a43a
46
sqlite3.go
46
sqlite3.go
|
@ -167,6 +167,7 @@ type SQLiteRows struct {
|
||||||
type functionInfo struct {
|
type functionInfo struct {
|
||||||
f reflect.Value
|
f reflect.Value
|
||||||
argConverters []callbackArgConverter
|
argConverters []callbackArgConverter
|
||||||
|
variadicConverter callbackArgConverter
|
||||||
retConverter callbackRetConverter
|
retConverter callbackRetConverter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +179,12 @@ func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {
|
||||||
|
|
||||||
func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
|
func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
|
||||||
var args []reflect.Value
|
var args []reflect.Value
|
||||||
for i, arg := range argv {
|
|
||||||
|
if len(argv) < len(fi.argConverters) {
|
||||||
|
fi.error(ctx, fmt.Errorf("function requires at least %d arguments", len(fi.argConverters)))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, arg := range argv[:len(fi.argConverters)] {
|
||||||
v, err := fi.argConverters[i](arg)
|
v, err := fi.argConverters[i](arg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fi.error(ctx, err)
|
fi.error(ctx, err)
|
||||||
|
@ -187,6 +193,17 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
|
||||||
args = append(args, v)
|
args = append(args, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if fi.variadicConverter != nil {
|
||||||
|
for _, arg := range argv[len(fi.argConverters):] {
|
||||||
|
v, err := fi.variadicConverter(arg)
|
||||||
|
if err != nil {
|
||||||
|
fi.error(ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
args = append(args, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ret := fi.f.Call(args)
|
ret := fi.f.Call(args)
|
||||||
|
|
||||||
if len(ret) == 2 && ret[1].Interface() != nil {
|
if len(ret) == 2 && ret[1].Interface() != nil {
|
||||||
|
@ -218,7 +235,8 @@ func (tx *SQLiteTx) Rollback() error {
|
||||||
// The function can accept arguments of any real numeric type
|
// The function can accept arguments of any real numeric type
|
||||||
// (i.e. not complex), as well as []byte and string. It must return a
|
// (i.e. not complex), as well as []byte and string. It must return a
|
||||||
// value of one of those types, and optionally an error as a second
|
// value of one of those types, and optionally an error as a second
|
||||||
// value.
|
// value. Variadic functions are allowed, if the variadic argument is
|
||||||
|
// one of the allowed types.
|
||||||
//
|
//
|
||||||
// If pure is true. SQLite will assume that the function's return
|
// If pure is true. SQLite will assume that the function's return
|
||||||
// value depends only on its inputs, and make more aggressive
|
// value depends only on its inputs, and make more aggressive
|
||||||
|
@ -230,9 +248,6 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
|
||||||
if t.Kind() != reflect.Func {
|
if t.Kind() != reflect.Func {
|
||||||
return errors.New("Non-function passed to RegisterFunc")
|
return errors.New("Non-function passed to RegisterFunc")
|
||||||
}
|
}
|
||||||
if t.IsVariadic() {
|
|
||||||
return errors.New("Variadic SQLite functions are not supported")
|
|
||||||
}
|
|
||||||
if t.NumOut() != 1 && t.NumOut() != 2 {
|
if t.NumOut() != 1 && t.NumOut() != 2 {
|
||||||
return errors.New("SQLite functions must return 1 or 2 values")
|
return errors.New("SQLite functions must return 1 or 2 values")
|
||||||
}
|
}
|
||||||
|
@ -240,7 +255,12 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
|
||||||
return errors.New("Second return value of SQLite function must be error")
|
return errors.New("Second return value of SQLite function must be error")
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < t.NumIn(); i++ {
|
numArgs := t.NumIn()
|
||||||
|
if t.IsVariadic() {
|
||||||
|
numArgs--
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numArgs; i++ {
|
||||||
conv, err := callbackArg(t.In(i))
|
conv, err := callbackArg(t.In(i))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -248,6 +268,18 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
|
||||||
fi.argConverters = append(fi.argConverters, conv)
|
fi.argConverters = append(fi.argConverters, conv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if t.IsVariadic() {
|
||||||
|
conv, err := callbackArg(t.In(numArgs).Elem())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fi.variadicConverter = conv
|
||||||
|
// Pass -1 to sqlite so that it allows any number of
|
||||||
|
// arguments. The call helper verifies that the minimum number
|
||||||
|
// of arguments is present for variadic functions.
|
||||||
|
numArgs = -1
|
||||||
|
}
|
||||||
|
|
||||||
conv, err := callbackRet(t.Out(0))
|
conv, err := callbackRet(t.Out(0))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -263,7 +295,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
|
||||||
if pure {
|
if pure {
|
||||||
opts |= C.SQLITE_DETERMINISTIC
|
opts |= C.SQLITE_DETERMINISTIC
|
||||||
}
|
}
|
||||||
rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil)
|
rv := C.sqlite3_create_function_v2(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil)
|
||||||
if rv != C.SQLITE_OK {
|
if rv != C.SQLITE_OK {
|
||||||
return c.lastError()
|
return c.lastError()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1071,6 +1071,13 @@ func TestFunctionRegistration(t *testing.T) {
|
||||||
regex := func(re, s string) (bool, error) {
|
regex := func(re, s string) (bool, error) {
|
||||||
return regexp.MatchString(re, s)
|
return regexp.MatchString(re, s)
|
||||||
}
|
}
|
||||||
|
variadic := func(a, b int64, c ...int64) int64 {
|
||||||
|
ret := a + b
|
||||||
|
for _, d := range c {
|
||||||
|
ret += d
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
|
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
|
||||||
ConnectHook: func(conn *SQLiteConn) error {
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
@ -1098,6 +1105,9 @@ func TestFunctionRegistration(t *testing.T) {
|
||||||
if err := conn.RegisterFunc("regex", regex, true); err != nil {
|
if err := conn.RegisterFunc("regex", regex, true); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if err := conn.RegisterFunc("variadic", variadic, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -1121,6 +1131,9 @@ func TestFunctionRegistration(t *testing.T) {
|
||||||
{"SELECT not(0)", true},
|
{"SELECT not(0)", true},
|
||||||
{`SELECT regex("^foo.*", "foobar")`, true},
|
{`SELECT regex("^foo.*", "foobar")`, true},
|
||||||
{`SELECT regex("^foo.*", "barfoobar")`, false},
|
{`SELECT regex("^foo.*", "barfoobar")`, false},
|
||||||
|
{"SELECT variadic(1,2)", int64(3)},
|
||||||
|
{"SELECT variadic(1,2,3,4)", int64(10)},
|
||||||
|
{"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, op := range ops {
|
for _, op := range ops {
|
||||||
|
|
Loading…
Reference in New Issue