Add hooks to support RFC 7692 (per-message compression extension)

Add newCompressionWriter and newDecompressionReader fields to Conn. When
not nil, these functions are used to create a compression/decompression
wrapper around an underlying message writer/reader.

Add code to set and check for RSV1 frame header bit.

Add functions compressNoContextTakeover and decompressNoContextTakeover
for creating no context takeover wrappers around an underlying message
writer/reader.

Work remaining:

- Add fields to Dialer and Upgrader for specifying compression options.
- Add compression negotiation to Dialer and Upgrader.
- Add function to enable/disable write compression:

    // EnableWriteCompression enables and disables write compression of
    // subsequent text and binary messages. This function is a noop if
    // compression was not negotiated with the peer.
    func (c *Conn) EnableWriteCompression(enable bool) {
            c.enableWriteCompression = enable
    }
This commit is contained in:
Gary Burd 2016-06-29 17:03:55 -07:00
parent b5389d0dc2
commit a87eae1d6f
3 changed files with 191 additions and 32 deletions

85
compression.go Normal file
View File

@ -0,0 +1,85 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"compress/flate"
"errors"
"io"
"strings"
)
func decompressNoContextTakeover(r io.Reader) io.Reader {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
}
func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
tw := &truncWriter{w: w}
fw, err := flate.NewWriter(tw, 3)
return &flateWrapper{fw: fw, tw: tw}, err
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
w io.WriteCloser
n int
p [4]byte
}
func (w *truncWriter) Write(p []byte) (int, error) {
n := 0
// fill buffer first for simplicity.
if w.n < len(w.p) {
n = copy(w.p[w.n:], p)
p = p[n:]
w.n += n
if len(p) == 0 {
return n, nil
}
}
m := len(p)
if m > len(w.p) {
m = len(w.p)
}
if nn, err := w.w.Write(w.p[:m]); err != nil {
return n + nn, err
}
copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
return n + nn, err
}
type flateWrapper struct {
fw *flate.Writer
tw *truncWriter
}
func (w *flateWrapper) Write(p []byte) (int, error) {
return w.fw.Write(p)
}
func (w *flateWrapper) Close() error {
err1 := w.fw.Flush()
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := w.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}

31
compression_test.go Normal file
View File

@ -0,0 +1,31 @@
package websocket
import (
"bytes"
"io"
"testing"
)
type nopCloser struct{ io.Writer }
func (nopCloser) Close() error { return nil }
func TestTruncWriter(t *testing.T) {
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
for n := 1; n <= 10; n++ {
var b bytes.Buffer
w := &truncWriter{w: nopCloser{&b}}
p := []byte(data)
for len(p) > 0 {
m := len(p)
if m > n {
m = n
}
w.Write(p[:m])
p = p[m:]
}
if b.String() != data[:len(data)-len(w.p)] {
t.Errorf("%d: %q", n, b.String())
}
}
}

107
conn.go
View File

@ -18,11 +18,19 @@ import (
) )
const ( const (
// Frame header byte 0 bits from Section 5.2 of RFC 6455
finalBit = 1 << 7
rsv1Bit = 1 << 6
rsv2Bit = 1 << 5
rsv3Bit = 1 << 4
// Frame header byte 1 bits from Section 5.2 of RFC 6455
maskBit = 1 << 7
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
maxControlFramePayloadSize = 125 maxControlFramePayloadSize = 125
finalBit = 1 << 7
maskBit = 1 << 7 writeWait = time.Second
writeWait = time.Second
defaultReadBufferSize = 4096 defaultReadBufferSize = 4096
defaultWriteBufferSize = 4096 defaultWriteBufferSize = 4096
@ -230,17 +238,20 @@ type Conn struct {
subprotocol string subprotocol string
// Write fields // Write fields
mu chan bool // used as mutex to protect write to conn and closeSent mu chan bool // used as mutex to protect write to conn and closeSent
closeSent bool // true if close message was sent closeSent bool // whether close message was sent
// Message writer fields.
writeErr error writeErr error
writeBuf []byte // frame is constructed in this buffer. writeBuf []byte // frame is constructed in this buffer.
writePos int // end of data in writeBuf. writePos int // end of data in writeBuf.
writeFrameType int // type of the current frame. writeFrameType int // type of the current frame.
writeDeadline time.Time writeDeadline time.Time
messageWriter *messageWriter // the current low-level message writer
writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection isWriting bool // for best-effort concurrent write detection
messageWriter *messageWriter // the current writer
enableWriteCompression bool
writeCompress bool // whether next call to flushFrame should set RSV1
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
// Read fields // Read fields
readErr error readErr error
@ -254,7 +265,10 @@ type Conn struct {
handlePong func(string) error handlePong func(string) error
handlePing func(string) error handlePing func(string) error
readErrCount int readErrCount int
messageReader *messageReader // the current reader messageReader *messageReader // the current low-level reader
readDecompress bool // whether last read frame had RSV1 set
newDecompressionReader func(io.Reader) io.Reader
} }
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@ -272,14 +286,15 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
} }
c := &Conn{ c := &Conn{
isServer: isServer, isServer: isServer,
br: bufio.NewReaderSize(conn, readBufferSize), br: bufio.NewReaderSize(conn, readBufferSize),
conn: conn, conn: conn,
mu: mu, mu: mu,
readFinal: true, readFinal: true,
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
writeFrameType: noFrame, writeFrameType: noFrame,
writePos: maxFrameHeaderSize, writePos: maxFrameHeaderSize,
enableWriteCompression: true,
} }
c.SetPingHandler(nil) c.SetPingHandler(nil)
c.SetPongHandler(nil) c.SetPongHandler(nil)
@ -403,8 +418,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
return nil, c.writeErr return nil, c.writeErr
} }
if c.writeFrameType != noFrame { // Close previous writer if not already closed by the application. It's
if err := c.flushFrame(true, nil); err != nil { // probably better to return an error in this situation, but we cannot
// change this without breaking existing applications.
if c.writer != nil {
err := c.writer.Close()
if err != nil {
return nil, err return nil, err
} }
} }
@ -414,11 +433,24 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
} }
c.writeFrameType = messageType c.writeFrameType = messageType
w := &messageWriter{c} c.messageWriter = &messageWriter{c}
c.messageWriter = w
var w io.WriteCloser = c.messageWriter
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
c.writeCompress = true
var err error
w, err = c.newCompressionWriter(w)
if err != nil {
c.writer.Close()
return nil, err
}
}
return w, nil return w, nil
} }
// flushFrame writes buffered data and extra as a frame to the network. The
// final argument indicates that this is the last frame in the message.
func (c *Conn) flushFrame(final bool, extra []byte) error { func (c *Conn) flushFrame(final bool, extra []byte) error {
length := c.writePos - maxFrameHeaderSize + len(extra) length := c.writePos - maxFrameHeaderSize + len(extra)
@ -426,6 +458,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
if isControl(c.writeFrameType) && if isControl(c.writeFrameType) &&
(!final || length > maxControlFramePayloadSize) { (!final || length > maxControlFramePayloadSize) {
c.messageWriter = nil c.messageWriter = nil
c.writer = nil
c.writeFrameType = noFrame c.writeFrameType = noFrame
c.writePos = maxFrameHeaderSize c.writePos = maxFrameHeaderSize
return errInvalidControlFrame return errInvalidControlFrame
@ -435,6 +468,11 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
if final { if final {
b0 |= finalBit b0 |= finalBit
} }
if c.writeCompress {
b0 |= rsv1Bit
}
c.writeCompress = false
b1 := byte(0) b1 := byte(0)
if !c.isServer { if !c.isServer {
b1 |= maskBit b1 |= maskBit
@ -494,6 +532,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
c.writeFrameType = continuationFrame c.writeFrameType = continuationFrame
if final { if final {
c.messageWriter = nil c.messageWriter = nil
c.writer = nil
c.writeFrameType = noFrame c.writeFrameType = noFrame
} }
return c.writeErr return c.writeErr
@ -526,14 +565,14 @@ func (w *messageWriter) ncopy(max int) (int, error) {
return n, nil return n, nil
} }
func (w *messageWriter) write(final bool, p []byte) (int, error) { func (w *messageWriter) Write(p []byte) (int, error) {
if err := w.err(); err != nil { if err := w.err(); err != nil {
return 0, err return 0, err
} }
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
// Don't buffer large messages. // Don't buffer large messages.
err := w.c.flushFrame(final, p) err := w.c.flushFrame(false, p)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -553,10 +592,6 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
return nn, nil return nn, nil
} }
func (w *messageWriter) Write(p []byte) (int, error) {
return w.write(false, p)
}
func (w *messageWriter) WriteString(p string) (int, error) { func (w *messageWriter) WriteString(p string) (int, error) {
if err := w.err(); err != nil { if err := w.err(); err != nil {
return 0, err return 0, err
@ -658,12 +693,17 @@ func (c *Conn) advanceFrame() (int, error) {
final := p[0]&finalBit != 0 final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf) frameType := int(p[0] & 0xf)
reserved := int((p[0] >> 4) & 0x7)
mask := p[1]&maskBit != 0 mask := p[1]&maskBit != 0
c.readRemaining = int64(p[1] & 0x7f) c.readRemaining = int64(p[1] & 0x7f)
if reserved != 0 { c.readDecompress = false
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
c.readDecompress = true
p[0] &^= rsv1Bit
}
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
} }
switch frameType { switch frameType {
@ -807,8 +847,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
break break
} }
if frameType == TextMessage || frameType == BinaryMessage { if frameType == TextMessage || frameType == BinaryMessage {
r := &messageReader{c} c.messageReader = &messageReader{c}
c.messageReader = r var r io.Reader = c.messageReader
if c.readDecompress {
r = c.newDecompressionReader(r)
}
return frameType, r, nil return frameType, r, nil
} }
} }