diff --git a/sqlite3.go b/sqlite3.go index 0ce9c30..98dcc0b 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -144,52 +144,55 @@ func (c *SQLiteConn) lastError() Error { // Implements Execer func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if len(args) == 0 { - return c.exec(query) - } + if len(args) == 0 { + return c.exec(query) + } - for { - s, err := c.Prepare(query) - if err != nil { - return nil, err - } - na := s.NumInput() - res, err := s.Exec(args[:na]) - if err != nil && err != driver.ErrSkip { - s.Close() - return nil, err - } - args = args[na:] - tail := s.(*SQLiteStmt).t - if tail == "" { - return res, nil - } - s.Close() - query = tail - } + for { + s, err := c.Prepare(query) + if err != nil { + return nil, err + } + var res driver.Result + if s.(*SQLiteStmt).s != nil { + na := s.NumInput() + res, err = s.Exec(args[:na]) + if err != nil && err != driver.ErrSkip { + s.Close() + return nil, err + } + args = args[na:] + } + tail := s.(*SQLiteStmt).t + if tail == "" { + return res, nil + } + s.Close() + query = tail + } } // Implements Queryer func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { - for { - s, err := c.Prepare(query) - if err != nil { - return nil, err - } - na := s.NumInput() - rows, err := s.Query(args[:na]) - if err != nil && err != driver.ErrSkip { - s.Close() - return nil, err - } - args = args[na:] - tail := s.(*SQLiteStmt).t - if tail == "" { - return rows, nil - } - s.Close() - query = tail - } + for { + s, err := c.Prepare(query) + if err != nil { + return nil, err + } + na := s.NumInput() + rows, err := s.Query(args[:na]) + if err != nil && err != driver.ErrSkip { + s.Close() + return nil, err + } + args = args[na:] + tail := s.(*SQLiteStmt).t + if tail == "" { + return rows, nil + } + s.Close() + query = tail + } } func (c *SQLiteConn) exec(cmd string) (driver.Result, error) { diff --git a/sqlite3_test.go b/sqlite3_test.go index bcbba44..2340014 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -1,7 +1,6 @@ package sqlite3 import ( - "./sqltest" "crypto/rand" "database/sql" "encoding/hex" @@ -9,6 +8,8 @@ import ( "path/filepath" "testing" "time" + + "./sqltest" ) func TempFilename() string { @@ -639,29 +640,26 @@ func TestSuite(t *testing.T) { // TODO: Execer & Queryer currently disabled // https://github.com/mattn/go-sqlite3/issues/82 -//func TestExecer(t *testing.T) { -// tempFilename := TempFilename() -// db, err := sql.Open("sqlite3", tempFilename) -// if err != nil { -// t.Fatal("Failed to open database:", err) -// } -// defer os.Remove(tempFilename) -// defer db.Close() -// -// _, err = db.Exec(` -// create table foo (id integer); -// insert into foo(id) values(?); -// insert into foo(id) values(?); -// insert into foo(id) values(?); -// `, 1, 2, 3) -// if err != nil { -// t.Error("Failed to call db.Exec:", err) -// } -// if err != nil { -// t.Error("Failed to call res.RowsAffected:", err) -// } -//} -// +func TestExecer(t *testing.T) { + tempFilename := TempFilename() + db, err := sql.Open("sqlite3", tempFilename) + if err != nil { + t.Fatal("Failed to open database:", err) + } + defer os.Remove(tempFilename) + defer db.Close() + + _, err = db.Exec(` + create table foo (id integer); -- one comment + insert into foo(id) values(?); + insert into foo(id) values(?); + insert into foo(id) values(?); -- another comment + `, 1, 2, 3) + if err != nil { + t.Error("Failed to call db.Exec:", err) + } +} + //func TestQueryer(t *testing.T) { // tempFilename := TempFilename() // db, err := sql.Open("sqlite3", tempFilename)