diff --git a/sqlite3.go b/sqlite3.go index 6598e02..d0f02e3 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -136,50 +136,48 @@ func (c *SQLiteConn) AutoCommit() bool { // Implements Execer func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { + var res driver.Result for { - ds, err := c.Prepare(query) + s, err := c.Prepare(query) if err != nil { return nil, err } - s := ds.(*SQLiteStmt) na := s.NumInput() - res, err := s.Exec(args[:na]) - args = args[na:] - if err != nil { + res, err = s.Exec(args[:na]) + if err != nil && err != driver.ErrSkip { s.Close() return nil, err } - if s.t == "" { + args = args[na:] + tail := s.(*SQLiteStmt).t + if tail == "" { return res, nil } s.Close() - query = s.t + query = tail } } // Implements Queryer func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { + var rows driver.Rows for { - ds, err := c.Prepare(query) + s, err := c.Prepare(query) if err != nil { return nil, err } - s := ds.(*SQLiteStmt) na := s.NumInput() - rows, err := s.Query(args[:na]) - args = args[na:] - if err != nil { + rows, err = s.Query(args[:na]) + if err != nil && err != driver.ErrSkip { s.Close() return nil, err } - if s.t == "" { + args = args[na:] + tail := s.(*SQLiteStmt).t + if tail == "" { return rows, nil } - if rows != nil { - rows.Close() - } - s.Close() - query = s.t + query = tail } } @@ -418,6 +416,9 @@ func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { // Close the rows. func (rc *SQLiteRows) Close() error { + if rc.s.closed { + return nil + } rv := C.sqlite3_reset(rc.s.s) if rv != C.SQLITE_OK { return ErrNo(rv) diff --git a/sqlite3_test.go b/sqlite3_test.go index 5da83f6..8a95314 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -591,9 +591,9 @@ func TestExecer(t *testing.T) { _, err = db.Exec(` create table foo (id integer); - insert into foo values(1); - insert into foo values(2); - insert into foo values(3); + insert into foo(id) values(1); + insert into foo(id) values(2); + insert into foo(id) values(3); `) if err != nil { t.Error("Failed to call db.Exec:", err) @@ -614,24 +614,26 @@ func TestQueryer(t *testing.T) { rows, err := db.Query(` create table foo (id integer); - insert into foo values(1); - insert into foo values(2); - insert into foo values(3); + insert into foo(id) values(?); + insert into foo(id) values(?); + insert into foo(id) values(?); select id from foo order by id; - `) + `, 3, 2, 1) if err != nil { - t.Error("Failed to call db.Exec:", err) + t.Error("Failed to call db.Query:", err) } defer rows.Close() n := 1 - for rows.Next() { - var id int - err = rows.Scan(&id) - if err != nil { - t.Error("Failed to db.Query:", err) - } - if id != n { - t.Error("Failed to db.Query: not matched results") + if rows != nil { + for rows.Next() { + var id int + err = rows.Scan(&id) + if err != nil { + t.Error("Failed to db.Query:", err) + } + if id != n { + t.Error("Failed to db.Query: not matched results") + } } } }