forked from mirror/go-sqlite3
Handle time.Time values with "timestamp" columns.
This commit is contained in:
parent
e85c34cf5c
commit
3abc337b8e
42
sqlite3.go
42
sqlite3.go
|
@ -23,9 +23,13 @@ import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const SQLiteTimestampFormat = "2006-01-02 15:04:05"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
sql.Register("sqlite3", &SQLiteDriver{})
|
sql.Register("sqlite3", &SQLiteDriver{})
|
||||||
}
|
}
|
||||||
|
@ -194,6 +198,9 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
|
||||||
p = &v[0]
|
p = &v[0]
|
||||||
}
|
}
|
||||||
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
|
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v)))
|
||||||
|
case time.Time:
|
||||||
|
b := []byte(v.Format(SQLiteTimestampFormat))
|
||||||
|
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
|
||||||
}
|
}
|
||||||
if rv != C.SQLITE_OK {
|
if rv != C.SQLITE_OK {
|
||||||
return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
|
return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db)))
|
||||||
|
@ -206,7 +213,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||||
if err := s.bind(args); err != nil {
|
if err := s.bind(args); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil}, nil
|
return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type SQLiteResult struct {
|
type SQLiteResult struct {
|
||||||
|
@ -233,9 +240,10 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type SQLiteRows struct {
|
type SQLiteRows struct {
|
||||||
s *SQLiteStmt
|
s *SQLiteStmt
|
||||||
nc int
|
nc int
|
||||||
cols []string
|
cols []string
|
||||||
|
decltype []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rc *SQLiteRows) Close() error {
|
func (rc *SQLiteRows) Close() error {
|
||||||
|
@ -264,10 +272,23 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||||
if rv != C.SQLITE_ROW {
|
if rv != C.SQLITE_ROW {
|
||||||
return errors.New(C.GoString(C.sqlite3_errmsg(rc.s.c.db)))
|
return errors.New(C.GoString(C.sqlite3_errmsg(rc.s.c.db)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rc.decltype == nil {
|
||||||
|
rc.decltype = make([]string, rc.nc)
|
||||||
|
for i := 0; i < rc.nc; i++ {
|
||||||
|
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for i := range dest {
|
for i := range dest {
|
||||||
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
|
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
|
||||||
case C.SQLITE_INTEGER:
|
case C.SQLITE_INTEGER:
|
||||||
dest[i] = int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
|
val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i)))
|
||||||
|
if rc.decltype[i] == "timestamp" {
|
||||||
|
dest[i] = time.Unix(val, 0)
|
||||||
|
} else {
|
||||||
|
dest[i] = val
|
||||||
|
}
|
||||||
case C.SQLITE_FLOAT:
|
case C.SQLITE_FLOAT:
|
||||||
dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i)))
|
dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i)))
|
||||||
case C.SQLITE_BLOB:
|
case C.SQLITE_BLOB:
|
||||||
|
@ -277,7 +298,16 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
|
||||||
case C.SQLITE_NULL:
|
case C.SQLITE_NULL:
|
||||||
dest[i] = nil
|
dest[i] = nil
|
||||||
case C.SQLITE_TEXT:
|
case C.SQLITE_TEXT:
|
||||||
dest[i] = C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))))
|
var err error
|
||||||
|
s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i)))))
|
||||||
|
if rc.decltype[i] == "timestamp" {
|
||||||
|
dest[i], err = time.Parse(SQLiteTimestampFormat, s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dest[i] = s
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -3,7 +3,9 @@ package sqlite
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOpen(t *testing.T) {
|
func TestOpen(t *testing.T) {
|
||||||
|
@ -260,3 +262,80 @@ func TestBooleanRoundtrip(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTimestamp(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", "./foo.db")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to open database:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer os.Remove("./foo.db")
|
||||||
|
|
||||||
|
_, err = db.Exec("DROP TABLE foo")
|
||||||
|
_, err = db.Exec("CREATE TABLE foo(id INTEGER, ts timeSTAMP)")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to create table:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC)
|
||||||
|
_, err = db.Exec("INSERT INTO foo(id, ts) VALUES(1, ?)", timestamp1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to insert timestamp:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp2 := time.Date(2012, time.April, 6, 23, 22, 0, 0, time.UTC)
|
||||||
|
_, err = db.Exec("INSERT INTO foo(id, ts) VALUES(2, ?)", timestamp2.Unix())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to insert timestamp:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec("INSERT INTO foo(id, ts) VALUES(3, ?)", "nonsense")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to insert nonsense:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT id, ts FROM foo ORDER BY id ASC")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unable to query foo table:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := 0
|
||||||
|
for rows.Next() {
|
||||||
|
var id int
|
||||||
|
var ts time.Time
|
||||||
|
|
||||||
|
if err := rows.Scan(&id, &ts); err != nil {
|
||||||
|
t.Errorf("Unable to scan results:", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == 1 {
|
||||||
|
seen += 1
|
||||||
|
if !timestamp1.Equal(ts) {
|
||||||
|
t.Errorf("Value for id 1 should be %v, not %v", timestamp1, ts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == 2 {
|
||||||
|
seen += 1
|
||||||
|
if !timestamp2.Equal(ts) {
|
||||||
|
t.Errorf("Value for id 2 should be %v, not %v", timestamp2, ts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if seen != 2 {
|
||||||
|
t.Errorf("Expected to see two valid timestamps")
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure "nonsense" triggered an error
|
||||||
|
err = rows.Err()
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "cannot parse \"nonsense\"") {
|
||||||
|
t.Errorf("Expected error from \"nonsense\" timestamp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue