diff --git a/sqlite3.go b/sqlite3.go index e3ce711..26cd3d1 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -23,9 +23,13 @@ import ( "database/sql/driver" "errors" "io" + "strings" + "time" "unsafe" ) +const SQLiteTimestampFormat = "2006-01-02 15:04:05" + func init() { sql.Register("sqlite3", &SQLiteDriver{}) } @@ -194,6 +198,9 @@ func (s *SQLiteStmt) bind(args []driver.Value) error { p = &v[0] } rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(p), C.int(len(v))) + case time.Time: + b := []byte(v.Format(SQLiteTimestampFormat)) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) } if rv != C.SQLITE_OK { return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db))) @@ -206,7 +213,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { if err := s.bind(args); err != nil { return nil, err } - return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil}, nil + return &SQLiteRows{s, int(C.sqlite3_column_count(s.s)), nil, nil}, nil } type SQLiteResult struct { @@ -233,9 +240,10 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { } type SQLiteRows struct { - s *SQLiteStmt - nc int - cols []string + s *SQLiteStmt + nc int + cols []string + decltype []string } func (rc *SQLiteRows) Close() error { @@ -264,10 +272,23 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { if rv != C.SQLITE_ROW { return errors.New(C.GoString(C.sqlite3_errmsg(rc.s.c.db))) } + + if rc.decltype == nil { + rc.decltype = make([]string, rc.nc) + for i := 0; i < rc.nc; i++ { + rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))) + } + } + for i := range dest { switch C.sqlite3_column_type(rc.s.s, C.int(i)) { case C.SQLITE_INTEGER: - dest[i] = int64(C.sqlite3_column_int64(rc.s.s, C.int(i))) + val := int64(C.sqlite3_column_int64(rc.s.s, C.int(i))) + if rc.decltype[i] == "timestamp" { + dest[i] = time.Unix(val, 0) + } else { + dest[i] = val + } case C.SQLITE_FLOAT: dest[i] = float64(C.sqlite3_column_double(rc.s.s, C.int(i))) case C.SQLITE_BLOB: @@ -277,7 +298,16 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error { case C.SQLITE_NULL: dest[i] = nil case C.SQLITE_TEXT: - dest[i] = C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i))))) + var err error + s := C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_text(rc.s.s, C.int(i))))) + if rc.decltype[i] == "timestamp" { + dest[i], err = time.Parse(SQLiteTimestampFormat, s) + if err != nil { + return err + } + } else { + dest[i] = s + } } } return nil diff --git a/sqlite3_test.go b/sqlite3_test.go index 60feb86..b46666b 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -3,7 +3,9 @@ package sqlite import ( "database/sql" "os" + "strings" "testing" + "time" ) func TestOpen(t *testing.T) { @@ -260,3 +262,80 @@ func TestBooleanRoundtrip(t *testing.T) { } } } + +func TestTimestamp(t *testing.T) { + db, err := sql.Open("sqlite3", "./foo.db") + if err != nil { + t.Errorf("Failed to open database:", err) + return + } + defer os.Remove("./foo.db") + + _, err = db.Exec("DROP TABLE foo") + _, err = db.Exec("CREATE TABLE foo(id INTEGER, ts timeSTAMP)") + if err != nil { + t.Errorf("Failed to create table:", err) + return + } + + timestamp1 := time.Date(2012, time.April, 6, 22, 50, 0, 0, time.UTC) + _, err = db.Exec("INSERT INTO foo(id, ts) VALUES(1, ?)", timestamp1) + if err != nil { + t.Errorf("Failed to insert timestamp:", err) + return + } + + timestamp2 := time.Date(2012, time.April, 6, 23, 22, 0, 0, time.UTC) + _, err = db.Exec("INSERT INTO foo(id, ts) VALUES(2, ?)", timestamp2.Unix()) + if err != nil { + t.Errorf("Failed to insert timestamp:", err) + return + } + + _, err = db.Exec("INSERT INTO foo(id, ts) VALUES(3, ?)", "nonsense") + if err != nil { + t.Errorf("Failed to insert nonsense:", err) + return + } + + rows, err := db.Query("SELECT id, ts FROM foo ORDER BY id ASC") + if err != nil { + t.Errorf("Unable to query foo table:", err) + return + } + + seen := 0 + for rows.Next() { + var id int + var ts time.Time + + if err := rows.Scan(&id, &ts); err != nil { + t.Errorf("Unable to scan results:", err) + continue + } + + if id == 1 { + seen += 1 + if !timestamp1.Equal(ts) { + t.Errorf("Value for id 1 should be %v, not %v", timestamp1, ts) + } + } + + if id == 2 { + seen += 1 + if !timestamp2.Equal(ts) { + t.Errorf("Value for id 2 should be %v, not %v", timestamp2, ts) + } + } + } + + if seen != 2 { + t.Errorf("Expected to see two valid timestamps") + } + + // make sure "nonsense" triggered an error + err = rows.Err() + if err == nil || !strings.Contains(err.Error(), "cannot parse \"nonsense\"") { + t.Errorf("Expected error from \"nonsense\" timestamp") + } +}