go-sqlcipher/sqlite3.go

434 lines
10 KiB
Go
Raw Normal View History

2013-08-13 17:10:30 +04:00
package sqlite3
2011-11-11 16:36:22 +04:00
/*
#include <sqlite3.h>
#include <stdlib.h>
#include <string.h>
2013-04-06 18:06:48 +04:00
#ifndef SQLITE_OPEN_READWRITE
# define SQLITE_OPEN_READWRITE 0
#endif
#ifndef SQLITE_OPEN_FULLMUTEX
# define SQLITE_OPEN_FULLMUTEX 0
#endif
static int
_sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) {
#ifdef SQLITE_OPEN_URI
return sqlite3_open_v2(filename, ppDb, flags | SQLITE_OPEN_URI, zVfs);
#else
return sqlite3_open_v2(filename, ppDb, flags, zVfs);
#endif
}
2011-11-11 16:36:22 +04:00
static int
_sqlite3_bind_text(sqlite3_stmt *stmt, int n, char *p, int np) {
return sqlite3_bind_text(stmt, n, p, np, SQLITE_TRANSIENT);
}
static int
_sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
return sqlite3_bind_blob(stmt, n, p, np, SQLITE_TRANSIENT);
}
#include <stdio.h>
#include <stdint.h>
static long
_sqlite3_last_insert_rowid(sqlite3* db) {
return (long) sqlite3_last_insert_rowid(db);
}
static long
_sqlite3_changes(sqlite3* db) {
return (long) sqlite3_changes(db);
}
2011-11-11 16:36:22 +04:00
*/
import "C"
import (
2012-01-20 20:44:24 +04:00
"database/sql"
"database/sql/driver"
"errors"
"io"
"strings"
"time"
2011-11-11 16:36:22 +04:00
"unsafe"
)
// Timestamp formats understood by both this module and SQLite.
// The first format in the slice will be used when saving time values
// into the database. When parsing a string from a timestamp or
// datetime column, the formats are tried in order.
var SQLiteTimestampFormats = []string{
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
"2006-01-02 15:04",
"2006-01-02T15:04",
"2006-01-02",
}
2011-11-11 16:36:22 +04:00
func init() {
2013-08-23 08:58:54 +04:00
sql.Register("sqlite3", &SQLiteDriver{false, nil})
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Driver struct.
2011-11-11 16:36:22 +04:00
type SQLiteDriver struct {
2013-08-23 09:26:33 +04:00
EnableLoadExtension bool
ConnectHook func(*SQLiteConn) error
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Conn struct.
2011-11-11 16:36:22 +04:00
type SQLiteConn struct {
db *C.sqlite3
}
2013-01-31 11:48:30 +04:00
// Tx struct.
2011-11-11 16:36:22 +04:00
type SQLiteTx struct {
c *SQLiteConn
}
2013-01-31 11:48:30 +04:00
// Stmt struct.
type SQLiteStmt struct {
c *SQLiteConn
s *C.sqlite3_stmt
t string
closed bool
}
// Result struct.
type SQLiteResult struct {
2013-05-11 16:45:48 +04:00
id int64
changes int64
2013-01-31 11:48:30 +04:00
}
// Rows struct.
type SQLiteRows struct {
s *SQLiteStmt
nc int
cols []string
decltype []string
}
// Commit transaction.
2011-11-11 16:36:22 +04:00
func (tx *SQLiteTx) Commit() error {
if err := tx.c.exec("COMMIT"); err != nil {
return err
}
return nil
}
2013-01-31 11:48:30 +04:00
// Rollback transaction.
2011-11-11 16:36:22 +04:00
func (tx *SQLiteTx) Rollback() error {
if err := tx.c.exec("ROLLBACK"); err != nil {
return err
}
return nil
}
2013-08-23 09:11:15 +04:00
func (c *SQLiteConn) AutoCommit() bool {
2013-08-23 09:26:33 +04:00
return int(C.sqlite3_get_autocommit(c.db)) != 0
2013-08-23 09:11:15 +04:00
}
2011-11-11 16:36:22 +04:00
func (c *SQLiteConn) exec(cmd string) error {
pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd))
rv := C.sqlite3_exec(c.db, pcmd, nil, nil, nil)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
2011-11-11 16:36:22 +04:00
}
return nil
}
2013-01-31 11:48:30 +04:00
// Begin transaction.
2011-11-11 16:36:22 +04:00
func (c *SQLiteConn) Begin() (driver.Tx, error) {
if err := c.exec("BEGIN"); err != nil {
return nil, err
}
return &SQLiteTx{c}, nil
}
2013-01-31 11:48:30 +04:00
// Open database and return a new connection.
// You can specify DSN string with URI filename.
// test.db
// file:test.db?cache=shared&mode=memory
// :memory:
// file::memory:
2011-11-11 16:36:22 +04:00
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
}
var db *C.sqlite3
name := C.CString(dsn)
defer C.free(unsafe.Pointer(name))
rv := C._sqlite3_open_v2(name, &db,
2011-11-11 16:36:22 +04:00
C.SQLITE_OPEN_FULLMUTEX|
C.SQLITE_OPEN_READWRITE|
C.SQLITE_OPEN_CREATE,
2011-11-11 16:36:22 +04:00
nil)
if rv != 0 {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
2011-11-11 16:36:22 +04:00
}
if db == nil {
return nil, errors.New("sqlite succeeded without returning a database")
}
2012-03-12 09:20:55 +04:00
rv = C.sqlite3_busy_timeout(db, 5000)
2012-03-12 09:20:55 +04:00
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
2012-03-12 09:20:55 +04:00
}
2013-08-23 09:26:33 +04:00
enableLoadExtension := 0
if d.EnableLoadExtension {
enableLoadExtension = 1
}
2013-08-23 09:26:33 +04:00
rv = C.sqlite3_enable_load_extension(db, C.int(enableLoadExtension))
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
2013-08-23 08:58:54 +04:00
conn := &SQLiteConn{db}
if d.ConnectHook != nil {
if err := d.ConnectHook(conn); err != nil {
return nil, err
}
2013-08-23 08:58:54 +04:00
}
return conn, nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Close the connection.
2011-11-11 16:36:22 +04:00
func (c *SQLiteConn) Close() error {
s := C.sqlite3_next_stmt(c.db, nil)
for s != nil {
C.sqlite3_finalize(s)
s = C.sqlite3_next_stmt(c.db, nil)
2011-11-11 16:36:22 +04:00
}
rv := C.sqlite3_close(c.db)
if rv != C.SQLITE_OK {
return errors.New("error while closing sqlite database connection")
2011-11-11 16:36:22 +04:00
}
c.db = nil
return nil
}
2013-01-31 11:48:30 +04:00
// Prepare query string. Return a new statement.
2011-11-11 16:36:22 +04:00
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
pquery := C.CString(query)
defer C.free(unsafe.Pointer(pquery))
var s *C.sqlite3_stmt
var perror *C.char
rv := C.sqlite3_prepare_v2(c.db, pquery, -1, &s, &perror)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
2011-11-11 16:36:22 +04:00
}
var t string
if perror != nil && C.strlen(perror) > 0 {
t = C.GoString(perror)
}
2012-02-20 11:14:49 +04:00
return &SQLiteStmt{c: c, s: s, t: t}, nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Close the statement.
2011-11-11 16:36:22 +04:00
func (s *SQLiteStmt) Close() error {
2012-02-20 11:14:49 +04:00
if s.closed {
return nil
}
s.closed = true
2013-02-13 05:32:40 +04:00
if s.c == nil || s.c.db == nil {
return errors.New("sqlite statement with already closed database connection")
}
2011-11-11 16:38:53 +04:00
rv := C.sqlite3_finalize(s.s)
2011-11-11 16:36:22 +04:00
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
2011-11-11 16:36:22 +04:00
}
return nil
}
2013-01-31 11:48:30 +04:00
// Return a number of parameters.
2011-11-11 16:36:22 +04:00
func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s))
}
2012-02-20 11:14:49 +04:00
func (s *SQLiteStmt) bind(args []driver.Value) error {
2011-11-11 16:36:22 +04:00
rv := C.sqlite3_reset(s.s)
2011-11-11 16:38:53 +04:00
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
2011-11-11 16:36:22 +04:00
}
for i, v := range args {
2011-11-11 16:38:53 +04:00
n := C.int(i + 1)
2011-11-11 16:36:22 +04:00
switch v := v.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
case string:
if len(v) == 0 {
b := []byte{0}
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
} else {
b := []byte(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
2011-11-11 16:36:22 +04:00
case int:
2013-04-09 10:18:47 +04:00
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
2013-04-08 12:38:30 +04:00
case int32:
rv = C.sqlite3_bind_int(s.s, n, C.int(v))
2011-11-11 16:36:22 +04:00
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case byte:
rv = C.sqlite3_bind_int(s.s, n, C.int(v))
case bool:
if bool(v) {
2012-03-12 09:20:55 +04:00
rv = C.sqlite3_bind_int(s.s, n, 1)
2011-11-11 16:36:22 +04:00
} else {
rv = C.sqlite3_bind_int(s.s, n, 0)
}
case float32:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
var p *byte
if len(v) > 0 {
p = &v[0]
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
case time.Time:
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
2011-11-11 16:36:22 +04:00
}
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
2011-11-11 16:36:22 +04:00
}
}
return nil
}
2013-01-31 11:48:30 +04:00
// Query the statment with arguments. Return records.
2012-02-20 11:14:49 +04:00
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
2011-11-11 16:36:22 +04:00
if err := s.bind(args); err != nil {
return nil, err
}
return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil}, nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Return last inserted ID.
2011-11-14 17:10:13 +04:00
func (r *SQLiteResult) LastInsertId() (int64, error) {
return r.id, nil
2011-11-14 17:10:13 +04:00
}
2013-01-31 11:48:30 +04:00
// Return how many rows affected.
2011-11-14 17:10:13 +04:00
func (r *SQLiteResult) RowsAffected() (int64, error) {
return r.changes, nil
2011-11-14 17:10:13 +04:00
}
2013-01-31 11:48:30 +04:00
// Execute the statement with arguments. Return result object.
2012-02-20 11:14:49 +04:00
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
2011-11-11 16:36:22 +04:00
if err := s.bind(args); err != nil {
return nil, err
}
rv := C.sqlite3_step(s.s)
2011-11-11 16:38:53 +04:00
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
2011-11-11 16:36:22 +04:00
}
2013-05-11 16:45:48 +04:00
res := &SQLiteResult{
int64(C._sqlite3_last_insert_rowid(s.c.db)),
int64(C._sqlite3_changes(s.c.db)),
}
return res, nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Close the rows.
2011-11-11 16:36:22 +04:00
func (rc *SQLiteRows) Close() error {
2012-03-12 09:20:55 +04:00
rv := C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return errors.New(C.GoString(C.sqlite3_errmsg(rc.s.c.db)))
2012-03-12 09:20:55 +04:00
}
return nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Return column names.
2011-11-11 16:36:22 +04:00
func (rc *SQLiteRows) Columns() []string {
if rc.nc != len(rc.cols) {
rc.cols = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
}
}
return rc.cols
}
2013-01-31 11:48:30 +04:00
// Move cursor to next.
2012-02-20 11:14:49 +04:00
func (rc *SQLiteRows) Next(dest []driver.Value) error {
2011-11-11 16:36:22 +04:00
rv := C.sqlite3_step(rc.s.s)
if rv == C.SQLITE_DONE {
return io.EOF
}
2011-11-11 16:36:22 +04:00
if rv != C.SQLITE_ROW {
return errors.New(C.GoString(C.sqlite3_errmsg(rc.s.c.db)))
2011-11-11 16:36:22 +04:00
}
if rc.decltype == nil {
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
}
}
2011-11-11 16:36:22 +04:00
for i := range dest {
2011-11-11 16:38:53 +04:00
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
case C.SQLITE_INTEGER:
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
switch rc.decltype[i] {
2012-12-30 02:20:27 +04:00
case "timestamp", "datetime":
dest[i] = time.Unix(val, 0)
case "boolean":
dest[i] = val > 0
default:
dest[i] = val
}
2011-11-11 16:38:53 +04:00
case C.SQLITE_FLOAT:
dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i)))
case C.SQLITE_BLOB:
p := C.sqlite3_column_blob(rc.s.s, C.int(i))
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
switch dest[i].(type) {
case sql.RawBytes:
dest[i] = (*[1 << 30]byte)(unsafe.Pointer(p))[0:n]
default:
slice := make([]byte, n)
copy(slice[:], (*[1 << 30]byte)(unsafe.Pointer(p))[0:n])
dest[i] = slice
}
2011-11-11 16:38:53 +04:00
case C.SQLITE_NULL:
dest[i] = nil
case C.SQLITE_TEXT:
var err error
s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))))
2012-12-30 02:20:27 +04:00
switch rc.decltype[i] {
case "timestamp", "datetime":
for _, format := range SQLiteTimestampFormats {
if dest[i], err = time.Parse(format, s); err == nil {
2012-12-30 02:20:27 +04:00
break
}
}
if err != nil {
2012-12-30 02:20:27 +04:00
// The column is a time value, so return the zero time on parse failure.
dest[i] = time.Time{}
}
default:
dest[i] = s
}
2011-11-11 16:36:22 +04:00
}
}
return nil
}