diff --git a/sqlite3.go b/sqlite3.go index d8f7756..69aa89f 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -30,7 +30,6 @@ func init() { } type SQLiteDriver struct { - } type SQLiteConn struct { @@ -109,9 +108,10 @@ func (c *SQLiteConn) Close() error { } type SQLiteStmt struct { - c *SQLiteConn - s *C.sqlite3_stmt - t string + c *SQLiteConn + s *C.sqlite3_stmt + t string + closed bool } func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { @@ -127,10 +127,14 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { if perror != nil && C.strlen(perror) > 0 { t = C.GoString(perror) } - return &SQLiteStmt{c, s, t}, nil + return &SQLiteStmt{c: c, s: s, t: t}, nil } func (s *SQLiteStmt) Close() error { + if s.closed { + return nil + } + s.closed = true rv := C.sqlite3_finalize(s.s) if rv != C.SQLITE_OK { return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db))) @@ -142,7 +146,7 @@ func (s *SQLiteStmt) NumInput() int { return int(C.sqlite3_bind_parameter_count(s.s)) } -func (s *SQLiteStmt) bind(args []interface{}) error { +func (s *SQLiteStmt) bind(args []driver.Value) error { rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { return errors.New(C.GoString(C.sqlite3_errmsg(s.c.db))) @@ -169,7 +173,7 @@ func (s *SQLiteStmt) bind(args []interface{}) error { rv = C.sqlite3_bind_int(s.s, n, C.int(v)) case bool: if bool(v) { - rv = C.sqlite3_bind_int(s.s, n, 1) + rv = C.sqlite3_bind_int(s.s, n, -1) } else { rv = C.sqlite3_bind_int(s.s, n, 0) } @@ -191,7 +195,7 @@ func (s *SQLiteStmt) bind(args []interface{}) error { return nil } -func (s *SQLiteStmt) Query(args []interface{}) (driver.Rows, error) { +func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { if err := s.bind(args); err != nil { return nil, err } @@ -210,7 +214,7 @@ func (r *SQLiteResult) RowsAffected() (int64, error) { return int64(C.sqlite3_changes(r.s.c.db)), nil } -func (s *SQLiteStmt) Exec(args []interface{}) (driver.Result, error) { +func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { if err := s.bind(args); err != nil { return nil, err } @@ -228,7 +232,7 @@ type SQLiteRows struct { } func (rc *SQLiteRows) Close() error { - return nil + return rc.s.Close() } func (rc *SQLiteRows) Columns() []string { @@ -241,7 +245,7 @@ func (rc *SQLiteRows) Columns() []string { return rc.cols } -func (rc *SQLiteRows) Next(dest []interface{}) error { +func (rc *SQLiteRows) Next(dest []driver.Value) error { rv := C.sqlite3_step(rc.s.s) if rv != C.SQLITE_ROW { return errors.New(C.GoString(C.sqlite3_errmsg(rc.s.c.db))) diff --git a/sqlite3_test.go b/sqlite3_test.go index a76d86f..7cbe8c3 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -247,7 +247,7 @@ func TestBooleanRoundtrip(t *testing.T) { var id int var value bool - if err := rows.Scan(&id, &value) ; err != nil { + if err := rows.Scan(&id, &value); err != nil { t.Errorf("Unable to scan results:", err) continue }