Avoid write after partial write

This commit is contained in:
tidwall 2022-04-21 07:27:33 -07:00
parent 86d28423be
commit 52d396ed1e
1 changed files with 53 additions and 6 deletions

View File

@ -25,6 +25,8 @@ var (
errTooMuchData = errors.New("too much data")
)
const maxBufferCap = 262144
type errProtocol struct {
msg string
}
@ -578,6 +580,7 @@ type TLSServer struct {
type Writer struct {
w io.Writer
b []byte
err error
}
// NewWriter creates a new RESP writer.
@ -589,6 +592,9 @@ func NewWriter(wr io.Writer) *Writer {
// WriteNull writes a null to the client
func (w *Writer) WriteNull() {
if w.err != nil {
return
}
w.b = AppendNull(w.b)
}
@ -600,67 +606,105 @@ func (w *Writer) WriteNull() {
// c.WriteBulkString("item 1")
// c.WriteBulkString("item 2")
func (w *Writer) WriteArray(count int) {
if w.err != nil {
return
}
w.b = AppendArray(w.b, count)
}
// WriteBulk writes bulk bytes to the client.
func (w *Writer) WriteBulk(bulk []byte) {
if w.err != nil {
return
}
w.b = AppendBulk(w.b, bulk)
}
// WriteBulkString writes a bulk string to the client.
func (w *Writer) WriteBulkString(bulk string) {
if w.err != nil {
return
}
w.b = AppendBulkString(w.b, bulk)
}
// Buffer returns the unflushed buffer. This is a copy so changes
// to the resulting []byte will not affect the writer.
func (w *Writer) Buffer() []byte {
if w.err != nil {
return nil
}
return append([]byte(nil), w.b...)
}
// SetBuffer replaces the unflushed buffer with new bytes.
func (w *Writer) SetBuffer(raw []byte) {
if w.err != nil {
return
}
w.b = w.b[:0]
w.b = append(w.b, raw...)
}
// Flush writes all unflushed Write* calls to the underlying writer.
func (w *Writer) Flush() error {
if _, err := w.w.Write(w.b); err != nil {
return err
if w.err != nil {
return w.err
}
_, w.err = w.w.Write(w.b)
if cap(w.b) > maxBufferCap || w.err != nil {
w.b = nil
} else {
w.b = w.b[:0]
return nil
}
return w.err
}
// WriteError writes an error to the client.
func (w *Writer) WriteError(msg string) {
if w.err != nil {
return
}
w.b = AppendError(w.b, msg)
}
// WriteString writes a string to the client.
func (w *Writer) WriteString(msg string) {
if w.err != nil {
return
}
w.b = AppendString(w.b, msg)
}
// WriteInt writes an integer to the client.
func (w *Writer) WriteInt(num int) {
if w.err != nil {
return
}
w.WriteInt64(int64(num))
}
// WriteInt64 writes a 64-bit signed integer to the client.
func (w *Writer) WriteInt64(num int64) {
if w.err != nil {
return
}
w.b = AppendInt(w.b, num)
}
// WriteUint64 writes a 64-bit unsigned integer to the client.
func (w *Writer) WriteUint64(num uint64) {
if w.err != nil {
return
}
w.b = AppendUint(w.b, num)
}
// WriteRaw writes raw data to the client.
func (w *Writer) WriteRaw(data []byte) {
if w.err != nil {
return
}
w.b = append(w.b, data...)
}
@ -677,6 +721,9 @@ func (w *Writer) WriteRaw(data []byte) {
// SimpleInt -> integer
// everything-else -> bulk-string representation using fmt.Sprint()
func (w *Writer) WriteAny(v interface{}) {
if w.err != nil {
return
}
w.b = AppendAny(w.b, v)
}