diff --git a/client.go b/client.go index 879d33e..a9323e7 100644 --- a/client.go +++ b/client.go @@ -23,6 +23,8 @@ import ( // invalid. var ErrBadHandshake = errors.New("websocket: bad handshake") +var errInvalidCompression = errors.New("websocket: invalid compression negotiation") + // NewClient creates a new client connection using the given net connection. // The URL u specifies the host and request URI. Use requestHeader to specify // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies @@ -70,6 +72,12 @@ type Dialer struct { // Subprotocols specifies the client's requested subprotocols. Subprotocols []string + + // EnableCompression specifies if the client should attempt to negotiate + // per message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool } var errMalformedURL = errors.New("malformed ws or wss URL") @@ -214,6 +222,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re k == "Connection" || k == "Sec-Websocket-Key" || k == "Sec-Websocket-Version" || + k == "Sec-Websocket-Extensions" || (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) default: @@ -221,6 +230,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } } + if d.EnableCompression { + req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + } + hostPort, hostNoPort := hostPortNoPort(u) var proxyURL *url.URL @@ -337,6 +350,20 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, resp, ErrBadHandshake } + for _, ext := range parseExtensions(req.Header) { + if ext[""] != "permessage-deflate" { + continue + } + _, snct := ext["server_no_context_takeover"] + _, cnct := ext["client_no_context_takeover"] + if !snct || !cnct { + return nil, resp, errInvalidCompression + } + conn.newCompressionWriter = compressNoContextTakeover + conn.newDecompressionReader = decompressNoContextTakeover + break + } + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") diff --git a/client_server_test.go b/client_server_test.go index 1cb9b64..8e1bdca 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -20,9 +20,10 @@ import ( ) var cstUpgrader = Upgrader{ - Subprotocols: []string{"p0", "p1"}, - ReadBufferSize: 1024, - WriteBufferSize: 1024, + Subprotocols: []string{"p0", "p1"}, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + EnableCompression: true, Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { http.Error(w, reason.Error(), status) }, @@ -446,3 +447,17 @@ func TestHostHeader(t *testing.T) { sendRecv(t, ws) } + +func TestDialCompression(t *testing.T) { + s := newServer(t) + defer s.Close() + + dialer := cstDialer + dialer.EnableCompression = true + ws, _, err := dialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} diff --git a/conn.go b/conn.go index eb4334e..d916c1d 100644 --- a/conn.go +++ b/conn.go @@ -985,6 +985,13 @@ func (c *Conn) UnderlyingConn() net.Conn { return c.conn } +// EnableWriteCompression enables and disables write compression of +// subsequent text and binary messages. This function is a noop if +// compression was not negotiated with the peer. +func (c *Conn) EnableWriteCompression(enable bool) { + c.enableWriteCompression = enable +} + // FormatCloseMessage formats closeCode and text as a WebSocket close message. func FormatCloseMessage(closeCode int, text string) []byte { buf := make([]byte, 2+len(text)) diff --git a/conn_test.go b/conn_test.go index 0243c11..3c938cb 100644 --- a/conn_test.go +++ b/conn_test.go @@ -48,60 +48,65 @@ func TestFraming(t *testing.T) { writeBuf[i] = byte(i) } - for _, isServer := range []bool{true, false} { - for _, chunker := range readChunkers { + for _, compress := range []bool{false, true} { + for _, isServer := range []bool{true, false} { + for _, chunker := range readChunkers { - var connBuf bytes.Buffer - wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) - rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) + var connBuf bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) + rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) + if compress { + wc.newCompressionWriter = compressNoContextTakeover + rc.newDecompressionReader = decompressNoContextTakeover + } + for _, n := range frameSizes { + for _, iocopy := range []bool{true, false} { + name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d c:%v", compress, isServer, chunker.name, n, iocopy) - for _, n := range frameSizes { - for _, iocopy := range []bool{true, false} { - name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy) + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + var nn int + if iocopy { + var n64 int64 + n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n])) + nn = int(n64) + } else { + nn, err = w.Write(writeBuf[:n]) + } + if err != nil || nn != n { + t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) + continue + } + err = w.Close() + if err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } - w, err := wc.NextWriter(TextMessage) - if err != nil { - t.Errorf("%s: wc.NextWriter() returned %v", name, err) - continue - } - var nn int - if iocopy { - var n64 int64 - n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n])) - nn = int(n64) - } else { - nn, err = w.Write(writeBuf[:n]) - } - if err != nil || nn != n { - t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) - continue - } - err = w.Close() - if err != nil { - t.Errorf("%s: w.Close() returned %v", name, err) - continue - } + opCode, r, err := rc.NextReader() + if err != nil || opCode != TextMessage { + t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) + continue + } + rbuf, err := ioutil.ReadAll(r) + if err != nil { + t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) + continue + } - opCode, r, err := rc.NextReader() - if err != nil || opCode != TextMessage { - t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) - continue - } - rbuf, err := ioutil.ReadAll(r) - if err != nil { - t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) - continue - } + if len(rbuf) != n { + t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) + continue + } - if len(rbuf) != n { - t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) - continue - } - - for i, b := range rbuf { - if byte(i) != b { - t.Errorf("%s: bad byte at offset %d", name, i) - break + for i, b := range rbuf { + if byte(i) != b { + t.Errorf("%s: bad byte at offset %d", name, i) + break + } } } } diff --git a/doc.go b/doc.go index c901a7a..610acf7 100644 --- a/doc.go +++ b/doc.go @@ -149,4 +149,25 @@ // The deprecated Upgrade function does not enforce an origin policy. It's the // application's responsibility to check the Origin header before calling // Upgrade. +// +// Compression [Experimental] +// +// Per message compression extensions (RFC 7692) are experimentally supported +// by this package in a limited capacity. Setting the EnableCompression option +// to true in Dialer or Upgrader will attempt to negotiate per message deflate +// support. If compression was successfully negotiated with the connection's +// peer, any message received in compressed form will be automatically +// decompressed. All Read methods will return uncompressed bytes. +// +// Per message compression of messages written to a connection can be enabled +// or disabled by calling the corresponding Conn method: +// +// conn.EnableWriteCompression(true) +// +// Currently this package does not support compression with "context takeover". +// This means that messages must be compressed and decompressed in isolation, +// without retaining sliding window or dictionary state across messages. For +// more details refer to RFC 7692. +// +// Use of compression is experimental and may result in decreased performance. package websocket diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index d96ac84..e98563b 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -8,17 +8,19 @@ package main import ( "errors" "flag" - "github.com/gorilla/websocket" "io" "log" "net/http" "time" "unicode/utf8" + + "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + EnableCompression: true, CheckOrigin: func(r *http.Request) bool { return true }, diff --git a/server.go b/server.go index 8402d20..aaedebd 100644 --- a/server.go +++ b/server.go @@ -46,6 +46,12 @@ type Upgrader struct { // CheckOrigin is nil, the host in the Origin header must not be set or // must match the host of the request. CheckOrigin func(r *http.Request) bool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { @@ -100,6 +106,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade if r.Method != "GET" { return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET") } + + if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { + return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific Sec-Websocket-Extensions headers are unsupported") + } + if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13") } @@ -127,6 +138,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade subprotocol := u.selectSubprotocol(r, responseHeader) + // Negotiate PMCE + var compress bool + if u.EnableCompression { + for _, ext := range parseExtensions(r.Header) { + if ext[""] != "permessage-deflate" { + continue + } + compress = true + break + } + } + var ( netConn net.Conn br *bufio.Reader @@ -152,6 +175,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c.subprotocol = subprotocol + if compress { + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover + } + p := c.writeBuf[:0] p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, computeAcceptKey(challengeKey)...) @@ -161,6 +189,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade p = append(p, c.subprotocol...) p = append(p, "\r\n"...) } + if compress { + p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + } for k, vs := range responseHeader { if k == "Sec-Websocket-Protocol" { continue