forked from mirror/go-sqlite3
Enable all prefixes for named parameters and allow for unused named parameters (#811)
* Allow unused named parameters Try to bind all named parameters and ignore those not used. * Allow "@" and "$" for named parameters * Add tests for named parameters Co-authored-by: Guido Berhoerster <guido+go-sqlite3@berhoerster.name>
This commit is contained in:
parent
44b2a6394a
commit
db4c9426f8
64
sqlite3.go
64
sqlite3.go
|
@ -802,20 +802,29 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []namedValue)
|
||||||
}
|
}
|
||||||
var res driver.Result
|
var res driver.Result
|
||||||
if s.(*SQLiteStmt).s != nil {
|
if s.(*SQLiteStmt).s != nil {
|
||||||
|
stmtArgs := make([]namedValue, 0, len(args))
|
||||||
na := s.NumInput()
|
na := s.NumInput()
|
||||||
if len(args) < na {
|
if len(args) - start < na {
|
||||||
s.Close()
|
s.Close()
|
||||||
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
|
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
|
||||||
}
|
}
|
||||||
for i := 0; i < na; i++ {
|
// consume the number of arguments used in the current
|
||||||
args[i].Ordinal -= start
|
// statement and append all named arguments not
|
||||||
|
// contained therein
|
||||||
|
stmtArgs = append(stmtArgs, args[start:start+na]...)
|
||||||
|
for i := range args {
|
||||||
|
if (i < start || i >= na) && args[i].Name != "" {
|
||||||
|
stmtArgs = append(stmtArgs, args[i])
|
||||||
}
|
}
|
||||||
res, err = s.(*SQLiteStmt).exec(ctx, args[:na])
|
}
|
||||||
|
for i := range stmtArgs {
|
||||||
|
stmtArgs[i].Ordinal = i + 1
|
||||||
|
}
|
||||||
|
res, err = s.(*SQLiteStmt).exec(ctx, stmtArgs)
|
||||||
if err != nil && err != driver.ErrSkip {
|
if err != nil && err != driver.ErrSkip {
|
||||||
s.Close()
|
s.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
args = args[na:]
|
|
||||||
start += na
|
start += na
|
||||||
}
|
}
|
||||||
tail := s.(*SQLiteStmt).t
|
tail := s.(*SQLiteStmt).t
|
||||||
|
@ -848,24 +857,33 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
|
||||||
func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
|
func (c *SQLiteConn) query(ctx context.Context, query string, args []namedValue) (driver.Rows, error) {
|
||||||
start := 0
|
start := 0
|
||||||
for {
|
for {
|
||||||
|
stmtArgs := make([]namedValue, 0, len(args))
|
||||||
s, err := c.prepare(ctx, query)
|
s, err := c.prepare(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s.(*SQLiteStmt).cls = true
|
s.(*SQLiteStmt).cls = true
|
||||||
na := s.NumInput()
|
na := s.NumInput()
|
||||||
if len(args) < na {
|
if len(args) - start < na {
|
||||||
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
|
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args) - start)
|
||||||
}
|
}
|
||||||
for i := 0; i < na; i++ {
|
// consume the number of arguments used in the current
|
||||||
args[i].Ordinal -= start
|
// statement and append all named arguments not contained
|
||||||
|
// therein
|
||||||
|
stmtArgs = append(stmtArgs, args[start:start+na]...)
|
||||||
|
for i := range args {
|
||||||
|
if (i < start || i >= na) && args[i].Name != "" {
|
||||||
|
stmtArgs = append(stmtArgs, args[i])
|
||||||
}
|
}
|
||||||
rows, err := s.(*SQLiteStmt).query(ctx, args[:na])
|
}
|
||||||
|
for i := range stmtArgs {
|
||||||
|
stmtArgs[i].Ordinal = i + 1
|
||||||
|
}
|
||||||
|
rows, err := s.(*SQLiteStmt).query(ctx, stmtArgs)
|
||||||
if err != nil && err != driver.ErrSkip {
|
if err != nil && err != driver.ErrSkip {
|
||||||
s.Close()
|
s.Close()
|
||||||
return rows, err
|
return rows, err
|
||||||
}
|
}
|
||||||
args = args[na:]
|
|
||||||
start += na
|
start += na
|
||||||
tail := s.(*SQLiteStmt).t
|
tail := s.(*SQLiteStmt).t
|
||||||
if tail == "" {
|
if tail == "" {
|
||||||
|
@ -1778,11 +1796,6 @@ func (s *SQLiteStmt) NumInput() int {
|
||||||
return int(C.sqlite3_bind_parameter_count(s.s))
|
return int(C.sqlite3_bind_parameter_count(s.s))
|
||||||
}
|
}
|
||||||
|
|
||||||
type bindArg struct {
|
|
||||||
n int
|
|
||||||
v driver.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
var placeHolder = []byte{0}
|
var placeHolder = []byte{0}
|
||||||
|
|
||||||
func (s *SQLiteStmt) bind(args []namedValue) error {
|
func (s *SQLiteStmt) bind(args []namedValue) error {
|
||||||
|
@ -1791,16 +1804,26 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
|
||||||
return s.c.lastError()
|
return s.c.lastError()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bindIndices := make([][3]int, len(args))
|
||||||
|
prefixes := []string{":", "@", "$"}
|
||||||
for i, v := range args {
|
for i, v := range args {
|
||||||
|
bindIndices[i][0] = args[i].Ordinal
|
||||||
if v.Name != "" {
|
if v.Name != "" {
|
||||||
cname := C.CString(":" + v.Name)
|
for j := range prefixes {
|
||||||
args[i].Ordinal = int(C.sqlite3_bind_parameter_index(s.s, cname))
|
cname := C.CString(prefixes[j] + v.Name)
|
||||||
|
bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname))
|
||||||
C.free(unsafe.Pointer(cname))
|
C.free(unsafe.Pointer(cname))
|
||||||
}
|
}
|
||||||
|
args[i].Ordinal = bindIndices[i][0]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, arg := range args {
|
for i, arg := range args {
|
||||||
n := C.int(arg.Ordinal)
|
for j := range bindIndices[i] {
|
||||||
|
if bindIndices[i][j] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n := C.int(bindIndices[i][j])
|
||||||
switch v := arg.Value.(type) {
|
switch v := arg.Value.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
rv = C.sqlite3_bind_null(s.s, n)
|
rv = C.sqlite3_bind_null(s.s, n)
|
||||||
|
@ -1839,6 +1862,7 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
|
||||||
return s.c.lastError()
|
return s.c.lastError()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1778,6 +1778,45 @@ func TestInsertNilByteSlice(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNamedParam(t *testing.T) {
|
||||||
|
tempFilename := TempFilename(t)
|
||||||
|
defer os.Remove(tempFilename)
|
||||||
|
db, err := sql.Open("sqlite3", tempFilename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to open database:", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err = db.Exec("drop table foo")
|
||||||
|
_, err = db.Exec("create table foo (id integer, name text, amount integer)")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to create table:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec("insert into foo(id, name, amount) values(:id, @name, $amount)",
|
||||||
|
sql.Named("bar", 42), sql.Named("baz", "quux"),
|
||||||
|
sql.Named("amount", 123), sql.Named("corge", "waldo"),
|
||||||
|
sql.Named("id", 2), sql.Named("name", "grault"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to insert record with named parameters:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query("select id, name, amount from foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Failed to select records:", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
rows.Next()
|
||||||
|
|
||||||
|
var id, amount int
|
||||||
|
var name string
|
||||||
|
rows.Scan(&id, &name, &amount)
|
||||||
|
if id != 2 || name != "grault" || amount != 123 {
|
||||||
|
t.Errorf("Expected %d, %q, %d for fetched result, but got %d, %q, %d:", 2, "grault", 123, id, name, amount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var customFunctionOnce sync.Once
|
var customFunctionOnce sync.Once
|
||||||
|
|
||||||
func BenchmarkCustomFunctions(b *testing.B) {
|
func BenchmarkCustomFunctions(b *testing.B) {
|
||||||
|
|
Loading…
Reference in New Issue