mirror of https://github.com/gorilla/websocket.git
Merge branch 'compress'
This commit is contained in:
commit
5ddbd28fbd
|
@ -13,30 +13,32 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
flateWriterPool = sync.Pool{}
|
flateWriterPool = sync.Pool{New: func() interface{} {
|
||||||
|
fw, _ := flate.NewWriter(nil, 3)
|
||||||
|
return fw
|
||||||
|
}}
|
||||||
|
flateReaderPool = sync.Pool{New: func() interface{} {
|
||||||
|
return flate.NewReader(nil)
|
||||||
|
}}
|
||||||
)
|
)
|
||||||
|
|
||||||
func decompressNoContextTakeover(r io.Reader) io.Reader {
|
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
|
||||||
const tail =
|
const tail =
|
||||||
// Add four bytes as specified in RFC
|
// Add four bytes as specified in RFC
|
||||||
"\x00\x00\xff\xff" +
|
"\x00\x00\xff\xff" +
|
||||||
// Add final block to squelch unexpected EOF error from flate reader.
|
// Add final block to squelch unexpected EOF error from flate reader.
|
||||||
"\x01\x00\x00\xff\xff"
|
"\x01\x00\x00\xff\xff"
|
||||||
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
|
|
||||||
|
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
||||||
|
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
|
||||||
|
return &flateReadWrapper{fr}
|
||||||
}
|
}
|
||||||
|
|
||||||
func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
|
func compressNoContextTakeover(w io.WriteCloser) io.WriteCloser {
|
||||||
tw := &truncWriter{w: w}
|
tw := &truncWriter{w: w}
|
||||||
i := flateWriterPool.Get()
|
fw, _ := flateWriterPool.Get().(*flate.Writer)
|
||||||
var fw *flate.Writer
|
|
||||||
var err error
|
|
||||||
if i == nil {
|
|
||||||
fw, err = flate.NewWriter(tw, 3)
|
|
||||||
} else {
|
|
||||||
fw = i.(*flate.Writer)
|
|
||||||
fw.Reset(tw)
|
fw.Reset(tw)
|
||||||
}
|
return &flateWriteWrapper{fw: fw, tw: tw}
|
||||||
return &flateWrapper{fw: fw, tw: tw}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncWriter is an io.Writer that writes all but the last four bytes of the
|
// truncWriter is an io.Writer that writes all but the last four bytes of the
|
||||||
|
@ -75,19 +77,19 @@ func (w *truncWriter) Write(p []byte) (int, error) {
|
||||||
return n + nn, err
|
return n + nn, err
|
||||||
}
|
}
|
||||||
|
|
||||||
type flateWrapper struct {
|
type flateWriteWrapper struct {
|
||||||
fw *flate.Writer
|
fw *flate.Writer
|
||||||
tw *truncWriter
|
tw *truncWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *flateWrapper) Write(p []byte) (int, error) {
|
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
|
||||||
if w.fw == nil {
|
if w.fw == nil {
|
||||||
return 0, errWriteClosed
|
return 0, errWriteClosed
|
||||||
}
|
}
|
||||||
return w.fw.Write(p)
|
return w.fw.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *flateWrapper) Close() error {
|
func (w *flateWriteWrapper) Close() error {
|
||||||
if w.fw == nil {
|
if w.fw == nil {
|
||||||
return errWriteClosed
|
return errWriteClosed
|
||||||
}
|
}
|
||||||
|
@ -103,3 +105,31 @@ func (w *flateWrapper) Close() error {
|
||||||
}
|
}
|
||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type flateReadWrapper struct {
|
||||||
|
fr io.ReadCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *flateReadWrapper) Read(p []byte) (int, error) {
|
||||||
|
if r.fr == nil {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
n, err := r.fr.Read(p)
|
||||||
|
if err == io.EOF {
|
||||||
|
// Preemptively place the reader back in the pool. This helps with
|
||||||
|
// scenarios where the application does not call NextReader() soon after
|
||||||
|
// this final read.
|
||||||
|
r.Close()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *flateReadWrapper) Close() error {
|
||||||
|
if r.fr == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
err := r.fr.Close()
|
||||||
|
flateReaderPool.Put(r.fr)
|
||||||
|
r.fr = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
26
conn.go
26
conn.go
|
@ -235,9 +235,10 @@ type Conn struct {
|
||||||
writeErr error
|
writeErr error
|
||||||
|
|
||||||
enableWriteCompression bool
|
enableWriteCompression bool
|
||||||
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
|
newCompressionWriter func(io.WriteCloser) io.WriteCloser
|
||||||
|
|
||||||
// Read fields
|
// Read fields
|
||||||
|
reader io.ReadCloser // the current reader returned to the application
|
||||||
readErr error
|
readErr error
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
readRemaining int64 // bytes remaining in current frame.
|
readRemaining int64 // bytes remaining in current frame.
|
||||||
|
@ -253,7 +254,7 @@ type Conn struct {
|
||||||
messageReader *messageReader // the current low-level reader
|
messageReader *messageReader // the current low-level reader
|
||||||
|
|
||||||
readDecompress bool // whether last read frame had RSV1 set
|
readDecompress bool // whether last read frame had RSV1 set
|
||||||
newDecompressionReader func(io.Reader) io.Reader
|
newDecompressionReader func(io.Reader) io.ReadCloser
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
||||||
|
@ -443,11 +444,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||||
}
|
}
|
||||||
c.writer = mw
|
c.writer = mw
|
||||||
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
||||||
w, err := c.newCompressionWriter(c.writer)
|
w := c.newCompressionWriter(c.writer)
|
||||||
if err != nil {
|
|
||||||
c.writer = nil
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
mw.compress = true
|
mw.compress = true
|
||||||
c.writer = w
|
c.writer = w
|
||||||
}
|
}
|
||||||
|
@ -855,6 +852,11 @@ func (c *Conn) handleProtocolError(message string) error {
|
||||||
// permanent. Once this method returns a non-nil error, all subsequent calls to
|
// permanent. Once this method returns a non-nil error, all subsequent calls to
|
||||||
// this method return the same error.
|
// this method return the same error.
|
||||||
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
|
// Close previous reader, only relevant for decompression.
|
||||||
|
if c.reader != nil {
|
||||||
|
c.reader.Close()
|
||||||
|
c.reader = nil
|
||||||
|
}
|
||||||
|
|
||||||
c.messageReader = nil
|
c.messageReader = nil
|
||||||
c.readLength = 0
|
c.readLength = 0
|
||||||
|
@ -867,11 +869,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
}
|
}
|
||||||
if frameType == TextMessage || frameType == BinaryMessage {
|
if frameType == TextMessage || frameType == BinaryMessage {
|
||||||
c.messageReader = &messageReader{c}
|
c.messageReader = &messageReader{c}
|
||||||
var r io.Reader = c.messageReader
|
c.reader = c.messageReader
|
||||||
if c.readDecompress {
|
if c.readDecompress {
|
||||||
r = c.newDecompressionReader(r)
|
c.reader = c.newDecompressionReader(c.reader)
|
||||||
}
|
}
|
||||||
return frameType, r, nil
|
return frameType, c.reader, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -933,6 +935,10 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *messageReader) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ReadMessage is a helper method for getting a reader using NextReader and
|
// ReadMessage is a helper method for getting a reader using NextReader and
|
||||||
// reading from that reader to a buffer.
|
// reading from that reader to a buffer.
|
||||||
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
|
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
|
||||||
|
|
2
doc.go
2
doc.go
|
@ -150,7 +150,7 @@
|
||||||
// application's responsibility to check the Origin header before calling
|
// application's responsibility to check the Origin header before calling
|
||||||
// Upgrade.
|
// Upgrade.
|
||||||
//
|
//
|
||||||
// Compression [Experimental]
|
// Compression
|
||||||
//
|
//
|
||||||
// Per message compression extensions (RFC 7692) are experimentally supported
|
// Per message compression extensions (RFC 7692) are experimentally supported
|
||||||
// by this package in a limited capacity. Setting the EnableCompression option
|
// by this package in a limited capacity. Setting the EnableCompression option
|
||||||
|
|
Loading…
Reference in New Issue