diff --git a/callback.go b/callback.go index ee9d40c..e2bf3c6 100644 --- a/callback.go +++ b/callback.go @@ -24,29 +24,75 @@ import ( "fmt" "math" "reflect" + "sync" "unsafe" ) //export callbackTrampoline func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc] - fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) + fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo) fi.Call(ctx, args) } //export stepTrampoline func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc] - ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) + ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo) ai.Step(ctx, args) } //export doneTrampoline func doneTrampoline(ctx *C.sqlite3_context) { - ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx))) + handle := uintptr(C.sqlite3_user_data(ctx)) + ai := lookupHandle(handle).(*aggInfo) ai.Done(ctx) } +// Use handles to avoid passing Go pointers to C. + +type handleVal struct { + db *SQLiteConn + val interface{} +} + +var handleLock sync.Mutex +var handleVals = make(map[uintptr]handleVal) +var handleIndex uintptr = 100 + +func newHandle(db *SQLiteConn, v interface{}) uintptr { + handleLock.Lock() + defer handleLock.Unlock() + i := handleIndex + handleIndex++ + handleVals[i] = handleVal{db, v} + return i +} + +func lookupHandle(handle uintptr) interface{} { + handleLock.Lock() + defer handleLock.Unlock() + r, ok := handleVals[handle] + if !ok { + if handle >= 100 && handle < handleIndex { + panic("deleted handle") + } else { + panic("invalid handle") + } + } + return r.val +} + +func deleteHandles(db *SQLiteConn) { + handleLock.Lock() + defer handleLock.Unlock() + for handle, val := range handleVals { + if val.db == db { + delete(handleVals, handle) + } + } +} + // This is only here so that tests can refer to it. type callbackArgRaw C.sqlite3_value diff --git a/sqlite3.go b/sqlite3.go index f79ef2d..964acbb 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -367,7 +367,7 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro if pure { opts |= C.SQLITE_DETERMINISTIC } - rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(uintptr(unsafe.Pointer(&fi))), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil) + rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(newHandle(c, &fi)), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil) if rv != C.SQLITE_OK { return c.lastError() } @@ -492,7 +492,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool if pure { opts |= C.SQLITE_DETERMINISTIC } - rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(uintptr(unsafe.Pointer(&ai))), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline))) + rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline))) if rv != C.SQLITE_OK { return c.lastError() } @@ -705,6 +705,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { // Close the connection. func (c *SQLiteConn) Close() error { + deleteHandles(c) rv := C.sqlite3_close_v2(c.db) if rv != C.SQLITE_OK { return c.lastError()