diff --git a/callback.go b/callback.go index 938d7fe..1692106 100644 --- a/callback.go +++ b/callback.go @@ -5,12 +5,25 @@ package sqlite3 +// You can't export a Go function to C and have definitions in the C +// preamble in the same file, so we have to have callbackTrampoline in +// its own file. Because we need a separate file anyway, the support +// code for SQLite custom functions is in here. + /* #include + +void _sqlite3_result_text(sqlite3_context* ctx, const char* s); +void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l); */ import "C" -import "unsafe" +import ( + "errors" + "fmt" + "reflect" + "unsafe" +) //export callbackTrampoline func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { @@ -18,3 +31,188 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) fi.Call(ctx, args) } + +// This is only here so that tests can refer to it. +type callbackArgRaw C.sqlite3_value + +type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error) + +type callbackArgCast struct { + f callbackArgConverter + typ reflect.Type +} + +func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) { + val, err := c.f(v) + if err != nil { + return reflect.Value{}, err + } + if !val.Type().ConvertibleTo(c.typ) { + return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ) + } + return val.Convert(c.typ), nil +} + +func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") + } + return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil +} + +func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { + return reflect.Value{}, fmt.Errorf("argument must be an INTEGER") + } + i := int64(C.sqlite3_value_int64(v)) + val := false + if i != 0 { + val = true + } + return reflect.ValueOf(val), nil +} + +func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) { + if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { + return reflect.Value{}, fmt.Errorf("argument must be a FLOAT") + } + return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil +} + +func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := C.sqlite3_value_blob(v) + return reflect.ValueOf(C.GoBytes(p, l)), nil + case C.SQLITE_TEXT: + l := C.sqlite3_value_bytes(v) + c := unsafe.Pointer(C.sqlite3_value_text(v)) + return reflect.ValueOf(C.GoBytes(c, l)), nil + default: + return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") + } +} + +func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) { + switch C.sqlite3_value_type(v) { + case C.SQLITE_BLOB: + l := C.sqlite3_value_bytes(v) + p := (*C.char)(C.sqlite3_value_blob(v)) + return reflect.ValueOf(C.GoStringN(p, l)), nil + case C.SQLITE_TEXT: + c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) + return reflect.ValueOf(C.GoString(c)), nil + default: + return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT") + } +} + +func callbackArg(typ reflect.Type) (callbackArgConverter, error) { + switch typ.Kind() { + case reflect.Slice: + if typ.Elem().Kind() != reflect.Uint8 { + return nil, errors.New("the only supported slice type is []byte") + } + return callbackArgBytes, nil + case reflect.String: + return callbackArgString, nil + case reflect.Bool: + return callbackArgBool, nil + case reflect.Int64: + return callbackArgInt64, nil + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + c := callbackArgCast{callbackArgInt64, typ} + return c.Run, nil + case reflect.Float64: + return callbackArgFloat64, nil + case reflect.Float32: + c := callbackArgCast{callbackArgFloat64, typ} + return c.Run, nil + default: + return nil, fmt.Errorf("don't know how to convert to %s", typ) + } +} + +type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error + +func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error { + switch v.Type().Kind() { + case reflect.Int64: + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + v = v.Convert(reflect.TypeOf(int64(0))) + case reflect.Bool: + b := v.Interface().(bool) + if b { + v = reflect.ValueOf(int64(1)) + } else { + v = reflect.ValueOf(int64(0)) + } + default: + return fmt.Errorf("cannot convert %s to INTEGER", v.Type()) + } + + C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64))) + return nil +} + +func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error { + switch v.Type().Kind() { + case reflect.Float64: + case reflect.Float32: + v = v.Convert(reflect.TypeOf(float64(0))) + default: + return fmt.Errorf("cannot convert %s to FLOAT", v.Type()) + } + + C.sqlite3_result_double(ctx, C.double(v.Interface().(float64))) + return nil +} + +func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error { + if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("cannot convert %s to BLOB", v.Type()) + } + i := v.Interface() + if i == nil || len(i.([]byte)) == 0 { + C.sqlite3_result_null(ctx) + } else { + bs := i.([]byte) + C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs))) + } + return nil +} + +func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error { + if v.Type().Kind() != reflect.String { + return fmt.Errorf("cannot convert %s to TEXT", v.Type()) + } + C._sqlite3_result_text(ctx, C.CString(v.Interface().(string))) + return nil +} + +func callbackRet(typ reflect.Type) (callbackRetConverter, error) { + switch typ.Kind() { + case reflect.Slice: + if typ.Elem().Kind() != reflect.Uint8 { + return nil, errors.New("the only supported slice type is []byte") + } + return callbackRetBlob, nil + case reflect.String: + return callbackRetText, nil + case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint: + return callbackRetInteger, nil + case reflect.Float32, reflect.Float64: + return callbackRetFloat, nil + default: + return nil, fmt.Errorf("don't know how to convert to %s", typ) + } +} + +// Test support code. Tests are not allowed to import "C", so we can't +// declare any functions that use C.sqlite3_value. +func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter { + return func(*C.sqlite3_value) (reflect.Value, error) { + return v, err + } +} diff --git a/callback_test.go b/callback_test.go new file mode 100644 index 0000000..5c61f44 --- /dev/null +++ b/callback_test.go @@ -0,0 +1,97 @@ +package sqlite3 + +import ( + "errors" + "math" + "reflect" + "testing" +) + +func TestCallbackArgCast(t *testing.T) { + intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil) + floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil) + errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test")) + + tests := []struct { + f callbackArgConverter + o reflect.Value + }{ + {intConv, reflect.ValueOf(int8(-1))}, + {intConv, reflect.ValueOf(int16(-1))}, + {intConv, reflect.ValueOf(int32(-1))}, + {intConv, reflect.ValueOf(uint8(math.MaxUint8))}, + {intConv, reflect.ValueOf(uint16(math.MaxUint16))}, + {intConv, reflect.ValueOf(uint32(math.MaxUint32))}, + // Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1 + {intConv, reflect.ValueOf(uint64(math.MaxInt64))}, + {floatConv, reflect.ValueOf(float32(math.Inf(1)))}, + } + + for _, test := range tests { + conv := callbackArgCast{test.f, test.o.Type()} + val, err := conv.Run(nil) + if err != nil { + t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err) + } else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) { + t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface()) + } + } + + conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))} + _, err := conv.Run(nil) + if err == nil { + t.Errorf("Expected error during callbackArgCast, but got none") + } +} + +func TestCallbackConverters(t *testing.T) { + tests := []struct { + v interface{} + err bool + }{ + // Unfortunately, we can't tell which converter was returned, + // but we can at least check which types can be converted. + {[]byte{0}, false}, + {"text", false}, + {true, false}, + {int8(0), false}, + {int16(0), false}, + {int32(0), false}, + {int64(0), false}, + {uint8(0), false}, + {uint16(0), false}, + {uint32(0), false}, + {uint64(0), false}, + {int(0), false}, + {uint(0), false}, + {float64(0), false}, + {float32(0), false}, + + {func() {}, true}, + {complex64(complex(0, 0)), true}, + {complex128(complex(0, 0)), true}, + {struct{}{}, true}, + {map[string]string{}, true}, + {[]string{}, true}, + {(*int8)(nil), true}, + {make(chan int), true}, + } + + for _, test := range tests { + _, err := callbackArg(reflect.TypeOf(test.v)) + if test.err && err == nil { + t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v)) + } else if !test.err && err != nil { + t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err) + } + } + + for _, test := range tests { + _, err := callbackRet(reflect.TypeOf(test.v)) + if test.err && err == nil { + t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v)) + } else if !test.err && err != nil { + t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err) + } + } +} diff --git a/sqlite3.go b/sqlite3.go index f995589..174a3ee 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -166,7 +166,8 @@ type SQLiteRows struct { type functionInfo struct { f reflect.Value - argConverters []func(*C.sqlite3_value) (reflect.Value, error) + argConverters []callbackArgConverter + retConverter callbackRetConverter } func (fi *functionInfo) error(ctx *C.sqlite3_context, err error) { @@ -193,58 +194,11 @@ func (fi *functionInfo) Call(ctx *C.sqlite3_context, argv []*C.sqlite3_value) { return } - res := ret[0].Interface() - // Normalize ret to one of the types sqlite knows. - switch r := res.(type) { - case int64, float64, []byte, string: - // Already the right type - case bool: - if r { - res = int64(1) - } else { - res = int64(0) - } - case int: - res = int64(r) - case uint: - res = int64(r) - case uint8: - res = int64(r) - case uint16: - res = int64(r) - case uint32: - res = int64(r) - case uint64: - res = int64(r) - case int8: - res = int64(r) - case int16: - res = int64(r) - case int32: - res = int64(r) - case float32: - res = float64(r) - default: - fi.error(ctx, errors.New("cannot convert returned type to sqlite type")) + err := fi.retConverter(ctx, ret[0]) + if err != nil { + fi.error(ctx, err) return } - - switch r := res.(type) { - case int64: - C.sqlite3_result_int64(ctx, C.sqlite3_int64(r)) - case float64: - C.sqlite3_result_double(ctx, C.double(r)) - case []byte: - if len(r) == 0 { - C.sqlite3_result_null(ctx) - } else { - C._sqlite3_result_blob(ctx, unsafe.Pointer(&r[0]), C.int(len(r))) - } - case string: - C._sqlite3_result_text(ctx, C.CString(r)) - default: - panic("unreachable") - } } // Commit transaction. @@ -261,10 +215,10 @@ func (tx *SQLiteTx) Rollback() error { // RegisterFunc makes a Go function available as a SQLite function. // -// The function must accept only arguments of type int64, float64, -// []byte or string, and return one value of any numeric type except -// complex, bool, []byte or string. Optionally, an error can be -// provided as a second return value. +// The function can accept arguments of any real numeric type +// (i.e. not complex), as well as []byte and string. It must return a +// value of one of those types, and optionally an error as a second +// value. // // If pure is true. SQLite will assume that the function's return // value depends only on its inputs, and make more aggressive @@ -287,59 +241,19 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro } for i := 0; i < t.NumIn(); i++ { - arg := t.In(i) - var conv func(*C.sqlite3_value) (reflect.Value, error) - switch arg.Kind() { - case reflect.Int64: - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - if C.sqlite3_value_type(v) != C.SQLITE_INTEGER { - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be an INTEGER", i+1, name) - } - return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil - } - case reflect.Float64: - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - if C.sqlite3_value_type(v) != C.SQLITE_FLOAT { - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be a FLOAT", i+1, name) - } - return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil - } - case reflect.Slice: - if arg.Elem().Kind() != reflect.Uint8 { - return errors.New("The only supported slice type is []byte") - } - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - switch C.sqlite3_value_type(v) { - case C.SQLITE_BLOB: - l := C.sqlite3_value_bytes(v) - p := C.sqlite3_value_blob(v) - return reflect.ValueOf(C.GoBytes(p, l)), nil - case C.SQLITE_TEXT: - l := C.sqlite3_value_bytes(v) - c := unsafe.Pointer(C.sqlite3_value_text(v)) - return reflect.ValueOf(C.GoBytes(c, l)), nil - default: - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) - } - } - case reflect.String: - conv = func(v *C.sqlite3_value) (reflect.Value, error) { - switch C.sqlite3_value_type(v) { - case C.SQLITE_BLOB: - l := C.sqlite3_value_bytes(v) - p := (*C.char)(C.sqlite3_value_blob(v)) - return reflect.ValueOf(C.GoStringN(p, l)), nil - case C.SQLITE_TEXT: - c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v))) - return reflect.ValueOf(C.GoString(c)), nil - default: - return reflect.Value{}, fmt.Errorf("Argument %d to %s must be BLOB or TEXT", i+1, name) - } - } + conv, err := callbackArg(t.In(i)) + if err != nil { + return err } fi.argConverters = append(fi.argConverters, conv) } + conv, err := callbackRet(t.Out(0)) + if err != nil { + return err + } + fi.retConverter = conv + // fi must outlast the database connection, or we'll have dangling pointers. c.funcs = append(c.funcs, &fi) diff --git a/sqlite3_test.go b/sqlite3_test.go index a58e373..e8dfe5c 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -15,6 +15,7 @@ import ( "net/url" "os" "path/filepath" + "reflect" "regexp" "strings" "sync" @@ -1060,25 +1061,41 @@ func TestDateTimeNow(t *testing.T) { } func TestFunctionRegistration(t *testing.T) { - custom_add := func(a, b int64) (int64, error) { - return a + b, nil - } - custom_regex := func(s, re string) bool { - matched, err := regexp.MatchString(re, s) - if err != nil { - // We should really return the error here, but this - // function is also testing single return value functions. - panic("Bad regexp") - } - return matched + addi_8_16_32 := func(a int8, b int16) int32 { return int32(a) + int32(b) } + addi_64 := func(a, b int64) int64 { return a + b } + addu_8_16_32 := func(a uint8, b uint16) uint32 { return uint32(a) + uint32(b) } + addu_64 := func(a, b uint64) uint64 { return a + b } + addiu := func(a int, b uint) int64 { return int64(a) + int64(b) } + addf_32_64 := func(a float32, b float64) float64 { return float64(a) + b } + not := func(a bool) bool { return !a } + regex := func(re, s string) (bool, error) { + return regexp.MatchString(re, s) } sql.Register("sqlite3_FunctionRegistration", &SQLiteDriver{ ConnectHook: func(conn *SQLiteConn) error { - if err := conn.RegisterFunc("custom_add", custom_add, true); err != nil { + if err := conn.RegisterFunc("addi_8_16_32", addi_8_16_32, true); err != nil { return err } - if err := conn.RegisterFunc("regexp", custom_regex, true); err != nil { + if err := conn.RegisterFunc("addi_64", addi_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("addu_8_16_32", addu_8_16_32, true); err != nil { + return err + } + if err := conn.RegisterFunc("addu_64", addu_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("addiu", addiu, true); err != nil { + return err + } + if err := conn.RegisterFunc("addf_32_64", addf_32_64, true); err != nil { + return err + } + if err := conn.RegisterFunc("not", not, true); err != nil { + return err + } + if err := conn.RegisterFunc("regex", regex, true); err != nil { return err } return nil @@ -1090,42 +1107,29 @@ func TestFunctionRegistration(t *testing.T) { } defer db.Close() - additions := []struct { - a, b, c int64 + ops := []struct { + query string + expected interface{} }{ - {1, 1, 2}, - {1, 3, 4}, - {1, -1, 0}, + {"SELECT addi_8_16_32(1,2)", int32(3)}, + {"SELECT addi_64(1,2)", int64(3)}, + {"SELECT addu_8_16_32(1,2)", uint32(3)}, + {"SELECT addu_64(1,2)", uint64(3)}, + {"SELECT addiu(1,2)", int64(3)}, + {"SELECT addf_32_64(1.5,1.5)", float64(3)}, + {"SELECT not(1)", false}, + {"SELECT not(0)", true}, + {`SELECT regex("^foo.*", "foobar")`, true}, + {`SELECT regex("^foo.*", "barfoobar")`, false}, } - for _, add := range additions { - var i int64 - err = db.QueryRow("SELECT custom_add($1, $2)", add.a, add.b).Scan(&i) + for _, op := range ops { + ret := reflect.New(reflect.TypeOf(op.expected)) + err = db.QueryRow(op.query).Scan(ret.Interface()) if err != nil { - t.Fatal("Failed to call custom_add:", err) - } - if i != add.c { - t.Fatalf("custom_add returned the wrong value, got %d, want %d", i, add.c) - } - } - - regexes := []struct { - re, in string - out bool - }{ - {".*", "foo", true}, - {"^foo.*", "foobar", true}, - {"^foo.*", "barfoo", false}, - } - - for _, re := range regexes { - var b bool - err = db.QueryRow("SELECT regexp($1, $2)", re.in, re.re).Scan(&b) - if err != nil { - t.Fatal("Failed to call regexp:", err) - } - if b != re.out { - t.Fatalf("regexp returned the wrong value, got %v, want %v", b, re.out) + t.Errorf("Query %q failed: %s", op.query, err) + } else if !reflect.DeepEqual(ret.Elem().Interface(), op.expected) { + t.Errorf("Query %q returned wrong value: got %v (%T), want %v (%T)", op.query, ret.Elem().Interface(), ret.Elem().Interface(), op.expected, op.expected) } } } @@ -1134,8 +1138,8 @@ var customFunctionOnce sync.Once func BenchmarkCustomFunctions(b *testing.B) { customFunctionOnce.Do(func() { - custom_add := func(a, b int64) (int64, error) { - return a + b, nil + custom_add := func(a, b int64) int64 { + return a + b } sql.Register("sqlite3_BenchmarkCustomFunctions", &SQLiteDriver{