diff --git a/client.go b/client.go index 3bf9b2e..8242c66 100644 --- a/client.go +++ b/client.go @@ -69,6 +69,9 @@ type Dialer struct { // Subprotocols specifies the client's requested subprotocols. Subprotocols []string + + // Extensions specifies the client requested extensions + Extensions []string } var errMalformedURL = errors.New("malformed ws or wss URL") @@ -196,6 +199,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re if len(d.Subprotocols) > 0 { req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} } + if len(d.Extensions) > 0 { + req.Header["Sec-WebSocket-Extensions"] = d.Extensions + } + for k, vs := range requestHeader { switch { case k == "Host": @@ -206,6 +213,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re k == "Connection" || k == "Sec-Websocket-Key" || k == "Sec-Websocket-Version" || + (k == "Sec-WebSocket-Extensions" && len(d.Extensions) > 0) || (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) default: @@ -328,6 +336,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") + if len(resp.Header.Get("Sec-WebSocket-Extensions")) > 0 { + conn.compressionNegotiated = true + } + netConn.SetDeadline(time.Time{}) netConn = nil // to avoid close in defer. return conn, resp, nil diff --git a/compression.go b/compression.go new file mode 100644 index 0000000..4290cd7 --- /dev/null +++ b/compression.go @@ -0,0 +1,92 @@ +package websocket + +import ( + //"bytes" + "compress/flate" + //"fmt" + "io" +) + +const ( + // Supported compression algorithm and parameters. + CompressPermessageDeflate = "permessage-deflate; server_no_context_takeover; client_no_context_takeover" + + // Deflate compression level + compressDeflateLevel int = 3 +) + +// Sits between a flate writer and the underlying writer i.e. messageWriter +// Truncates last bytes of flate compresses message +type FlateAdaptor struct { + last5bytes []byte + msgWriter io.WriteCloser +} + +func NewFlateAdaptor(w io.WriteCloser) *FlateAdaptor { + return &FlateAdaptor{ + msgWriter: w, + last5bytes: []byte{}, + } +} + +func (aw *FlateAdaptor) Write(p []byte) (n int, err error) { + + t := append(aw.last5bytes, p...) + + if len(t) > 4 { + aw.last5bytes = make([]byte, 5) + copy(aw.last5bytes, t[len(t)-5:]) + _, err = aw.msgWriter.Write(t[:len(t)-5]) + } else { + aw.last5bytes = make([]byte, len(t)) + aw.last5bytes = t + } + + n = len(p) + return +} + +func (aw *FlateAdaptor) writeEndBlock() (int, error) { + var t []byte + if aw.last5bytes[4] != 0x00 { + t = append(aw.last5bytes, 0x00) + } + + return aw.msgWriter.Write(t[:len(t)-5]) +} + +func (aw *FlateAdaptor) Close() (err error) { + if _, err = aw.writeEndBlock(); err == nil { + err = aw.msgWriter.Close() + } + return +} + +// FlateAdaptorWriter --> FlateAdaptor --> messageWriter +type FlateAdaptorWriter struct { + flWriter *flate.Writer + flAdaptor *FlateAdaptor +} + +func NewFlateAdaptorWriter(msgWriter io.WriteCloser, level int) (faw *FlateAdaptorWriter, err error) { + faw = &FlateAdaptorWriter{ + flAdaptor: NewFlateAdaptor(msgWriter), + } + faw.flWriter, err = flate.NewWriter(faw.flAdaptor, level) + return +} + +func (faw *FlateAdaptorWriter) Write(p []byte) (c int, err error) { + if c, err = faw.flWriter.Write(p); err == nil { + err = faw.flWriter.Flush() + } + return +} + +func (faw *FlateAdaptorWriter) Close() (err error) { + if err = faw.flWriter.Close(); err == nil { + err = faw.flAdaptor.Close() + } + + return +} diff --git a/compression_test.go b/compression_test.go new file mode 100644 index 0000000..d58bb3b --- /dev/null +++ b/compression_test.go @@ -0,0 +1,26 @@ +package websocket + +import ( + "bytes" + "compress/flate" + "testing" +) + +func Test_NewAdaptorWriter(t *testing.T) { + backendBuff := new(bytes.Buffer) + aw := NewAdaptorWriter(backendBuff) + + fw, err := flate.NewWriter(aw, -1) + if err != nil { + t.Fatal(err) + } + + var n int + n, err = fw.Write([]byte("test")) + t.Log(n, err) + + if err = fw.Flush(); err != nil { + t.Fatal(err) + } + +} diff --git a/conn.go b/conn.go index e8b6b3e..a7c5345 100644 --- a/conn.go +++ b/conn.go @@ -6,6 +6,7 @@ package websocket import ( "bufio" + "compress/flate" "encoding/binary" "errors" "io" @@ -13,6 +14,7 @@ import ( "math/rand" "net" "strconv" + "strings" "time" ) @@ -21,6 +23,7 @@ const ( maxControlFramePayloadSize = 125 finalBit = 1 << 7 maskBit = 1 << 7 + compressionBit = 1 << 6 // used in flushFrame on writes writeWait = time.Second defaultReadBufferSize = 4096 @@ -144,11 +147,15 @@ type Conn struct { isServer bool subprotocol string + compressionNegotiated bool // negotiated compression based on handshake + // Write fields mu chan bool // used as mutex to protect write to conn and closeSent closeSent bool // true if close message was sent // Message writer fields. + writeCompressionEnabled bool + writeErr error writeBuf []byte // frame is constructed in this buffer. writePos int // end of data in writeBuf. @@ -157,6 +164,8 @@ type Conn struct { writeDeadline time.Time // Read fields + readMessageCompressed bool + readErr error br *bufio.Reader readRemaining int64 // bytes remaining in current frame. @@ -218,6 +227,12 @@ func (c *Conn) RemoteAddr() net.Addr { // Write methods +// EnableWriteCompression enables and disables write compression of subsequent text and +// binary messages. This function is a noop if compression was not negotiated with the peer. +func (c *Conn) EnableWriteCompression(enable bool) { + c.writeCompressionEnabled = enable +} + func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { <-c.mu defer func() { c.mu <- true }() @@ -327,7 +342,15 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { } c.writeFrameType = messageType - return messageWriter{c, c.writeSeq}, nil + + var wc io.WriteCloser = messageWriter{c, c.writeSeq} + + // Return compression writer on data frame + if c.compressionNegotiated && c.writeCompressionEnabled && isData(messageType) { + return NewFlateAdaptorWriter(wc, compressDeflateLevel) + } + + return wc, nil } func (c *Conn) flushFrame(final bool, extra []byte) error { @@ -346,6 +369,13 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { if final { b0 |= finalBit } + + // Check compression and that it is not a continuation frame + // as those should not have compression bit set per RFC + if c.compressionNegotiated && c.writeCompressionEnabled && c.writeFrameType != continuationFrame { + b0 |= compressionBit + } + b1 := byte(0) if !c.isServer { b1 |= maskBit @@ -515,15 +545,30 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { if err != nil { return err } - w := wr.(messageWriter) - if _, err := w.write(true, data); err != nil { - return err - } - if c.writeSeq == w.seq { - if err := c.flushFrame(true, nil); err != nil { + + if c.compressionNegotiated && c.writeCompressionEnabled { + + fw := wr.(*FlateAdaptorWriter) + if _, err = fw.Write(data); err != nil { return err } + return fw.Close() + + } else { + + w := wr.(messageWriter) + if _, err = w.write(true, data); err != nil { + return err + } + + // final flush + if c.writeSeq == w.seq { + if err = c.flushFrame(true, nil); err != nil { + return err + } + } } + return nil } @@ -577,7 +622,17 @@ func (c *Conn) advanceFrame() (int, error) { mask := b[1]&maskBit != 0 c.readRemaining = int64(b[1] & 0x7f) - if reserved != 0 { + switch reserved { + case 4: + if !c.compressionNegotiated { + return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) + } + // Only the first frame of a compressed message has the reserved bit set. + c.readMessageCompressed = true + break + case 0: + break + default: return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) } @@ -633,7 +688,7 @@ func (c *Conn) advanceFrame() (int, error) { // 5. For text and binary messages, enforce read limit and return. - if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { + if frameType == continuationFrame || isData(frameType) { c.readLength += c.readRemaining if c.readLimit > 0 && c.readLength > c.readLimit { @@ -696,7 +751,7 @@ func (c *Conn) handleProtocolError(message string) error { // // The NextReader method and the readers returned from the method cannot be // accessed by more than one goroutine at a time. -func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { +func (c *Conn) NextReader() (int, io.Reader, error) { c.readSeq++ c.readLength = 0 @@ -707,8 +762,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { c.readErr = hideTempErr(err) break } - if frameType == TextMessage || frameType == BinaryMessage { - return frameType, messageReader{c, c.readSeq}, nil + + if isData(frameType) { + var r io.Reader = messageReader{c, c.readSeq} + if c.compressionNegotiated && c.readMessageCompressed { + // Append compression bytes to output on the final read + r = flate.NewReader(io.MultiReader(r, strings.NewReader("\x00\x00\xff\xff\x01\x00\x00\xff\xff"))) + } + return frameType, r, nil } } return noFrame, nil, c.readErr @@ -742,6 +803,11 @@ func (r messageReader) Read(b []byte) (int, error) { if r.c.readFinal { r.c.readSeq++ + // Reset compression for the next frame + if r.c.compressionNegotiated && r.c.readMessageCompressed { + r.c.readMessageCompressed = false + } + return 0, io.EOF } @@ -749,7 +815,7 @@ func (r messageReader) Read(b []byte) (int, error) { switch { case err != nil: r.c.readErr = hideTempErr(err) - case frameType == TextMessage || frameType == BinaryMessage: + case isData(frameType): r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } } diff --git a/examples/autobahn/README.md b/examples/autobahn/README.md index 075ac15..bc784fa 100644 --- a/examples/autobahn/README.md +++ b/examples/autobahn/README.md @@ -11,3 +11,10 @@ and start the client test driver wstest -m fuzzingclient -s fuzzingclient.json When the client completes, it writes a report to reports/clients/index.html. + + +# Install client test driver + + pip install autobahntestsuite + +This will install the test suite containing the `wstest` command. \ No newline at end of file diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index d96ac84..f20ed51 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -8,7 +8,7 @@ package main import ( "errors" "flag" - "github.com/gorilla/websocket" + "github.com/euforia/websocket" "io" "log" "net/http" @@ -22,6 +22,9 @@ var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, + Extensions: []string{ + "permessage-deflate; server_no_context_takeover; client_no_context_takeover", + }, } // echoCopy echoes messages from the client using io.Copy. diff --git a/examples/compression/README.md b/examples/compression/README.md new file mode 100644 index 0000000..9a5b774 --- /dev/null +++ b/examples/compression/README.md @@ -0,0 +1,10 @@ +# Compression example +This example covers enabling compression on the server. It starts a websocket server with permessage-deflate enabled for compression. You can then visit the page to send/recieve messages through the browser. + +Start the server by running the following in this directory: + + go run server.go + +You can now navigate to the displayed address in you browser: + + http://localhost:12345/ diff --git a/examples/compression/client.go b/examples/compression/client.go new file mode 100644 index 0000000..de5624a --- /dev/null +++ b/examples/compression/client.go @@ -0,0 +1,59 @@ +// +build ignore + +package main + +import ( + "log" + "net/http" + "time" + + "github.com/euforia/websocket" +) + +func main() { + dialer := websocket.Dialer{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + Extensions: []string{"permessage-deflate"}, + Proxy: http.ProxyFromEnvironment, + } + + c, respHdr, err := dialer.Dial("ws://localhost:9001/f", nil) + + if err != nil { + log.Fatal("dial:", err) + } + defer c.Close() + + log.Printf("Extensions: %s\n", respHdr.Header.Get("Sec-Websocket-Extensions")) + + compressEnabled := true + + go func() { + defer c.Close() + for { + _, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + break + } + log.Printf("Received: %s", message) + } + }() + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for t := range ticker.C { + err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) + if err != nil { + log.Println("write:", err) + break + } + + log.Printf("Wrote: compressed=%v; value=%s\n", compressEnabled, t.String()) + + compressEnabled = !compressEnabled + c.EnableWriteCompression(compressEnabled) + } +} diff --git a/examples/compression/index.html b/examples/compression/index.html new file mode 100644 index 0000000..df3f598 --- /dev/null +++ b/examples/compression/index.html @@ -0,0 +1,53 @@ + + +
+
+
+
+
+
+ |
+
+ Echo Response:
+
+ |
+