forked from mirror/go-sqlcipher
Add loc=XXX parameters to handle timezone
This commit is contained in:
parent
da2bf8a0f3
commit
4c5c4e5261
52
sqlite3.go
52
sqlite3.go
|
@ -65,6 +65,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -107,7 +108,8 @@ type SQLiteDriver struct {
|
||||||
|
|
||||||
// Conn struct.
|
// Conn struct.
|
||||||
type SQLiteConn struct {
|
type SQLiteConn struct {
|
||||||
db *C.sqlite3
|
db *C.sqlite3
|
||||||
|
loc *time.Location
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tx struct.
|
// Tx struct.
|
||||||
|
@ -256,11 +258,31 @@ func errorString(err Error) string {
|
||||||
// file:test.db?cache=shared&mode=memory
|
// file:test.db?cache=shared&mode=memory
|
||||||
// :memory:
|
// :memory:
|
||||||
// file::memory:
|
// file::memory:
|
||||||
|
// go-sqlite handle especially query parameters.
|
||||||
|
// loc=XXX
|
||||||
|
// Specify location of time format. It's possible to specify "auto".
|
||||||
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
if C.sqlite3_threadsafe() == 0 {
|
if C.sqlite3_threadsafe() == 0 {
|
||||||
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var loc *time.Location
|
||||||
|
if u, err := url.Parse(dsn); err == nil {
|
||||||
|
for k, v := range u.Query() {
|
||||||
|
switch k {
|
||||||
|
case "loc":
|
||||||
|
if len(v) > 0 {
|
||||||
|
if v[0] == "auto" {
|
||||||
|
v[0] = time.Local.String()
|
||||||
|
}
|
||||||
|
if loc, err = time.LoadLocation(v[0]); err != nil {
|
||||||
|
return nil, fmt.Errorf("Invalid loc: %v: %v", v[0], err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var db *C.sqlite3
|
var db *C.sqlite3
|
||||||
name := C.CString(dsn)
|
name := C.CString(dsn)
|
||||||
defer C.free(unsafe.Pointer(name))
|
defer C.free(unsafe.Pointer(name))
|
||||||
|
@ -281,7 +303,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||||
return nil, Error{Code: ErrNo(rv)}
|
return nil, Error{Code: ErrNo(rv)}
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := &SQLiteConn{db}
|
conn := &SQLiteConn{db: db, loc: loc}
|
||||||
|
|
||||||
if len(d.Extensions) > 0 {
|
if len(d.Extensions) > 0 {
|
||||||
rv = C.sqlite3_enable_load_extension(db, 1)
|
rv = C.sqlite3_enable_load_extension(db, 1)
|
||||||
|
@ -401,8 +423,13 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
|
||||||
}
|
}
|
||||||
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
|
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
|
||||||
case time.Time:
|
case time.Time:
|
||||||
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
|
if s.c.loc != nil {
|
||||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
b := []byte(v.In(s.c.loc).Format(SQLiteTimestampFormats[0]))
|
||||||
|
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||||
|
} else {
|
||||||
|
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
|
||||||
|
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if rv != C.SQLITE_OK {
|
if rv != C.SQLITE_OK {
|
||||||
return s.c.lastError()
|
return s.c.lastError()
|
||||||
|
@ -545,10 +572,19 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||||
|
|
||||||
switch rc.decltype[i] {
|
switch rc.decltype[i] {
|
||||||
case "timestamp", "datetime", "date":
|
case "timestamp", "datetime", "date":
|
||||||
for _, format := range SQLiteTimestampFormats {
|
if rc.s.c.loc != nil {
|
||||||
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
|
for _, format := range SQLiteTimestampFormats {
|
||||||
dest[i] = timeVal.Local()
|
if timeVal, err = time.ParseInLocation(format, s, rc.s.c.loc); err == nil {
|
||||||
break
|
dest[i] = timeVal
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, format := range SQLiteTimestampFormats {
|
||||||
|
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
|
||||||
|
dest[i] = timeVal
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -744,6 +744,71 @@ func TestStress(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDateTimeLocal(t *testing.T) {
|
||||||
|
zone := "Asia/Tokyo"
|
||||||
|
z, err := time.LoadLocation(zone)
|
||||||
|
if err != nil {
|
||||||
|
t.Skip("Failed to load timezon:", err)
|
||||||
|
}
|
||||||
|
tempFilename := TempFilename()
|
||||||
|
db, err := sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
db.Exec("CREATE TABLE foo (id datetime);")
|
||||||
|
db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');")
|
||||||
|
|
||||||
|
row := db.QueryRow("select * from foo")
|
||||||
|
var d time.Time
|
||||||
|
err = row.Scan(&d)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to scan datetime:", err)
|
||||||
|
}
|
||||||
|
if d.Local().Hour() != 15 {
|
||||||
|
t.Fatal("Result should have timezone", d)
|
||||||
|
}
|
||||||
|
db.Close()
|
||||||
|
|
||||||
|
db, err = sql.Open("sqlite3", "file:///"+tempFilename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
row = db.QueryRow("select * from foo")
|
||||||
|
err = row.Scan(&d)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to scan datetime:", err)
|
||||||
|
}
|
||||||
|
if d.In(z).Hour() == 15 {
|
||||||
|
t.Fatalf("Result should not have timezone %v", zone)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec("DELETE FROM foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to delete table:", err)
|
||||||
|
}
|
||||||
|
dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to parse datetime:", err)
|
||||||
|
}
|
||||||
|
db.Exec("INSERT INTO foo VALUES(?);", dt)
|
||||||
|
|
||||||
|
db.Close()
|
||||||
|
db, err = sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
row = db.QueryRow("select * from foo")
|
||||||
|
err = row.Scan(&d)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to scan datetime:", err)
|
||||||
|
}
|
||||||
|
if d.Hour() == 15 {
|
||||||
|
t.Fatalf("Result should have timezone %v", zone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestVersion(t *testing.T) {
|
func TestVersion(t *testing.T) {
|
||||||
s, n, id := Version()
|
s, n, id := Version()
|
||||||
if s == "" || n == 0 || id == "" {
|
if s == "" || n == 0 || id == "" {
|
||||||
|
|
Loading…
Reference in New Issue