Avoid write after partial write

This commit is contained in:
tidwall 2022-04-21 07:27:33 -07:00
parent 86d28423be
commit 52d396ed1e
1 changed files with 53 additions and 6 deletions

View File

@ -25,6 +25,8 @@ var (
errTooMuchData = errors.New("too much data") errTooMuchData = errors.New("too much data")
) )
const maxBufferCap = 262144
type errProtocol struct { type errProtocol struct {
msg string msg string
} }
@ -576,8 +578,9 @@ type TLSServer struct {
// Writer allows for writing RESP messages. // Writer allows for writing RESP messages.
type Writer struct { type Writer struct {
w io.Writer w io.Writer
b []byte b []byte
err error
} }
// NewWriter creates a new RESP writer. // NewWriter creates a new RESP writer.
@ -589,6 +592,9 @@ func NewWriter(wr io.Writer) *Writer {
// WriteNull writes a null to the client // WriteNull writes a null to the client
func (w *Writer) WriteNull() { func (w *Writer) WriteNull() {
if w.err != nil {
return
}
w.b = AppendNull(w.b) w.b = AppendNull(w.b)
} }
@ -600,67 +606,105 @@ func (w *Writer) WriteNull() {
// c.WriteBulkString("item 1") // c.WriteBulkString("item 1")
// c.WriteBulkString("item 2") // c.WriteBulkString("item 2")
func (w *Writer) WriteArray(count int) { func (w *Writer) WriteArray(count int) {
if w.err != nil {
return
}
w.b = AppendArray(w.b, count) w.b = AppendArray(w.b, count)
} }
// WriteBulk writes bulk bytes to the client. // WriteBulk writes bulk bytes to the client.
func (w *Writer) WriteBulk(bulk []byte) { func (w *Writer) WriteBulk(bulk []byte) {
if w.err != nil {
return
}
w.b = AppendBulk(w.b, bulk) w.b = AppendBulk(w.b, bulk)
} }
// WriteBulkString writes a bulk string to the client. // WriteBulkString writes a bulk string to the client.
func (w *Writer) WriteBulkString(bulk string) { func (w *Writer) WriteBulkString(bulk string) {
if w.err != nil {
return
}
w.b = AppendBulkString(w.b, bulk) w.b = AppendBulkString(w.b, bulk)
} }
// Buffer returns the unflushed buffer. This is a copy so changes // Buffer returns the unflushed buffer. This is a copy so changes
// to the resulting []byte will not affect the writer. // to the resulting []byte will not affect the writer.
func (w *Writer) Buffer() []byte { func (w *Writer) Buffer() []byte {
if w.err != nil {
return nil
}
return append([]byte(nil), w.b...) return append([]byte(nil), w.b...)
} }
// SetBuffer replaces the unflushed buffer with new bytes. // SetBuffer replaces the unflushed buffer with new bytes.
func (w *Writer) SetBuffer(raw []byte) { func (w *Writer) SetBuffer(raw []byte) {
if w.err != nil {
return
}
w.b = w.b[:0] w.b = w.b[:0]
w.b = append(w.b, raw...) w.b = append(w.b, raw...)
} }
// Flush writes all unflushed Write* calls to the underlying writer. // Flush writes all unflushed Write* calls to the underlying writer.
func (w *Writer) Flush() error { func (w *Writer) Flush() error {
if _, err := w.w.Write(w.b); err != nil { if w.err != nil {
return err return w.err
} }
w.b = w.b[:0] _, w.err = w.w.Write(w.b)
return nil if cap(w.b) > maxBufferCap || w.err != nil {
w.b = nil
} else {
w.b = w.b[:0]
}
return w.err
} }
// WriteError writes an error to the client. // WriteError writes an error to the client.
func (w *Writer) WriteError(msg string) { func (w *Writer) WriteError(msg string) {
if w.err != nil {
return
}
w.b = AppendError(w.b, msg) w.b = AppendError(w.b, msg)
} }
// WriteString writes a string to the client. // WriteString writes a string to the client.
func (w *Writer) WriteString(msg string) { func (w *Writer) WriteString(msg string) {
if w.err != nil {
return
}
w.b = AppendString(w.b, msg) w.b = AppendString(w.b, msg)
} }
// WriteInt writes an integer to the client. // WriteInt writes an integer to the client.
func (w *Writer) WriteInt(num int) { func (w *Writer) WriteInt(num int) {
if w.err != nil {
return
}
w.WriteInt64(int64(num)) w.WriteInt64(int64(num))
} }
// WriteInt64 writes a 64-bit signed integer to the client. // WriteInt64 writes a 64-bit signed integer to the client.
func (w *Writer) WriteInt64(num int64) { func (w *Writer) WriteInt64(num int64) {
if w.err != nil {
return
}
w.b = AppendInt(w.b, num) w.b = AppendInt(w.b, num)
} }
// WriteUint64 writes a 64-bit unsigned integer to the client. // WriteUint64 writes a 64-bit unsigned integer to the client.
func (w *Writer) WriteUint64(num uint64) { func (w *Writer) WriteUint64(num uint64) {
if w.err != nil {
return
}
w.b = AppendUint(w.b, num) w.b = AppendUint(w.b, num)
} }
// WriteRaw writes raw data to the client. // WriteRaw writes raw data to the client.
func (w *Writer) WriteRaw(data []byte) { func (w *Writer) WriteRaw(data []byte) {
if w.err != nil {
return
}
w.b = append(w.b, data...) w.b = append(w.b, data...)
} }
@ -677,6 +721,9 @@ func (w *Writer) WriteRaw(data []byte) {
// SimpleInt -> integer // SimpleInt -> integer
// everything-else -> bulk-string representation using fmt.Sprint() // everything-else -> bulk-string representation using fmt.Sprint()
func (w *Writer) WriteAny(v interface{}) { func (w *Writer) WriteAny(v interface{}) {
if w.err != nil {
return
}
w.b = AppendAny(w.b, v) w.b = AppendAny(w.b, v)
} }