forked from mirror/go-sqlcipher
Merge pull request #423 from danderson/master
Add support for collation sequences implemented in Go.
This commit is contained in:
commit
05548ff555
|
@ -53,6 +53,12 @@ func doneTrampoline(ctx *C.sqlite3_context) {
|
|||
ai.Done(ctx)
|
||||
}
|
||||
|
||||
//export compareTrampoline
|
||||
func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.char) C.int {
|
||||
cmp := lookupHandle(handlePtr).(func(string, string) int)
|
||||
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
|
||||
}
|
||||
|
||||
//export commitHookTrampoline
|
||||
func commitHookTrampoline(handle uintptr) int {
|
||||
callback := lookupHandle(handle).(func() int)
|
||||
|
|
25
sqlite3.go
25
sqlite3.go
|
@ -100,6 +100,8 @@ int _sqlite3_create_function(
|
|||
}
|
||||
|
||||
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
|
||||
|
||||
int compareTrampoline(void*, int, char*, int, char*);
|
||||
int commitHookTrampoline(void*);
|
||||
void rollbackHookTrampoline(void*);
|
||||
void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
|
||||
|
@ -326,6 +328,29 @@ func (tx *SQLiteTx) Rollback() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// RegisterCollation makes a Go function available as a collation.
|
||||
//
|
||||
// cmp receives two UTF-8 strings, a and b. The result should be 0 if
|
||||
// a==b, -1 if a < b, and +1 if a > b.
|
||||
//
|
||||
// cmp must always return the same result given the same
|
||||
// inputs. Additionally, it must have the following properties for all
|
||||
// strings A, B and C: if A==B then B==A; if A==B and B==C then A==C;
|
||||
// if A<B then B>A; if A<B and B<C then A<C.
|
||||
//
|
||||
// If cmp does not obey these constraints, sqlite3's behavior is
|
||||
// undefined when the collation is used.
|
||||
func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int) error {
|
||||
handle := newHandle(c, cmp)
|
||||
cname := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cname))
|
||||
rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, unsafe.Pointer(handle), (*[0]byte)(unsafe.Pointer(C.compareTrampoline)))
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterCommitHook sets the commit hook for a connection.
|
||||
//
|
||||
// If the callback returns non-zero the transaction will become a rollback.
|
||||
|
|
121
sqlite3_test.go
121
sqlite3_test.go
|
@ -1232,6 +1232,127 @@ func TestFunctionRegistration(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func rot13(r rune) rune {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
return 'A' + (r-'A'+13)%26
|
||||
case r >= 'a' && r <= 'z':
|
||||
return 'a' + (r-'a'+13)%26
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestCollationRegistration(t *testing.T) {
|
||||
collateRot13 := func(a, b string) int {
|
||||
ra, rb := strings.Map(rot13, a), strings.Map(rot13, b)
|
||||
return strings.Compare(ra, rb)
|
||||
}
|
||||
collateRot13Reverse := func(a, b string) int {
|
||||
return collateRot13(b, a)
|
||||
}
|
||||
|
||||
sql.Register("sqlite3_CollationRegistration", &SQLiteDriver{
|
||||
ConnectHook: func(conn *SQLiteConn) error {
|
||||
if err := conn.RegisterCollation("rot13", collateRot13); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.RegisterCollation("rot13reverse", collateRot13Reverse); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
db, err := sql.Open("sqlite3_CollationRegistration", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
populate := []string{
|
||||
`CREATE TABLE test (s TEXT)`,
|
||||
`INSERT INTO test VALUES ("aaaa")`,
|
||||
`INSERT INTO test VALUES ("ffff")`,
|
||||
`INSERT INTO test VALUES ("qqqq")`,
|
||||
`INSERT INTO test VALUES ("tttt")`,
|
||||
`INSERT INTO test VALUES ("zzzz")`,
|
||||
}
|
||||
for _, stmt := range populate {
|
||||
if _, err := db.Exec(stmt); err != nil {
|
||||
t.Fatal("Failed to populate test DB:", err)
|
||||
}
|
||||
}
|
||||
|
||||
ops := []struct {
|
||||
query string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13 ASC",
|
||||
[]string{
|
||||
"qqqq",
|
||||
"tttt",
|
||||
"zzzz",
|
||||
"aaaa",
|
||||
"ffff",
|
||||
},
|
||||
},
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13 DESC",
|
||||
[]string{
|
||||
"ffff",
|
||||
"aaaa",
|
||||
"zzzz",
|
||||
"tttt",
|
||||
"qqqq",
|
||||
},
|
||||
},
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13reverse ASC",
|
||||
[]string{
|
||||
"ffff",
|
||||
"aaaa",
|
||||
"zzzz",
|
||||
"tttt",
|
||||
"qqqq",
|
||||
},
|
||||
},
|
||||
{
|
||||
"SELECT * FROM test ORDER BY s COLLATE rot13reverse DESC",
|
||||
[]string{
|
||||
"qqqq",
|
||||
"tttt",
|
||||
"zzzz",
|
||||
"aaaa",
|
||||
"ffff",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, op := range ops {
|
||||
rows, err := db.Query(op.query)
|
||||
if err != nil {
|
||||
t.Fatalf("Query %q failed: %s", op.query, err)
|
||||
}
|
||||
got := []string{}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var s string
|
||||
if err = rows.Scan(&s); err != nil {
|
||||
t.Fatalf("Reading row for %q: %s", op.query, err)
|
||||
}
|
||||
got = append(got, s)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("Reading rows for %q: %s", op.query, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, op.want) {
|
||||
t.Fatalf("Unexpected output from %q\ngot:\n%s\n\nwant:\n%s", op.query, strings.Join(got, "\n"), strings.Join(op.want, "\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeclTypes(t *testing.T) {
|
||||
|
||||
d := SQLiteDriver{}
|
||||
|
|
Loading…
Reference in New Issue