diff --git a/callback.go b/callback.go index 5a735c0..2c68973 100644 --- a/callback.go +++ b/callback.go @@ -77,6 +77,12 @@ func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, row callback(op, C.GoString(db), C.GoString(table), rowid) } +//export authorizerTrampoline +func authorizerTrampoline(handle uintptr, op int, arg1 *C.char, arg2 *C.char, arg3 *C.char) int { + callback := lookupHandle(handle).(func(int, string, string, string) int) + return callback(op, C.GoString(arg1), C.GoString(arg2), C.GoString(arg3)) +} + // Use handles to avoid passing Go pointers to C. type handleVal struct { diff --git a/sqlite3.go b/sqlite3.go index b97647b..0dc9b04 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -119,6 +119,8 @@ int commitHookTrampoline(void*); void rollbackHookTrampoline(void*); void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64); +int authorizerTrampoline(void*, int, char*, char*, char*, char*); + #ifdef SQLITE_LIMIT_WORKER_THREADS # define _SQLITE_HAS_LIMIT # define SQLITE_LIMIT_LENGTH 0 @@ -200,9 +202,43 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) { } const ( - SQLITE_DELETE = C.SQLITE_DELETE - SQLITE_INSERT = C.SQLITE_INSERT - SQLITE_UPDATE = C.SQLITE_UPDATE + SQLITE_OK = C.SQLITE_OK + SQLITE_IGNORE = C.SQLITE_IGNORE + SQLITE_DENY = C.SQLITE_DENY + SQLITE_CREATE_INDEX = C.SQLITE_CREATE_INDEX + SQLITE_CREATE_TABLE = C.SQLITE_CREATE_TABLE + SQLITE_CREATE_TEMP_INDEX = C.SQLITE_CREATE_TEMP_INDEX + SQLITE_CREATE_TEMP_TABLE = C.SQLITE_CREATE_TEMP_TABLE + SQLITE_CREATE_TEMP_TRIGGER = C.SQLITE_CREATE_TEMP_TRIGGER + SQLITE_CREATE_TEMP_VIEW = C.SQLITE_CREATE_TEMP_VIEW + SQLITE_CREATE_TRIGGER = C.SQLITE_CREATE_TRIGGER + SQLITE_CREATE_VIEW = C.SQLITE_CREATE_VIEW + SQLITE_DELETE = C.SQLITE_DELETE + SQLITE_DROP_INDEX = C.SQLITE_DROP_INDEX + SQLITE_DROP_TABLE = C.SQLITE_DROP_TABLE + SQLITE_DROP_TEMP_INDEX = C.SQLITE_DROP_TEMP_INDEX + SQLITE_DROP_TEMP_TABLE = C.SQLITE_DROP_TEMP_TABLE + SQLITE_DROP_TEMP_TRIGGER = C.SQLITE_DROP_TEMP_TRIGGER + SQLITE_DROP_TEMP_VIEW = C.SQLITE_DROP_TEMP_VIEW + SQLITE_DROP_TRIGGER = C.SQLITE_DROP_TRIGGER + SQLITE_DROP_VIEW = C.SQLITE_DROP_VIEW + SQLITE_INSERT = C.SQLITE_INSERT + SQLITE_PRAGMA = C.SQLITE_PRAGMA + SQLITE_READ = C.SQLITE_READ + SQLITE_SELECT = C.SQLITE_SELECT + SQLITE_TRANSACTION = C.SQLITE_TRANSACTION + SQLITE_UPDATE = C.SQLITE_UPDATE + SQLITE_ATTACH = C.SQLITE_ATTACH + SQLITE_DETACH = C.SQLITE_DETACH + SQLITE_ALTER_TABLE = C.SQLITE_ALTER_TABLE + SQLITE_REINDEX = C.SQLITE_REINDEX + SQLITE_ANALYZE = C.SQLITE_ANALYZE + SQLITE_CREATE_VTABLE = C.SQLITE_CREATE_VTABLE + SQLITE_DROP_VTABLE = C.SQLITE_DROP_VTABLE + SQLITE_FUNCTION = C.SQLITE_FUNCTION + SQLITE_SAVEPOINT = C.SQLITE_SAVEPOINT + SQLITE_COPY = C.SQLITE_COPY + SQLITE_RECURSIVE = C.SQLITE_RECURSIVE ) // SQLiteDriver implement sql.Driver. @@ -440,6 +476,20 @@ func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64 } } +// RegisterAuthorizer sets the authorizer for connection. +// +// The parameters to the callback are the operation (one of the constants +// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), and 1 to 3 arguments, +// depending on operation. More details see: +// https://www.sqlite.org/c3ref/c_alter_table.html +func (c *SQLiteConn) RegisterAuthorizer(callback func(int, string, string, string) int) { + if callback == nil { + C.sqlite3_set_authorizer(c.db, nil, nil) + } else { + C.sqlite3_set_authorizer(c.db, (*[0]byte)(C.authorizerTrampoline), unsafe.Pointer(newHandle(c, callback))) + } +} + // RegisterFunc makes a Go function available as a SQLite function. // // The Go function can have arguments of the following types: any diff --git a/sqlite3_test.go b/sqlite3_test.go index 75d8f52..bfed027 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1574,6 +1574,47 @@ func TestUpdateAndTransactionHooks(t *testing.T) { } } +func TestAuthorizer(t *testing.T) { + var authorizerReturn = 0 + + sql.Register("sqlite3_Authorizer", &SQLiteDriver{ + ConnectHook: func(conn *SQLiteConn) error { + conn.RegisterAuthorizer(func(op int, arg1, arg2, arg3 string) int { + return authorizerReturn + }) + return nil + }, + }) + db, err := sql.Open("sqlite3_Authorizer", ":memory:") + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer db.Close() + + statements := []string{ + "create table foo (id integer primary key, name varchar)", + "insert into foo values (9, 'test9')", + "update foo set name = 'test99' where id = 9", + "select * from foo", + } + + authorizerReturn = SQLITE_OK + for _, statement := range statements { + _, err = db.Exec(statement) + if err != nil { + t.Fatalf("No error expected [%v]: %v", statement, err) + } + } + + authorizerReturn = SQLITE_DENY + for _, statement := range statements { + _, err = db.Exec(statement) + if err == nil { + t.Fatalf("Authorizer didn't worked - nil received, but error expected: [%v]", statement) + } + } +} + func TestNilAndEmptyBytes(t *testing.T) { db, err := sql.Open("sqlite3", ":memory:") if err != nil {