forked from mirror/go-sqlcipher
Add loc=XXX parameters to handle timezone
This commit is contained in:
parent
da2bf8a0f3
commit
4c5c4e5261
52
sqlite3.go
52
sqlite3.go
|
@ -65,6 +65,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -107,7 +108,8 @@ type SQLiteDriver struct {
|
|||
|
||||
// Conn struct.
|
||||
type SQLiteConn struct {
|
||||
db *C.sqlite3
|
||||
db *C.sqlite3
|
||||
loc *time.Location
|
||||
}
|
||||
|
||||
// Tx struct.
|
||||
|
@ -256,11 +258,31 @@ func errorString(err Error) string {
|
|||
// file:test.db?cache=shared&mode=memory
|
||||
// :memory:
|
||||
// file::memory:
|
||||
// go-sqlite handle especially query parameters.
|
||||
// loc=XXX
|
||||
// Specify location of time format. It's possible to specify "auto".
|
||||
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
||||
if C.sqlite3_threadsafe() == 0 {
|
||||
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
|
||||
}
|
||||
|
||||
var loc *time.Location
|
||||
if u, err := url.Parse(dsn); err == nil {
|
||||
for k, v := range u.Query() {
|
||||
switch k {
|
||||
case "loc":
|
||||
if len(v) > 0 {
|
||||
if v[0] == "auto" {
|
||||
v[0] = time.Local.String()
|
||||
}
|
||||
if loc, err = time.LoadLocation(v[0]); err != nil {
|
||||
return nil, fmt.Errorf("Invalid loc: %v: %v", v[0], err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var db *C.sqlite3
|
||||
name := C.CString(dsn)
|
||||
defer C.free(unsafe.Pointer(name))
|
||||
|
@ -281,7 +303,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
|
|||
return nil, Error{Code: ErrNo(rv)}
|
||||
}
|
||||
|
||||
conn := &SQLiteConn{db}
|
||||
conn := &SQLiteConn{db: db, loc: loc}
|
||||
|
||||
if len(d.Extensions) > 0 {
|
||||
rv = C.sqlite3_enable_load_extension(db, 1)
|
||||
|
@ -401,8 +423,13 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
|
|||
}
|
||||
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
|
||||
case time.Time:
|
||||
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||
if s.c.loc != nil {
|
||||
b := []byte(v.In(s.c.loc).Format(SQLiteTimestampFormats[0]))
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||
} else {
|
||||
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
|
||||
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||
}
|
||||
}
|
||||
if rv != C.SQLITE_OK {
|
||||
return s.c.lastError()
|
||||
|
@ -545,10 +572,19 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
|||
|
||||
switch rc.decltype[i] {
|
||||
case "timestamp", "datetime", "date":
|
||||
for _, format := range SQLiteTimestampFormats {
|
||||
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
|
||||
dest[i] = timeVal.Local()
|
||||
break
|
||||
if rc.s.c.loc != nil {
|
||||
for _, format := range SQLiteTimestampFormats {
|
||||
if timeVal, err = time.ParseInLocation(format, s, rc.s.c.loc); err == nil {
|
||||
dest[i] = timeVal
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, format := range SQLiteTimestampFormats {
|
||||
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
|
||||
dest[i] = timeVal
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
|
|
|
@ -744,6 +744,71 @@ func TestStress(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDateTimeLocal(t *testing.T) {
|
||||
zone := "Asia/Tokyo"
|
||||
z, err := time.LoadLocation(zone)
|
||||
if err != nil {
|
||||
t.Skip("Failed to load timezon:", err)
|
||||
}
|
||||
tempFilename := TempFilename()
|
||||
db, err := sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
db.Exec("CREATE TABLE foo (id datetime);")
|
||||
db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');")
|
||||
|
||||
row := db.QueryRow("select * from foo")
|
||||
var d time.Time
|
||||
err = row.Scan(&d)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to scan datetime:", err)
|
||||
}
|
||||
if d.Local().Hour() != 15 {
|
||||
t.Fatal("Result should have timezone", d)
|
||||
}
|
||||
db.Close()
|
||||
|
||||
db, err = sql.Open("sqlite3", "file:///"+tempFilename)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
|
||||
row = db.QueryRow("select * from foo")
|
||||
err = row.Scan(&d)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to scan datetime:", err)
|
||||
}
|
||||
if d.In(z).Hour() == 15 {
|
||||
t.Fatalf("Result should not have timezone %v", zone)
|
||||
}
|
||||
|
||||
_, err = db.Exec("DELETE FROM foo")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to delete table:", err)
|
||||
}
|
||||
dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to parse datetime:", err)
|
||||
}
|
||||
db.Exec("INSERT INTO foo VALUES(?);", dt)
|
||||
|
||||
db.Close()
|
||||
db, err = sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to open database:", err)
|
||||
}
|
||||
|
||||
row = db.QueryRow("select * from foo")
|
||||
err = row.Scan(&d)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to scan datetime:", err)
|
||||
}
|
||||
if d.Hour() == 15 {
|
||||
t.Fatalf("Result should have timezone %v", zone)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersion(t *testing.T) {
|
||||
s, n, id := Version()
|
||||
if s == "" || n == 0 || id == "" {
|
||||
|
|
Loading…
Reference in New Issue