fix trace callback.

Close #352
This commit is contained in:
Yasuhiro Matsumoto 2016-11-08 12:19:13 +09:00
parent 3e26a9df84
commit dd2c82226b
4 changed files with 97 additions and 31 deletions

View File

@ -40,8 +40,8 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value
} }
//export stepTrampoline //export stepTrampoline
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) { func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc] args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo) ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
ai.Step(ctx, args) ai.Step(ctx, args)
} }

View File

@ -191,6 +191,7 @@ type SQLiteRows struct {
decltype []string decltype []string
cls bool cls bool
done chan struct{} done chan struct{}
next *SQLiteRows
} }
type functionInfo struct { type functionInfo struct {
@ -296,19 +297,19 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
// Commit transaction. // Commit transaction.
func (tx *SQLiteTx) Commit() error { func (tx *SQLiteTx) Commit() error {
_, err := tx.c.execQuery("COMMIT") _, err := tx.c.exec(context.Background(), "COMMIT", nil)
if err != nil && err.(Error).Code == C.SQLITE_BUSY { if err != nil && err.(Error).Code == C.SQLITE_BUSY {
// sqlite3 will leave the transaction open in this scenario. // sqlite3 will leave the transaction open in this scenario.
// However, database/sql considers the transaction complete once we // However, database/sql considers the transaction complete once we
// return from Commit() - we must clean up to honour its semantics. // return from Commit() - we must clean up to honour its semantics.
tx.c.execQuery("ROLLBACK") tx.c.exec(context.Background(), "ROLLBACK", nil)
} }
return err return err
} }
// Rollback transaction. // Rollback transaction.
func (tx *SQLiteTx) Rollback() error { func (tx *SQLiteTx) Rollback() error {
_, err := tx.c.execQuery("ROLLBACK") _, err := tx.c.exec(context.Background(), "ROLLBACK", nil)
return err return err
} }
@ -382,13 +383,17 @@ func (c *SQLiteConn) RegisterFunc(name string, impl interface{}, pure bool) erro
if pure { if pure {
opts |= C.SQLITE_DETERMINISTIC opts |= C.SQLITE_DETERMINISTIC
} }
rv := C._sqlite3_create_function(c.db, cname, C.int(numArgs), C.int(opts), C.uintptr_t(newHandle(c, &fi)), (*[0]byte)(unsafe.Pointer(C.callbackTrampoline)), nil, nil) rv := sqlite3_create_function(c.db, cname, numArgs, opts, newHandle(c, &fi), C.callbackTrampoline, nil, nil)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return c.lastError() return c.lastError()
} }
return nil return nil
} }
func sqlite3_create_function(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTextRep C.int, pApp C.uintptr_t, xFunc unsafe.Pointer, xStep unsafe.Pointer, xFinal unsafe.Pointer) C.int {
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, pApp, (*[0]byte)(unsafe.Pointer(xFunc)), (*[0]byte)(unsafe.Pointer(xStep)), (*[0]byte)(unsafe.Pointer(xFinal)))
}
// AutoCommit return which currently auto commit or not. // AutoCommit return which currently auto commit or not.
func (c *SQLiteConn) AutoCommit() bool { func (c *SQLiteConn) AutoCommit() bool {
return int(C.sqlite3_get_autocommit(c.db)) != 0 return int(C.sqlite3_get_autocommit(c.db)) != 0
@ -404,10 +409,6 @@ func (c *SQLiteConn) lastError() Error {
// Exec implements Execer. // Exec implements Execer.
func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if len(args) == 0 {
return c.execQuery(query)
}
list := make([]namedValue, len(args)) list := make([]namedValue, len(args))
for i, v := range args { for i, v := range args {
list[i] = namedValue{ list[i] = namedValue{
@ -470,6 +471,7 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
} }
func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) { func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
var top, cur *SQLiteRows
start := 0 start := 0
for { for {
s, err := c.Prepare(query) s, err := c.Prepare(query)
@ -487,7 +489,14 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue)
rows, err := s.(*SQLiteStmt).query(ctx, args[:na]) rows, err := s.(*SQLiteStmt).query(ctx, args[:na])
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.Close()
return nil, err return top, err
}
if top == nil {
top = rows.(*SQLiteRows)
cur = top
} else {
cur.next = rows.(*SQLiteRows)
cur = cur.next
} }
args = args[na:] args = args[na:]
start += na start += na
@ -501,25 +510,13 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue)
} }
} }
func (c *SQLiteConn) execQuery(cmd string) (driver.Result, error) {
pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd))
var rowid, changes C.longlong
rv := C._sqlite3_exec(c.db, pcmd, &rowid, &changes)
if rv != C.SQLITE_OK {
return nil, c.lastError()
}
return &SQLiteResult{int64(rowid), int64(changes)}, nil
}
// Begin transaction. // Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) { func (c *SQLiteConn) Begin() (driver.Tx, error) {
return c.begin(context.Background()) return c.begin(context.Background())
} }
func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) { func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) {
if _, err := c.execQuery(c.txlock); err != nil { if _, err := c.exec(ctx, c.txlock, nil); err != nil {
return nil, err return nil, err
} }
return &SQLiteTx{c}, nil return &SQLiteTx{c}, nil
@ -775,6 +772,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
decltype: nil, decltype: nil,
cls: s.cls, cls: s.cls,
done: make(chan struct{}), done: make(chan struct{}),
next: nil,
} }
go func() { go func() {
@ -837,7 +835,7 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
return nil, err return nil, err
} }
return &SQLiteResult{int64(rowid), int64(changes)}, nil return &SQLiteResult{id: int64(rowid), changes: int64(changes)}, nil
} }
// Close the rows. // Close the rows.
@ -972,3 +970,15 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
} }
return nil return nil
} }
func (rc *SQLiteRows) HasNextResultSet() bool {
return rc.next != nil
}
func (rc *SQLiteRows) NextResultSet() error {
if rc.next == nil {
return io.EOF
}
*rc = *rc.next
return nil
}

View File

@ -9,6 +9,7 @@ package sqlite3
import ( import (
"database/sql" "database/sql"
"fmt"
"os" "os"
"testing" "testing"
) )
@ -48,3 +49,58 @@ func TestNamedParams(t *testing.T) {
t.Error("Failed to db.QueryRow: not matched results") t.Error("Failed to db.QueryRow: not matched results")
} }
} }
func TestMultipleResultSet(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
for i := 0; i < 100; i++ {
_, err = db.Exec(`insert into foo(id, name) values(?, ?)`, i+1, fmt.Sprintf("foo%03d", i+1))
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
}
rows, err := db.Query(`
select id, name from foo where id < :id1;
select id, name from foo where id = :id2;
select id, name from foo where id > :id3;
`,
sql.Param(":id1", 3),
sql.Param(":id2", 50),
sql.Param(":id3", 98),
)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
var id int
var extra string
for {
for rows.Next() {
err = rows.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}
if !rows.NextResultSet() {
break
}
}
}

View File

@ -17,7 +17,7 @@ package sqlite3
void stepTrampoline(sqlite3_context*, int, sqlite3_value**); void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
void doneTrampoline(sqlite3_context*); void doneTrampoline(sqlite3_context*);
void traceCallbackTrampoline(unsigned traceEventCode, void *ctx, void *p, void *x); int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
*/ */
import "C" import "C"
@ -76,7 +76,7 @@ type TraceUserCallback func(TraceInfo) int
type TraceConfig struct { type TraceConfig struct {
Callback TraceUserCallback Callback TraceUserCallback
EventMask uint EventMask C.uint
WantExpandedSQL bool WantExpandedSQL bool
} }
@ -102,13 +102,13 @@ func fillExpandedSQL(info *TraceInfo, db *C.sqlite3, pStmt unsafe.Pointer) {
//export traceCallbackTrampoline //export traceCallbackTrampoline
func traceCallbackTrampoline( func traceCallbackTrampoline(
traceEventCode uint, traceEventCode C.uint,
// Parameter named 'C' in SQLite docs = Context given at registration: // Parameter named 'C' in SQLite docs = Context given at registration:
ctx unsafe.Pointer, ctx unsafe.Pointer,
// Parameter named 'P' in SQLite docs (Primary event data?): // Parameter named 'P' in SQLite docs (Primary event data?):
p unsafe.Pointer, p unsafe.Pointer,
// Parameter named 'X' in SQLite docs (eXtra event data?): // Parameter named 'X' in SQLite docs (eXtra event data?):
xValue unsafe.Pointer) int { xValue unsafe.Pointer) C.int {
if ctx == nil { if ctx == nil {
panic(fmt.Sprintf("No context (ev 0x%x)", traceEventCode)) panic(fmt.Sprintf("No context (ev 0x%x)", traceEventCode))
@ -196,7 +196,7 @@ func traceCallbackTrampoline(
if traceConf.Callback != nil { if traceConf.Callback != nil {
r = traceConf.Callback(info) r = traceConf.Callback(info)
} }
return r return C.int(r)
} }
type traceMapEntry struct { type traceMapEntry struct {
@ -358,7 +358,7 @@ func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool
if pure { if pure {
opts |= C.SQLITE_DETERMINISTIC opts |= C.SQLITE_DETERMINISTIC
} }
rv := C._sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, (*[0]byte)(unsafe.Pointer(C.stepTrampoline)), (*[0]byte)(unsafe.Pointer(C.doneTrampoline))) rv := sqlite3_create_function(c.db, cname, C.int(stepNArgs), C.int(opts), C.uintptr_t(newHandle(c, &ai)), nil, C.stepTrampoline, C.doneTrampoline)
if rv != C.SQLITE_OK { if rv != C.SQLITE_OK {
return c.lastError() return c.lastError()
} }