Merge pull request #167 from garyburd/master

enable compression
This commit is contained in:
Gary Burd 2016-10-17 17:39:55 -07:00 committed by GitHub
commit 8003df83ee
7 changed files with 163 additions and 55 deletions

View File

@ -23,6 +23,8 @@ import (
// invalid. // invalid.
var ErrBadHandshake = errors.New("websocket: bad handshake") var ErrBadHandshake = errors.New("websocket: bad handshake")
var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
// NewClient creates a new client connection using the given net connection. // NewClient creates a new client connection using the given net connection.
// The URL u specifies the host and request URI. Use requestHeader to specify // The URL u specifies the host and request URI. Use requestHeader to specify
// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
@ -70,6 +72,12 @@ type Dialer struct {
// Subprotocols specifies the client's requested subprotocols. // Subprotocols specifies the client's requested subprotocols.
Subprotocols []string Subprotocols []string
// EnableCompression specifies if the client should attempt to negotiate
// per message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
} }
var errMalformedURL = errors.New("malformed ws or wss URL") var errMalformedURL = errors.New("malformed ws or wss URL")
@ -214,6 +222,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
k == "Connection" || k == "Connection" ||
k == "Sec-Websocket-Key" || k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" || k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
default: default:
@ -221,6 +230,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
} }
} }
if d.EnableCompression {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
}
hostPort, hostNoPort := hostPortNoPort(u) hostPort, hostNoPort := hostPortNoPort(u)
var proxyURL *url.URL var proxyURL *url.URL
@ -337,6 +350,20 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
return nil, resp, ErrBadHandshake return nil, resp, ErrBadHandshake
} }
for _, ext := range parseExtensions(req.Header) {
if ext[""] != "permessage-deflate" {
continue
}
_, snct := ext["server_no_context_takeover"]
_, cnct := ext["client_no_context_takeover"]
if !snct || !cnct {
return nil, resp, errInvalidCompression
}
conn.newCompressionWriter = compressNoContextTakeover
conn.newDecompressionReader = decompressNoContextTakeover
break
}
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")

View File

@ -23,6 +23,7 @@ var cstUpgrader = Upgrader{
Subprotocols: []string{"p0", "p1"}, Subprotocols: []string{"p0", "p1"},
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
EnableCompression: true,
Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
http.Error(w, reason.Error(), status) http.Error(w, reason.Error(), status)
}, },
@ -446,3 +447,17 @@ func TestHostHeader(t *testing.T) {
sendRecv(t, ws) sendRecv(t, ws)
} }
func TestDialCompression(t *testing.T) {
s := newServer(t)
defer s.Close()
dialer := cstDialer
dialer.EnableCompression = true
ws, _, err := dialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}

View File

@ -985,6 +985,13 @@ func (c *Conn) UnderlyingConn() net.Conn {
return c.conn return c.conn
} }
// EnableWriteCompression enables and disables write compression of
// subsequent text and binary messages. This function is a noop if
// compression was not negotiated with the peer.
func (c *Conn) EnableWriteCompression(enable bool) {
c.enableWriteCompression = enable
}
// FormatCloseMessage formats closeCode and text as a WebSocket close message. // FormatCloseMessage formats closeCode and text as a WebSocket close message.
func FormatCloseMessage(closeCode int, text string) []byte { func FormatCloseMessage(closeCode int, text string) []byte {
buf := make([]byte, 2+len(text)) buf := make([]byte, 2+len(text))

View File

@ -48,16 +48,20 @@ func TestFraming(t *testing.T) {
writeBuf[i] = byte(i) writeBuf[i] = byte(i)
} }
for _, compress := range []bool{false, true} {
for _, isServer := range []bool{true, false} { for _, isServer := range []bool{true, false} {
for _, chunker := range readChunkers { for _, chunker := range readChunkers {
var connBuf bytes.Buffer var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
if compress {
wc.newCompressionWriter = compressNoContextTakeover
rc.newDecompressionReader = decompressNoContextTakeover
}
for _, n := range frameSizes { for _, n := range frameSizes {
for _, iocopy := range []bool{true, false} { for _, iocopy := range []bool{true, false} {
name := fmt.Sprintf("s:%v, r:%s, n:%d c:%v", isServer, chunker.name, n, iocopy) name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d c:%v", compress, isServer, chunker.name, n, iocopy)
w, err := wc.NextWriter(TextMessage) w, err := wc.NextWriter(TextMessage)
if err != nil { if err != nil {
@ -109,6 +113,7 @@ func TestFraming(t *testing.T) {
} }
} }
} }
}
func TestControl(t *testing.T) { func TestControl(t *testing.T) {
const message = "this is a ping/pong messsage" const message = "this is a ping/pong messsage"

21
doc.go
View File

@ -149,4 +149,25 @@
// The deprecated Upgrade function does not enforce an origin policy. It's the // The deprecated Upgrade function does not enforce an origin policy. It's the
// application's responsibility to check the Origin header before calling // application's responsibility to check the Origin header before calling
// Upgrade. // Upgrade.
//
// Compression [Experimental]
//
// Per message compression extensions (RFC 7692) are experimentally supported
// by this package in a limited capacity. Setting the EnableCompression option
// to true in Dialer or Upgrader will attempt to negotiate per message deflate
// support. If compression was successfully negotiated with the connection's
// peer, any message received in compressed form will be automatically
// decompressed. All Read methods will return uncompressed bytes.
//
// Per message compression of messages written to a connection can be enabled
// or disabled by calling the corresponding Conn method:
//
// conn.EnableWriteCompression(true)
//
// Currently this package does not support compression with "context takeover".
// This means that messages must be compressed and decompressed in isolation,
// without retaining sliding window or dictionary state across messages. For
// more details refer to RFC 7692.
//
// Use of compression is experimental and may result in decreased performance.
package websocket package websocket

View File

@ -8,17 +8,19 @@ package main
import ( import (
"errors" "errors"
"flag" "flag"
"github.com/gorilla/websocket"
"io" "io"
"log" "log"
"net/http" "net/http"
"time" "time"
"unicode/utf8" "unicode/utf8"
"github.com/gorilla/websocket"
) )
var upgrader = websocket.Upgrader{ var upgrader = websocket.Upgrader{
ReadBufferSize: 4096, ReadBufferSize: 4096,
WriteBufferSize: 4096, WriteBufferSize: 4096,
EnableCompression: true,
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return true return true
}, },

View File

@ -46,6 +46,12 @@ type Upgrader struct {
// CheckOrigin is nil, the host in the Origin header must not be set or // CheckOrigin is nil, the host in the Origin header must not be set or
// must match the host of the request. // must match the host of the request.
CheckOrigin func(r *http.Request) bool CheckOrigin func(r *http.Request) bool
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
} }
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
@ -100,6 +106,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
if r.Method != "GET" { if r.Method != "GET" {
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET") return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET")
} }
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific Sec-Websocket-Extensions headers are unsupported")
}
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13") return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
} }
@ -127,6 +138,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
subprotocol := u.selectSubprotocol(r, responseHeader) subprotocol := u.selectSubprotocol(r, responseHeader)
// Negotiate PMCE
var compress bool
if u.EnableCompression {
for _, ext := range parseExtensions(r.Header) {
if ext[""] != "permessage-deflate" {
continue
}
compress = true
break
}
}
var ( var (
netConn net.Conn netConn net.Conn
br *bufio.Reader br *bufio.Reader
@ -152,6 +175,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c.subprotocol = subprotocol c.subprotocol = subprotocol
if compress {
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
p := c.writeBuf[:0] p := c.writeBuf[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...) p = append(p, computeAcceptKey(challengeKey)...)
@ -161,6 +189,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
p = append(p, c.subprotocol...) p = append(p, c.subprotocol...)
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
} }
if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
for k, vs := range responseHeader { for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" { if k == "Sec-Websocket-Protocol" {
continue continue