go-sqlcipher/sqlite3.go

531 lines
12 KiB
Go
Raw Normal View History

2014-08-18 11:56:31 +04:00
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
2014-08-18 12:00:59 +04:00
2013-08-13 17:10:30 +04:00
package sqlite3
2011-11-11 16:36:22 +04:00
/*
#include <sqlite3.h>
#include <stdlib.h>
#include <string.h>
#ifdef __CYGWIN__
# include <errno.h>
#endif
2013-04-06 18:06:48 +04:00
#ifndef SQLITE_OPEN_READWRITE
# define SQLITE_OPEN_READWRITE 0
#endif
#ifndef SQLITE_OPEN_FULLMUTEX
# define SQLITE_OPEN_FULLMUTEX 0
#endif
static int
_sqlite3_open_v2(const char *filename, sqlite3 **ppDb, int flags, const char *zVfs) {
#ifdef SQLITE_OPEN_URI
return sqlite3_open_v2(filename, ppDb, flags | SQLITE_OPEN_URI, zVfs);
#else
return sqlite3_open_v2(filename, ppDb, flags, zVfs);
#endif
}
2011-11-11 16:36:22 +04:00
static int
_sqlite3_bind_text(sqlite3_stmt *stmt, int n, char *p, int np) {
return sqlite3_bind_text(stmt, n, p, np, SQLITE_TRANSIENT);
}
static int
_sqlite3_bind_blob(sqlite3_stmt *stmt, int n, void *p, int np) {
return sqlite3_bind_blob(stmt, n, p, np, SQLITE_TRANSIENT);
}
#include <stdio.h>
#include <stdint.h>
static long
_sqlite3_last_insert_rowid(sqlite3* db) {
return (long) sqlite3_last_insert_rowid(db);
}
static long
_sqlite3_changes(sqlite3* db) {
return (long) sqlite3_changes(db);
}
2011-11-11 16:36:22 +04:00
*/
import "C"
import (
2012-01-20 20:44:24 +04:00
"database/sql"
"database/sql/driver"
"errors"
"io"
"strings"
"time"
2011-11-11 16:36:22 +04:00
"unsafe"
)
// Timestamp formats understood by both this module and SQLite.
// The first format in the slice will be used when saving time values
// into the database. When parsing a string from a timestamp or
// datetime column, the formats are tried in order.
var SQLiteTimestampFormats = []string{
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
"2006-01-02 15:04",
"2006-01-02T15:04",
"2006-01-02",
2014-07-08 13:56:44 +04:00
"2006-01-02 15:04:05-07:00",
}
2011-11-11 16:36:22 +04:00
func init() {
sql.Register("sqlite3", &SQLiteDriver{})
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Driver struct.
2011-11-11 16:36:22 +04:00
type SQLiteDriver struct {
Extensions []string
ConnectHook func(*SQLiteConn) error
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Conn struct.
2011-11-11 16:36:22 +04:00
type SQLiteConn struct {
db *C.sqlite3
}
2013-01-31 11:48:30 +04:00
// Tx struct.
2011-11-11 16:36:22 +04:00
type SQLiteTx struct {
c *SQLiteConn
}
2013-01-31 11:48:30 +04:00
// Stmt struct.
type SQLiteStmt struct {
c *SQLiteConn
s *C.sqlite3_stmt
t string
closed bool
cls bool
2013-01-31 11:48:30 +04:00
}
// Result struct.
type SQLiteResult struct {
2013-05-11 16:45:48 +04:00
id int64
changes int64
2013-01-31 11:48:30 +04:00
}
// Rows struct.
type SQLiteRows struct {
s *SQLiteStmt
nc int
cols []string
decltype []string
cls bool
2013-01-31 11:48:30 +04:00
}
// Commit transaction.
2011-11-11 16:36:22 +04:00
func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec("COMMIT")
return err
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Rollback transaction.
2011-11-11 16:36:22 +04:00
func (tx *SQLiteTx) Rollback() error {
_, err := tx.c.exec("ROLLBACK")
return err
2011-11-11 16:36:22 +04:00
}
2013-09-09 05:44:44 +04:00
// AutoCommit return which currently auto commit or not.
2013-08-23 09:11:15 +04:00
func (c *SQLiteConn) AutoCommit() bool {
2013-08-23 09:26:33 +04:00
return int(C.sqlite3_get_autocommit(c.db)) != 0
2013-08-23 09:11:15 +04:00
}
func (c *SQLiteConn) lastError() Error {
2014-04-01 16:01:19 +04:00
return Error{
Code: ErrNo(C.sqlite3_errcode(c.db)),
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)),
err: C.GoString(C.sqlite3_errmsg(c.db)),
}
}
// Implements Execer
func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
2014-06-25 22:54:09 +04:00
if len(args) == 0 {
return c.exec(query)
}
for {
s, err := c.Prepare(query)
if err != nil {
return nil, err
}
var res driver.Result
if s.(*SQLiteStmt).s != nil {
na := s.NumInput()
res, err = s.Exec(args[:na])
if err != nil && err != driver.ErrSkip {
s.Close()
return nil, err
}
args = args[na:]
}
tail := s.(*SQLiteStmt).t
s.Close()
2014-06-25 22:54:09 +04:00
if tail == "" {
return res, nil
}
query = tail
}
}
// Implements Queryer
func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
2014-06-25 22:54:09 +04:00
for {
s, err := c.Prepare(query)
if err != nil {
return nil, err
}
s.(*SQLiteStmt).cls = true
2014-06-25 22:54:09 +04:00
na := s.NumInput()
rows, err := s.Query(args[:na])
if err != nil && err != driver.ErrSkip {
s.Close()
return nil, err
}
args = args[na:]
tail := s.(*SQLiteStmt).t
if tail == "" {
return rows, nil
}
s.Close()
query = tail
}
}
func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
2011-11-11 16:36:22 +04:00
pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd))
rv := C.sqlite3_exec(c.db, pcmd, nil, nil, nil)
if rv != C.SQLITE_OK {
return nil, c.lastError()
2011-11-11 16:36:22 +04:00
}
return &SQLiteResult{
int64(C._sqlite3_last_insert_rowid(c.db)),
int64(C._sqlite3_changes(c.db)),
}, nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Begin transaction.
2011-11-11 16:36:22 +04:00
func (c *SQLiteConn) Begin() (driver.Tx, error) {
if _, err := c.exec("BEGIN"); err != nil {
2011-11-11 16:36:22 +04:00
return nil, err
}
return &SQLiteTx{c}, nil
}
func errorString(err Error) string {
return C.GoString(C.sqlite3_errstr(C.int(err.Code)))
}
2013-01-31 11:48:30 +04:00
// Open database and return a new connection.
// You can specify DSN string with URI filename.
// test.db
// file:test.db?cache=shared&mode=memory
// :memory:
// file::memory:
2011-11-11 16:36:22 +04:00
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
}
var db *C.sqlite3
name := C.CString(dsn)
defer C.free(unsafe.Pointer(name))
rv := C._sqlite3_open_v2(name, &db,
2011-11-11 16:36:22 +04:00
C.SQLITE_OPEN_FULLMUTEX|
C.SQLITE_OPEN_READWRITE|
C.SQLITE_OPEN_CREATE,
2011-11-11 16:36:22 +04:00
nil)
if rv != 0 {
return nil, Error{Code: ErrNo(rv)}
2011-11-11 16:36:22 +04:00
}
if db == nil {
return nil, errors.New("sqlite succeeded without returning a database")
}
2012-03-12 09:20:55 +04:00
rv = C.sqlite3_busy_timeout(db, 5000)
2012-03-12 09:20:55 +04:00
if rv != C.SQLITE_OK {
return nil, Error{Code: ErrNo(rv)}
2012-03-12 09:20:55 +04:00
}
2013-08-23 08:58:54 +04:00
conn := &SQLiteConn{db}
if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
stmt, err := conn.Prepare("SELECT load_extension(?);")
if err != nil {
return nil, err
}
for _, extension := range d.Extensions {
if _, err = stmt.Exec([]driver.Value{extension}); err != nil {
return nil, err
}
}
if err = stmt.Close(); err != nil {
return nil, err
}
rv = C.sqlite3_enable_load_extension(db, 0)
if rv != C.SQLITE_OK {
return nil, errors.New(C.GoString(C.sqlite3_errmsg(db)))
}
}
2013-08-23 08:58:54 +04:00
if d.ConnectHook != nil {
if err := d.ConnectHook(conn); err != nil {
return nil, err
}
2013-08-23 08:58:54 +04:00
}
return conn, nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Close the connection.
2011-11-11 16:36:22 +04:00
func (c *SQLiteConn) Close() error {
2013-10-24 17:21:37 +04:00
rv := C.sqlite3_close_v2(c.db)
2011-11-11 16:36:22 +04:00
if rv != C.SQLITE_OK {
return c.lastError()
2011-11-11 16:36:22 +04:00
}
c.db = nil
return nil
}
2013-01-31 11:48:30 +04:00
// Prepare query string. Return a new statement.
2011-11-11 16:36:22 +04:00
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
pquery := C.CString(query)
defer C.free(unsafe.Pointer(pquery))
var s *C.sqlite3_stmt
2013-09-09 05:44:44 +04:00
var tail *C.char
rv := C.sqlite3_prepare_v2(c.db, pquery, -1, &s, &tail)
2011-11-11 16:36:22 +04:00
if rv != C.SQLITE_OK {
return nil, c.lastError()
2011-11-11 16:36:22 +04:00
}
var t string
2013-09-09 05:44:44 +04:00
if tail != nil && C.strlen(tail) > 0 {
t = strings.TrimSpace(C.GoString(tail))
2011-11-11 16:36:22 +04:00
}
2012-02-20 11:14:49 +04:00
return &SQLiteStmt{c: c, s: s, t: t}, nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Close the statement.
2011-11-11 16:36:22 +04:00
func (s *SQLiteStmt) Close() error {
2012-02-20 11:14:49 +04:00
if s.closed {
return nil
}
s.closed = true
2013-02-13 05:32:40 +04:00
if s.c == nil || s.c.db == nil {
return errors.New("sqlite statement with already closed database connection")
}
2011-11-11 16:38:53 +04:00
rv := C.sqlite3_finalize(s.s)
2011-11-11 16:36:22 +04:00
if rv != C.SQLITE_OK {
return s.c.lastError()
2011-11-11 16:36:22 +04:00
}
return nil
}
2013-01-31 11:48:30 +04:00
// Return a number of parameters.
2011-11-11 16:36:22 +04:00
func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s))
}
2012-02-20 11:14:49 +04:00
func (s *SQLiteStmt) bind(args []driver.Value) error {
2011-11-11 16:36:22 +04:00
rv := C.sqlite3_reset(s.s)
2011-11-11 16:38:53 +04:00
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return s.c.lastError()
2011-11-11 16:36:22 +04:00
}
for i, v := range args {
2011-11-11 16:38:53 +04:00
n := C.int(i + 1)
2011-11-11 16:36:22 +04:00
switch v := v.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
case string:
if len(v) == 0 {
b := []byte{0}
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
} else {
b := []byte(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
2011-11-11 16:36:22 +04:00
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case bool:
if bool(v) {
2012-03-12 09:20:55 +04:00
rv = C.sqlite3_bind_int(s.s, n, 1)
2011-11-11 16:36:22 +04:00
} else {
rv = C.sqlite3_bind_int(s.s, n, 0)
}
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
var p *byte
if len(v) > 0 {
p = &v[0]
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
case time.Time:
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
2011-11-11 16:36:22 +04:00
}
if rv != C.SQLITE_OK {
return s.c.lastError()
2011-11-11 16:36:22 +04:00
}
}
return nil
}
2014-02-18 04:06:30 +04:00
// Query the statement with arguments. Return records.
2012-02-20 11:14:49 +04:00
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
2011-11-11 16:36:22 +04:00
if err := s.bind(args); err != nil {
return nil, err
}
return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil, s.cls}, nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Return last inserted ID.
2011-11-14 17:10:13 +04:00
func (r *SQLiteResult) LastInsertId() (int64, error) {
return r.id, nil
2011-11-14 17:10:13 +04:00
}
2013-01-31 11:48:30 +04:00
// Return how many rows affected.
2011-11-14 17:10:13 +04:00
func (r *SQLiteResult) RowsAffected() (int64, error) {
return r.changes, nil
2011-11-14 17:10:13 +04:00
}
2013-01-31 11:48:30 +04:00
// Execute the statement with arguments. Return result object.
2012-02-20 11:14:49 +04:00
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
2011-11-11 16:36:22 +04:00
if err := s.bind(args); err != nil {
return nil, err
}
rv := C.sqlite3_step(s.s)
2011-11-11 16:38:53 +04:00
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return nil, s.c.lastError()
2011-11-11 16:36:22 +04:00
}
2013-05-11 16:45:48 +04:00
res := &SQLiteResult{
int64(C._sqlite3_last_insert_rowid(s.c.db)),
int64(C._sqlite3_changes(s.c.db)),
}
return res, nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Close the rows.
2011-11-11 16:36:22 +04:00
func (rc *SQLiteRows) Close() error {
2013-09-09 07:28:34 +04:00
if rc.s.closed {
return nil
}
if rc.cls {
return rc.s.Close()
}
2012-03-12 09:20:55 +04:00
rv := C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return rc.s.c.lastError()
2012-03-12 09:20:55 +04:00
}
return nil
2011-11-11 16:36:22 +04:00
}
2013-01-31 11:48:30 +04:00
// Return column names.
2011-11-11 16:36:22 +04:00
func (rc *SQLiteRows) Columns() []string {
if rc.nc != len(rc.cols) {
rc.cols = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
}
}
return rc.cols
}
2013-01-31 11:48:30 +04:00
// Move cursor to next.
2012-02-20 11:14:49 +04:00
func (rc *SQLiteRows) Next(dest []driver.Value) error {
2011-11-11 16:36:22 +04:00
rv := C.sqlite3_step(rc.s.s)
if rv == C.SQLITE_DONE {
return io.EOF
}
2011-11-11 16:36:22 +04:00
if rv != C.SQLITE_ROW {
2013-09-09 08:44:24 +04:00
rv = C.sqlite3_reset(rc.s.s)
if rv != C.SQLITE_OK {
return rc.s.c.lastError()
2013-09-09 08:44:24 +04:00
}
return nil
2011-11-11 16:36:22 +04:00
}
if rc.decltype == nil {
rc.decltype = make([]string, rc.nc)
for i := 0; i < rc.nc; i++ {
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
}
}
2011-11-11 16:36:22 +04:00
for i := range dest {
2011-11-11 16:38:53 +04:00
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
case C.SQLITE_INTEGER:
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
switch rc.decltype[i] {
case "timestamp", "datetime", "date":
dest[i] = time.Unix(val, 0)
case "boolean":
dest[i] = val > 0
default:
dest[i] = val
}
2011-11-11 16:38:53 +04:00
case C.SQLITE_FLOAT:
dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i)))
case C.SQLITE_BLOB:
p := C.sqlite3_column_blob(rc.s.s, C.int(i))
if p == nil {
dest[i] = nil
continue
}
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
switch dest[i].(type) {
case sql.RawBytes:
dest[i] = (*[1 << 30]byte)(unsafe.Pointer(p))[0:n]
default:
slice := make([]byte, n)
copy(slice[:], (*[1 << 30]byte)(unsafe.Pointer(p))[0:n])
dest[i] = slice
}
2011-11-11 16:38:53 +04:00
case C.SQLITE_NULL:
dest[i] = nil
case C.SQLITE_TEXT:
var err error
s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))))
2012-12-30 02:20:27 +04:00
switch rc.decltype[i] {
case "timestamp", "datetime", "date":
for _, format := range SQLiteTimestampFormats {
if dest[i], err = time.Parse(format, s); err == nil {
2012-12-30 02:20:27 +04:00
break
}
}
if err != nil {
2012-12-30 02:20:27 +04:00
// The column is a time value, so return the zero time on parse failure.
dest[i] = time.Time{}
}
default:
dest[i] = []byte(s)
}
2011-11-11 16:36:22 +04:00
}
}
return nil
}