forked from mirror/go-sqlcipher
Support $NNN-style named parameter. Close #187
This commit is contained in:
parent
5253daf856
commit
a6c208564e
59
sqlite3.go
59
sqlite3.go
|
@ -121,6 +121,7 @@ type SQLiteTx struct {
|
||||||
type SQLiteStmt struct {
|
type SQLiteStmt struct {
|
||||||
c *SQLiteConn
|
c *SQLiteConn
|
||||||
s *C.sqlite3_stmt
|
s *C.sqlite3_stmt
|
||||||
|
nv int
|
||||||
t string
|
t string
|
||||||
closed bool
|
closed bool
|
||||||
cls bool
|
cls bool
|
||||||
|
@ -368,7 +369,19 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
|
||||||
if tail != nil && C.strlen(tail) > 0 {
|
if tail != nil && C.strlen(tail) > 0 {
|
||||||
t = strings.TrimSpace(C.GoString(tail))
|
t = strings.TrimSpace(C.GoString(tail))
|
||||||
}
|
}
|
||||||
ss := &SQLiteStmt{c: c, s: s, t: t}
|
nv := int(C.sqlite3_bind_parameter_count(s))
|
||||||
|
if nv > 0 {
|
||||||
|
pn := C.GoString(C.sqlite3_bind_parameter_name(s, 1))
|
||||||
|
/* TODO: map argument for named parameters
|
||||||
|
if len(pn) > 0 && pn[0] == '$' && pn[1] != '1' {
|
||||||
|
nv = -1
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
if len(pn) > 0 && pn[0] != '?' {
|
||||||
|
nv = -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss := &SQLiteStmt{c: c, s: s, nv: nv, t: t}
|
||||||
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
|
runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
|
||||||
return ss, nil
|
return ss, nil
|
||||||
}
|
}
|
||||||
|
@ -392,7 +405,12 @@ func (s *SQLiteStmt) Close() error {
|
||||||
|
|
||||||
// Return a number of parameters.
|
// Return a number of parameters.
|
||||||
func (s *SQLiteStmt) NumInput() int {
|
func (s *SQLiteStmt) NumInput() int {
|
||||||
return int(C.sqlite3_bind_parameter_count(s.s))
|
return s.nv
|
||||||
|
}
|
||||||
|
|
||||||
|
type bindArg struct {
|
||||||
|
n int
|
||||||
|
v driver.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLiteStmt) bind(args []driver.Value) error {
|
func (s *SQLiteStmt) bind(args []driver.Value) error {
|
||||||
|
@ -401,8 +419,43 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
|
||||||
return s.c.lastError()
|
return s.c.lastError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var vargs []bindArg
|
||||||
|
narg := len(args)
|
||||||
|
if s.nv == -1 {
|
||||||
|
/* TODO: map argument for named parameters
|
||||||
|
if narg == 1 {
|
||||||
|
if m, ok := args[0].(map[string]driver.Value); ok {
|
||||||
|
for k, v := range m {
|
||||||
|
pn := C.CString(k)
|
||||||
|
if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 {
|
||||||
|
println(pi)
|
||||||
|
vargs = append(vargs, bindArg{pi, v})
|
||||||
|
}
|
||||||
|
C.free(unsafe.Pointer(pn))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
narg = 0
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
if narg > 0 {
|
||||||
|
for i := 0; i < narg; i++ {
|
||||||
|
pn := C.CString(fmt.Sprint(i + 1))
|
||||||
|
if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 {
|
||||||
|
vargs = append(vargs, bindArg{pi, args[i]})
|
||||||
|
}
|
||||||
|
C.free(unsafe.Pointer(pn))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vargs = make([]bindArg, narg)
|
||||||
for i, v := range args {
|
for i, v := range args {
|
||||||
n := C.int(i + 1)
|
vargs[i] = bindArg{i + 1, v}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, varg := range vargs {
|
||||||
|
n := C.int(varg.n)
|
||||||
|
v := varg.v
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
rv = C.sqlite3_bind_null(s.s, n)
|
rv = C.sqlite3_bind_null(s.s, n)
|
||||||
|
|
|
@ -909,3 +909,39 @@ func TestVersion(t *testing.T) {
|
||||||
t.Errorf("Version failed %q, %d, %q\n", s, n, id)
|
t.Errorf("Version failed %q, %d, %q\n", s, n, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNumberNamedParams(t *testing.T) {
|
||||||
|
tempFilename := TempFilename()
|
||||||
|
db, err := sql.Open("sqlite3", tempFilename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tempFilename)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err = db.Exec(`
|
||||||
|
create table foo (id integer, name text, extra text);
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Failed to call db.Query:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec(`insert into foo(id, name, extra)) values($1, $2, $2)`, 1, "foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Failed to call db.Exec:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
row := db.QueryRow(`select id, name, extra where id = $1 and extra = $2`, 1, "foo")
|
||||||
|
if row == nil {
|
||||||
|
t.Error("Failed to call db.QueryRow")
|
||||||
|
}
|
||||||
|
var id int
|
||||||
|
var extra string
|
||||||
|
err = row.Scan(&id, &extra)
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Failed to db.Scan:", err)
|
||||||
|
}
|
||||||
|
if id != 1 || extra != "foo" {
|
||||||
|
t.Error("Failed to db.QueryRow: not matched results")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue