Enable all prefixes for named parameters and allow for unused named parameters (#811)

* Allow unused named parameters

Try to bind all named parameters and ignore those not used.

* Allow "@" and "$" for named parameters

* Add tests for named parameters

Co-authored-by: Guido Berhoerster <guido+go-sqlite3@berhoerster.name>
This commit is contained in:
gber 2020-05-14 14:28:04 +00:00 committed by GitHub
parent 44b2a6394a
commit db4c9426f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 117 additions and 54 deletions

View File

@ -802,20 +802,29 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []namedValue)
} }
var res driver.Result var res driver.Result
if s.(*SQLiteStmt).s != nil { if s.(*SQLiteStmt).s != nil {
stmtArgs := make([]namedValue, 0, len(args))
na := s.NumInput() na := s.NumInput()
if len(args) < na { if len(args) - start < na {
s.Close() s.Close()
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
} }
for i := 0; i < na; i++ { // consume the number of arguments used in the current
args[i].Ordinal -= start // statement and append all named arguments not
// contained therein
stmtArgs = append(stmtArgs, args[start:start+na]...)
for i := range args {
if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i])
}
} }
res, err = s.(*SQLiteStmt).exec(ctx, args[:na]) for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1
}
res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.Close()
return nil, err return nil, err
} }
args = args[na:]
start += na start += na
} }
tail := s.(*SQLiteStmt).t tail := s.(*SQLiteStmt).t
@ -848,24 +857,33 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) { func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
start := 0 start := 0
for { for {
stmtArgs := make([]namedValue, 0, len(args))
s, err := c.prepare(ctx, query) s, err := c.prepare(ctx, query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.(*SQLiteStmt).cls = true s.(*SQLiteStmt).cls = true
na := s.NumInput() na := s.NumInput()
if len(args) < na { if len(args) - start < na {
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)) return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args) - start)
} }
for i := 0; i < na; i++ { // consume the number of arguments used in the current
args[i].Ordinal -= start // statement and append all named arguments not contained
// therein
stmtArgs = append(stmtArgs, args[start:start+na]...)
for i := range args {
if (i < start || i >= na) && args[i].Name != "" {
stmtArgs = append(stmtArgs, args[i])
}
} }
rows, err := s.(*SQLiteStmt).query(ctx, args[:na]) for i := range stmtArgs {
stmtArgs[i].Ordinal = i + 1
}
rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs)
if err != nil && err != driver.ErrSkip { if err != nil && err != driver.ErrSkip {
s.Close() s.Close()
return rows, err return rows, err
} }
args = args[na:]
start += na start += na
tail := s.(*SQLiteStmt).t tail := s.(*SQLiteStmt).t
if tail == "" { if tail == "" {
@ -1778,11 +1796,6 @@ func (s *SQLiteStmt) NumInput() int {
return int(C.sqlite3_bind_parameter_count(s.s)) return int(C.sqlite3_bind_parameter_count(s.s))
} }
type bindArg struct {
n int
v driver.Value
}
var placeHolder = []byte{0} var placeHolder = []byte{0}
func (s *SQLiteStmt) bind(args []namedValue) error { func (s *SQLiteStmt) bind(args []namedValue) error {
@ -1791,52 +1804,63 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
return s.c.lastError() return s.c.lastError()
} }
bindIndices := make([][3]int, len(args))
prefixes := []string{":", "@", "$"}
for i, v := range args { for i, v := range args {
bindIndices[i][0] = args[i].Ordinal
if v.Name != "" { if v.Name != "" {
cname := C.CString(":" + v.Name) for j := range prefixes {
args[i].Ordinal = int(C.sqlite3_bind_parameter_index(s.s, cname)) cname := C.CString(prefixes[j] + v.Name)
C.free(unsafe.Pointer(cname)) bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname))
C.free(unsafe.Pointer(cname))
}
args[i].Ordinal = bindIndices[i][0]
} }
} }
for _, arg := range args { for i, arg := range args {
n := C.int(arg.Ordinal) for j := range bindIndices[i] {
switch v := arg.Value.(type) { if bindIndices[i][j] == 0 {
case nil: continue
rv = C.sqlite3_bind_null(s.s, n) }
case string: n := C.int(bindIndices[i][j])
if len(v) == 0 { switch v := arg.Value.(type) {
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0)) case nil:
} else { rv = C.sqlite3_bind_null(s.s, n)
b := []byte(v) case string:
if len(v) == 0 {
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
} else {
b := []byte(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case bool:
if v {
rv = C.sqlite3_bind_int(s.s, n, 1)
} else {
rv = C.sqlite3_bind_int(s.s, n, 0)
}
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
if v == nil {
rv = C.sqlite3_bind_null(s.s, n)
} else {
ln := len(v)
if ln == 0 {
v = placeHolder
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
}
case time.Time:
b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
} }
case int64: if rv != C.SQLITE_OK {
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) return s.c.lastError()
case bool:
if v {
rv = C.sqlite3_bind_int(s.s, n, 1)
} else {
rv = C.sqlite3_bind_int(s.s, n, 0)
} }
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
if v == nil {
rv = C.sqlite3_bind_null(s.s, n)
} else {
ln := len(v)
if ln == 0 {
v = placeHolder
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
}
case time.Time:
b := []byte(v.Format(SQLiteTimestampFormats[0]))
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
if rv != C.SQLITE_OK {
return s.c.lastError()
} }
} }
return nil return nil

View File

@ -1778,6 +1778,45 @@ func TestInsertNilByteSlice(t *testing.T) {
} }
} }
func TestNamedParam(t *testing.T) {
tempFilename := TempFilename(t)
defer os.Remove(tempFilename)
db, err := sql.Open("sqlite3", tempFilename)
if err != nil {
t.Fatal("Failed to open database:", err)
}
defer db.Close()
_, err = db.Exec("drop table foo")
_, err = db.Exec("create table foo (id integer, name text, amount integer)")
if err != nil {
t.Fatal("Failed to create table:", err)
}
_, err = db.Exec("insert into foo(id, name, amount) values(:id, @name, $amount)",
sql.Named("bar", 42), sql.Named("baz", "quux"),
sql.Named("amount", 123), sql.Named("corge", "waldo"),
sql.Named("id", 2), sql.Named("name", "grault"))
if err != nil {
t.Fatal("Failed to insert record with named parameters:", err)
}
rows, err := db.Query("select id, name, amount from foo")
if err != nil {
t.Fatal("Failed to select records:", err)
}
defer rows.Close()
rows.Next()
var id, amount int
var name string
rows.Scan(&id, &name, &amount)
if id != 2 || name != "grault" || amount != 123 {
t.Errorf("Expected %d, %q, %d for fetched result, but got %d, %q, %d:", 2, "grault", 123, id, name, amount)
}
}
var customFunctionOnce sync.Once var customFunctionOnce sync.Once
func BenchmarkCustomFunctions(b *testing.B) { func BenchmarkCustomFunctions(b *testing.B) {