From 52d396ed1ef10ebff46177d5d2fbc059b3df8227 Mon Sep 17 00:00:00 2001 From: tidwall Date: Thu, 21 Apr 2022 07:27:33 -0700 Subject: [PATCH] Avoid write after partial write --- redcon.go | 59 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 6 deletions(-) diff --git a/redcon.go b/redcon.go index 032785b..9c069bc 100644 --- a/redcon.go +++ b/redcon.go @@ -25,6 +25,8 @@ var ( errTooMuchData = errors.New("too much data") ) +const maxBufferCap = 262144 + type errProtocol struct { msg string } @@ -576,8 +578,9 @@ type TLSServer struct { // Writer allows for writing RESP messages. type Writer struct { - w io.Writer - b []byte + w io.Writer + b []byte + err error } // NewWriter creates a new RESP writer. @@ -589,6 +592,9 @@ func NewWriter(wr io.Writer) *Writer { // WriteNull writes a null to the client func (w *Writer) WriteNull() { + if w.err != nil { + return + } w.b = AppendNull(w.b) } @@ -600,67 +606,105 @@ func (w *Writer) WriteNull() { // c.WriteBulkString("item 1") // c.WriteBulkString("item 2") func (w *Writer) WriteArray(count int) { + if w.err != nil { + return + } w.b = AppendArray(w.b, count) } // WriteBulk writes bulk bytes to the client. func (w *Writer) WriteBulk(bulk []byte) { + if w.err != nil { + return + } w.b = AppendBulk(w.b, bulk) } // WriteBulkString writes a bulk string to the client. func (w *Writer) WriteBulkString(bulk string) { + if w.err != nil { + return + } w.b = AppendBulkString(w.b, bulk) } // Buffer returns the unflushed buffer. This is a copy so changes // to the resulting []byte will not affect the writer. func (w *Writer) Buffer() []byte { + if w.err != nil { + return nil + } return append([]byte(nil), w.b...) } // SetBuffer replaces the unflushed buffer with new bytes. func (w *Writer) SetBuffer(raw []byte) { + if w.err != nil { + return + } w.b = w.b[:0] w.b = append(w.b, raw...) } // Flush writes all unflushed Write* calls to the underlying writer. func (w *Writer) Flush() error { - if _, err := w.w.Write(w.b); err != nil { - return err + if w.err != nil { + return w.err } - w.b = w.b[:0] - return nil + _, w.err = w.w.Write(w.b) + if cap(w.b) > maxBufferCap || w.err != nil { + w.b = nil + } else { + w.b = w.b[:0] + } + return w.err } // WriteError writes an error to the client. func (w *Writer) WriteError(msg string) { + if w.err != nil { + return + } w.b = AppendError(w.b, msg) } // WriteString writes a string to the client. func (w *Writer) WriteString(msg string) { + if w.err != nil { + return + } w.b = AppendString(w.b, msg) } // WriteInt writes an integer to the client. func (w *Writer) WriteInt(num int) { + if w.err != nil { + return + } w.WriteInt64(int64(num)) } // WriteInt64 writes a 64-bit signed integer to the client. func (w *Writer) WriteInt64(num int64) { + if w.err != nil { + return + } w.b = AppendInt(w.b, num) } // WriteUint64 writes a 64-bit unsigned integer to the client. func (w *Writer) WriteUint64(num uint64) { + if w.err != nil { + return + } w.b = AppendUint(w.b, num) } // WriteRaw writes raw data to the client. func (w *Writer) WriteRaw(data []byte) { + if w.err != nil { + return + } w.b = append(w.b, data...) } @@ -677,6 +721,9 @@ func (w *Writer) WriteRaw(data []byte) { // SimpleInt -> integer // everything-else -> bulk-string representation using fmt.Sprint() func (w *Writer) WriteAny(v interface{}) { + if w.err != nil { + return + } w.b = AppendAny(w.b, v) }