diff --git a/sqlite3.go b/sqlite3.go index ce985ec..6234652 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1921,26 +1921,90 @@ func (s *SQLiteStmt) NumInput() int { var placeHolder = []byte{0} +func hasNamedArgs(args []driver.NamedValue) bool { + for _, v := range args { + if v.Name != "" { + return true + } + } + return false +} + func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv := C.sqlite3_reset(s.s) if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE { return s.c.lastError() } + if hasNamedArgs(args) { + return s.bindIndices(args) + } + + for _, arg := range args { + n := C.int(arg.Ordinal) + switch v := arg.Value.(type) { + case nil: + rv = C.sqlite3_bind_null(s.s, n) + case string: + p := stringData(v) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(v))) + case int64: + rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) + case bool: + val := 0 + if v { + val = 1 + } + rv = C.sqlite3_bind_int(s.s, n, C.int(val)) + case float64: + rv = C.sqlite3_bind_double(s.s, n, C.double(v)) + case []byte: + if v == nil { + rv = C.sqlite3_bind_null(s.s, n) + } else { + ln := len(v) + if ln == 0 { + v = placeHolder + } + rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) + } + case time.Time: + ts := v.Format(SQLiteTimestampFormats[0]) + p := stringData(ts) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(ts))) + } + if rv != C.SQLITE_OK { + return s.c.lastError() + } + } + return nil +} + +func (s *SQLiteStmt) bindIndices(args []driver.NamedValue) error { + // Find the longest named parameter name. + n := 0 + for _, v := range args { + if m := len(v.Name); m > n { + n = m + } + } + buf := make([]byte, 0, n+2) // +2 for placeholder and null terminator + bindIndices := make([][3]int, len(args)) - prefixes := []string{":", "@", "$"} for i, v := range args { bindIndices[i][0] = args[i].Ordinal if v.Name != "" { - for j := range prefixes { - cname := C.CString(prefixes[j] + v.Name) - bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, cname)) - C.free(unsafe.Pointer(cname)) + for j, c := range []byte{':', '@', '$'} { + buf = append(buf[:0], c) + buf = append(buf, v.Name...) + buf = append(buf, 0) + bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, (*C.char)(unsafe.Pointer(&buf[0])))) } args[i].Ordinal = bindIndices[i][0] } } + var rv C.int for i, arg := range args { for j := range bindIndices[i] { if bindIndices[i][j] == 0 { @@ -1951,20 +2015,16 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { case nil: rv = C.sqlite3_bind_null(s.s, n) case string: - if len(v) == 0 { - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0)) - } else { - b := []byte(v) - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) - } + p := stringData(v) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(v))) case int64: rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v)) case bool: + val := 0 if v { - rv = C.sqlite3_bind_int(s.s, n, 1) - } else { - rv = C.sqlite3_bind_int(s.s, n, 0) + val = 1 } + rv = C.sqlite3_bind_int(s.s, n, C.int(val)) case float64: rv = C.sqlite3_bind_double(s.s, n, C.double(v)) case []byte: @@ -1978,8 +2038,9 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error { rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln)) } case time.Time: - b := []byte(v.Format(SQLiteTimestampFormats[0])) - rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b))) + ts := v.Format(SQLiteTimestampFormats[0]) + p := stringData(ts) + rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(ts))) } if rv != C.SQLITE_OK { return s.c.lastError() diff --git a/unsafe_go120.go b/unsafe_go120.go new file mode 100644 index 0000000..95d673e --- /dev/null +++ b/unsafe_go120.go @@ -0,0 +1,17 @@ +//go:build !go1.21 +// +build !go1.21 + +package sqlite3 + +import "unsafe" + +// stringData is a safe version of unsafe.StringData that handles empty strings. +func stringData(s string) *byte { + if len(s) != 0 { + b := *(*[]byte)(unsafe.Pointer(&s)) + return &b[0] + } + // The return value of unsafe.StringData + // is unspecified if the string is empty. + return &placeHolder[0] +} diff --git a/unsafe_go121.go b/unsafe_go121.go new file mode 100644 index 0000000..b9c00a1 --- /dev/null +++ b/unsafe_go121.go @@ -0,0 +1,23 @@ +//go:build go1.21 +// +build go1.21 + +// The unsafe.StringData function was made available in Go 1.20 but it +// was not until Go 1.21 that Go was changed to interpret the Go version +// in go.mod (1.19 as of writing this) as the minimum version required +// instead of the exact version. +// +// See: https://github.com/golang/go/issues/59033 + +package sqlite3 + +import "unsafe" + +// stringData is a safe version of unsafe.StringData that handles empty strings. +func stringData(s string) *byte { + if len(s) != 0 { + return unsafe.StringData(s) + } + // The return value of unsafe.StringData + // is unspecified if the string is empty. + return &placeHolder[0] +}