Merge branch 'compress'

This commit is contained in:
Gary Burd 2016-12-27 17:08:04 -05:00
commit 5ddbd28fbd
3 changed files with 65 additions and 29 deletions

View File

@ -13,30 +13,32 @@ import (
) )
var ( var (
flateWriterPool = sync.Pool{} flateWriterPool = sync.Pool{New: func() interface{} {
fw, _ := flate.NewWriter(nil, 3)
return fw
}}
flateReaderPool = sync.Pool{New: func() interface{} {
return flate.NewReader(nil)
}}
) )
func decompressNoContextTakeover(r io.Reader) io.Reader { func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
const tail = const tail =
// Add four bytes as specified in RFC // Add four bytes as specified in RFC
"\x00\x00\xff\xff" + "\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader. // Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff" "\x01\x00\x00\xff\xff"
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
return &flateReadWrapper{fr}
} }
func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) { func compressNoContextTakeover(w io.WriteCloser) io.WriteCloser {
tw := &truncWriter{w: w} tw := &truncWriter{w: w}
i := flateWriterPool.Get() fw, _ := flateWriterPool.Get().(*flate.Writer)
var fw *flate.Writer fw.Reset(tw)
var err error return &flateWriteWrapper{fw: fw, tw: tw}
if i == nil {
fw, err = flate.NewWriter(tw, 3)
} else {
fw = i.(*flate.Writer)
fw.Reset(tw)
}
return &flateWrapper{fw: fw, tw: tw}, err
} }
// truncWriter is an io.Writer that writes all but the last four bytes of the // truncWriter is an io.Writer that writes all but the last four bytes of the
@ -75,19 +77,19 @@ func (w *truncWriter) Write(p []byte) (int, error) {
return n + nn, err return n + nn, err
} }
type flateWrapper struct { type flateWriteWrapper struct {
fw *flate.Writer fw *flate.Writer
tw *truncWriter tw *truncWriter
} }
func (w *flateWrapper) Write(p []byte) (int, error) { func (w *flateWriteWrapper) Write(p []byte) (int, error) {
if w.fw == nil { if w.fw == nil {
return 0, errWriteClosed return 0, errWriteClosed
} }
return w.fw.Write(p) return w.fw.Write(p)
} }
func (w *flateWrapper) Close() error { func (w *flateWriteWrapper) Close() error {
if w.fw == nil { if w.fw == nil {
return errWriteClosed return errWriteClosed
} }
@ -103,3 +105,31 @@ func (w *flateWrapper) Close() error {
} }
return err2 return err2
} }
type flateReadWrapper struct {
fr io.ReadCloser
}
func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}
n, err := r.fr.Read(p)
if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.
r.Close()
}
return n, err
}
func (r *flateReadWrapper) Close() error {
if r.fr == nil {
return io.ErrClosedPipe
}
err := r.fr.Close()
flateReaderPool.Put(r.fr)
r.fr = nil
return err
}

26
conn.go
View File

@ -235,9 +235,10 @@ type Conn struct {
writeErr error writeErr error
enableWriteCompression bool enableWriteCompression bool
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error) newCompressionWriter func(io.WriteCloser) io.WriteCloser
// Read fields // Read fields
reader io.ReadCloser // the current reader returned to the application
readErr error readErr error
br *bufio.Reader br *bufio.Reader
readRemaining int64 // bytes remaining in current frame. readRemaining int64 // bytes remaining in current frame.
@ -253,7 +254,7 @@ type Conn struct {
messageReader *messageReader // the current low-level reader messageReader *messageReader // the current low-level reader
readDecompress bool // whether last read frame had RSV1 set readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.Reader newDecompressionReader func(io.Reader) io.ReadCloser
} }
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@ -443,11 +444,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
} }
c.writer = mw c.writer = mw
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w, err := c.newCompressionWriter(c.writer) w := c.newCompressionWriter(c.writer)
if err != nil {
c.writer = nil
return nil, err
}
mw.compress = true mw.compress = true
c.writer = w c.writer = w
} }
@ -855,6 +852,11 @@ func (c *Conn) handleProtocolError(message string) error {
// permanent. Once this method returns a non-nil error, all subsequent calls to // permanent. Once this method returns a non-nil error, all subsequent calls to
// this method return the same error. // this method return the same error.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
// Close previous reader, only relevant for decompression.
if c.reader != nil {
c.reader.Close()
c.reader = nil
}
c.messageReader = nil c.messageReader = nil
c.readLength = 0 c.readLength = 0
@ -867,11 +869,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
} }
if frameType == TextMessage || frameType == BinaryMessage { if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c} c.messageReader = &messageReader{c}
var r io.Reader = c.messageReader c.reader = c.messageReader
if c.readDecompress { if c.readDecompress {
r = c.newDecompressionReader(r) c.reader = c.newDecompressionReader(c.reader)
} }
return frameType, r, nil return frameType, c.reader, nil
} }
} }
@ -933,6 +935,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
return 0, err return 0, err
} }
func (r *messageReader) Close() error {
return nil
}
// ReadMessage is a helper method for getting a reader using NextReader and // ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer. // reading from that reader to a buffer.
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {

4
doc.go
View File

@ -150,7 +150,7 @@
// application's responsibility to check the Origin header before calling // application's responsibility to check the Origin header before calling
// Upgrade. // Upgrade.
// //
// Compression [Experimental] // Compression
// //
// Per message compression extensions (RFC 7692) are experimentally supported // Per message compression extensions (RFC 7692) are experimentally supported
// by this package in a limited capacity. Setting the EnableCompression option // by this package in a limited capacity. Setting the EnableCompression option
@ -162,7 +162,7 @@
// Per message compression of messages written to a connection can be enabled // Per message compression of messages written to a connection can be enabled
// or disabled by calling the corresponding Conn method: // or disabled by calling the corresponding Conn method:
// //
// conn.EnableWriteCompression(true) // conn.EnableWriteCompression(true)
// //
// Currently this package does not support compression with "context takeover". // Currently this package does not support compression with "context takeover".
// This means that messages must be compressed and decompressed in isolation, // This means that messages must be compressed and decompressed in isolation,