go-sqlcipher/sqlite3.go

534 lines
12 KiB
Go

package sqlite3
/*
#include <sqlite3.h>
#include <stdlib.h>
#include <string.h>
#ifdef __CYGWIN__
# include <errno.h>
#endif
#ifndef SQLITE_OPEN_READWRITE
# define SQLITE_OPEN_READWRITE 0
#endif
#ifndef SQLITE_OPEN_FULLMUTEX
# define SQLITE_OPEN_FULLMUTEX 0
#endif
static int
_sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) {
#ifdef SQLITE_OPEN_URI
return sqlite3_open_v2(filename, ppDb, flags | SQLITE_OPEN_URI, zVfs);
#else
return sqlite3_open_v2(filename, ppDb, flags, zVfs);
#endif
}
static int
_sqlite3_bind_text(sqlite3_stmt *stmt, int n, char *p, int np) {
return sqlite3_bind_text(stmt, n, p, np, SQLITE_TRANSIENT);
}
static int
_sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
return sqlite3_bind_blob(stmt, n, p, np, SQLITE_TRANSIENT);
}
#include <stdio.h>
#include <stdint.h>
static long
_sqlite3_last_insert_rowid(sqlite3* db) {
return (long) sqlite3_last_insert_rowid(db);
}
static long
_sqlite3_changes(sqlite3* db) {
return (long) sqlite3_changes(db);
}
*/
import "C"
import (
"database/sql"
"database/sql/driver"
"errors"
"io"
"strings"
"time"
"unsafe"
)
// Timestamp formats understood by both this module and SQLite.
// The first format in the slice will be used when saving time values
// into the database. When parsing a string from a timestamp or
// datetime column, the formats are tried in order.
var SQLiteTimestampFormats = []string{
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
"2006-01-02 15:04",
"2006-01-02T15:04",
"2006-01-02",
}
func init() {
sql.Register("sqlite3", &SQLiteDriver{})
}
// Driver struct.
type SQLiteDriver struct {
Extensions []string
ConnectHook func(*SQLiteConn) error
}
// Conn struct.
type SQLiteConn struct {
db *C.sqlite3
}
// Tx struct.
type SQLiteTx struct {
c *SQLiteConn
}
// Stmt struct.
type SQLiteStmt struct {
c *SQLiteConn
s *C.sqlite3_stmt
t string
closed bool
}
// Result struct.
type SQLiteResult struct {
id int64
changes int64
}
// Rows struct.
type SQLiteRows struct {
s *SQLiteStmt
nc int
cols []string
decltype []string
}
// Commit transaction.
func (tx *SQLiteTx) Commit() error {
if err := tx.c.exec("COMMIT"); err != nil {
return err
}
return nil
}
// Rollback transaction.
func (tx *SQLiteTx) Rollback() error {
if err := tx.c.exec("ROLLBACK"); err != nil {
return err
}
return nil
}
// AutoCommit return which currently auto commit or not.
func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0
}
// TODO: Execer & Queryer currently disabled
// https://github.com/mattn/go-sqlite3/issues/82
//// Implements Execer
//func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
// tx, err := c.Begin()
// if err != nil {
// return nil, err
// }
// for {
// s, err := c.Prepare(query)
// if err != nil {
// tx.Rollback()
// return nil, err
// }
// na := s.NumInput()
// res, err := s.Exec(args[:na])
// if err != nil && err != driver.ErrSkip {
// tx.Rollback()
// s.Close()
// return nil, err
// }
// args = args[na:]
// tail := s.(*SQLiteStmt).t
// if tail == "" {
// tx.Commit()
// return res, nil
// }
// s.Close()
// query = tail
// }
//}
//
//// Implements Queryer
//func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
// tx, err := c.Begin()
// if err != nil {
// return nil, err
// }
// for {
// s, err := c.Prepare(query)
// if err != nil {
// tx.Rollback()
// return nil, err
// }
// na := s.NumInput()
// rows, err := s.Query(args[:na])
// if err != nil && err != driver.ErrSkip {
// tx.Rollback()
// s.Close()
// return nil, err
// }
// args = args[na:]
// tail := s.(*SQLiteStmt).t
// if tail == "" {
// tx.Commit()
// return rows, nil
// }
// s.Close()
// query = tail
// }
//}
func (c *SQLiteConn) exec(cmd string) error {
pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd))
rv := C.sqlite3_exec(c.db, pcmd, nil, nil, nil)
if rv != C.SQLITE_OK {
return ErrNo(rv)
}
return nil
}
// Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) {
if err := c.exec("BEGIN"); err != nil {
return nil, err
}
return &SQLiteTx{c}, nil
}
func errorString(err ErrNo) string {
return C.GoString(C.sqlite3_errstr(C.int(err)))
}
// Open database and return a new connection.
// You can specify DSN string with URI filename.
// test.db
// file:test.db?cache=shared&mode=memory
// :memory:
// file::memory:
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
}
var db *C.sqlite3
name := C.CString(dsn)
defer C.free(unsafe.Pointer(name))
rv := C._sqlite3_open_v2(name, &db,
C.SQLITE_OPEN_FULLMUTEX|
C.SQLITE_OPEN_READWRITE|
C.SQLITE_OPEN_CREATE,
nil)
if rv != 0 {
return nil, ErrNo(rv)
}
if db == nil {
return nil, errors.New("sqlite succeeded without returning a database")
}
rv = C.sqlite3_busy_timeout(db, 5000)
if rv != C.SQLITE_OK {
return nil, ErrNo(rv)
}
conn := &SQLiteConn{db}
if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
stmt, err := conn.Prepare("SELECT load_extension(?);")
if err != nil {
return nil, err
}
for _, extension := range d.Extensions {
if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
return nil, err
}
}
if err = stmt.Close(); err != nil {
return nil, err
}
rv = C.sqlite3_enable_load_extension(db, 0)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
}
if d.ConnectHook != nil {
if err := d.ConnectHook(conn); err != nil {
return nil, err
}
}
return conn, nil
}
// Close the connection.
func (c *SQLiteConn) Close() error {
s := C.sqlite3_next_stmt(c.db, nil)
for s != nil {
C.sqlite3_finalize(s)
s = C.sqlite3_next_stmt(c.db, nil)
}
rv := C.sqlite3_close(c.db)
if rv != C.SQLITE_OK {
return ErrNo(rv)
}
c.db = nil
return nil
}
// Prepare query string. Return a new statement.
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
pquery := C.CString(query)
defer C.free(unsafe.Pointer(pquery))
var s *C.sqlite3_stmt
var tail *C.char
rv := C.sqlite3_prepare_v2(c.db, pquery, -1, &s, &tail)
if rv != C.SQLITE_OK {
return nil, ErrNo(rv)
}
var t string
if tail != nil && C.strlen(tail) > 0 {
t = strings.TrimSpace(C.GoString(tail))
}
return &SQLiteStmt{c: c, s: s, t: t}, nil
}
// Close the statement.
func (s *SQLiteStmt) Close() error {
if s.closed {
return nil
}
s.closed = true
if s.c == nil || s.c.db == nil {
return errors.New("sqlite statement with already closed database connection")
}
rv := C.sqlite3_finalize(s.s)
if rv != C.SQLITE_OK {
return ErrNo(rv)
}
return nil
}
// Return a number of parameters.
func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s))
}
func (s *SQLiteStmt) bind(args []driver.Value) error {
rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return ErrNo(rv)
}
for i, v := range args {
n := C.int(i + 1)
switch v := v.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
case string:
if len(v) == 0 {
b := []byte{0}
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
} else {
b := []byte(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
case int:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case int32:
rv = C.sqlite3_bind_int(s.s, n, C.int(v))
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case byte:
rv = C.sqlite3_bind_int(s.s, n, C.int(v))
case bool:
if bool(v) {
rv = C.sqlite3_bind_int(s.s, n, 1)
} else {
rv = C.sqlite3_bind_int(s.s, n, 0)
}
case float32:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
var p *byte
if len(v) > 0 {
p = &v[0]
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
case time.Time:
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
if rv != C.SQLITE_OK {
return ErrNo(rv)
}
}
return nil
}
// Query the statment with arguments. Return records.
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
if err := s.bind(args); err != nil {
return nil, err
}
return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil}, nil
}
// Return last inserted ID.
func (r *SQLiteResult) LastInsertId() (int64, error) {
return r.id, nil
}
// Return how many rows affected.
func (r *SQLiteResult) RowsAffected() (int64, error) {
return r.changes, nil
}
// Execute the statement with arguments. Return result object.
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
if err := s.bind(args); err != nil {
return nil, err
}
rv := C.sqlite3_step(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return nil, ErrNo(rv)
}
res := &SQLiteResult{
int64(C._sqlite3_last_insert_rowid(s.c.db)),
int64(C._sqlite3_changes(s.c.db)),
}
return res, nil
}
// Close the rows.
func (rc *SQLiteRows) Close() error {
if rc.s.closed {
return nil
}
rv := C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return ErrNo(rv)
}
return nil
}
// Return column names.
func (rc *SQLiteRows) Columns() []string {
if rc.nc != len(rc.cols) {
rc.cols = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
}
}
return rc.cols
}
// Move cursor to next.
func (rc *SQLiteRows) Next(dest []driver.Value) error {
rv := C.sqlite3_step(rc.s.s)
if rv == C.SQLITE_DONE {
return io.EOF
}
if rv != C.SQLITE_ROW {
rv = C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return ErrNo(rv)
}
return nil
}
if rc.decltype == nil {
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
}
}
for i := range dest {
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
case C.SQLITE_INTEGER:
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
switch rc.decltype[i] {
case "timestamp", "datetime":
dest[i] = time.Unix(val, 0)
case "boolean":
dest[i] = val > 0
default:
dest[i] = val
}
case C.SQLITE_FLOAT:
dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i)))
case C.SQLITE_BLOB:
p := C.sqlite3_column_blob(rc.s.s, C.int(i))
if p == nil {
dest[i] = nil
continue
}
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
switch dest[i].(type) {
case sql.RawBytes:
dest[i] = (*[1 << 30]byte)(unsafe.Pointer(p))[0:n]
default:
slice := make([]byte, n)
copy(slice[:], (*[1 << 30]byte)(unsafe.Pointer(p))[0:n])
dest[i] = slice
}
case C.SQLITE_NULL:
dest[i] = nil
case C.SQLITE_TEXT:
var err error
s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))))
switch rc.decltype[i] {
case "timestamp", "datetime":
for _, format := range SQLiteTimestampFormats {
if dest[i], err = time.Parse(format, s); err == nil {
break
}
}
if err != nil {
// The column is a time value, so return the zero time on parse failure.
dest[i] = time.Time{}
}
default:
dest[i] = s
}
}
}
return nil
}