Merge from mattn/go-sqlite3.

This commit is contained in:
xeodou 2015-05-18 13:14:17 +08:00
commit 242cae0e70
3 changed files with 143 additions and 11 deletions

View File

@ -121,6 +121,7 @@ type SQLiteDriver struct {
type SQLiteConn struct { type SQLiteConn struct {
db *C.sqlite3 db *C.sqlite3
loc *time.Location loc *time.Location
txlock string
} }
// Tx struct. // Tx struct.
@ -254,7 +255,7 @@ func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
// Begin transaction. // Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) { func (c *SQLiteConn) Begin() (driver.Tx, error) {
if _, err := c.exec("BEGIN"); err != nil { if _, err := c.exec(c.txlock); err != nil {
return nil, err return nil, err
} }
return &SQLiteTx{c}, nil return &SQLiteTx{c}, nil
@ -275,12 +276,16 @@ func errorString(err Error) string {
// Specify location of time format. It's possible to specify "auto". // Specify location of time format. It's possible to specify "auto".
// _busy_timeout=XXX // _busy_timeout=XXX
// Specify value for sqlite3_busy_timeout. // Specify value for sqlite3_busy_timeout.
// _txlock=XXX
// Specify locking behavior for transactions. XXX can be "immediate",
// "deferred", "exclusive".
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 { if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation") return nil, errors.New("sqlite library was not compiled for thread-safe operation")
} }
var loc *time.Location var loc *time.Location
txlock := "BEGIN"
busy_timeout := 5000 busy_timeout := 5000
pos := strings.IndexRune(dsn, '?') pos := strings.IndexRune(dsn, '?')
if pos >= 1 { if pos >= 1 {
@ -310,6 +315,20 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
busy_timeout = int(iv) busy_timeout = int(iv)
} }
// _txlock
if val := params.Get("_txlock"); val != "" {
switch val {
case "immediate":
txlock = "BEGIN IMMEDIATE"
case "exclusive":
txlock = "BEGIN EXCLUSIVE"
case "deferred":
txlock = "BEGIN"
default:
return nil, fmt.Errorf("Invalid _txlock: %v", val)
}
}
if !strings.HasPrefix(dsn, "file:") { if !strings.HasPrefix(dsn, "file:") {
dsn = dsn[:pos] dsn = dsn[:pos]
} }
@ -335,7 +354,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, Error{Code: ErrNo(rv)} return nil, Error{Code: ErrNo(rv)}
} }
conn := &SQLiteConn{db: db, loc: loc} conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
if len(d.Extensions) > 0 { if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1) rv = C.sqlite3_enable_load_extension(db, 1)
@ -626,11 +645,14 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
case C.SQLITE_TEXT: case C.SQLITE_TEXT:
var err error var err error
var timeVal time.Time var timeVal time.Time
s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))))
n := int(C.sqlite3_column_bytes(rc.s.s, C.int(i)))
s := C.GoStringN((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))), C.int(n))
switch rc.decltype[i] { switch rc.decltype[i] {
case "timestamp", "datetime", "date": case "timestamp", "datetime", "date":
var t time.Time var t time.Time
s = strings.TrimSuffix(s, "Z")
for _, format := range SQLiteTimestampFormats { for _, format := range SQLiteTimestampFormats {
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
t = timeVal t = timeVal

View File

@ -9,6 +9,5 @@ package sqlite3
/* /*
#cgo CFLAGS: -I. #cgo CFLAGS: -I.
#cgo linux LDFLAGS: -ldl #cgo linux LDFLAGS: -ldl
#cgo LDFLAGS: -lpthread
*/ */
import "C" import "C"

View File

@ -8,7 +8,10 @@ package sqlite3
import ( import (
"crypto/rand" "crypto/rand"
"database/sql" "database/sql"
"database/sql/driver"
"encoding/hex" "encoding/hex"
"errors"
"fmt"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
@ -25,11 +28,17 @@ func TempFilename() string {
return filepath.Join(os.TempDir(), "foo"+hex.EncodeToString(randBytes)+".db") return filepath.Join(os.TempDir(), "foo"+hex.EncodeToString(randBytes)+".db")
} }
func TestOpen(t *testing.T) { func doTestOpen(t *testing.T, option string) (string, error) {
var url string
tempFilename := TempFilename() tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename) if option != "" {
url = tempFilename + option
} else {
url = tempFilename
}
db, err := sql.Open("sqlite3", url)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) return "Failed to open database:", err
} }
defer os.Remove(tempFilename) defer os.Remove(tempFilename)
defer db.Close() defer db.Close()
@ -37,11 +46,38 @@ func TestOpen(t *testing.T) {
_, err = db.Exec("drop table foo") _, err = db.Exec("drop table foo")
_, err = db.Exec("create table foo (id integer)") _, err = db.Exec("create table foo (id integer)")
if err != nil { if err != nil {
t.Fatal("Failed to create table:", err) return "Failed to create table:", err
} }
if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() { if stat, err := os.Stat(tempFilename); err != nil || stat.IsDir() {
t.Error("Failed to create ./foo.db") return "Failed to create ./foo.db", nil
}
return "", nil
}
func TestOpen(t *testing.T) {
cases := map[string]bool{
"": true,
"?_txlock=immediate": true,
"?_txlock=deferred": true,
"?_txlock=exclusive": true,
"?_txlock=bogus": false,
}
for option, expectedPass := range cases {
result, err := doTestOpen(t, option)
if result == "" {
if !expectedPass {
errmsg := fmt.Sprintf("_txlock error not caught at dbOpen with option: %s", option)
t.Fatal(errmsg)
}
} else if expectedPass {
if err == nil {
t.Fatal(result)
} else {
t.Fatal(result, err)
}
}
} }
} }
@ -1024,3 +1060,78 @@ func TestEncryptoDatabase(t *testing.T) {
} }
} }
func TestStringContainingZero(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
const text = "foo\x00bar"
_, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, text)
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, text)
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != text {
t.Error("Failed to db.QueryRow: not matched results")
}
}
const CurrentTimeStamp = "2006-01-02 15:04:05"
type TimeStamp struct{ *time.Time }
func (t TimeStamp) Scan(value interface{}) error {
var err error
switch v := value.(type) {
case string:
*t.Time, err = time.Parse(CurrentTimeStamp, v)
case []byte:
*t.Time, err = time.Parse(CurrentTimeStamp, string(v))
default:
err = errors.New("invalid type for current_timestamp")
}
return err
}
func (t TimeStamp) Value() (driver.Value, error) {
return t.Time.Format(CurrentTimeStamp), nil
}
func TestDateTimeNow(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
var d time.Time
err = db.QueryRow("SELECT datetime('now')").Scan(TimeStamp{&d})
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
}