Fixed bug for loc parameter

This commit is contained in:
mattn 2015-03-05 01:17:38 +09:00
parent 18aa166fa9
commit e273a1552e
2 changed files with 25 additions and 26 deletions

View File

@ -423,14 +423,9 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
} }
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v))) rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
case time.Time: case time.Time:
if s.c.loc != nil {
b := []byte(v.In(s.c.loc).Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
} else {
b := []byte(v.UTC().Format(SQLiteTimestampFormats[0])) b := []byte(v.UTC().Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
} }
}
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return s.c.lastError() return s.c.lastError()
} }
@ -536,11 +531,18 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
return fmt.Errorf("error parsing %s value %d, %s", rc.decltype[i], val, err) return fmt.Errorf("error parsing %s value %d, %s", rc.decltype[i], val, err)
} }
epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) epoch := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)
dest[i] = epoch.Add(duration) if rc.s.c.loc != nil {
dest[i] = epoch.Add(duration).In(rc.s.c.loc)
} else { } else {
dest[i] = time.Unix(val, 0).Local() dest[i] = epoch.Add(duration)
}
} else {
if rc.s.c.loc != nil {
dest[i] = time.Unix(val, 0).In(rc.s.c.loc)
} else {
dest[i] = time.Unix(val, 0)
}
} }
case "boolean": case "boolean":
dest[i] = val > 0 dest[i] = val > 0
default: default:
@ -572,13 +574,13 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
switch rc.decltype[i] { switch rc.decltype[i] {
case "timestamp", "datetime", "date": case "timestamp", "datetime", "date":
zone := rc.s.c.loc
if zone == nil {
zone = time.UTC
}
for _, format := range SQLiteTimestampFormats { for _, format := range SQLiteTimestampFormats {
if timeVal, err = time.ParseInLocation(format, s, zone); err == nil { if timeVal, err = time.ParseInLocation(format, s, time.UTC); err == nil {
if rc.s.c.loc != nil {
dest[i] = timeVal.In(rc.s.c.loc)
} else {
dest[i] = timeVal dest[i] = timeVal
}
break break
} }
} }

View File

@ -11,6 +11,7 @@ import (
"encoding/hex" "encoding/hex"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"time" "time"
@ -746,16 +747,12 @@ func TestStress(t *testing.T) {
func TestDateTimeLocal(t *testing.T) { func TestDateTimeLocal(t *testing.T) {
zone := "Asia/Tokyo" zone := "Asia/Tokyo"
z, err := time.LoadLocation(zone)
if err != nil {
t.Skip("Failed to load timezon:", err)
}
tempFilename := TempFilename() tempFilename := TempFilename()
db, err := sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone) db, err := sql.Open("sqlite3", "file:///"+tempFilename+"?loc="+zone)
if err != nil { if err != nil {
t.Fatal("Failed to open database:", err) t.Fatal("Failed to open database:", err)
} }
db.Exec("CREATE TABLE foo (id datetime);") db.Exec("CREATE TABLE foo (dt datetime);")
db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');") db.Exec("INSERT INTO foo VALUES('2015-03-05 15:16:17');")
row := db.QueryRow("select * from foo") row := db.QueryRow("select * from foo")
@ -764,7 +761,7 @@ func TestDateTimeLocal(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("Failed to scan datetime:", err) t.Fatal("Failed to scan datetime:", err)
} }
if d.Local().Hour() != 15 { if d.Hour() == 15 || !strings.Contains(d.String(), "JST") {
t.Fatal("Result should have timezone", d) t.Fatal("Result should have timezone", d)
} }
db.Close() db.Close()
@ -779,8 +776,8 @@ func TestDateTimeLocal(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("Failed to scan datetime:", err) t.Fatal("Failed to scan datetime:", err)
} }
if d.In(z).Hour() == 15 { if d.UTC().Hour() != 15 || !strings.Contains(d.String(), "UTC") {
t.Fatalf("Result should not have timezone %v", zone) t.Fatalf("Result should not have timezone %v %v", zone, d.String())
} }
_, err = db.Exec("DELETE FROM foo") _, err = db.Exec("DELETE FROM foo")
@ -804,8 +801,8 @@ func TestDateTimeLocal(t *testing.T) {
if err != nil { if err != nil {
t.Fatal("Failed to scan datetime:", err) t.Fatal("Failed to scan datetime:", err)
} }
if d.Hour() == 15 { if d.Hour() != 15 || !strings.Contains(d.String(), "JST") {
t.Fatalf("Result should have timezone %v", zone) t.Fatalf("Result should have timezone %v %v", zone, d.String())
} }
} }