diff --git a/sqlite3.go b/sqlite3.go index 174a3ee..8bb9826 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -165,9 +165,10 @@ type SQLiteRows struct { } type functionInfo struct { - f reflect.Value - argConverters []callbackArgConverter - retConverter callbackRetConverter + f reflect.Value + argConverters []callbackArgConverter + variadicConverter callbackArgConverter + retConverter callbackRetConverter } func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { @@ -178,7 +179,12 @@ func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { var args []reflect.Value - for i, arg := range argv { + + if len(argv) < len(fi.argConverters) { + fi.error(ctx, fmt.Errorf("function requires at least %d arguments", len(fi.argConverters))) + } + + for i, arg := range argv[:len(fi.argConverters)] { v, err := fi.argConverters[i](arg) if err != nil { fi.error(ctx, err) @@ -187,6 +193,17 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { args = append(args, v) } + if fi.variadicConverter != nil { + for _, arg := range argv[len(fi.argConverters):] { + v, err := fi.variadicConverter(arg) + if err != nil { + fi.error(ctx, err) + return + } + args = append(args, v) + } + } + ret := fi.f.Call(args) if len(ret) == 2 && ret[1].Interface() != nil { @@ -218,7 +235,8 @@ func (tx *SQLiteTx) Rollback() error { // The function can accept arguments of any real numeric type // (i.e. not complex), as well as []byte and string. It must return a // value of one of those types, and optionally an error as a second -// value. +// value. Variadic functions are allowed, if the variadic argument is +// one of the allowed types. // // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive @@ -230,9 +248,6 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro if t.Kind() != reflect.Func { return errors.New("Non-function passed to RegisterFunc") } - if t.IsVariadic() { - return errors.New("Variadic SQLite functions are not supported") - } if t.NumOut() != 1 && t.NumOut() != 2 { return errors.New("SQLite functions must return 1 or 2 values") } @@ -240,7 +255,12 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro return errors.New("Second return value of SQLite function must be error") } - for i := 0; i < t.NumIn(); i++ { + numArgs := t.NumIn() + if t.IsVariadic() { + numArgs-- + } + + for i := 0; i < numArgs; i++ { conv, err := callbackArg(t.In(i)) if err != nil { return err @@ -248,6 +268,18 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro fi.argConverters = append(fi.argConverters, conv) } + if t.IsVariadic() { + conv, err := callbackArg(t.In(numArgs).Elem()) + if err != nil { + return err + } + fi.variadicConverter = conv + // Pass -1 to sqlite so that it allows any number of + // arguments. The call helper verifies that the minimum number + // of arguments is present for variadic functions. + numArgs = -1 + } + conv, err := callbackRet(t.Out(0)) if err != nil { return err @@ -263,7 +295,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro if pure { opts |= C.SQLITE_DETERMINISTIC } - rv := C.sqlite3_create_function_v2(c.db, cname, C.int(t.NumIn()), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) + rv := C.sqlite3_create_function_v2(c.db, cname, C.int(numArgs), C.int(opts), unsafe.Pointer(&fi), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil, nil) if rv != C.SQLITE_OK { return c.lastError() } diff --git a/sqlite3_test.go b/sqlite3_test.go index e8dfe5c..a563c08 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1071,6 +1071,13 @@ func TestFunctionRegistration(t *testing.T) { regex := func(re, s string) (bool, error) { return regexp.MatchString(re, s) } + variadic := func(a, b int64, c ...int64) int64 { + ret := a + b + for _, d := range c { + ret += d + } + return ret + } sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { @@ -1098,6 +1105,9 @@ func TestFunctionRegistration(t *testing.T) { if err := conn.RegisterFunc("regex", regex, true); err != nil { return err } + if err := conn.RegisterFunc("variadic", variadic, true); err != nil { + return err + } return nil }, }) @@ -1121,6 +1131,9 @@ func TestFunctionRegistration(t *testing.T) { {"SELECT not(0)", true}, {`SELECT regex("^foo.*", "foobar")`, true}, {`SELECT regex("^foo.*", "barfoobar")`, false}, + {"SELECT variadic(1,2)", int64(3)}, + {"SELECT variadic(1,2,3,4)", int64(10)}, + {"SELECT variadic(1,1,1,1,1,1,1,1,1,1)", int64(10)}, } for _, op := range ops {