forked from mirror/go-sqlcipher
Add support for collation sequences implemented in Go.
This allows Go programs to register custom comparison functions with sqlite, and ORDER BY that comparator.
This commit is contained in:
parent
83772a7051
commit
0430b37250
|
@ -53,6 +53,12 @@ func doneTrampoline(ctx *C.sqlite3_context) {
|
||||||
ai.Done(ctx)
|
ai.Done(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//export compareTrampoline
|
||||||
|
func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.char) C.int {
|
||||||
|
cmp := lookupHandle(handlePtr).(func(string, string) int)
|
||||||
|
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
|
||||||
|
}
|
||||||
|
|
||||||
// Use handles to avoid passing Go pointers to C.
|
// Use handles to avoid passing Go pointers to C.
|
||||||
|
|
||||||
type handleVal struct {
|
type handleVal struct {
|
||||||
|
|
25
sqlite3.go
25
sqlite3.go
|
@ -100,6 +100,8 @@ int _sqlite3_create_function(
|
||||||
}
|
}
|
||||||
|
|
||||||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||||
|
|
||||||
|
int compareTrampoline(void*, int, char*, int, char*);
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
|
@ -313,6 +315,29 @@ func (tx *SQLiteTx) Rollback() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterCollation makes a Go function available as a collation.
|
||||||
|
//
|
||||||
|
// cmp receives two UTF-8 strings, a and b. The result should be 0 if
|
||||||
|
// a==b, -1 if a < b, and +1 if a > b.
|
||||||
|
//
|
||||||
|
// cmp must always return the same result given the same
|
||||||
|
// inputs. Additionally, it must have the following properties for all
|
||||||
|
// strings A, B and C: if A==B then B==A; if A==B and B==C then A==C;
|
||||||
|
// if A<B then B>A; if A<B and B<C then A<C.
|
||||||
|
//
|
||||||
|
// If cmp does not obey these constraints, sqlite3's behavior is
|
||||||
|
// undefined when the collation is used.
|
||||||
|
func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int) error {
|
||||||
|
handle := newHandle(c, cmp)
|
||||||
|
cname := C.CString(name)
|
||||||
|
defer C.free(unsafe.Pointer(cname))
|
||||||
|
rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, unsafe.Pointer(handle), (*[0]byte)(unsafe.Pointer(C.compareTrampoline)))
|
||||||
|
if rv != C.SQLITE_OK {
|
||||||
|
return c.lastError()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterFunc makes a Go function available as a SQLite function.
|
// RegisterFunc makes a Go function available as a SQLite function.
|
||||||
//
|
//
|
||||||
// The Go function can have arguments of the following types: any
|
// The Go function can have arguments of the following types: any
|
||||||
|
|
121
sqlite3_test.go
121
sqlite3_test.go
|
@ -1213,6 +1213,127 @@ func TestFunctionRegistration(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func rot13(r rune) rune {
|
||||||
|
switch {
|
||||||
|
case r >= 'A' && r <= 'Z':
|
||||||
|
return 'A' + (r-'A'+13)%26
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
return 'a' + (r-'a'+13)%26
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollationRegistration(t *testing.T) {
|
||||||
|
collateRot13 := func(a, b string) int {
|
||||||
|
ra, rb := strings.Map(rot13, a), strings.Map(rot13, b)
|
||||||
|
return strings.Compare(ra, rb)
|
||||||
|
}
|
||||||
|
collateRot13Reverse := func(a, b string) int {
|
||||||
|
return collateRot13(b, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql.Register("sqlite3_CollationRegistration", &SQLiteDriver{
|
||||||
|
ConnectHook: func(conn *SQLiteConn) error {
|
||||||
|
if err := conn.RegisterCollation("rot13", collateRot13); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := conn.RegisterCollation("rot13reverse", collateRot13Reverse); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite3_CollationRegistration", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
populate := []string{
|
||||||
|
`CREATE TABLE test (s TEXT)`,
|
||||||
|
`INSERT INTO test VALUES ("aaaa")`,
|
||||||
|
`INSERT INTO test VALUES ("ffff")`,
|
||||||
|
`INSERT INTO test VALUES ("qqqq")`,
|
||||||
|
`INSERT INTO test VALUES ("tttt")`,
|
||||||
|
`INSERT INTO test VALUES ("zzzz")`,
|
||||||
|
}
|
||||||
|
for _, stmt := range populate {
|
||||||
|
if _, err := db.Exec(stmt); err != nil {
|
||||||
|
t.Fatal("Failed to populate test DB:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ops := []struct {
|
||||||
|
query string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"SELECT * FROM test ORDER BY s COLLATE rot13 ASC",
|
||||||
|
[]string{
|
||||||
|
"qqqq",
|
||||||
|
"tttt",
|
||||||
|
"zzzz",
|
||||||
|
"aaaa",
|
||||||
|
"ffff",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SELECT * FROM test ORDER BY s COLLATE rot13 DESC",
|
||||||
|
[]string{
|
||||||
|
"ffff",
|
||||||
|
"aaaa",
|
||||||
|
"zzzz",
|
||||||
|
"tttt",
|
||||||
|
"qqqq",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SELECT * FROM test ORDER BY s COLLATE rot13reverse ASC",
|
||||||
|
[]string{
|
||||||
|
"ffff",
|
||||||
|
"aaaa",
|
||||||
|
"zzzz",
|
||||||
|
"tttt",
|
||||||
|
"qqqq",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SELECT * FROM test ORDER BY s COLLATE rot13reverse DESC",
|
||||||
|
[]string{
|
||||||
|
"qqqq",
|
||||||
|
"tttt",
|
||||||
|
"zzzz",
|
||||||
|
"aaaa",
|
||||||
|
"ffff",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, op := range ops {
|
||||||
|
rows, err := db.Query(op.query)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query %q failed: %s", op.query, err)
|
||||||
|
}
|
||||||
|
got := []string{}
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
var s string
|
||||||
|
if err = rows.Scan(&s); err != nil {
|
||||||
|
t.Fatalf("Reading row for %q: %s", op.query, err)
|
||||||
|
}
|
||||||
|
got = append(got, s)
|
||||||
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
t.Fatalf("Reading rows for %q: %s", op.query, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, op.want) {
|
||||||
|
t.Fatalf("Unexpected output from %q\ngot:\n%s\n\nwant:\n%s", op.query, strings.Join(got, "\n"), strings.Join(op.want, "\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDeclTypes(t *testing.T) {
|
func TestDeclTypes(t *testing.T) {
|
||||||
|
|
||||||
d := SQLiteDriver{}
|
d := SQLiteDriver{}
|
||||||
|
|
Loading…
Reference in New Issue