forked from mirror/go-sqlcipher
Add support for collation sequences implemented in Go.
This allows Go programs to register custom comparison functions with sqlite, and ORDER BY that comparator.
This commit is contained in:
parent
83772a7051
commit
0430b37250
|
@ -53,6 +53,12 @@ func doneTrampoline(ctx *C.sqlite3_context) {
|
|||
ai.Done(ctx)
|
||||
}
|
||||
|
||||
//export compareTrampoline
|
||||
func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.char) C.int {
|
||||
cmp := lookupHandle(handlePtr).(func(string, string) int)
|
||||
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
|
||||
}
|
||||
|
||||
// Use handles to avoid passing Go pointers to C.
|
||||
|
||||
type handleVal struct {
|
||||
|
|
25
sqlite3.go
25
sqlite3.go
|
@ -100,6 +100,8 @@ int _sqlite3_create_function(
|
|||
}
|
||||
|
||||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||
|
||||
int compareTrampoline(void*, int, char*, int, char*);
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
|
@ -313,6 +315,29 @@ func (tx *SQLiteTx) Rollback() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// RegisterCollation makes a Go function available as a collation.
|
||||
//
|
||||
// cmp receives two UTF-8 strings, a and b. The result should be 0 if
|
||||
// a==b, -1 if a < b, and +1 if a > b.
|
||||
//
|
||||
// cmp must always return the same result given the same
|
||||
// inputs. Additionally, it must have the following properties for all
|
||||
// strings A, B and C: if A==B then B==A; if A==B and B==C then A==C;
|
||||
// if A<B then B>A; if A<B and B<C then A<C.
|
||||
//
|
||||
// If cmp does not obey these constraints, sqlite3's behavior is
|
||||
// undefined when the collation is used.
|
||||
func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int) error {
|
||||
handle := newHandle(c, cmp)
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, unsafe.Pointer(handle), (*[0]byte)(unsafe.Pointer(C.compareTrampoline)))
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterFunc makes a Go function available as a SQLite function.
|
||||
//
|
||||
// The Go function can have arguments of the following types: any
|
||||
|
|
121
sqlite3_test.go
121
sqlite3_test.go
|
@ -1213,6 +1213,127 @@ func TestFunctionRegistration(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func rot13(r rune) rune {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
return 'A' + (r-'A'+13)%26
|
||||
case r >= 'a' && r <= 'z':
|
||||
return 'a' + (r-'a'+13)%26
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestCollationRegistration(t *testing.T) {
|
||||
collateRot13 := func(a, b string) int {
|
||||
ra, rb := strings.Map(rot13, a), strings.Map(rot13, b)
|
||||
return strings.Compare(ra, rb)
|
||||
}
|
||||
collateRot13Reverse := func(a, b string) int {
|
||||
return collateRot13(b, a)
|
||||
}
|
||||
|
||||
sql.Register("sqlite3_CollationRegistration", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
if err := conn.RegisterCollation("rot13", collateRot13); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.RegisterCollation("rot13reverse", collateRot13Reverse); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite3_CollationRegistration", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
populate := []string{
|
||||
`CREATE TABLE test (s TEXT)`,
|
||||
`INSERT INTO test VALUES ("aaaa")`,
|
||||
`INSERT INTO test VALUES ("ffff")`,
|
||||
`INSERT INTO test VALUES ("qqqq")`,
|
||||
`INSERT INTO test VALUES ("tttt")`,
|
||||
`INSERT INTO test VALUES ("zzzz")`,
|
||||
}
|
||||
for _, stmt := range populate {
|
||||
if _, err := db.Exec(stmt); err != nil {
|
||||
t.Fatal("Failed to populate test DB:", err)
|
||||
}
|
||||
}
|
||||
|
||||
ops := []struct {
|
||||
query string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13 ASC",
|
||||
[]string{
|
||||
"qqqq",
|
||||
"tttt",
|
||||
"zzzz",
|
||||
"aaaa",
|
||||
"ffff",
|
||||
},
|
||||
},
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13 DESC",
|
||||
[]string{
|
||||
"ffff",
|
||||
"aaaa",
|
||||
"zzzz",
|
||||
"tttt",
|
||||
"qqqq",
|
||||
},
|
||||
},
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13reverse ASC",
|
||||
[]string{
|
||||
"ffff",
|
||||
"aaaa",
|
||||
"zzzz",
|
||||
"tttt",
|
||||
"qqqq",
|
||||
},
|
||||
},
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13reverse DESC",
|
||||
[]string{
|
||||
"qqqq",
|
||||
"tttt",
|
||||
"zzzz",
|
||||
"aaaa",
|
||||
"ffff",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, op := range ops {
|
||||
rows, err := db.Query(op.query)
|
||||
if err != nil {
|
||||
t.Fatalf("Query %q failed: %s", op.query, err)
|
||||
}
|
||||
got := []string{}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var s string
|
||||
if err = rows.Scan(&s); err != nil {
|
||||
t.Fatalf("Reading row for %q: %s", op.query, err)
|
||||
}
|
||||
got = append(got, s)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("Reading rows for %q: %s", op.query, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, op.want) {
|
||||
t.Fatalf("Unexpected output from %q\ngot:\n%s\n\nwant:\n%s", op.query, strings.Join(got, "\n"), strings.Join(op.want, "\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeclTypes(t *testing.T) {
|
||||
|
||||
d := SQLiteDriver{}
|
||||
|
|
Loading…
Reference in New Issue