forked from mirror/go-sqlcipher
Move argument converters to callback.go, and optimize return value handling.
A call now doesn't have to do any reflection, it just blindly invokes a bunch of argument and return value handlers to execute the translation, and the safety of the translation is determined at registration time.
This commit is contained in:
parent
cf8fa0af80
commit
122ddb16de
200
callback.go
200
callback.go
|
@ -5,12 +5,25 @@
|
||||||
|
|
||||||
package sqlite3
|
package sqlite3
|
||||||
|
|
||||||
|
// You can't export a Go function to C and have definitions in the C
|
||||||
|
// preamble in the same file, so we have to have callbackTrampoline in
|
||||||
|
// its own file. Because we need a separate file anyway, the support
|
||||||
|
// code for SQLite custom functions is in here.
|
||||||
|
|
||||||
/*
|
/*
|
||||||
#include <sqlite3-binding.h>
|
#include <sqlite3-binding.h>
|
||||||
|
|
||||||
|
void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
|
||||||
|
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import "unsafe"
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
//export callbackTrampoline
|
//export callbackTrampoline
|
||||||
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
|
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
|
||||||
|
@ -18,3 +31,188 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value
|
||||||
fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
|
fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
|
||||||
fi.Call(ctx, args)
|
fi.Call(ctx, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This is only here so that tests can refer to it.
|
||||||
|
type callbackArgRaw C.sqlite3_value
|
||||||
|
|
||||||
|
type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
|
||||||
|
|
||||||
|
type callbackArgCast struct {
|
||||||
|
f callbackArgConverter
|
||||||
|
typ reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
val, err := c.f(v)
|
||||||
|
if err != nil {
|
||||||
|
return reflect.Value{}, err
|
||||||
|
}
|
||||||
|
if !val.Type().ConvertibleTo(c.typ) {
|
||||||
|
return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
|
||||||
|
}
|
||||||
|
return val.Convert(c.typ), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
|
||||||
|
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
|
||||||
|
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
|
||||||
|
}
|
||||||
|
i := int64(C.sqlite3_value_int64(v))
|
||||||
|
val := false
|
||||||
|
if i != 0 {
|
||||||
|
val = true
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(val), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
|
||||||
|
return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
|
||||||
|
}
|
||||||
|
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
switch C.sqlite3_value_type(v) {
|
||||||
|
case C.SQLITE_BLOB:
|
||||||
|
l := C.sqlite3_value_bytes(v)
|
||||||
|
p := C.sqlite3_value_blob(v)
|
||||||
|
return reflect.ValueOf(C.GoBytes(p, l)), nil
|
||||||
|
case C.SQLITE_TEXT:
|
||||||
|
l := C.sqlite3_value_bytes(v)
|
||||||
|
c := unsafe.Pointer(C.sqlite3_value_text(v))
|
||||||
|
return reflect.ValueOf(C.GoBytes(c, l)), nil
|
||||||
|
default:
|
||||||
|
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
switch C.sqlite3_value_type(v) {
|
||||||
|
case C.SQLITE_BLOB:
|
||||||
|
l := C.sqlite3_value_bytes(v)
|
||||||
|
p := (*C.char)(C.sqlite3_value_blob(v))
|
||||||
|
return reflect.ValueOf(C.GoStringN(p, l)), nil
|
||||||
|
case C.SQLITE_TEXT:
|
||||||
|
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
|
||||||
|
return reflect.ValueOf(C.GoString(c)), nil
|
||||||
|
default:
|
||||||
|
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
|
||||||
|
switch typ.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
if typ.Elem().Kind() != reflect.Uint8 {
|
||||||
|
return nil, errors.New("the only supported slice type is []byte")
|
||||||
|
}
|
||||||
|
return callbackArgBytes, nil
|
||||||
|
case reflect.String:
|
||||||
|
return callbackArgString, nil
|
||||||
|
case reflect.Bool:
|
||||||
|
return callbackArgBool, nil
|
||||||
|
case reflect.Int64:
|
||||||
|
return callbackArgInt64, nil
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||||
|
c := callbackArgCast{callbackArgInt64, typ}
|
||||||
|
return c.Run, nil
|
||||||
|
case reflect.Float64:
|
||||||
|
return callbackArgFloat64, nil
|
||||||
|
case reflect.Float32:
|
||||||
|
c := callbackArgCast{callbackArgFloat64, typ}
|
||||||
|
return c.Run, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("don't know how to convert to %s", typ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
|
||||||
|
|
||||||
|
func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||||
|
switch v.Type().Kind() {
|
||||||
|
case reflect.Int64:
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||||
|
v = v.Convert(reflect.TypeOf(int64(0)))
|
||||||
|
case reflect.Bool:
|
||||||
|
b := v.Interface().(bool)
|
||||||
|
if b {
|
||||||
|
v = reflect.ValueOf(int64(1))
|
||||||
|
} else {
|
||||||
|
v = reflect.ValueOf(int64(0))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||||
|
switch v.Type().Kind() {
|
||||||
|
case reflect.Float64:
|
||||||
|
case reflect.Float32:
|
||||||
|
v = v.Convert(reflect.TypeOf(float64(0)))
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||||
|
if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
|
||||||
|
return fmt.Errorf("cannot convert %s to BLOB", v.Type())
|
||||||
|
}
|
||||||
|
i := v.Interface()
|
||||||
|
if i == nil || len(i.([]byte)) == 0 {
|
||||||
|
C.sqlite3_result_null(ctx)
|
||||||
|
} else {
|
||||||
|
bs := i.([]byte)
|
||||||
|
C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||||
|
if v.Type().Kind() != reflect.String {
|
||||||
|
return fmt.Errorf("cannot convert %s to TEXT", v.Type())
|
||||||
|
}
|
||||||
|
C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
|
||||||
|
switch typ.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
if typ.Elem().Kind() != reflect.Uint8 {
|
||||||
|
return nil, errors.New("the only supported slice type is []byte")
|
||||||
|
}
|
||||||
|
return callbackRetBlob, nil
|
||||||
|
case reflect.String:
|
||||||
|
return callbackRetText, nil
|
||||||
|
case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||||
|
return callbackRetInteger, nil
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return callbackRetFloat, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("don't know how to convert to %s", typ)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test support code. Tests are not allowed to import "C", so we can't
|
||||||
|
// declare any functions that use C.sqlite3_value.
|
||||||
|
func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
|
||||||
|
return func(*C.sqlite3_value) (reflect.Value, error) {
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCallbackArgCast(t *testing.T) {
|
||||||
|
intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil)
|
||||||
|
floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil)
|
||||||
|
errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test"))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
f callbackArgConverter
|
||||||
|
o reflect.Value
|
||||||
|
}{
|
||||||
|
{intConv, reflect.ValueOf(int8(-1))},
|
||||||
|
{intConv, reflect.ValueOf(int16(-1))},
|
||||||
|
{intConv, reflect.ValueOf(int32(-1))},
|
||||||
|
{intConv, reflect.ValueOf(uint8(math.MaxUint8))},
|
||||||
|
{intConv, reflect.ValueOf(uint16(math.MaxUint16))},
|
||||||
|
{intConv, reflect.ValueOf(uint32(math.MaxUint32))},
|
||||||
|
// Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1
|
||||||
|
{intConv, reflect.ValueOf(uint64(math.MaxInt64))},
|
||||||
|
{floatConv, reflect.ValueOf(float32(math.Inf(1)))},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
conv := callbackArgCast{test.f, test.o.Type()}
|
||||||
|
val, err := conv.Run(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err)
|
||||||
|
} else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) {
|
||||||
|
t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))}
|
||||||
|
_, err := conv.Run(nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error during callbackArgCast, but got none")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCallbackConverters(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
v interface{}
|
||||||
|
err bool
|
||||||
|
}{
|
||||||
|
// Unfortunately, we can't tell which converter was returned,
|
||||||
|
// but we can at least check which types can be converted.
|
||||||
|
{[]byte{0}, false},
|
||||||
|
{"text", false},
|
||||||
|
{true, false},
|
||||||
|
{int8(0), false},
|
||||||
|
{int16(0), false},
|
||||||
|
{int32(0), false},
|
||||||
|
{int64(0), false},
|
||||||
|
{uint8(0), false},
|
||||||
|
{uint16(0), false},
|
||||||
|
{uint32(0), false},
|
||||||
|
{uint64(0), false},
|
||||||
|
{int(0), false},
|
||||||
|
{uint(0), false},
|
||||||
|
{float64(0), false},
|
||||||
|
{float32(0), false},
|
||||||
|
|
||||||
|
{func() {}, true},
|
||||||
|
{complex64(complex(0, 0)), true},
|
||||||
|
{complex128(complex(0, 0)), true},
|
||||||
|
{struct{}{}, true},
|
||||||
|
{map[string]string{}, true},
|
||||||
|
{[]string{}, true},
|
||||||
|
{(*int8)(nil), true},
|
||||||
|
{make(chan int), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
_, err := callbackArg(reflect.TypeOf(test.v))
|
||||||
|
if test.err && err == nil {
|
||||||
|
t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
|
||||||
|
} else if !test.err && err != nil {
|
||||||
|
t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
_, err := callbackRet(reflect.TypeOf(test.v))
|
||||||
|
if test.err && err == nil {
|
||||||
|
t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
|
||||||
|
} else if !test.err && err != nil {
|
||||||
|
t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
122
sqlite3.go
122
sqlite3.go
|
@ -166,7 +166,8 @@ type SQLiteRows struct {
|
||||||
|
|
||||||
type functionInfo struct {
|
type functionInfo struct {
|
||||||
f reflect.Value
|
f reflect.Value
|
||||||
argConverters []func(*C.sqlite3_value) (reflect.Value, error)
|
argConverters []callbackArgConverter
|
||||||
|
retConverter callbackRetConverter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {
|
func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) {
|
||||||
|
@ -193,58 +194,11 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := ret[0].Interface()
|
err := fi.retConverter(ctx, ret[0])
|
||||||
// Normalize ret to one of the types sqlite knows.
|
if err != nil {
|
||||||
switch r := res.(type) {
|
fi.error(ctx, err)
|
||||||
case int64, float64, []byte, string:
|
|
||||||
// Already the right type
|
|
||||||
case bool:
|
|
||||||
if r {
|
|
||||||
res = int64(1)
|
|
||||||
} else {
|
|
||||||
res = int64(0)
|
|
||||||
}
|
|
||||||
case int:
|
|
||||||
res = int64(r)
|
|
||||||
case uint:
|
|
||||||
res = int64(r)
|
|
||||||
case uint8:
|
|
||||||
res = int64(r)
|
|
||||||
case uint16:
|
|
||||||
res = int64(r)
|
|
||||||
case uint32:
|
|
||||||
res = int64(r)
|
|
||||||
case uint64:
|
|
||||||
res = int64(r)
|
|
||||||
case int8:
|
|
||||||
res = int64(r)
|
|
||||||
case int16:
|
|
||||||
res = int64(r)
|
|
||||||
case int32:
|
|
||||||
res = int64(r)
|
|
||||||
case float32:
|
|
||||||
res = float64(r)
|
|
||||||
default:
|
|
||||||
fi.error(ctx, errors.New("cannot convert returned type to sqlite type"))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r := res.(type) {
|
|
||||||
case int64:
|
|
||||||
C.sqlite3_result_int64(ctx, C.sqlite3_int64(r))
|
|
||||||
case float64:
|
|
||||||
C.sqlite3_result_double(ctx, C.double(r))
|
|
||||||
case []byte:
|
|
||||||
if len(r) == 0 {
|
|
||||||
C.sqlite3_result_null(ctx)
|
|
||||||
} else {
|
|
||||||
C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r)))
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
C._sqlite3_result_text(ctx, C.CString(r))
|
|
||||||
default:
|
|
||||||
panic("unreachable")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit transaction.
|
// Commit transaction.
|
||||||
|
@ -261,10 +215,10 @@ func (tx *SQLiteTx) Rollback() error {
|
||||||
|
|
||||||
// RegisterFunc makes a Go function available as a SQLite function.
|
// RegisterFunc makes a Go function available as a SQLite function.
|
||||||
//
|
//
|
||||||
// The function must accept only arguments of type int64, float64,
|
// The function can accept arguments of any real numeric type
|
||||||
// []byte or string, and return one value of any numeric type except
|
// (i.e. not complex), as well as []byte and string. It must return a
|
||||||
// complex, bool, []byte or string. Optionally, an error can be
|
// value of one of those types, and optionally an error as a second
|
||||||
// provided as a second return value.
|
// value.
|
||||||
//
|
//
|
||||||
// If pure is true. SQLite will assume that the function's return
|
// If pure is true. SQLite will assume that the function's return
|
||||||
// value depends only on its inputs, and make more aggressive
|
// value depends only on its inputs, and make more aggressive
|
||||||
|
@ -287,59 +241,19 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < t.NumIn(); i++ {
|
for i := 0; i < t.NumIn(); i++ {
|
||||||
arg := t.In(i)
|
conv, err := callbackArg(t.In(i))
|
||||||
var conv func(*C.sqlite3_value) (reflect.Value, error)
|
if err != nil {
|
||||||
switch arg.Kind() {
|
return err
|
||||||
case reflect.Int64:
|
|
||||||
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
|
|
||||||
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
|
|
||||||
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name)
|
|
||||||
}
|
|
||||||
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
|
|
||||||
}
|
|
||||||
case reflect.Float64:
|
|
||||||
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
|
|
||||||
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
|
|
||||||
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name)
|
|
||||||
}
|
|
||||||
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
|
|
||||||
}
|
|
||||||
case reflect.Slice:
|
|
||||||
if arg.Elem().Kind() != reflect.Uint8 {
|
|
||||||
return errors.New("The only supported slice type is []byte")
|
|
||||||
}
|
|
||||||
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
|
|
||||||
switch C.sqlite3_value_type(v) {
|
|
||||||
case C.SQLITE_BLOB:
|
|
||||||
l := C.sqlite3_value_bytes(v)
|
|
||||||
p := C.sqlite3_value_blob(v)
|
|
||||||
return reflect.ValueOf(C.GoBytes(p, l)), nil
|
|
||||||
case C.SQLITE_TEXT:
|
|
||||||
l := C.sqlite3_value_bytes(v)
|
|
||||||
c := unsafe.Pointer(C.sqlite3_value_text(v))
|
|
||||||
return reflect.ValueOf(C.GoBytes(c, l)), nil
|
|
||||||
default:
|
|
||||||
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case reflect.String:
|
|
||||||
conv = func(v *C.sqlite3_value) (reflect.Value, error) {
|
|
||||||
switch C.sqlite3_value_type(v) {
|
|
||||||
case C.SQLITE_BLOB:
|
|
||||||
l := C.sqlite3_value_bytes(v)
|
|
||||||
p := (*C.char)(C.sqlite3_value_blob(v))
|
|
||||||
return reflect.ValueOf(C.GoStringN(p, l)), nil
|
|
||||||
case C.SQLITE_TEXT:
|
|
||||||
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
|
|
||||||
return reflect.ValueOf(C.GoString(c)), nil
|
|
||||||
default:
|
|
||||||
return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
fi.argConverters = append(fi.argConverters, conv)
|
fi.argConverters = append(fi.argConverters, conv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conv, err := callbackRet(t.Out(0))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fi.retConverter = conv
|
||||||
|
|
||||||
// fi must outlast the database connection, or we'll have dangling pointers.
|
// fi must outlast the database connection, or we'll have dangling pointers.
|
||||||
c.funcs = append(c.funcs, &fi)
|
c.funcs = append(c.funcs, &fi)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -1060,25 +1061,41 @@ func TestDateTimeNow(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFunctionRegistration(t *testing.T) {
|
func TestFunctionRegistration(t *testing.T) {
|
||||||
custom_add := func(a, b int64) (int64, error) {
|
addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) }
|
||||||
return a + b, nil
|
addi_64 := func(a, b int64) int64 { return a + b }
|
||||||
}
|
addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) }
|
||||||
custom_regex := func(s, re string) bool {
|
addu_64 := func(a, b uint64) uint64 { return a + b }
|
||||||
matched, err := regexp.MatchString(re, s)
|
addiu := func(a int, b uint) int64 { return int64(a) + int64(b) }
|
||||||
if err != nil {
|
addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b }
|
||||||
// We should really return the error here, but this
|
not := func(a bool) bool { return !a }
|
||||||
// function is also testing single return value functions.
|
regex := func(re, s string) (bool, error) {
|
||||||
panic("Bad regexp")
|
return regexp.MatchString(re, s)
|
||||||
}
|
|
||||||
return matched
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
|
sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{
|
||||||
ConnectHook: func(conn *SQLiteConn) error {
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil {
|
if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil {
|
if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("addiu", addiu, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("not", not, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterFunc("regex", regex, true); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -1090,42 +1107,29 @@ func TestFunctionRegistration(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|
||||||
additions := []struct {
|
ops := []struct {
|
||||||
a, b, c int64
|
query string
|
||||||
|
expected interface{}
|
||||||
}{
|
}{
|
||||||
{1, 1, 2},
|
{"SELECT addi_8_16_32(1,2)", int32(3)},
|
||||||
{1, 3, 4},
|
{"SELECT addi_64(1,2)", int64(3)},
|
||||||
{1, -1, 0},
|
{"SELECT addu_8_16_32(1,2)", uint32(3)},
|
||||||
|
{"SELECT addu_64(1,2)", uint64(3)},
|
||||||
|
{"SELECT addiu(1,2)", int64(3)},
|
||||||
|
{"SELECT addf_32_64(1.5,1.5)", float64(3)},
|
||||||
|
{"SELECT not(1)", false},
|
||||||
|
{"SELECT not(0)", true},
|
||||||
|
{`SELECT regex("^foo.*", "foobar")`, true},
|
||||||
|
{`SELECT regex("^foo.*", "barfoobar")`, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, add := range additions {
|
for _, op := range ops {
|
||||||
var i int64
|
ret := reflect.New(reflect.TypeOf(op.expected))
|
||||||
err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i)
|
err = db.QueryRow(op.query).Scan(ret.Interface())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to call custom_add:", err)
|
t.Errorf("Query %q failed: %s", op.query, err)
|
||||||
}
|
} else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) {
|
||||||
if i != add.c {
|
t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected)
|
||||||
t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
regexes := []struct {
|
|
||||||
re, in string
|
|
||||||
out bool
|
|
||||||
}{
|
|
||||||
{".*", "foo", true},
|
|
||||||
{"^foo.*", "foobar", true},
|
|
||||||
{"^foo.*", "barfoo", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, re := range regexes {
|
|
||||||
var b bool
|
|
||||||
err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("Failed to call regexp:", err)
|
|
||||||
}
|
|
||||||
if b != re.out {
|
|
||||||
t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1134,8 +1138,8 @@ var customFunctionOnce sync.Once
|
||||||
|
|
||||||
func BenchmarkCustomFunctions(b *testing.B) {
|
func BenchmarkCustomFunctions(b *testing.B) {
|
||||||
customFunctionOnce.Do(func() {
|
customFunctionOnce.Do(func() {
|
||||||
custom_add := func(a, b int64) (int64, error) {
|
custom_add := func(a, b int64) int64 {
|
||||||
return a + b, nil
|
return a + b
|
||||||
}
|
}
|
||||||
|
|
||||||
sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
|
sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{
|
||||||
|
|
Loading…
Reference in New Issue