forked from mirror/go-sqlcipher
Merge pull request #1 from toba/missing-callback-hooks
Incorporate original PR 271 from https://github.com/brokensandals
This commit is contained in:
commit
dbaad204e9
|
@ -14,6 +14,12 @@ func main() {
|
||||||
&sqlite3.SQLiteDriver{
|
&sqlite3.SQLiteDriver{
|
||||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||||
sqlite3conn = append(sqlite3conn, conn)
|
sqlite3conn = append(sqlite3conn, conn)
|
||||||
|
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
|
||||||
|
switch op {
|
||||||
|
case sqlite3.SQLITE_INSERT:
|
||||||
|
log.Println("Notified of insert on db", db, "table", table, "rowid", rowid)
|
||||||
|
}
|
||||||
|
})
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
18
callback.go
18
callback.go
|
@ -53,6 +53,24 @@ func doneTrampoline(ctx *C.sqlite3_context) {
|
||||||
ai.Done(ctx)
|
ai.Done(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//export commitHookTrampoline
|
||||||
|
func commitHookTrampoline(handle uintptr) int {
|
||||||
|
callback := lookupHandle(handle).(func() int)
|
||||||
|
return callback()
|
||||||
|
}
|
||||||
|
|
||||||
|
//export rollbackHookTrampoline
|
||||||
|
func rollbackHookTrampoline(handle uintptr) {
|
||||||
|
callback := lookupHandle(handle).(func())
|
||||||
|
callback()
|
||||||
|
}
|
||||||
|
|
||||||
|
//export updateHookTrampoline
|
||||||
|
func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
|
||||||
|
callback := lookupHandle(handle).(func(int, string, string, int64))
|
||||||
|
callback(op, C.GoString(db), C.GoString(table), rowid)
|
||||||
|
}
|
||||||
|
|
||||||
// Use handles to avoid passing Go pointers to C.
|
// Use handles to avoid passing Go pointers to C.
|
||||||
|
|
||||||
type handleVal struct {
|
type handleVal struct {
|
||||||
|
|
54
sqlite3.go
54
sqlite3.go
|
@ -100,6 +100,9 @@ int _sqlite3_create_function(
|
||||||
}
|
}
|
||||||
|
|
||||||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||||
|
int commitHookTrampoline(void*);
|
||||||
|
void rollbackHookTrampoline(void*);
|
||||||
|
void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
|
@ -150,6 +153,12 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) {
|
||||||
return libVersion, libVersionNumber, sourceID
|
return libVersion, libVersionNumber, sourceID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
SQLITE_DELETE = C.SQLITE_DELETE
|
||||||
|
SQLITE_INSERT = C.SQLITE_INSERT
|
||||||
|
SQLITE_UPDATE = C.SQLITE_UPDATE
|
||||||
|
)
|
||||||
|
|
||||||
// SQLiteDriver implement sql.Driver.
|
// SQLiteDriver implement sql.Driver.
|
||||||
type SQLiteDriver struct {
|
type SQLiteDriver struct {
|
||||||
Extensions []string
|
Extensions []string
|
||||||
|
@ -315,6 +324,51 @@ func (tx *SQLiteTx) Rollback() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterCommitHook sets the commit hook for a connection.
|
||||||
|
//
|
||||||
|
// If the callback returns non-zero the transaction will become a rollback.
|
||||||
|
//
|
||||||
|
// If there is an existing commit hook for this connection, it will be
|
||||||
|
// removed. If callback is nil the existing hook (if any) will be removed
|
||||||
|
// without creating a new one.
|
||||||
|
func (c *SQLiteConn) RegisterCommitHook(callback func() int) {
|
||||||
|
if callback == nil {
|
||||||
|
C.sqlite3_commit_hook(c.db, nil, nil)
|
||||||
|
} else {
|
||||||
|
C.sqlite3_commit_hook(c.db, (*[0]byte)(unsafe.Pointer(C.commitHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRollbackHook sets the rollback hook for a connection.
|
||||||
|
//
|
||||||
|
// If there is an existing rollback hook for this connection, it will be
|
||||||
|
// removed. If callback is nil the existing hook (if any) will be removed
|
||||||
|
// without creating a new one.
|
||||||
|
func (c *SQLiteConn) RegisterRollbackHook(callback func()) {
|
||||||
|
if callback == nil {
|
||||||
|
C.sqlite3_rollback_hook(c.db, nil, nil)
|
||||||
|
} else {
|
||||||
|
C.sqlite3_rollback_hook(c.db, (*[0]byte)(unsafe.Pointer(C.rollbackHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterUpdateHook sets the update hook for a connection.
|
||||||
|
//
|
||||||
|
// The parameters to the callback are the operation (one of the constants
|
||||||
|
// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the
|
||||||
|
// table name, and the rowid.
|
||||||
|
//
|
||||||
|
// If there is an existing update hook for this connection, it will be
|
||||||
|
// removed. If callback is nil the existing hook (if any) will be removed
|
||||||
|
// without creating a new one.
|
||||||
|
func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) {
|
||||||
|
if callback == nil {
|
||||||
|
C.sqlite3_update_hook(c.db, nil, nil)
|
||||||
|
} else {
|
||||||
|
C.sqlite3_update_hook(c.db, (*[0]byte)(unsafe.Pointer(C.updateHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterFunc makes a Go function available as a SQLite function.
|
// RegisterFunc makes a Go function available as a SQLite function.
|
||||||
//
|
//
|
||||||
// The Go function can have arguments of the following types: any
|
// The Go function can have arguments of the following types: any
|
||||||
|
|
|
@ -1265,6 +1265,67 @@ func TestPinger(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateAndTransactionHooks(t *testing.T) {
|
||||||
|
var events []string
|
||||||
|
var commitHookReturn = 0
|
||||||
|
|
||||||
|
sql.Register("sqlite3_UpdateHook", &SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
conn.RegisterCommitHook(func() int {
|
||||||
|
events = append(events, "commit")
|
||||||
|
return commitHookReturn
|
||||||
|
})
|
||||||
|
conn.RegisterRollbackHook(func() {
|
||||||
|
events = append(events, "rollback")
|
||||||
|
})
|
||||||
|
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
|
||||||
|
events = append(events, fmt.Sprintf("update(op=%v db=%v table=%v rowid=%v)", op, db, table, rowid))
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
db, err := sql.Open("sqlite3_UpdateHook", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
statements := []string{
|
||||||
|
"create table foo (id integer primary key)",
|
||||||
|
"insert into foo values (9)",
|
||||||
|
"update foo set id = 99 where id = 9",
|
||||||
|
"delete from foo where id = 99",
|
||||||
|
}
|
||||||
|
for _, statement := range statements {
|
||||||
|
_, err = db.Exec(statement)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unable to prepare test data [%v]: %v", statement, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
commitHookReturn = 1
|
||||||
|
_, err = db.Exec("insert into foo values (5)")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Commit hook failed to rollback transaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
var expected = []string{
|
||||||
|
"commit",
|
||||||
|
fmt.Sprintf("update(op=%v db=main table=foo rowid=9)", SQLITE_INSERT),
|
||||||
|
"commit",
|
||||||
|
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_UPDATE),
|
||||||
|
"commit",
|
||||||
|
fmt.Sprintf("update(op=%v db=main table=foo rowid=99)", SQLITE_DELETE),
|
||||||
|
"commit",
|
||||||
|
fmt.Sprintf("update(op=%v db=main table=foo rowid=5)", SQLITE_INSERT),
|
||||||
|
"commit",
|
||||||
|
"rollback",
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(events, expected) {
|
||||||
|
t.Errorf("Expected notifications %v but got %v", expected, events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var customFunctionOnce sync.Once
|
var customFunctionOnce sync.Once
|
||||||
|
|
||||||
func BenchmarkCustomFunctions(b *testing.B) {
|
func BenchmarkCustomFunctions(b *testing.B) {
|
||||||
|
|
Loading…
Reference in New Issue