From 2db2f66488d0a378f61644175cfb19a41a48fef6 Mon Sep 17 00:00:00 2001 From: Cyrus Katrak Date: Sat, 17 Dec 2016 15:33:06 -0800 Subject: [PATCH] pool flate readers --- compression.go | 47 +++++++++++++++++++++++++++++++++++++++++------ conn.go | 18 ++++++++++++++---- 2 files changed, 55 insertions(+), 10 deletions(-) diff --git a/compression.go b/compression.go index 72c166b..88dff5f 100644 --- a/compression.go +++ b/compression.go @@ -14,15 +14,22 @@ import ( var ( flateWriterPool = sync.Pool{} + flateReaderPool = sync.Pool{} ) -func decompressNoContextTakeover(r io.Reader) io.Reader { +func decompressNoContextTakeover(r io.Reader) io.ReadCloser { const tail = // Add four bytes as specified in RFC "\x00\x00\xff\xff" + // Add final block to squelch unexpected EOF error from flate reader. "\x01\x00\x00\xff\xff" - return flate.NewReader(io.MultiReader(r, strings.NewReader(tail))) + + i := flateReaderPool.Get() + if i == nil { + i = flate.NewReader(nil) + } + i.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + return &flateReadWrapper{i.(io.ReadCloser)} } func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { @@ -36,7 +43,7 @@ func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { fw = i.(*flate.Writer) fw.Reset(tw) } - return &flateWrapper{fw: fw, tw: tw}, err + return &flateWriteWrapper{fw: fw, tw: tw}, err } // truncWriter is an io.Writer that writes all but the last four bytes of the @@ -75,19 +82,19 @@ func (w *truncWriter) Write(p []byte) (int, error) { return n + nn, err } -type flateWrapper struct { +type flateWriteWrapper struct { fw *flate.Writer tw *truncWriter } -func (w *flateWrapper) Write(p []byte) (int, error) { +func (w *flateWriteWrapper) Write(p []byte) (int, error) { if w.fw == nil { return 0, errWriteClosed } return w.fw.Write(p) } -func (w *flateWrapper) Close() error { +func (w *flateWriteWrapper) Close() error { if w.fw == nil { return errWriteClosed } @@ -103,3 +110,31 @@ func (w *flateWrapper) Close() error { } return err2 } + +type flateReadWrapper struct { + fr io.ReadCloser +} + +func (r *flateReadWrapper) Read(p []byte) (int, error) { + if r.fr == nil { + return 0, io.ErrClosedPipe + } + n, err := r.fr.Read(p) + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + return n, err +} + +func (r *flateReadWrapper) Close() error { + if r.fr == nil { + return io.ErrClosedPipe + } + err := r.fr.Close() + flateReaderPool.Put(r.fr) + r.fr = nil + return err +} diff --git a/conn.go b/conn.go index ce7f0a6..70562f9 100644 --- a/conn.go +++ b/conn.go @@ -238,6 +238,7 @@ type Conn struct { newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error) // Read fields + reader io.ReadCloser // the current reader returned to the application readErr error br *bufio.Reader readRemaining int64 // bytes remaining in current frame. @@ -253,7 +254,7 @@ type Conn struct { messageReader *messageReader // the current low-level reader readDecompress bool // whether last read frame had RSV1 set - newDecompressionReader func(io.Reader) io.Reader + newDecompressionReader func(io.Reader) io.ReadCloser } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -855,6 +856,11 @@ func (c *Conn) handleProtocolError(message string) error { // permanent. Once this method returns a non-nil error, all subsequent calls to // this method return the same error. func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + // Close previous reader, only relevant for decompression. + if c.reader != nil { + c.reader.Close() + c.reader = nil + } c.messageReader = nil c.readLength = 0 @@ -867,11 +873,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { } if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} - var r io.Reader = c.messageReader + c.reader = c.messageReader if c.readDecompress { - r = c.newDecompressionReader(r) + c.reader = c.newDecompressionReader(c.reader) } - return frameType, r, nil + return frameType, c.reader, nil } } @@ -933,6 +939,10 @@ func (r *messageReader) Read(b []byte) (int, error) { return 0, err } +func (r *messageReader) Close() error { + return nil +} + // ReadMessage is a helper method for getting a reader using NextReader and // reading from that reader to a buffer. func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {