Add loc=XXX parameters to handle timezone

This commit is contained in:
Yasuhiro Matsumoto 2015-03-04 22:49:17 +09:00
parent da2bf8a0f3
commit 4c5c4e5261
2 changed files with 109 additions and 8 deletions

View File

@ -65,6 +65,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/url"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@ -107,7 +108,8 @@ type SQLiteDriver struct {
// Conn struct. // Conn struct.
type SQLiteConn struct { type SQLiteConn struct {
db *C.sqlite3 db *C.sqlite3
loc *time.Location
} }
// Tx struct. // Tx struct.
@ -256,11 +258,31 @@ func errorString(err Error) string {
// file:test.db?cache=shared&mode=memory // file:test.db?cache=shared&mode=memory
// :memory: // :memory:
// file::memory: // file::memory:
// go-sqlite handle especially query parameters.
// loc=XXX
// Specify location of time format. It's possible to specify "auto".
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
if C.sqlite3_threadsafe() == 0 { if C.sqlite3_threadsafe() == 0 {
return nil, errors.New("sqlite library was not compiled for thread-safe operation") return nil, errors.New("sqlite library was not compiled for thread-safe operation")
} }
var loc *time.Location
if u, err := url.Parse(dsn); err == nil {
for k, v := range u.Query() {
switch k {
case "loc":
if len(v) > 0 {
if v[0] == "auto" {
v[0] = time.Local.String()
}
if loc, err = time.LoadLocation(v[0]); err != nil {
return nil, fmt.Errorf("Invalid loc: %v: %v", v[0], err)
}
}
}
}
}
var db *C.sqlite3 var db *C.sqlite3
name := C.CString(dsn) name := C.CString(dsn)
defer C.free(unsafe.Pointer(name)) defer C.free(unsafe.Pointer(name))
@ -281,7 +303,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
return nil, Error{Code: ErrNo(rv)} return nil, Error{Code: ErrNo(rv)}
} }
conn := &SQLiteConn{db} conn := &SQLiteConn{db: db, loc: loc}
if len(d.Extensions) > 0 { if len(d.Extensions) > 0 {
rv = C.sqlite3_enable_load_extension(db, 1) rv = C.sqlite3_enable_load_extension(db, 1)
@ -401,8 +423,13 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
} }
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v))) rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
case time.Time: case time.Time:
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) if s.c.loc != nil {
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) b := []byte(v.In(s.c.loc).Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
} else {
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
} }
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return s.c.lastError() return s.c.lastError()
@ -545,10 +572,19 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
switch rc.decltype[i] { switch rc.decltype[i] {
case "timestamp", "datetime", "date": case "timestamp", "datetime", "date":
for _, format := range SQLiteTimestampFormats { if rc.s.c.loc != nil {
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { for _, format := range SQLiteTimestampFormats {
dest[i] = timeVal.Local() if timeVal, err = time.ParseInLocation(format, s, rc.s.c.loc); err == nil {
break dest[i] = timeVal
break
}
}
} else {
for _, format := range SQLiteTimestampFormats {
if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
dest[i] = timeVal
break
}
} }
} }
if err != nil { if err != nil {

View File

@ -744,6 +744,71 @@ func TestStress(t *testing.T) {
} }
} }
func TestDateTimeLocal(t *testing.T) {
zone := "Asia/Tokyo"
z, err := time.LoadLocation(zone)
if err != nil {
t.Skip("Failed to load timezon:", err)
}
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone)
if err != nil {
t.Fatal("Failed to open database:", err)
}
db.Exec("CREATE TABLE foo (id datetime);")
db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');")
row := db.QueryRow("select * from foo")
var d time.Time
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.Local().Hour() != 15 {
t.Fatal("Result should have timezone", d)
}
db.Close()
db, err = sql.Open("sqlite3", "file:///"+tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
row = db.QueryRow("select * from foo")
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.In(z).Hour() == 15 {
t.Fatalf("Result should not have timezone %v", zone)
}
_, err = db.Exec("DELETE FROM foo")
if err != nil {
t.Fatal("Failed to delete table:", err)
}
dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST")
if err != nil {
t.Fatal("Failed to parse datetime:", err)
}
db.Exec("INSERT INTO foo VALUES(?);", dt)
db.Close()
db, err = sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone)
if err != nil {
t.Fatal("Failed to open database:", err)
}
row = db.QueryRow("select * from foo")
err = row.Scan(&d)
if err != nil {
t.Fatal("Failed to scan datetime:", err)
}
if d.Hour() == 15 {
t.Fatalf("Result should have timezone %v", zone)
}
}
func TestVersion(t *testing.T) { func TestVersion(t *testing.T) {
s, n, id := Version() s, n, id := Version()
if s == "" || n == 0 || id == "" { if s == "" || n == 0 || id == "" {