Added permessage-deflate compression

This commit is contained in:
euforia 2016-04-26 14:43:57 -07:00
parent 361d4c0ffd
commit e655eaeb26
11 changed files with 466 additions and 14 deletions

View File

@ -69,6 +69,9 @@ type Dialer struct {
// Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
// Extensions specifies the client requested extensions
Extensions []string
}
var errMalformedURL = errors.New("malformed ws or wss URL")
@ -196,6 +199,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
if len(d.Subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
}
if len(d.Extensions) > 0 {
req.Header["Sec-WebSocket-Extensions"] = d.Extensions
}
for k, vs := range requestHeader {
switch {
case k == "Host":
@ -206,6 +213,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
k == "Connection" ||
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
(k == "Sec-WebSocket-Extensions" && len(d.Extensions) > 0) ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
default:
@ -328,6 +336,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
if len(resp.Header.Get("Sec-WebSocket-Extensions")) > 0 {
conn.compressionNegotiated = true
}
netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer.
return conn, resp, nil

92
compression.go Normal file
View File

@ -0,0 +1,92 @@
package websocket
import (
//"bytes"
"compress/flate"
//"fmt"
"io"
)
const (
// Supported compression algorithm and parameters.
CompressPermessageDeflate = "permessage-deflate; server_no_context_takeover; client_no_context_takeover"
// Deflate compression level
compressDeflateLevel int = 3
)
// Sits between a flate writer and the underlying writer i.e. messageWriter
// Truncates last bytes of flate compresses message
type FlateAdaptor struct {
last5bytes []byte
msgWriter io.WriteCloser
}
func NewFlateAdaptor(w io.WriteCloser) *FlateAdaptor {
return &FlateAdaptor{
msgWriter: w,
last5bytes: []byte{},
}
}
func (aw *FlateAdaptor) Write(p []byte) (n int, err error) {
t := append(aw.last5bytes, p...)
if len(t) > 4 {
aw.last5bytes = make([]byte, 5)
copy(aw.last5bytes, t[len(t)-5:])
_, err = aw.msgWriter.Write(t[:len(t)-5])
} else {
aw.last5bytes = make([]byte, len(t))
aw.last5bytes = t
}
n = len(p)
return
}
func (aw *FlateAdaptor) writeEndBlock() (int, error) {
var t []byte
if aw.last5bytes[4] != 0x00 {
t = append(aw.last5bytes, 0x00)
}
return aw.msgWriter.Write(t[:len(t)-5])
}
func (aw *FlateAdaptor) Close() (err error) {
if _, err = aw.writeEndBlock(); err == nil {
err = aw.msgWriter.Close()
}
return
}
// FlateAdaptorWriter --> FlateAdaptor --> messageWriter
type FlateAdaptorWriter struct {
flWriter *flate.Writer
flAdaptor *FlateAdaptor
}
func NewFlateAdaptorWriter(msgWriter io.WriteCloser, level int) (faw *FlateAdaptorWriter, err error) {
faw = &FlateAdaptorWriter{
flAdaptor: NewFlateAdaptor(msgWriter),
}
faw.flWriter, err = flate.NewWriter(faw.flAdaptor, level)
return
}
func (faw *FlateAdaptorWriter) Write(p []byte) (c int, err error) {
if c, err = faw.flWriter.Write(p); err == nil {
err = faw.flWriter.Flush()
}
return
}
func (faw *FlateAdaptorWriter) Close() (err error) {
if err = faw.flWriter.Close(); err == nil {
err = faw.flAdaptor.Close()
}
return
}

26
compression_test.go Normal file
View File

@ -0,0 +1,26 @@
package websocket
import (
"bytes"
"compress/flate"
"testing"
)
func Test_NewAdaptorWriter(t *testing.T) {
backendBuff := new(bytes.Buffer)
aw := NewAdaptorWriter(backendBuff)
fw, err := flate.NewWriter(aw, -1)
if err != nil {
t.Fatal(err)
}
var n int
n, err = fw.Write([]byte("test"))
t.Log(n, err)
if err = fw.Flush(); err != nil {
t.Fatal(err)
}
}

92
conn.go
View File

@ -6,6 +6,7 @@ package websocket
import (
"bufio"
"compress/flate"
"encoding/binary"
"errors"
"io"
@ -13,6 +14,7 @@ import (
"math/rand"
"net"
"strconv"
"strings"
"time"
)
@ -21,6 +23,7 @@ const (
maxControlFramePayloadSize = 125
finalBit = 1 << 7
maskBit = 1 << 7
compressionBit = 1 << 6 // used in flushFrame on writes
writeWait = time.Second
defaultReadBufferSize = 4096
@ -144,11 +147,15 @@ type Conn struct {
isServer bool
subprotocol string
compressionNegotiated bool // negotiated compression based on handshake
// Write fields
mu chan bool // used as mutex to protect write to conn and closeSent
closeSent bool // true if close message was sent
// Message writer fields.
writeCompressionEnabled bool
writeErr error
writeBuf []byte // frame is constructed in this buffer.
writePos int // end of data in writeBuf.
@ -157,6 +164,8 @@ type Conn struct {
writeDeadline time.Time
// Read fields
readMessageCompressed bool
readErr error
br *bufio.Reader
readRemaining int64 // bytes remaining in current frame.
@ -218,6 +227,12 @@ func (c *Conn) RemoteAddr() net.Addr {
// Write methods
// EnableWriteCompression enables and disables write compression of subsequent text and
// binary messages. This function is a noop if compression was not negotiated with the peer.
func (c *Conn) EnableWriteCompression(enable bool) {
c.writeCompressionEnabled = enable
}
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
<-c.mu
defer func() { c.mu <- true }()
@ -327,7 +342,15 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
}
c.writeFrameType = messageType
return messageWriter{c, c.writeSeq}, nil
var wc io.WriteCloser = messageWriter{c, c.writeSeq}
// Return compression writer on data frame
if c.compressionNegotiated && c.writeCompressionEnabled && isData(messageType) {
return NewFlateAdaptorWriter(wc, compressDeflateLevel)
}
return wc, nil
}
func (c *Conn) flushFrame(final bool, extra []byte) error {
@ -346,6 +369,13 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
if final {
b0 |= finalBit
}
// Check compression and that it is not a continuation frame
// as those should not have compression bit set per RFC
if c.compressionNegotiated && c.writeCompressionEnabled && c.writeFrameType != continuationFrame {
b0 |= compressionBit
}
b1 := byte(0)
if !c.isServer {
b1 |= maskBit
@ -515,15 +545,30 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
if err != nil {
return err
}
w := wr.(messageWriter)
if _, err := w.write(true, data); err != nil {
return err
}
if c.writeSeq == w.seq {
if err := c.flushFrame(true, nil); err != nil {
if c.compressionNegotiated && c.writeCompressionEnabled {
fw := wr.(*FlateAdaptorWriter)
if _, err = fw.Write(data); err != nil {
return err
}
return fw.Close()
} else {
w := wr.(messageWriter)
if _, err = w.write(true, data); err != nil {
return err
}
// final flush
if c.writeSeq == w.seq {
if err = c.flushFrame(true, nil); err != nil {
return err
}
}
}
return nil
}
@ -577,7 +622,17 @@ func (c *Conn) advanceFrame() (int, error) {
mask := b[1]&maskBit != 0
c.readRemaining = int64(b[1] & 0x7f)
if reserved != 0 {
switch reserved {
case 4:
if !c.compressionNegotiated {
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
}
// Only the first frame of a compressed message has the reserved bit set.
c.readMessageCompressed = true
break
case 0:
break
default:
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
}
@ -633,7 +688,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 5. For text and binary messages, enforce read limit and return.
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
if frameType == continuationFrame || isData(frameType) {
c.readLength += c.readRemaining
if c.readLimit > 0 && c.readLength > c.readLimit {
@ -696,7 +751,7 @@ func (c *Conn) handleProtocolError(message string) error {
//
// The NextReader method and the readers returned from the method cannot be
// accessed by more than one goroutine at a time.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
func (c *Conn) NextReader() (int, io.Reader, error) {
c.readSeq++
c.readLength = 0
@ -707,8 +762,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.readErr = hideTempErr(err)
break
}
if frameType == TextMessage || frameType == BinaryMessage {
return frameType, messageReader{c, c.readSeq}, nil
if isData(frameType) {
var r io.Reader = messageReader{c, c.readSeq}
if c.compressionNegotiated && c.readMessageCompressed {
// Append compression bytes to output on the final read
r = flate.NewReader(io.MultiReader(r, strings.NewReader("\x00\x00\xff\xff\x01\x00\x00\xff\xff")))
}
return frameType, r, nil
}
}
return noFrame, nil, c.readErr
@ -742,6 +803,11 @@ func (r messageReader) Read(b []byte) (int, error) {
if r.c.readFinal {
r.c.readSeq++
// Reset compression for the next frame
if r.c.compressionNegotiated && r.c.readMessageCompressed {
r.c.readMessageCompressed = false
}
return 0, io.EOF
}
@ -749,7 +815,7 @@ func (r messageReader) Read(b []byte) (int, error) {
switch {
case err != nil:
r.c.readErr = hideTempErr(err)
case frameType == TextMessage || frameType == BinaryMessage:
case isData(frameType):
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
}
}

View File

@ -11,3 +11,10 @@ and start the client test driver
wstest -m fuzzingclient -s fuzzingclient.json
When the client completes, it writes a report to reports/clients/index.html.
# Install client test driver
pip install autobahntestsuite
This will install the test suite containing the `wstest` command.

View File

@ -8,7 +8,7 @@ package main
import (
"errors"
"flag"
"github.com/gorilla/websocket"
"github.com/euforia/websocket"
"io"
"log"
"net/http"
@ -22,6 +22,9 @@ var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
Extensions: []string{
"permessage-deflate; server_no_context_takeover; client_no_context_takeover",
},
}
// echoCopy echoes messages from the client using io.Copy.

View File

@ -0,0 +1,10 @@
# Compression example
This example covers enabling compression on the server. It starts a websocket server with permessage-deflate enabled for compression. You can then visit the page to send/recieve messages through the browser.
Start the server by running the following in this directory:
go run server.go
You can now navigate to the displayed address in you browser:
http://localhost:12345/

View File

@ -0,0 +1,59 @@
// +build ignore
package main
import (
"log"
"net/http"
"time"
"github.com/euforia/websocket"
)
func main() {
dialer := websocket.Dialer{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
Extensions: []string{"permessage-deflate"},
Proxy: http.ProxyFromEnvironment,
}
c, respHdr, err := dialer.Dial("ws://localhost:9001/f", nil)
if err != nil {
log.Fatal("dial:", err)
}
defer c.Close()
log.Printf("Extensions: %s\n", respHdr.Header.Get("Sec-Websocket-Extensions"))
compressEnabled := true
go func() {
defer c.Close()
for {
_, message, err := c.ReadMessage()
if err != nil {
log.Println("read:", err)
break
}
log.Printf("Received: %s", message)
}
}()
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for t := range ticker.C {
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
if err != nil {
log.Println("write:", err)
break
}
log.Printf("Wrote: compressed=%v; value=%s\n", compressEnabled, t.String())
compressEnabled = !compressEnabled
c.EnableWriteCompression(compressEnabled)
}
}

View File

@ -0,0 +1,53 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>WebSocket Example</title>
<style type="text/css">
body { font-family: helvetica; color: #333;}
</style>
</head>
<body>
<div id="conn-details" style="padding: 10px"></div>
<table style="width:100%"><tr>
<td style="width:50%;vertical-align:top">
<div style="text-align:center">
<textarea id="inputArea" style="width:100%;min-height:250px"></textarea>
</div>
<div style="padding:10px"><button onclick="sendData()">send</button></div>
</td>
<td style="width:50%;padding:5px;">
<div style="padding:10px">Echo Response:</div>
<pre id="fileData" style="padding:10px;border:none;min-height:250px;margin:0;white-space: pre-wrap;"></pre>
</td>
</tr></table>
<script type="text/javascript">
var inputElem = document.getElementById("inputArea");
var conn = new WebSocket("ws://127.0.0.1:9001/f");
(function() {
var data = document.getElementById("fileData");
conn.onopen = function(evt) {
var connDetails = document.getElementById('conn-details');
connDetails.textContent = 'Connection Extensions: [' + conn.extensions + ']';
}
conn.onclose = function(evt) {
console.log(evt);
data.textContent = 'Connection closed';
}
conn.onmessage = function(evt) {
console.log('Message:', evt.data);
data.textContent += '[ '+ (new Date()).toString() + ' ] ' + evt.data + '\n';
}
})();
function sendData() {
if (inputElem.value.length > 0) {
conn.send(inputElem.value);
}
}
</script>
</body>
</html>

View File

@ -0,0 +1,81 @@
// +build ignore
package main
import (
"io"
"log"
"net/http"
"path/filepath"
"github.com/euforia/websocket"
)
var (
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
Extensions: []string{websocket.CompressPermessageDeflate},
}
webroot, _ = filepath.Abs("./")
listenAddr = "0.0.0.0:9001"
)
func ServeWebSocket(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
defer ws.Close()
log.Printf("Client connected: %s\n", r.RemoteAddr)
//if err = ws.WriteMessage(websocket.TextMessage, []byte("Hello!")); err != nil {
// log.Println(err)
// err = nil
//}
for {
/*
if msgType, msgBytes, err = ws.ReadMessage(); err != nil {
log.Printf("Client disconnected %s: %s\n", r.RemoteAddr, err)
break
}
log.Printf("type: %d; payload: %d bytes;\n", msgType, len(msgBytes))
*/
msgType, rd, err := ws.NextReader()
if err != nil {
log.Printf("Client disconnected (%s): %s\n", r.RemoteAddr, err)
break
}
wr, err := ws.NextWriter(msgType)
if err != nil {
log.Println(err)
continue
}
if _, err = io.Copy(wr, rd); err != nil {
log.Println(err)
}
if err = wr.Close(); err != nil {
log.Println(err)
}
}
}
func main() {
// Serve index.html
http.Handle("/", http.FileServer(http.Dir(webroot)))
// Websocket endpoint
http.HandleFunc("/f", ServeWebSocket)
log.Printf("Starting server on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalln(err)
}
}

View File

@ -7,6 +7,7 @@ package websocket
import (
"bufio"
"errors"
"fmt"
"net"
"net/http"
"net/url"
@ -38,6 +39,11 @@ type Upgrader struct {
// requested by the client.
Subprotocols []string
// Server supported extensions (e.g. permessage-deflate)
// Currently only 'permessage-deflate; server_no_context_takeover; client_no_context_takeover'
// is supported. i.e. no deflate options at this time.
Extensions []string
// Error specifies the function for generating HTTP error responses. If Error
// is nil, then http.Error is used to generate the HTTP response.
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
@ -87,6 +93,32 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
return ""
}
// Check for compression support, select compression type and
// return extension with server supported options along with
// whether valid compression headers were found.
func (u *Upgrader) selectCompressionExtension(r *http.Request) (string, bool, error) {
extensions := r.Header.Get("Sec-WebSocket-Extensions")
if u.Extensions != nil && len(extensions) > 0 {
extOpts := strings.Split(extensions, " ")
if len(extOpts) > 0 && strings.HasPrefix(extOpts[0], "permessage-deflate") {
ext := strings.TrimSuffix(extOpts[0], ";")
// Find and return extension with supported options.
for _, e := range u.Extensions {
// Check if server supports supplied extension
if strings.HasPrefix(e, ext) {
if e != CompressPermessageDeflate {
return "", false, fmt.Errorf("Compression options not supported: %s", e)
}
return e, true, nil
}
}
}
}
return "", false, nil
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
@ -147,6 +179,10 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.subprotocol = subprotocol
_, c.compressionNegotiated, err = u.selectCompressionExtension(r)
if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
}
p := c.writeBuf[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
@ -175,6 +211,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
p = append(p, "\r\n"...)
}
}
// Set the selected compression header if enabled.
if c.compressionNegotiated {
// Turn compression on by default
c.writeCompressionEnabled = true
p = append(p, "Sec-WebSocket-Extensions: "+CompressPermessageDeflate+"\r\n"...)
}
p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server.