From 2db2f66488d0a378f61644175cfb19a41a48fef6 Mon Sep 17 00:00:00 2001 From: Cyrus Katrak Date: Sat, 17 Dec 2016 15:33:06 -0800 Subject: [PATCH 1/2] pool flate readers --- compression.go | 47 +++++++++++++++++++++++++++++++++++++++++------ conn.go | 18 ++++++++++++++---- 2 files changed, 55 insertions(+), 10 deletions(-) diff --git a/compression.go b/compression.go index 72c166b..88dff5f 100644 --- a/compression.go +++ b/compression.go @@ -14,15 +14,22 @@ import ( var ( flateWriterPool = sync.Pool{} + flateReaderPool = sync.Pool{} ) -func decompressNoContextTakeover(r io.Reader) io.Reader { +func decompressNoContextTakeover(r io.Reader) io.ReadCloser { const tail = // Add four bytes as specified in RFC "\x00\x00\xff\xff" + // Add final block to squelch unexpected EOF error from flate reader. "\x01\x00\x00\xff\xff" - return flate.NewReader(io.MultiReader(r, strings.NewReader(tail))) + + i := flateReaderPool.Get() + if i == nil { + i = flate.NewReader(nil) + } + i.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + return &flateReadWrapper{i.(io.ReadCloser)} } func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { @@ -36,7 +43,7 @@ func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { fw = i.(*flate.Writer) fw.Reset(tw) } - return &flateWrapper{fw: fw, tw: tw}, err + return &flateWriteWrapper{fw: fw, tw: tw}, err } // truncWriter is an io.Writer that writes all but the last four bytes of the @@ -75,19 +82,19 @@ func (w *truncWriter) Write(p []byte) (int, error) { return n + nn, err } -type flateWrapper struct { +type flateWriteWrapper struct { fw *flate.Writer tw *truncWriter } -func (w *flateWrapper) Write(p []byte) (int, error) { +func (w *flateWriteWrapper) Write(p []byte) (int, error) { if w.fw == nil { return 0, errWriteClosed } return w.fw.Write(p) } -func (w *flateWrapper) Close() error { +func (w *flateWriteWrapper) Close() error { if w.fw == nil { return errWriteClosed } @@ -103,3 +110,31 @@ func (w *flateWrapper) Close() error { } return err2 } + +type flateReadWrapper struct { + fr io.ReadCloser +} + +func (r *flateReadWrapper) Read(p []byte) (int, error) { + if r.fr == nil { + return 0, io.ErrClosedPipe + } + n, err := r.fr.Read(p) + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + return n, err +} + +func (r *flateReadWrapper) Close() error { + if r.fr == nil { + return io.ErrClosedPipe + } + err := r.fr.Close() + flateReaderPool.Put(r.fr) + r.fr = nil + return err +} diff --git a/conn.go b/conn.go index ce7f0a6..70562f9 100644 --- a/conn.go +++ b/conn.go @@ -238,6 +238,7 @@ type Conn struct { newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error) // Read fields + reader io.ReadCloser // the current reader returned to the application readErr error br *bufio.Reader readRemaining int64 // bytes remaining in current frame. @@ -253,7 +254,7 @@ type Conn struct { messageReader *messageReader // the current low-level reader readDecompress bool // whether last read frame had RSV1 set - newDecompressionReader func(io.Reader) io.Reader + newDecompressionReader func(io.Reader) io.ReadCloser } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -855,6 +856,11 @@ func (c *Conn) handleProtocolError(message string) error { // permanent. Once this method returns a non-nil error, all subsequent calls to // this method return the same error. func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + // Close previous reader, only relevant for decompression. + if c.reader != nil { + c.reader.Close() + c.reader = nil + } c.messageReader = nil c.readLength = 0 @@ -867,11 +873,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { } if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} - var r io.Reader = c.messageReader + c.reader = c.messageReader if c.readDecompress { - r = c.newDecompressionReader(r) + c.reader = c.newDecompressionReader(c.reader) } - return frameType, r, nil + return frameType, c.reader, nil } } @@ -933,6 +939,10 @@ func (r *messageReader) Read(b []byte) (int, error) { return 0, err } +func (r *messageReader) Close() error { + return nil +} + // ReadMessage is a helper method for getting a reader using NextReader and // reading from that reader to a buffer. func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { From 6c51b25bc8e504cabd0c58004acddfc25b1e6f6c Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Tue, 27 Dec 2016 17:05:16 -0500 Subject: [PATCH 2/2] Compression improvements - Remove unnecessary error return from compressNoContextTakeover. - Simplify use of sync.Pool. - Fix formatting in compression documentation. --- compression.go | 33 ++++++++++++++------------------- conn.go | 8 ++------ doc.go | 4 ++-- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/compression.go b/compression.go index 88dff5f..d642631 100644 --- a/compression.go +++ b/compression.go @@ -13,8 +13,13 @@ import ( ) var ( - flateWriterPool = sync.Pool{} - flateReaderPool = sync.Pool{} + flateWriterPool = sync.Pool{New: func() interface{} { + fw, _ := flate.NewWriter(nil, 3) + return fw + }} + flateReaderPool = sync.Pool{New: func() interface{} { + return flate.NewReader(nil) + }} ) func decompressNoContextTakeover(r io.Reader) io.ReadCloser { @@ -24,26 +29,16 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser { // Add final block to squelch unexpected EOF error from flate reader. "\x01\x00\x00\xff\xff" - i := flateReaderPool.Get() - if i == nil { - i = flate.NewReader(nil) - } - i.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) - return &flateReadWrapper{i.(io.ReadCloser)} + fr, _ := flateReaderPool.Get().(io.ReadCloser) + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + return &flateReadWrapper{fr} } -func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { +func compressNoContextTakeover(w io.WriteCloser) io.WriteCloser { tw := &truncWriter{w: w} - i := flateWriterPool.Get() - var fw *flate.Writer - var err error - if i == nil { - fw, err = flate.NewWriter(tw, 3) - } else { - fw = i.(*flate.Writer) - fw.Reset(tw) - } - return &flateWriteWrapper{fw: fw, tw: tw}, err + fw, _ := flateWriterPool.Get().(*flate.Writer) + fw.Reset(tw) + return &flateWriteWrapper{fw: fw, tw: tw} } // truncWriter is an io.Writer that writes all but the last four bytes of the diff --git a/conn.go b/conn.go index 70562f9..9c3645b 100644 --- a/conn.go +++ b/conn.go @@ -235,7 +235,7 @@ type Conn struct { writeErr error enableWriteCompression bool - newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error) + newCompressionWriter func(io.WriteCloser) io.WriteCloser // Read fields reader io.ReadCloser // the current reader returned to the application @@ -444,11 +444,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { } c.writer = mw if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { - w, err := c.newCompressionWriter(c.writer) - if err != nil { - c.writer = nil - return nil, err - } + w := c.newCompressionWriter(c.writer) mw.compress = true c.writer = w } diff --git a/doc.go b/doc.go index 610acf7..e046b8c 100644 --- a/doc.go +++ b/doc.go @@ -150,7 +150,7 @@ // application's responsibility to check the Origin header before calling // Upgrade. // -// Compression [Experimental] +// Compression // // Per message compression extensions (RFC 7692) are experimentally supported // by this package in a limited capacity. Setting the EnableCompression option @@ -162,7 +162,7 @@ // Per message compression of messages written to a connection can be enabled // or disabled by calling the corresponding Conn method: // -// conn.EnableWriteCompression(true) +// conn.EnableWriteCompression(true) // // Currently this package does not support compression with "context takeover". // This means that messages must be compressed and decompressed in isolation,