diff --git a/sqlite3.go b/sqlite3.go index c5afd1f..c3abc78 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -136,20 +136,27 @@ func (c *SQLiteConn) AutoCommit() bool { // Implements Execer func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { + tx, err := c.Begin() + if err != nil { + return nil, err + } for { s, err := c.Prepare(query) if err != nil { + tx.Rollback() return nil, err } na := s.NumInput() res, err := s.Exec(args[:na]) if err != nil && err != driver.ErrSkip { + tx.Rollback() s.Close() return nil, err } args = args[na:] tail := s.(*SQLiteStmt).t if tail == "" { + tx.Commit() return res, nil } s.Close() @@ -159,20 +166,27 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err // Implements Queryer func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { + tx, err := c.Begin() + if err != nil { + return nil, err + } for { s, err := c.Prepare(query) if err != nil { + tx.Rollback() return nil, err } na := s.NumInput() rows, err := s.Query(args[:na]) if err != nil && err != driver.ErrSkip { + tx.Rollback() s.Close() return nil, err } args = args[na:] tail := s.(*SQLiteStmt).t if tail == "" { + tx.Commit() return rows, nil } s.Close()