diff --git a/sqlite3.go b/sqlite3.go index 605474c..4457798 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -65,6 +65,7 @@ import ( "errors" "fmt" "io" + "net/url" "runtime" "strconv" "strings" @@ -107,7 +108,8 @@ type SQLiteDriver struct { // Conn struct. type SQLiteConn struct { - db *C.sqlite3 + db *C.sqlite3 + loc *time.Location } // Tx struct. @@ -256,11 +258,31 @@ func errorString(err Error) string { // file:test.db?cache=shared&mode=memory // :memory: // file::memory: +// go-sqlite handle especially query parameters. +// loc=XXX +// Specify location of time format. It's possible to specify "auto". func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { if C.sqlite3_threadsafe() == 0 { return nil, errors.New("sqlite library was not compiled for thread-safe operation") } + var loc *time.Location + if u, err := url.Parse(dsn); err == nil { + for k, v := range u.Query() { + switch k { + case "loc": + if len(v) > 0 { + if v[0] == "auto" { + v[0] = time.Local.String() + } + if loc, err = time.LoadLocation(v[0]); err != nil { + return nil, fmt.Errorf("Invalid loc: %v: %v", v[0], err) + } + } + } + } + } + var db *C.sqlite3 name := C.CString(dsn) defer C.free(unsafe.Pointer(name)) @@ -281,7 +303,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) { return nil, Error{Code: ErrNo(rv)} } - conn := &SQLiteConn{db} + conn := &SQLiteConn{db: db, loc: loc} if len(d.Extensions) > 0 { rv = C.sqlite3_enable_load_extension(db, 1) @@ -401,8 +423,13 @@ func (s *SQLiteStmt) bind(args []driver.Value) error { } rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v))) case time.Time: - b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + if s.c.loc != nil { + b := []byte(v.In(s.c.loc).Format(SQLiteTimestampFormats[0])) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + } else { + b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + } } if rv != C.SQLITE_OK { return s.c.lastError() @@ -545,10 +572,19 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { switch rc.decltype[i] { case "timestamp", "datetime", "date": - for _, format := range SQLiteTimestampFormats { - if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { - dest[i] = timeVal.Local() - break + if rc.s.c.loc != nil { + for _, format := range SQLiteTimestampFormats { + if timeVal, err = time.ParseInLocation(format, s, rc.s.c.loc); err == nil { + dest[i] = timeVal + break + } + } + } else { + for _, format := range SQLiteTimestampFormats { + if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil { + dest[i] = timeVal + break + } } } if err != nil { diff --git a/sqlite3_test.go b/sqlite3_test.go index a0adf30..325ba8e 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -744,6 +744,71 @@ func TestStress(t *testing.T) { } } +func TestDateTimeLocal(t *testing.T) { + zone := "Asia/Tokyo" + z, err := time.LoadLocation(zone) + if err != nil { + t.Skip("Failed to load timezon:", err) + } + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone) + if err != nil { + t.Fatal("Failed to open database:", err) + } + db.Exec("CREATE TABLE foo (id datetime);") + db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');") + + row := db.QueryRow("select * from foo") + var d time.Time + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.Local().Hour() != 15 { + t.Fatal("Result should have timezone", d) + } + db.Close() + + db, err = sql.Open("sqlite3", "file:///"+tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + + row = db.QueryRow("select * from foo") + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.In(z).Hour() == 15 { + t.Fatalf("Result should not have timezone %v", zone) + } + + _, err = db.Exec("DELETE FROM foo") + if err != nil { + t.Fatal("Failed to delete table:", err) + } + dt, err := time.Parse("2006/1/2 15/4/5 -0700 MST", "2015/3/5 15/16/17 +0900 JST") + if err != nil { + t.Fatal("Failed to parse datetime:", err) + } + db.Exec("INSERT INTO foo VALUES(?);", dt) + + db.Close() + db, err = sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone) + if err != nil { + t.Fatal("Failed to open database:", err) + } + + row = db.QueryRow("select * from foo") + err = row.Scan(&d) + if err != nil { + t.Fatal("Failed to scan datetime:", err) + } + if d.Hour() == 15 { + t.Fatalf("Result should have timezone %v", zone) + } +} + func TestVersion(t *testing.T) { s, n, id := Version() if s == "" || n == 0 || id == "" {