forked from mirror/websocket
Add hooks to support RFC 7692 (per-message compression extension)
Add newCompressionWriter and newDecompressionReader fields to Conn. When not nil, these functions are used to create a compression/decompression wrapper around an underlying message writer/reader. Add code to set and check for RSV1 frame header bit. Add functions compressNoContextTakeover and decompressNoContextTakeover for creating no context takeover wrappers around an underlying message writer/reader. Work remaining: - Add fields to Dialer and Upgrader for specifying compression options. - Add compression negotiation to Dialer and Upgrader. - Add function to enable/disable write compression: // EnableWriteCompression enables and disables write compression of // subsequent text and binary messages. This function is a noop if // compression was not negotiated with the peer. func (c *Conn) EnableWriteCompression(enable bool) { c.enableWriteCompression = enable }
This commit is contained in:
parent
b5389d0dc2
commit
a87eae1d6f
|
@ -0,0 +1,85 @@
|
|||
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"compress/flate"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func decompressNoContextTakeover(r io.Reader) io.Reader {
|
||||
const tail =
|
||||
// Add four bytes as specified in RFC
|
||||
"\x00\x00\xff\xff" +
|
||||
// Add final block to squelch unexpected EOF error from flate reader.
|
||||
"\x01\x00\x00\xff\xff"
|
||||
|
||||
return flate.NewReader(io.MultiReader(r, strings.NewReader(tail)))
|
||||
}
|
||||
|
||||
func compressNoContextTakeover(w io.WriteCloser) (io.WriteCloser, error) {
|
||||
tw := &truncWriter{w: w}
|
||||
fw, err := flate.NewWriter(tw, 3)
|
||||
return &flateWrapper{fw: fw, tw: tw}, err
|
||||
}
|
||||
|
||||
// truncWriter is an io.Writer that writes all but the last four bytes of the
|
||||
// stream to another io.Writer.
|
||||
type truncWriter struct {
|
||||
w io.WriteCloser
|
||||
n int
|
||||
p [4]byte
|
||||
}
|
||||
|
||||
func (w *truncWriter) Write(p []byte) (int, error) {
|
||||
n := 0
|
||||
|
||||
// fill buffer first for simplicity.
|
||||
if w.n < len(w.p) {
|
||||
n = copy(w.p[w.n:], p)
|
||||
p = p[n:]
|
||||
w.n += n
|
||||
if len(p) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
m := len(p)
|
||||
if m > len(w.p) {
|
||||
m = len(w.p)
|
||||
}
|
||||
|
||||
if nn, err := w.w.Write(w.p[:m]); err != nil {
|
||||
return n + nn, err
|
||||
}
|
||||
|
||||
copy(w.p[:], w.p[m:])
|
||||
copy(w.p[len(w.p)-m:], p[len(p)-m:])
|
||||
nn, err := w.w.Write(p[:len(p)-m])
|
||||
return n + nn, err
|
||||
}
|
||||
|
||||
type flateWrapper struct {
|
||||
fw *flate.Writer
|
||||
tw *truncWriter
|
||||
}
|
||||
|
||||
func (w *flateWrapper) Write(p []byte) (int, error) {
|
||||
return w.fw.Write(p)
|
||||
}
|
||||
|
||||
func (w *flateWrapper) Close() error {
|
||||
err1 := w.fw.Flush()
|
||||
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
|
||||
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
|
||||
}
|
||||
err2 := w.tw.w.Close()
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type nopCloser struct{ io.Writer }
|
||||
|
||||
func (nopCloser) Close() error { return nil }
|
||||
|
||||
func TestTruncWriter(t *testing.T) {
|
||||
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
|
||||
for n := 1; n <= 10; n++ {
|
||||
var b bytes.Buffer
|
||||
w := &truncWriter{w: nopCloser{&b}}
|
||||
p := []byte(data)
|
||||
for len(p) > 0 {
|
||||
m := len(p)
|
||||
if m > n {
|
||||
m = n
|
||||
}
|
||||
w.Write(p[:m])
|
||||
p = p[m:]
|
||||
}
|
||||
if b.String() != data[:len(data)-len(w.p)] {
|
||||
t.Errorf("%d: %q", n, b.String())
|
||||
}
|
||||
}
|
||||
}
|
87
conn.go
87
conn.go
|
@ -18,10 +18,18 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// Frame header byte 0 bits from Section 5.2 of RFC 6455
|
||||
finalBit = 1 << 7
|
||||
rsv1Bit = 1 << 6
|
||||
rsv2Bit = 1 << 5
|
||||
rsv3Bit = 1 << 4
|
||||
|
||||
// Frame header byte 1 bits from Section 5.2 of RFC 6455
|
||||
maskBit = 1 << 7
|
||||
|
||||
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
|
||||
maxControlFramePayloadSize = 125
|
||||
finalBit = 1 << 7
|
||||
maskBit = 1 << 7
|
||||
|
||||
writeWait = time.Second
|
||||
|
||||
defaultReadBufferSize = 4096
|
||||
|
@ -231,16 +239,19 @@ type Conn struct {
|
|||
|
||||
// Write fields
|
||||
mu chan bool // used as mutex to protect write to conn and closeSent
|
||||
closeSent bool // true if close message was sent
|
||||
|
||||
// Message writer fields.
|
||||
closeSent bool // whether close message was sent
|
||||
writeErr error
|
||||
writeBuf []byte // frame is constructed in this buffer.
|
||||
writePos int // end of data in writeBuf.
|
||||
writeFrameType int // type of the current frame.
|
||||
writeDeadline time.Time
|
||||
messageWriter *messageWriter // the current low-level message writer
|
||||
writer io.WriteCloser // the current writer returned to the application
|
||||
isWriting bool // for best-effort concurrent write detection
|
||||
messageWriter *messageWriter // the current writer
|
||||
|
||||
enableWriteCompression bool
|
||||
writeCompress bool // whether next call to flushFrame should set RSV1
|
||||
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
|
||||
|
||||
// Read fields
|
||||
readErr error
|
||||
|
@ -254,7 +265,10 @@ type Conn struct {
|
|||
handlePong func(string) error
|
||||
handlePing func(string) error
|
||||
readErrCount int
|
||||
messageReader *messageReader // the current reader
|
||||
messageReader *messageReader // the current low-level reader
|
||||
|
||||
readDecompress bool // whether last read frame had RSV1 set
|
||||
newDecompressionReader func(io.Reader) io.Reader
|
||||
}
|
||||
|
||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
||||
|
@ -280,6 +294,7 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
|
|||
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
|
||||
writeFrameType: noFrame,
|
||||
writePos: maxFrameHeaderSize,
|
||||
enableWriteCompression: true,
|
||||
}
|
||||
c.SetPingHandler(nil)
|
||||
c.SetPongHandler(nil)
|
||||
|
@ -403,8 +418,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
|||
return nil, c.writeErr
|
||||
}
|
||||
|
||||
if c.writeFrameType != noFrame {
|
||||
if err := c.flushFrame(true, nil); err != nil {
|
||||
// Close previous writer if not already closed by the application. It's
|
||||
// probably better to return an error in this situation, but we cannot
|
||||
// change this without breaking existing applications.
|
||||
if c.writer != nil {
|
||||
err := c.writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -414,11 +433,24 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
|||
}
|
||||
|
||||
c.writeFrameType = messageType
|
||||
w := &messageWriter{c}
|
||||
c.messageWriter = w
|
||||
c.messageWriter = &messageWriter{c}
|
||||
|
||||
var w io.WriteCloser = c.messageWriter
|
||||
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
||||
c.writeCompress = true
|
||||
var err error
|
||||
w, err = c.newCompressionWriter(w)
|
||||
if err != nil {
|
||||
c.writer.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// flushFrame writes buffered data and extra as a frame to the network. The
|
||||
// final argument indicates that this is the last frame in the message.
|
||||
func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||
length := c.writePos - maxFrameHeaderSize + len(extra)
|
||||
|
||||
|
@ -426,6 +458,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|||
if isControl(c.writeFrameType) &&
|
||||
(!final || length > maxControlFramePayloadSize) {
|
||||
c.messageWriter = nil
|
||||
c.writer = nil
|
||||
c.writeFrameType = noFrame
|
||||
c.writePos = maxFrameHeaderSize
|
||||
return errInvalidControlFrame
|
||||
|
@ -435,6 +468,11 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|||
if final {
|
||||
b0 |= finalBit
|
||||
}
|
||||
if c.writeCompress {
|
||||
b0 |= rsv1Bit
|
||||
}
|
||||
c.writeCompress = false
|
||||
|
||||
b1 := byte(0)
|
||||
if !c.isServer {
|
||||
b1 |= maskBit
|
||||
|
@ -494,6 +532,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|||
c.writeFrameType = continuationFrame
|
||||
if final {
|
||||
c.messageWriter = nil
|
||||
c.writer = nil
|
||||
c.writeFrameType = noFrame
|
||||
}
|
||||
return c.writeErr
|
||||
|
@ -526,14 +565,14 @@ func (w *messageWriter) ncopy(max int) (int, error) {
|
|||
return n, nil
|
||||
}
|
||||
|
||||
func (w *messageWriter) write(final bool, p []byte) (int, error) {
|
||||
func (w *messageWriter) Write(p []byte) (int, error) {
|
||||
if err := w.err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
|
||||
// Don't buffer large messages.
|
||||
err := w.c.flushFrame(final, p)
|
||||
err := w.c.flushFrame(false, p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -553,10 +592,6 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
|
|||
return nn, nil
|
||||
}
|
||||
|
||||
func (w *messageWriter) Write(p []byte) (int, error) {
|
||||
return w.write(false, p)
|
||||
}
|
||||
|
||||
func (w *messageWriter) WriteString(p string) (int, error) {
|
||||
if err := w.err(); err != nil {
|
||||
return 0, err
|
||||
|
@ -658,12 +693,17 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
|
||||
final := p[0]&finalBit != 0
|
||||
frameType := int(p[0] & 0xf)
|
||||
reserved := int((p[0] >> 4) & 0x7)
|
||||
mask := p[1]&maskBit != 0
|
||||
c.readRemaining = int64(p[1] & 0x7f)
|
||||
|
||||
if reserved != 0 {
|
||||
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
||||
c.readDecompress = false
|
||||
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
|
||||
c.readDecompress = true
|
||||
p[0] &^= rsv1Bit
|
||||
}
|
||||
|
||||
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
|
||||
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
|
||||
}
|
||||
|
||||
switch frameType {
|
||||
|
@ -807,8 +847,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|||
break
|
||||
}
|
||||
if frameType == TextMessage || frameType == BinaryMessage {
|
||||
r := &messageReader{c}
|
||||
c.messageReader = r
|
||||
c.messageReader = &messageReader{c}
|
||||
var r io.Reader = c.messageReader
|
||||
if c.readDecompress {
|
||||
r = c.newDecompressionReader(r)
|
||||
}
|
||||
return frameType, r, nil
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue