Merge pull request #348 from mattn/ping

new go1.8 feature
This commit is contained in:
mattn 2016-11-04 23:24:02 +09:00 committed by GitHub
commit f4e49e6484
5 changed files with 266 additions and 73 deletions

View File

@ -114,6 +114,8 @@ import (
"strings" "strings"
"time" "time"
"unsafe" "unsafe"
"golang.org/x/net/context"
) )
// Timestamp formats understood by both this module and SQLite. // Timestamp formats understood by both this module and SQLite.
@ -170,8 +172,6 @@ type SQLiteTx struct {
type SQLiteStmt struct { type SQLiteStmt struct {
c *SQLiteConn c *SQLiteConn
s *C.sqlite3_stmt s *C.sqlite3_stmt
nv int
nn []string
t string t string
closed bool closed bool
cls bool cls bool
@ -295,19 +295,19 @@ func (ai *aggInfo) Done(ctx *C.sqlite3_context) {
// Commit transaction. // Commit transaction.
func (tx *SQLiteTx) Commit() error { func (tx *SQLiteTx) Commit() error {
_, err := tx.c.exec("COMMIT") _, err := tx.c.execQuery("COMMIT")
if err != nil && err.(Error).Code == C.SQLITE_BUSY { if err != nil && err.(Error).Code == C.SQLITE_BUSY {
// sqlite3 will leave the transaction open in this scenario. // sqlite3 will leave the transaction open in this scenario.
// However, database/sql considers the transaction complete once we // However, database/sql considers the transaction complete once we
// return from Commit() - we must clean up to honour its semantics. // return from Commit() - we must clean up to honour its semantics.
tx.c.exec("ROLLBACK") tx.c.execQuery("ROLLBACK")
} }
return err return err
} }
// Rollback transaction. // Rollback transaction.
func (tx *SQLiteTx) Rollback() error { func (tx *SQLiteTx) Rollback() error {
_, err := tx.c.exec("ROLLBACK") _, err := tx.c.execQuery("ROLLBACK")
return err return err
} }
@ -404,9 +404,21 @@ func (c *SQLiteConn) lastError() Error {
// Implements Execer // Implements Execer
func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) { func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if len(args) == 0 { if len(args) == 0 {
return c.exec(query) return c.execQuery(query)
} }
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return c.exec(context.Background(), query, list)
}
func (c *SQLiteConn) exec(ctx context.Context, query string, args []namedValue) (driver.Result, error) {
start := 0
for { for {
s, err := c.Prepare(query) s, err := c.Prepare(query)
if err != nil { if err != nil {
@ -418,12 +430,16 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
if len(args) < na { if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args)) return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
} }
res, err = s.Exec(args[:na]) for i := 0; i < na; i++ {
args[i].Ordinal -= start
}
res, err = s.(*SQLiteStmt).exec(ctx, args[:na])
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.Close()
return nil, err return nil, err
} }
args = args[na:] args = args[na:]
start += na
} }
tail := s.(*SQLiteStmt).t tail := s.(*SQLiteStmt).t
s.Close() s.Close()
@ -434,8 +450,26 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
} }
} }
type namedValue struct {
Name string
Ordinal int
Value driver.Value
}
// Implements Queryer // Implements Queryer
func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) { func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return c.query(context.Background(), query, list)
}
func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
start := 0
for { for {
s, err := c.Prepare(query) s, err := c.Prepare(query)
if err != nil { if err != nil {
@ -446,12 +480,16 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
if len(args) < na { if len(args) < na {
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args)) return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
} }
rows, err := s.Query(args[:na]) for i := 0; i < na; i++ {
args[i].Ordinal -= start
}
rows, err := s.(*SQLiteStmt).query(ctx, args[:na])
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.Close()
return nil, err return nil, err
} }
args = args[na:] args = args[na:]
start += na
tail := s.(*SQLiteStmt).t tail := s.(*SQLiteStmt).t
if tail == "" { if tail == "" {
return rows, nil return rows, nil
@ -462,7 +500,7 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
} }
} }
func (c *SQLiteConn) exec(cmd string) (driver.Result, error) { func (c *SQLiteConn) execQuery(cmd string) (driver.Result, error) {
pcmd := C.CString(cmd) pcmd := C.CString(cmd)
defer C.free(unsafe.Pointer(pcmd)) defer C.free(unsafe.Pointer(pcmd))
@ -476,7 +514,11 @@ func (c *SQLiteConn) exec(cmd string) (driver.Result, error) {
// Begin transaction. // Begin transaction.
func (c *SQLiteConn) Begin() (driver.Tx, error) { func (c *SQLiteConn) Begin() (driver.Tx, error) {
if _, err := c.exec(c.txlock); err != nil { return c.begin(context.Background())
}
func (c *SQLiteConn) begin(ctx context.Context) (driver.Tx, error) {
if _, err := c.execQuery(c.txlock); err != nil {
return nil, err return nil, err
} }
return &SQLiteTx{c}, nil return &SQLiteTx{c}, nil
@ -606,6 +648,10 @@ func (c *SQLiteConn) Close() error {
// Prepare the query string. Return a new statement. // Prepare the query string. Return a new statement.
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) { func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
return c.prepare(context.Background(), query)
}
func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, error) {
pquery := C.CString(query) pquery := C.CString(query)
defer C.free(unsafe.Pointer(pquery)) defer C.free(unsafe.Pointer(pquery))
var s *C.sqlite3_stmt var s *C.sqlite3_stmt
@ -618,15 +664,7 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
if tail != nil && *tail != '\000' { if tail != nil && *tail != '\000' {
t = strings.TrimSpace(C.GoString(tail)) t = strings.TrimSpace(C.GoString(tail))
} }
nv := int(C.sqlite3_bind_parameter_count(s)) ss := &SQLiteStmt{c: c, s: s, t: t}
var nn []string
for i := 0; i < nv; i++ {
pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))
if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 {
nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))))
}
}
ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close) runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
return ss, nil return ss, nil
} }
@ -650,7 +688,7 @@ func (s *SQLiteStmt) Close() error {
// Return a number of parameters. // Return a number of parameters.
func (s *SQLiteStmt) NumInput() int { func (s *SQLiteStmt) NumInput() int {
return s.nv return int(C.sqlite3_bind_parameter_count(s.s))
} }
type bindArg struct { type bindArg struct {
@ -658,31 +696,23 @@ type bindArg struct {
v driver.Value v driver.Value
} }
func (s *SQLiteStmt) bind(args []driver.Value) error { func (s *SQLiteStmt) bind(args []namedValue) error {
rv := C.sqlite3_reset(s.s) rv := C.sqlite3_reset(s.s)
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
return s.c.lastError() return s.c.lastError()
} }
var vargs []bindArg for i, v := range args {
narg := len(args) if v.Name != "" {
vargs = make([]bindArg, narg) cname := C.CString(v.Name)
if len(s.nn) > 0 { args[i].Ordinal = int(C.sqlite3_bind_parameter_index(s.s, cname))
for i, v := range s.nn { C.free(unsafe.Pointer(cname))
if pi, err := strconv.Atoi(v[1:]); err == nil {
vargs[i] = bindArg{pi, args[i]}
}
}
} else {
for i, v := range args {
vargs[i] = bindArg{i + 1, v}
} }
} }
for _, varg := range vargs { for _, arg := range args {
n := C.int(varg.n) n := C.int(arg.Ordinal)
v := varg.v switch v := arg.Value.(type) {
switch v := v.(type) {
case nil: case nil:
rv = C.sqlite3_bind_null(s.s, n) rv = C.sqlite3_bind_null(s.s, n)
case string: case string:
@ -722,6 +752,17 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
// Query the statement with arguments. Return records. // Query the statement with arguments. Return records.
func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) { func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return s.query(context.Background(), list)
}
func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows, error) {
if err := s.bind(args); err != nil { if err := s.bind(args); err != nil {
return nil, err return nil, err
} }
@ -740,6 +781,17 @@ func (r *SQLiteResult) RowsAffected() (int64, error) {
// Execute the statement with arguments. Return result object. // Execute the statement with arguments. Return result object.
func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) { func (s *SQLiteStmt) Exec(args []driver.Value) (driver.Result, error) {
list := make([]namedValue, len(args))
for i, v := range args {
list[i] = namedValue{
Ordinal: i + 1,
Value: v,
}
}
return s.exec(context.Background(), list)
}
func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
if err := s.bind(args); err != nil { if err := s.bind(args); err != nil {
C.sqlite3_reset(s.s) C.sqlite3_reset(s.s)
C.sqlite3_clear_bindings(s.s) C.sqlite3_clear_bindings(s.s)

58
sqlite3_go18.go Normal file
View File

@ -0,0 +1,58 @@
// +build go1.8
package sqlite3
import (
"database/sql/driver"
"errors"
"golang.org/x/net/context"
)
// Ping implement Pinger.
func (c *SQLiteConn) Ping(ctx context.Context) error {
if c.db == nil {
return errors.New("Connection was closed")
}
return nil
}
func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return c.query(ctx, query, list)
}
func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return c.exec(ctx, query, list)
}
func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return c.prepare(ctx, query)
}
func (c *SQLiteConn) BeginContext(ctx context.Context) (driver.Tx, error) {
return c.begin(ctx)
}
func (s *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return s.query(ctx, list)
}
func (s *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
}
return s.exec(ctx, list)
}

49
sqlite3_go18_test.go Normal file
View File

@ -0,0 +1,49 @@
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.
// +build go1.8
package sqlite3
import (
"database/sql"
"os"
"testing"
)
func TestNamedParams(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name text, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
_, err = db.Exec(`insert into foo(id, name, extra) values(:id, :name, :name)`, sql.Param(":name", "foo"), sql.Param(":id", 1))
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
row := db.QueryRow(`select id, extra from foo where id = :id and extra = :extra`, sql.Param(":id", 1), sql.Param(":extra", "foo"))
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}

View File

@ -993,42 +993,6 @@ func TestVersion(t *testing.T) {
} }
} }
func TestNumberNamedParams(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name text, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
_, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, "foo")
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, "foo")
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}
func TestStringContainingZero(t *testing.T) { func TestStringContainingZero(t *testing.T) {
tempFilename := TempFilename(t) tempFilename := TempFilename(t)
defer os.Remove(tempFilename) defer os.Remove(tempFilename)
@ -1315,6 +1279,22 @@ func TestDeclTypes(t *testing.T) {
} }
} }
func TestPinger(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
err = db.Ping()
if err != nil {
t.Fatal(err)
}
db.Close()
err = db.Ping()
if err == nil {
t.Fatal("Should be closed")
}
}
var customFunctionOnce sync.Once var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) { func BenchmarkCustomFunctions(b *testing.B) {

54
sqlite3_type.go Normal file
View File

@ -0,0 +1,54 @@
package sqlite3
/*
#ifndef USE_LIBSQLITE3
#include <sqlite3-binding.h>
#else
#include <sqlite3.h>
#endif
*/
import "C"
import (
"reflect"
"time"
)
func (rc *SQLiteRows) ColumnTypeDatabaseTypeName(i int) string {
return C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))
}
/*
func (rc *SQLiteRows) ColumnTypeLength(index int) (length int64, ok bool) {
return 0, false
}
func (rc *SQLiteRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
return 0, 0, false
}
*/
func (rc *SQLiteRows) ColumnTypeNullable(i int) (nullable, ok bool) {
return true, true
}
func (rc *SQLiteRows) ColumnTypeScanType(i int) reflect.Type {
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
case C.SQLITE_INTEGER:
switch C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))) {
case "timestamp", "datetime", "date":
return reflect.TypeOf(time.Time{})
case "boolean":
return reflect.TypeOf(false)
}
return reflect.TypeOf(int64(0))
case C.SQLITE_FLOAT:
return reflect.TypeOf(float64(0))
case C.SQLITE_BLOB:
return reflect.SliceOf(reflect.TypeOf(byte(0)))
case C.SQLITE_NULL:
return reflect.TypeOf(nil)
case C.SQLITE_TEXT:
return reflect.TypeOf("")
}
return reflect.SliceOf(reflect.TypeOf(byte(0)))
}