Support $NNN-style named parameter. Close #187

This commit is contained in:
mattn 2015-03-22 02:08:47 +09:00
parent 5253daf856
commit a6c208564e
2 changed files with 93 additions and 4 deletions

View File

@ -121,6 +121,7 @@ type SQLiteTx struct {
type SQLiteStmt struct { type SQLiteStmt struct {
c *SQLiteConn c *SQLiteConn
s *C.sqlite3_stmt s *C.sqlite3_stmt
nv int
t string t string
closed bool closed bool
cls bool cls bool
@ -368,7 +369,19 @@ func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
if tail != nil && C.strlen(tail) > 0 { if tail != nil && C.strlen(tail) > 0 {
t = strings.TrimSpace(C.GoString(tail)) t = strings.TrimSpace(C.GoString(tail))
} }
ss := &SQLiteStmt{c: c, s: s, t: t} nv := int(C.sqlite3_bind_parameter_count(s))
if nv > 0 {
pn := C.GoString(C.sqlite3_bind_parameter_name(s, 1))
/* TODO: map argument for named parameters
if len(pn) > 0 && pn[0] == '$' && pn[1] != '1' {
nv = -1
}
*/
if len(pn) > 0 && pn[0] != '?' {
nv = -1
}
}
ss := &SQLiteStmt{c: c, s: s, nv: nv, t: t}
runtime.SetFinalizer(ss, (*SQLiteStmt).Close) runtime.SetFinalizer(ss, (*SQLiteStmt).Close)
return ss, nil return ss, nil
} }
@ -392,7 +405,12 @@ func (s *SQLiteStmt) Close() error {
// Return a number of parameters. // Return a number of parameters.
func (s *SQLiteStmt) NumInput() int { func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s)) return s.nv
}
type bindArg struct {
n int
v driver.Value
} }
func (s *SQLiteStmt) bind(args []driver.Value) error { func (s *SQLiteStmt) bind(args []driver.Value) error {
@ -401,8 +419,43 @@ func (s *SQLiteStmt) bind(args []driver.Value) error {
return s.c.lastError() return s.c.lastError()
} }
for i, v := range args { var vargs []bindArg
n := C.int(i + 1) narg := len(args)
if s.nv == -1 {
/* TODO: map argument for named parameters
if narg == 1 {
if m, ok := args[0].(map[string]driver.Value); ok {
for k, v := range m {
pn := C.CString(k)
if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 {
println(pi)
vargs = append(vargs, bindArg{pi, v})
}
C.free(unsafe.Pointer(pn))
}
}
narg = 0
}
*/
if narg > 0 {
for i := 0; i < narg; i++ {
pn := C.CString(fmt.Sprint(i + 1))
if pi := int(C.sqlite3_bind_parameter_index(s.s, pn)); pi > 0 {
vargs = append(vargs, bindArg{pi, args[i]})
}
C.free(unsafe.Pointer(pn))
}
}
} else {
vargs = make([]bindArg, narg)
for i, v := range args {
vargs[i] = bindArg{i + 1, v}
}
}
for _, varg := range vargs {
n := C.int(varg.n)
v := varg.v
switch v := v.(type) { switch v := v.(type) {
case nil: case nil:
rv = C.sqlite3_bind_null(s.s, n) rv = C.sqlite3_bind_null(s.s, n)

View File

@ -909,3 +909,39 @@ func TestVersion(t *testing.T) {
t.Errorf("Version failed %q, %d, %q\n", s, n, id) t.Errorf("Version failed %q, %d, %q\n", s, n, id)
} }
} }
func TestNumberNamedParams(t *testing.T) {
tempFilename := TempFilename()
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer os.Remove(tempFilename)
defer db.Close()
_, err = db.Exec(`
create table foo (id integer, name text, extra text);
`)
if err != nil {
t.Error("Failed to call db.Query:", err)
}
_, err = db.Exec(`insert into foo(id, name, extra)) values($1, $2, $2)`, 1, "foo")
if err != nil {
t.Error("Failed to call db.Exec:", err)
}
row := db.QueryRow(`select id, name, extra where id = $1 and extra = $2`, 1, "foo")
if row == nil {
t.Error("Failed to call db.QueryRow")
}
var id int
var extra string
err = row.Scan(&id, &extra)
if err != nil {
t.Error("Failed to db.Scan:", err)
}
if id != 1 || extra != "foo" {
t.Error("Failed to db.QueryRow: not matched results")
}
}