forked from mirror/go-sqlcipher
Merge from mattn/go-sqlite3.
This commit is contained in:
commit
242cae0e70
32
sqlite3.go
32
sqlite3.go
|
@ -119,8 +119,9 @@ type SQLiteDriver struct {
|
||||||
|
|
||||||
// Conn struct.
|
// Conn struct.
|
||||||
type SQLiteConn struct {
|
type SQLiteConn struct {
|
||||||
db *C.sqlite3
|
db *C.sqlite3
|
||||||
loc *time.Location
|
loc *time.Location
|
||||||
|
txlock string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tx struct.
|
// Tx struct.
|
||||||
|
@ -254,7 +255,7 @@ func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
|
||||||
|
|
||||||
// Begin transaction.
|
// Begin transaction.
|
||||||
func (c *SQLiteConn) Begin() (driver.Tx, error) {
|
func (c *SQLiteConn) Begin() (driver.Tx, error) {
|
||||||
if _, err := c.exec("BEGIN"); err != nil {
|
if _, err := c.exec(c.txlock); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &SQLiteTx{c}, nil
|
return &SQLiteTx{c}, nil
|
||||||
|
@ -275,12 +276,16 @@ func errorString(err Error) string {
|
||||||
// Specify location of time format. It's possible to specify "auto".
|
// Specify location of time format. It's possible to specify "auto".
|
||||||
// _busy_timeout=XXX
|
// _busy_timeout=XXX
|
||||||
// Specify value for sqlite3_busy_timeout.
|
// Specify value for sqlite3_busy_timeout.
|
||||||
|
// _txlock=XXX
|
||||||
|
// Specify locking behavior for transactions. XXX can be "immediate",
|
||||||
|
// "deferred", "exclusive".
|
||||||
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
if C.sqlite3_threadsafe() == 0 {
|
if C.sqlite3_threadsafe() == 0 {
|
||||||
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
||||||
}
|
}
|
||||||
|
|
||||||
var loc *time.Location
|
var loc *time.Location
|
||||||
|
txlock := "BEGIN"
|
||||||
busy_timeout := 5000
|
busy_timeout := 5000
|
||||||
pos := strings.IndexRune(dsn, '?')
|
pos := strings.IndexRune(dsn, '?')
|
||||||
if pos >= 1 {
|
if pos >= 1 {
|
||||||
|
@ -310,6 +315,20 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
busy_timeout = int(iv)
|
busy_timeout = int(iv)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// _txlock
|
||||||
|
if val := params.Get("_txlock"); val != "" {
|
||||||
|
switch val {
|
||||||
|
case "immediate":
|
||||||
|
txlock = "BEGIN IMMEDIATE"
|
||||||
|
case "exclusive":
|
||||||
|
txlock = "BEGIN EXCLUSIVE"
|
||||||
|
case "deferred":
|
||||||
|
txlock = "BEGIN"
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Invalid _txlock: %v", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(dsn, "file:") {
|
if !strings.HasPrefix(dsn, "file:") {
|
||||||
dsn = dsn[:pos]
|
dsn = dsn[:pos]
|
||||||
}
|
}
|
||||||
|
@ -335,7 +354,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
return nil, Error{Code: ErrNo(rv)}
|
return nil, Error{Code: ErrNo(rv)}
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := &SQLiteConn{db: db, loc: loc}
|
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
|
||||||
|
|
||||||
if len(d.Extensions) > 0 {
|
if len(d.Extensions) > 0 {
|
||||||
rv = C.sqlite3_enable_load_extension(db, 1)
|
rv = C.sqlite3_enable_load_extension(db, 1)
|
||||||
|
@ -626,11 +645,14 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||||
case C.SQLITE_TEXT:
|
case C.SQLITE_TEXT:
|
||||||
var err error
|
var err error
|
||||||
var timeVal time.Time
|
var timeVal time.Time
|
||||||
s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))))
|
|
||||||
|
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
|
||||||
|
s := C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))), C.int(n))
|
||||||
|
|
||||||
switch rc.decltype[i] {
|
switch rc.decltype[i] {
|
||||||
case "timestamp", "datetime", "date":
|
case "timestamp", "datetime", "date":
|
||||||
var t time.Time
|
var t time.Time
|
||||||
|
s = strings.TrimSuffix(s, "Z")
|
||||||
for _, format := range SQLiteTimestampFormats {
|
for _, format := range SQLiteTimestampFormats {
|
||||||
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
|
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
|
||||||
t = timeVal
|
t = timeVal
|
||||||
|
|
|
@ -9,6 +9,5 @@ package sqlite3
|
||||||
/*
|
/*
|
||||||
#cgo CFLAGS: -I.
|
#cgo CFLAGS: -I.
|
||||||
#cgo linux LDFLAGS: -ldl
|
#cgo linux LDFLAGS: -ldl
|
||||||
#cgo LDFLAGS: -lpthread
|
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
121
sqlite3_test.go
121
sqlite3_test.go
|
@ -8,7 +8,10 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -25,11 +28,17 @@ func TempFilename() string {
|
||||||
return filepath.Join(os.TempDir(), "foo"+hex.EncodeToString(randBytes)+".db")
|
return filepath.Join(os.TempDir(), "foo"+hex.EncodeToString(randBytes)+".db")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpen(t *testing.T) {
|
func doTestOpen(t *testing.T, option string) (string, error) {
|
||||||
|
var url string
|
||||||
tempFilename := TempFilename()
|
tempFilename := TempFilename()
|
||||||
db, err := sql.Open("sqlite3", tempFilename)
|
if option != "" {
|
||||||
|
url = tempFilename + option
|
||||||
|
} else {
|
||||||
|
url = tempFilename
|
||||||
|
}
|
||||||
|
db, err := sql.Open("sqlite3", url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to open database:", err)
|
return "Failed to open database:", err
|
||||||
}
|
}
|
||||||
defer os.Remove(tempFilename)
|
defer os.Remove(tempFilename)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
@ -37,11 +46,38 @@ func TestOpen(t *testing.T) {
|
||||||
_, err = db.Exec("drop table foo")
|
_, err = db.Exec("drop table foo")
|
||||||
_, err = db.Exec("create table foo (id integer)")
|
_, err = db.Exec("create table foo (id integer)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Failed to create table:", err)
|
return "Failed to create table:", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() {
|
if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() {
|
||||||
t.Error("Failed to create ./foo.db")
|
return "Failed to create ./foo.db", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpen(t *testing.T) {
|
||||||
|
cases := map[string]bool{
|
||||||
|
"": true,
|
||||||
|
"?_txlock=immediate": true,
|
||||||
|
"?_txlock=deferred": true,
|
||||||
|
"?_txlock=exclusive": true,
|
||||||
|
"?_txlock=bogus": false,
|
||||||
|
}
|
||||||
|
for option, expectedPass := range cases {
|
||||||
|
result, err := doTestOpen(t, option)
|
||||||
|
if result == "" {
|
||||||
|
if !expectedPass {
|
||||||
|
errmsg := fmt.Sprintf("_txlock error not caught at dbOpen with option: %s", option)
|
||||||
|
t.Fatal(errmsg)
|
||||||
|
}
|
||||||
|
} else if expectedPass {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal(result)
|
||||||
|
} else {
|
||||||
|
t.Fatal(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1024,3 +1060,78 @@ func TestEncryptoDatabase(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStringContainingZero(t *testing.T) {
|
||||||
|
tempFilename := TempFilename()
|
||||||
|
db, err := sql.Open("sqlite3", tempFilename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tempFilename)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err = db.Exec(`
|
||||||
|
create table foo (id integer, name, extra text);
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Failed to call db.Query:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const text = "foo\x00bar"
|
||||||
|
|
||||||
|
_, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, text)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Failed to call db.Exec:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, text)
|
||||||
|
if row == nil {
|
||||||
|
t.Error("Failed to call db.QueryRow")
|
||||||
|
}
|
||||||
|
|
||||||
|
var id int
|
||||||
|
var extra string
|
||||||
|
err = row.Scan(&id, &extra)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Failed to db.Scan:", err)
|
||||||
|
}
|
||||||
|
if id != 1 || extra != text {
|
||||||
|
t.Error("Failed to db.QueryRow: not matched results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const CurrentTimeStamp = "2006-01-02 15:04:05"
|
||||||
|
|
||||||
|
type TimeStamp struct{ *time.Time }
|
||||||
|
|
||||||
|
func (t TimeStamp) Scan(value interface{}) error {
|
||||||
|
var err error
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
*t.Time, err = time.Parse(CurrentTimeStamp, v)
|
||||||
|
case []byte:
|
||||||
|
*t.Time, err = time.Parse(CurrentTimeStamp, string(v))
|
||||||
|
default:
|
||||||
|
err = errors.New("invalid type for current_timestamp")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TimeStamp) Value() (driver.Value, error) {
|
||||||
|
return t.Time.Format(CurrentTimeStamp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDateTimeNow(t *testing.T) {
|
||||||
|
tempFilename := TempFilename()
|
||||||
|
db, err := sql.Open("sqlite3", tempFilename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var d time.Time
|
||||||
|
err = db.QueryRow("SELECT datetime('now')").Scan(TimeStamp{&d})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to scan datetime:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue