mirror of https://github.com/mattn/go-sqlite3.git
support named params
This commit is contained in:
parent
c95a77965c
commit
b23526fb3c
48
sqlite3.go
48
sqlite3.go
|
@ -172,8 +172,6 @@ type SQLiteTx struct {
|
|||
type SQLiteStmt struct {
|
||||
c *SQLiteConn
|
||||
s *C.sqlite3_stmt
|
||||
nv int
|
||||
nn []string
|
||||
t string
|
||||
closed bool
|
||||
cls bool
|
||||
|
@ -420,6 +418,7 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
|
|||
}
|
||||
|
||||
func (c *SQLiteConn) exec(ctx context.Context, query string, args []namedValue) (driver.Result, error) {
|
||||
start := 0
|
||||
for {
|
||||
s, err := c.Prepare(query)
|
||||
if err != nil {
|
||||
|
@ -431,12 +430,16 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []namedValue)
|
|||
if len(args) < na {
|
||||
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
|
||||
}
|
||||
for i := 0; i < na; i++ {
|
||||
args[i].Ordinal -= start
|
||||
}
|
||||
res, err = s.(*SQLiteStmt).exec(ctx, args[:na])
|
||||
if err != nil && err != driver.ErrSkip {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
args = args[na:]
|
||||
start += na
|
||||
}
|
||||
tail := s.(*SQLiteStmt).t
|
||||
s.Close()
|
||||
|
@ -466,6 +469,7 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
|
|||
}
|
||||
|
||||
func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
|
||||
start := 0
|
||||
for {
|
||||
s, err := c.Prepare(query)
|
||||
if err != nil {
|
||||
|
@ -476,12 +480,16 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue)
|
|||
if len(args) < na {
|
||||
return nil, fmt.Errorf("Not enough args to execute query. Expected %d, got %d.", na, len(args))
|
||||
}
|
||||
for i := 0; i < na; i++ {
|
||||
args[i].Ordinal -= start
|
||||
}
|
||||
rows, err := s.(*SQLiteStmt).query(ctx, args[:na])
|
||||
if err != nil && err != driver.ErrSkip {
|
||||
s.Close()
|
||||
return nil, err
|
||||
}
|
||||
args = args[na:]
|
||||
start += na
|
||||
tail := s.(*SQLiteStmt).t
|
||||
if tail == "" {
|
||||
return rows, nil
|
||||
|
@ -648,15 +656,7 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
|
|||
if tail != nil && *tail != '\000' {
|
||||
t = strings.TrimSpace(C.GoString(tail))
|
||||
}
|
||||
nv := int(C.sqlite3_bind_parameter_count(s))
|
||||
var nn []string
|
||||
for i := 0; i < nv; i++ {
|
||||
pn := C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1)))
|
||||
if len(pn) > 1 && pn[0] == '$' && 48 <= pn[1] && pn[1] <= 57 {
|
||||
nn = append(nn, C.GoString(C.sqlite3_bind_parameter_name(s, C.int(i+1))))
|
||||
}
|
||||
}
|
||||
ss := &SQLiteStmt{c: c, s: s, nv: nv, nn: nn, t: t}
|
||||
ss := &SQLiteStmt{c: c, s: s, t: t}
|
||||
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
|
||||
return ss, nil
|
||||
}
|
||||
|
@ -680,7 +680,7 @@ func (s *SQLiteStmt) Close() error {
|
|||
|
||||
// Return a number of parameters.
|
||||
func (s *SQLiteStmt) NumInput() int {
|
||||
return s.nv
|
||||
return int(C.sqlite3_bind_parameter_count(s.s))
|
||||
}
|
||||
|
||||
type bindArg struct {
|
||||
|
@ -694,25 +694,17 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
|
|||
return s.c.lastError()
|
||||
}
|
||||
|
||||
var vargs []bindArg
|
||||
narg := len(args)
|
||||
vargs = make([]bindArg, narg)
|
||||
if len(s.nn) > 0 {
|
||||
for i, v := range s.nn {
|
||||
if pi, err := strconv.Atoi(v[1:]); err == nil {
|
||||
vargs[i] = bindArg{pi, args[i].Value}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i, v := range args {
|
||||
vargs[i] = bindArg{i + 1, v.Value}
|
||||
for i, v := range args {
|
||||
if v.Name != "" {
|
||||
cname := C.CString(v.Name)
|
||||
args[i].Ordinal = int(C.sqlite3_bind_parameter_index(s.s, cname))
|
||||
C.free(unsafe.Pointer(cname))
|
||||
}
|
||||
}
|
||||
|
||||
for _, varg := range vargs {
|
||||
n := C.int(varg.n)
|
||||
v := varg.v
|
||||
switch v := v.(type) {
|
||||
for _, arg := range args {
|
||||
n := C.int(arg.Ordinal)
|
||||
switch v := arg.Value.(type) {
|
||||
case nil:
|
||||
rv = C.sqlite3_bind_null(s.s, n)
|
||||
case string:
|
||||
|
|
|
@ -23,6 +23,14 @@ func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driv
|
|||
return c.query(ctx, query, list)
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = namedValue(nv)
|
||||
}
|
||||
return c.exec(ctx, query, list)
|
||||
}
|
||||
|
||||
func (s *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
|
@ -30,3 +38,11 @@ func (s *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue)
|
|||
}
|
||||
return s.query(ctx, list)
|
||||
}
|
||||
|
||||
func (s *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = namedValue(nv)
|
||||
}
|
||||
return s.exec(ctx, list)
|
||||
}
|
||||
|
|
|
@ -993,7 +993,7 @@ func TestVersion(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNumberNamedParams(t *testing.T) {
|
||||
func TestNamedParams(t *testing.T) {
|
||||
tempFilename := TempFilename(t)
|
||||
defer os.Remove(tempFilename)
|
||||
db, err := sql.Open("sqlite3", tempFilename)
|
||||
|
@ -1009,12 +1009,12 @@ func TestNumberNamedParams(t *testing.T) {
|
|||
t.Error("Failed to call db.Query:", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`insert into foo(id, name, extra) values($1, $2, $2)`, 1, "foo")
|
||||
_, err = db.Exec(`insert into foo(id, name, extra) values(:id, :name, :name)`, sql.Param(":name", "foo"), sql.Param(":id", 1))
|
||||
if err != nil {
|
||||
t.Error("Failed to call db.Exec:", err)
|
||||
}
|
||||
|
||||
row := db.QueryRow(`select id, extra from foo where id = $1 and extra = $2`, 1, "foo")
|
||||
row := db.QueryRow(`select id, extra from foo where id = :id and extra = :extra`, sql.Param(":id", 1), sql.Param(":extra", "foo"))
|
||||
if row == nil {
|
||||
t.Error("Failed to call db.QueryRow")
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue