mirror of https://github.com/gorilla/websocket.git
Add hooks to support RFC 7692 (per-message compression extension)
Add newCompressionWriter and newDecompressionReader fields to Conn. When not nil, these functions are used to create a compression/decompression wrapper around an underlying message writer/reader. Add code to set and check for RSV1 frame header bit. Add functions compressNoContextTakeover and decompressNoContextTakeover for creating no context takeover wrappers around an underlying message writer/reader. Work remaining: - Add fields to Dialer and Upgrader for specifying compression options. - Add compression negotiation to Dialer and Upgrader. - Add function to enable/disable write compression: // EnableWriteCompression enables and disables write compression of // subsequent text and binary messages. This function is a noop if // compression was not negotiated with the peer. func (c *Conn) EnableWriteCompression(enable bool) { c.enableWriteCompression = enable }
This commit is contained in:
parent
b5389d0dc2
commit
a87eae1d6f
|
@ -0,0 +1,85 @@
|
||||||
|
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/flate"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func decompressNoContextTakeover(r io.Reader) io.Reader {
|
||||||
|
const tail =
|
||||||
|
// Add four bytes as specified in RFC
|
||||||
|
"\x00\x00\xff\xff" +
|
||||||
|
// Add final block to squelch unexpected EOF error from flate reader.
|
||||||
|
"\x01\x00\x00\xff\xff"
|
||||||
|
|
||||||
|
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
|
||||||
|
tw := &truncWriter{w: w}
|
||||||
|
fw, err := flate.NewWriter(tw, 3)
|
||||||
|
return &flateWrapper{fw: fw, tw: tw}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncWriter is an io.Writer that writes all but the last four bytes of the
|
||||||
|
// stream to another io.Writer.
|
||||||
|
type truncWriter struct {
|
||||||
|
w io.WriteCloser
|
||||||
|
n int
|
||||||
|
p [4]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *truncWriter) Write(p []byte) (int, error) {
|
||||||
|
n := 0
|
||||||
|
|
||||||
|
// fill buffer first for simplicity.
|
||||||
|
if w.n < len(w.p) {
|
||||||
|
n = copy(w.p[w.n:], p)
|
||||||
|
p = p[n:]
|
||||||
|
w.n += n
|
||||||
|
if len(p) == 0 {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m := len(p)
|
||||||
|
if m > len(w.p) {
|
||||||
|
m = len(w.p)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nn, err := w.w.Write(w.p[:m]); err != nil {
|
||||||
|
return n + nn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(w.p[:], w.p[m:])
|
||||||
|
copy(w.p[len(w.p)-m:], p[len(p)-m:])
|
||||||
|
nn, err := w.w.Write(p[:len(p)-m])
|
||||||
|
return n + nn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type flateWrapper struct {
|
||||||
|
fw *flate.Writer
|
||||||
|
tw *truncWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flateWrapper) Write(p []byte) (int, error) {
|
||||||
|
return w.fw.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *flateWrapper) Close() error {
|
||||||
|
err1 := w.fw.Flush()
|
||||||
|
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
|
||||||
|
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
|
||||||
|
}
|
||||||
|
err2 := w.tw.w.Close()
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type nopCloser struct{ io.Writer }
|
||||||
|
|
||||||
|
func (nopCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
func TestTruncWriter(t *testing.T) {
|
||||||
|
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
|
||||||
|
for n := 1; n <= 10; n++ {
|
||||||
|
var b bytes.Buffer
|
||||||
|
w := &truncWriter{w: nopCloser{&b}}
|
||||||
|
p := []byte(data)
|
||||||
|
for len(p) > 0 {
|
||||||
|
m := len(p)
|
||||||
|
if m > n {
|
||||||
|
m = n
|
||||||
|
}
|
||||||
|
w.Write(p[:m])
|
||||||
|
p = p[m:]
|
||||||
|
}
|
||||||
|
if b.String() != data[:len(data)-len(w.p)] {
|
||||||
|
t.Errorf("%d: %q", n, b.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
107
conn.go
107
conn.go
|
@ -18,11 +18,19 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// Frame header byte 0 bits from Section 5.2 of RFC 6455
|
||||||
|
finalBit = 1 << 7
|
||||||
|
rsv1Bit = 1 << 6
|
||||||
|
rsv2Bit = 1 << 5
|
||||||
|
rsv3Bit = 1 << 4
|
||||||
|
|
||||||
|
// Frame header byte 1 bits from Section 5.2 of RFC 6455
|
||||||
|
maskBit = 1 << 7
|
||||||
|
|
||||||
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
|
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
|
||||||
maxControlFramePayloadSize = 125
|
maxControlFramePayloadSize = 125
|
||||||
finalBit = 1 << 7
|
|
||||||
maskBit = 1 << 7
|
writeWait = time.Second
|
||||||
writeWait = time.Second
|
|
||||||
|
|
||||||
defaultReadBufferSize = 4096
|
defaultReadBufferSize = 4096
|
||||||
defaultWriteBufferSize = 4096
|
defaultWriteBufferSize = 4096
|
||||||
|
@ -230,17 +238,20 @@ type Conn struct {
|
||||||
subprotocol string
|
subprotocol string
|
||||||
|
|
||||||
// Write fields
|
// Write fields
|
||||||
mu chan bool // used as mutex to protect write to conn and closeSent
|
mu chan bool // used as mutex to protect write to conn and closeSent
|
||||||
closeSent bool // true if close message was sent
|
closeSent bool // whether close message was sent
|
||||||
|
|
||||||
// Message writer fields.
|
|
||||||
writeErr error
|
writeErr error
|
||||||
writeBuf []byte // frame is constructed in this buffer.
|
writeBuf []byte // frame is constructed in this buffer.
|
||||||
writePos int // end of data in writeBuf.
|
writePos int // end of data in writeBuf.
|
||||||
writeFrameType int // type of the current frame.
|
writeFrameType int // type of the current frame.
|
||||||
writeDeadline time.Time
|
writeDeadline time.Time
|
||||||
|
messageWriter *messageWriter // the current low-level message writer
|
||||||
|
writer io.WriteCloser // the current writer returned to the application
|
||||||
isWriting bool // for best-effort concurrent write detection
|
isWriting bool // for best-effort concurrent write detection
|
||||||
messageWriter *messageWriter // the current writer
|
|
||||||
|
enableWriteCompression bool
|
||||||
|
writeCompress bool // whether next call to flushFrame should set RSV1
|
||||||
|
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
|
||||||
|
|
||||||
// Read fields
|
// Read fields
|
||||||
readErr error
|
readErr error
|
||||||
|
@ -254,7 +265,10 @@ type Conn struct {
|
||||||
handlePong func(string) error
|
handlePong func(string) error
|
||||||
handlePing func(string) error
|
handlePing func(string) error
|
||||||
readErrCount int
|
readErrCount int
|
||||||
messageReader *messageReader // the current reader
|
messageReader *messageReader // the current low-level reader
|
||||||
|
|
||||||
|
readDecompress bool // whether last read frame had RSV1 set
|
||||||
|
newDecompressionReader func(io.Reader) io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
||||||
|
@ -272,14 +286,15 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
isServer: isServer,
|
isServer: isServer,
|
||||||
br: bufio.NewReaderSize(conn, readBufferSize),
|
br: bufio.NewReaderSize(conn, readBufferSize),
|
||||||
conn: conn,
|
conn: conn,
|
||||||
mu: mu,
|
mu: mu,
|
||||||
readFinal: true,
|
readFinal: true,
|
||||||
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
|
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
|
||||||
writeFrameType: noFrame,
|
writeFrameType: noFrame,
|
||||||
writePos: maxFrameHeaderSize,
|
writePos: maxFrameHeaderSize,
|
||||||
|
enableWriteCompression: true,
|
||||||
}
|
}
|
||||||
c.SetPingHandler(nil)
|
c.SetPingHandler(nil)
|
||||||
c.SetPongHandler(nil)
|
c.SetPongHandler(nil)
|
||||||
|
@ -403,8 +418,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||||
return nil, c.writeErr
|
return nil, c.writeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.writeFrameType != noFrame {
|
// Close previous writer if not already closed by the application. It's
|
||||||
if err := c.flushFrame(true, nil); err != nil {
|
// probably better to return an error in this situation, but we cannot
|
||||||
|
// change this without breaking existing applications.
|
||||||
|
if c.writer != nil {
|
||||||
|
err := c.writer.Close()
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -414,11 +433,24 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.writeFrameType = messageType
|
c.writeFrameType = messageType
|
||||||
w := &messageWriter{c}
|
c.messageWriter = &messageWriter{c}
|
||||||
c.messageWriter = w
|
|
||||||
|
var w io.WriteCloser = c.messageWriter
|
||||||
|
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
||||||
|
c.writeCompress = true
|
||||||
|
var err error
|
||||||
|
w, err = c.newCompressionWriter(w)
|
||||||
|
if err != nil {
|
||||||
|
c.writer.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return w, nil
|
return w, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// flushFrame writes buffered data and extra as a frame to the network. The
|
||||||
|
// final argument indicates that this is the last frame in the message.
|
||||||
func (c *Conn) flushFrame(final bool, extra []byte) error {
|
func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||||
length := c.writePos - maxFrameHeaderSize + len(extra)
|
length := c.writePos - maxFrameHeaderSize + len(extra)
|
||||||
|
|
||||||
|
@ -426,6 +458,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||||
if isControl(c.writeFrameType) &&
|
if isControl(c.writeFrameType) &&
|
||||||
(!final || length > maxControlFramePayloadSize) {
|
(!final || length > maxControlFramePayloadSize) {
|
||||||
c.messageWriter = nil
|
c.messageWriter = nil
|
||||||
|
c.writer = nil
|
||||||
c.writeFrameType = noFrame
|
c.writeFrameType = noFrame
|
||||||
c.writePos = maxFrameHeaderSize
|
c.writePos = maxFrameHeaderSize
|
||||||
return errInvalidControlFrame
|
return errInvalidControlFrame
|
||||||
|
@ -435,6 +468,11 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||||
if final {
|
if final {
|
||||||
b0 |= finalBit
|
b0 |= finalBit
|
||||||
}
|
}
|
||||||
|
if c.writeCompress {
|
||||||
|
b0 |= rsv1Bit
|
||||||
|
}
|
||||||
|
c.writeCompress = false
|
||||||
|
|
||||||
b1 := byte(0)
|
b1 := byte(0)
|
||||||
if !c.isServer {
|
if !c.isServer {
|
||||||
b1 |= maskBit
|
b1 |= maskBit
|
||||||
|
@ -494,6 +532,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||||
c.writeFrameType = continuationFrame
|
c.writeFrameType = continuationFrame
|
||||||
if final {
|
if final {
|
||||||
c.messageWriter = nil
|
c.messageWriter = nil
|
||||||
|
c.writer = nil
|
||||||
c.writeFrameType = noFrame
|
c.writeFrameType = noFrame
|
||||||
}
|
}
|
||||||
return c.writeErr
|
return c.writeErr
|
||||||
|
@ -526,14 +565,14 @@ func (w *messageWriter) ncopy(max int) (int, error) {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *messageWriter) write(final bool, p []byte) (int, error) {
|
func (w *messageWriter) Write(p []byte) (int, error) {
|
||||||
if err := w.err(); err != nil {
|
if err := w.err(); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
|
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
|
||||||
// Don't buffer large messages.
|
// Don't buffer large messages.
|
||||||
err := w.c.flushFrame(final, p)
|
err := w.c.flushFrame(false, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -553,10 +592,6 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
|
||||||
return nn, nil
|
return nn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *messageWriter) Write(p []byte) (int, error) {
|
|
||||||
return w.write(false, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *messageWriter) WriteString(p string) (int, error) {
|
func (w *messageWriter) WriteString(p string) (int, error) {
|
||||||
if err := w.err(); err != nil {
|
if err := w.err(); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
@ -658,12 +693,17 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
final := p[0]&finalBit != 0
|
final := p[0]&finalBit != 0
|
||||||
frameType := int(p[0] & 0xf)
|
frameType := int(p[0] & 0xf)
|
||||||
reserved := int((p[0] >> 4) & 0x7)
|
|
||||||
mask := p[1]&maskBit != 0
|
mask := p[1]&maskBit != 0
|
||||||
c.readRemaining = int64(p[1] & 0x7f)
|
c.readRemaining = int64(p[1] & 0x7f)
|
||||||
|
|
||||||
if reserved != 0 {
|
c.readDecompress = false
|
||||||
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
|
||||||
|
c.readDecompress = true
|
||||||
|
p[0] &^= rsv1Bit
|
||||||
|
}
|
||||||
|
|
||||||
|
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
|
||||||
|
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
|
||||||
}
|
}
|
||||||
|
|
||||||
switch frameType {
|
switch frameType {
|
||||||
|
@ -807,8 +847,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if frameType == TextMessage || frameType == BinaryMessage {
|
if frameType == TextMessage || frameType == BinaryMessage {
|
||||||
r := &messageReader{c}
|
c.messageReader = &messageReader{c}
|
||||||
c.messageReader = r
|
var r io.Reader = c.messageReader
|
||||||
|
if c.readDecompress {
|
||||||
|
r = c.newDecompressionReader(r)
|
||||||
|
}
|
||||||
return frameType, r, nil
|
return frameType, r, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue