From 5764183c0a8c5302eea3dc669128433ac529e17e Mon Sep 17 00:00:00 2001 From: Rick Branson Date: Tue, 6 Mar 2018 14:30:23 -0800 Subject: [PATCH] Add support for pre-update hooks --- callback.go | 23 +++- convert.go | 301 ++++++++++++++++++++++++++++++++++++++++++++++++ sqlite3.go | 96 +++++++++++++++ sqlite3_test.go | 114 ++++++++++++++++++ 4 files changed, 532 insertions(+), 2 deletions(-) create mode 100644 convert.go diff --git a/callback.go b/callback.go index 29ece3d..848b042 100644 --- a/callback.go +++ b/callback.go @@ -77,6 +77,21 @@ func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, row callback(op, C.GoString(db), C.GoString(table), rowid) } +//export preUpdateHookTrampoline +func preUpdateHookTrampoline(handle uintptr, dbHandle uintptr, op int, db *C.char, table *C.char, oldrowid int64, newrowid int64) { + hval := lookupHandleVal(handle) + data := SQLitePreUpdateData{ + Conn: hval.db, + Op: op, + DatabaseName: C.GoString(db), + TableName: C.GoString(table), + OldRowID: oldrowid, + NewRowID: newrowid, + } + callback := hval.val.(func(SQLitePreUpdateData)) + callback(data) +} + // Use handles to avoid passing Go pointers to C. type handleVal struct { @@ -97,7 +112,7 @@ func newHandle(db *SQLiteConn, v interface{}) uintptr { return i } -func lookupHandle(handle uintptr) interface{} { +func lookupHandleVal(handle uintptr) handleVal { handleLock.Lock() defer handleLock.Unlock() r, ok := handleVals[handle] @@ -108,7 +123,11 @@ func lookupHandle(handle uintptr) interface{} { panic("invalid handle") } } - return r.val + return r +} + +func lookupHandle(handle uintptr) interface{} { + return lookupHandleVal(handle).val } func deleteHandles(db *SQLiteConn) { diff --git a/convert.go b/convert.go new file mode 100644 index 0000000..2c39226 --- /dev/null +++ b/convert.go @@ -0,0 +1,301 @@ +// Extracted from Go database/sql source code + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Type conversions for Scan. + +package sqlite3 + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "time" +) + +var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +// convertAssign copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. +func convertAssign(dest, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = append((*d)[:0], s...) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = string(s) + return nil + case *interface{}: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = cloneBytes(s) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s + return nil + } + case time.Time: + switch d := dest.(type) { + case *time.Time: + *d = s + return nil + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) + return nil + } + case nil: + switch d := dest.(type) { + case *interface{}: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *sql.RawBytes: + if d == nil { + return errNilPtr + } + *d = nil + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = asString(src) + return nil + } + case *[]byte: + sv = reflect.ValueOf(src) + if b, ok := asBytes(nil, sv); ok { + *d = b + return nil + } + case *sql.RawBytes: + sv = reflect.ValueOf(src) + if b, ok := asBytes([]byte(*d)[:0], sv); ok { + *d = sql.RawBytes(b) + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *interface{}: + *d = src + return nil + } + + if scanner, ok := dest.(sql.Scanner); ok { + return scanner.Scan(src) + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Ptr { + return errors.New("destination not a pointer") + } + if dpv.IsNil() { + return errNilPtr + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(cloneBytes(b))) + default: + dv.Set(sv) + } + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Ptr: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } else { + dv.Set(reflect.New(dv.Type().Elem())) + return convertAssign(dv.Interface(), src) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + } + + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) +} + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } else { + c := make([]byte, len(b)) + copy(c, b) + return c + } +} + +func asString(src interface{}) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(buf, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(buf, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(buf, rv.Bool()), true + case reflect.String: + s := rv.String() + return append(buf, s...), true + } + return +} diff --git a/sqlite3.go b/sqlite3.go index ff29d79..cb156c5 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -14,6 +14,7 @@ package sqlite3 #cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61 #cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15 #cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC +#cgo CFLAGS: -DSQLITE_ENABLE_PREUPDATE_HOOK #cgo CFLAGS: -Wno-deprecated-declarations #ifndef USE_LIBSQLITE3 #include @@ -110,6 +111,7 @@ int compareTrampoline(void*, int, char*, int, char*); int commitHookTrampoline(void*); void rollbackHookTrampoline(void*); void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64); +void preUpdateHookTrampoline(void*, sqlite3 *, int, char *, char *, sqlite3_int64, sqlite3_int64); #ifdef SQLITE_LIMIT_WORKER_THREADS # define _SQLITE_HAS_LIMIT @@ -239,6 +241,17 @@ type SQLiteRows struct { done chan struct{} } +// SQLitePreUpdateData represents all of the data available during a +// pre-update hook call. +type SQLitePreUpdateData struct { + Conn *SQLiteConn + Op int + DatabaseName string + TableName string + OldRowID int64 + NewRowID int64 +} + type functionInfo struct { f reflect.Value argConverters []callbackArgConverter @@ -426,6 +439,22 @@ func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64 } } +// RegisterPreUpdateHook sets the pre-update hook for a connection. +// +// The callback is passed a SQLitePreUpdateData struct with the data for +// the update, as well as methods for fetching copies of impacted data. +// +// If there is an existing update hook for this connection, it will be +// removed. If callback is nil the existing hook (if any) will be removed +// without creating a new one. +func (c *SQLiteConn) RegisterPreUpdateHook(callback func(SQLitePreUpdateData)) { + if callback == nil { + C.sqlite3_preupdate_hook(c.db, nil, nil) + } else { + C.sqlite3_preupdate_hook(c.db, (*[0]byte)(unsafe.Pointer(C.preUpdateHookTrampoline)), unsafe.Pointer(newHandle(c, callback))) + } +} + // RegisterFunc makes a Go function available as a SQLite function. // // The Go function can have arguments of the following types: any @@ -1351,3 +1380,70 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { } return nil } + +// Depth returns the source path of the write, see sqlite3_preupdate_depth() +func (d *SQLitePreUpdateData) Depth() int { + return int(C.sqlite3_preupdate_depth(d.Conn.db)) +} + +// Count returns the number of columns in the row +func (d *SQLitePreUpdateData) Count() int { + return int(C.sqlite3_preupdate_count(d.Conn.db)) +} + +func (d *SQLitePreUpdateData) row(dest []interface{}, new bool) error { + for i := 0; i < d.Count() && i < len(dest); i++ { + var val *C.sqlite3_value + var src interface{} + + // Initially I tried making this just a function pointer argument, but + // it's absurdly complicated to pass C function pointers. + if new { + C.sqlite3_preupdate_new(d.Conn.db, C.int(i), &val) + } else { + C.sqlite3_preupdate_old(d.Conn.db, C.int(i), &val) + } + + switch C.sqlite3_value_type(val) { + case C.SQLITE_INTEGER: + src = int64(C.sqlite3_value_int64(val)) + case C.SQLITE_FLOAT: + src = float64(C.sqlite3_value_double(val)) + case C.SQLITE_BLOB: + len := C.sqlite3_value_bytes(val) + blobptr := C.sqlite3_value_blob(val) + src = C.GoBytes(blobptr, len) + case C.SQLITE_TEXT: + len := C.sqlite3_value_bytes(val) + cstrptr := unsafe.Pointer(C.sqlite3_value_text(val)) + src = C.GoBytes(cstrptr, len) + case C.SQLITE_NULL: + src = nil + } + + err := convertAssign(&dest[i], src) + if err != nil { + return err + } + } + + return nil +} + +// Old populates dest with the row data to be replaced. This works similar to +// database/sql's Rows.Scan() +func (d *SQLitePreUpdateData) Old(dest ...interface{}) error { + if d.Op == SQLITE_INSERT { + return errors.New("There is no old row for INSERT operations") + } + return d.row(dest, false) +} + +// New populates dest with the replacement row data. This works similar to +// database/sql's Rows.Scan() +func (d *SQLitePreUpdateData) New(dest ...interface{}) error { + if d.Op == SQLITE_DELETE { + return errors.New("There is no new row for DELETE operations") + } + return d.row(dest, true) +} diff --git a/sqlite3_test.go b/sqlite3_test.go index 84ecb5a..61d2f59 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1464,6 +1464,120 @@ func TestPinger(t *testing.T) { } } +type preUpdateHookDataForTest struct { + databaseName string + tableName string + count int + op int + oldRow []interface{} + newRow []interface{} +} + +func TestPreUpdateHook(t *testing.T) { + var events []preUpdateHookDataForTest + + sql.Register("sqlite3_PreUpdateHook", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + conn.RegisterPreUpdateHook(func(data SQLitePreUpdateData) { + eval := -1 + oldRow := []interface{}{eval} + if data.Op != SQLITE_INSERT { + err := data.Old(oldRow...) + if err != nil { + t.Fatalf("Unexpected error calling SQLitePreUpdateData.Old: %v", err) + } + } + + eval2 := -1 + newRow := []interface{}{eval2} + if data.Op != SQLITE_DELETE { + err := data.New(newRow...) + if err != nil { + t.Fatalf("Unexpected error calling SQLitePreUpdateData.New: %v", err) + } + } + + // tests dest bound checks in loop + var tooSmallRow []interface{} + if data.Op != SQLITE_INSERT { + err := data.Old(tooSmallRow...) + if err != nil { + t.Fatalf("Unexpected error calling SQLitePreUpdateData.Old: %v", err) + } + if len(tooSmallRow) != 0 { + t.Errorf("Expected tooSmallRow to be empty, got: %v", tooSmallRow) + } + } + + events = append(events, preUpdateHookDataForTest{ + databaseName: data.DatabaseName, + tableName: data.TableName, + count: data.Count(), + op: data.Op, + oldRow: oldRow, + newRow: newRow, + }) + }) + return nil + }, + }) + + db, err := sql.Open("sqlite3_PreUpdateHook", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + statements := []string{ + "create table foo (id integer primary key)", + "insert into foo values (9)", + "update foo set id = 99 where id = 9", + "delete from foo where id = 99", + } + for _, statement := range statements { + _, err = db.Exec(statement) + if err != nil { + t.Fatalf("Unable to prepare test data [%v]: %v", statement, err) + } + } + + if len(events) != 3 { + t.Errorf("Events should be 3 entries, got: %d", len(events)) + } + + if events[0].op != SQLITE_INSERT { + t.Errorf("Op isn't as expected: %v", events[0].op) + } + + if events[1].op != SQLITE_UPDATE { + t.Errorf("Op isn't as expected: %v", events[1].op) + } + + if events[1].count != 1 { + t.Errorf("Expected event row 1 to have 1 column, had: %v", events[1].count) + } + + newRow_0_0 := events[0].newRow[0].(int64) + if newRow_0_0 != 9 { + t.Errorf("Expected event row 0 new column 0 to be == 9, got: %v", newRow_0_0) + } + + oldRow_1_0 := events[1].oldRow[0].(int64) + if oldRow_1_0 != 9 { + t.Errorf("Expected event row 1 old column 0 to be == 9, got: %v", oldRow_1_0) + } + + newRow_1_0 := events[1].newRow[0].(int64) + if newRow_1_0 != 99 { + t.Errorf("Expected event row 1 new column 0 to be == 99, got: %v", newRow_1_0) + } + + oldRow_2_0 := events[2].oldRow[0].(int64) + if oldRow_2_0 != 99 { + t.Errorf("Expected event row 1 new column 0 to be == 99, got: %v", oldRow_2_0) + } +} + func TestUpdateAndTransactionHooks(t *testing.T) { var events []string var commitHookReturn = 0