mirror of https://github.com/gorilla/websocket.git
Added permessage-deflate support
This commit is contained in:
commit
3eabc85eac
12
client.go
12
client.go
|
@ -69,6 +69,9 @@ type Dialer struct {
|
|||
|
||||
// Subprotocols specifies the client's requested subprotocols.
|
||||
Subprotocols []string
|
||||
|
||||
// Extensions specifies the client requested extensions
|
||||
Extensions []string
|
||||
}
|
||||
|
||||
var errMalformedURL = errors.New("malformed ws or wss URL")
|
||||
|
@ -196,6 +199,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
if len(d.Subprotocols) > 0 {
|
||||
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
|
||||
}
|
||||
if len(d.Extensions) > 0 {
|
||||
req.Header["Sec-WebSocket-Extensions"] = d.Extensions
|
||||
}
|
||||
|
||||
for k, vs := range requestHeader {
|
||||
switch {
|
||||
case k == "Host":
|
||||
|
@ -206,6 +213,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
k == "Connection" ||
|
||||
k == "Sec-Websocket-Key" ||
|
||||
k == "Sec-Websocket-Version" ||
|
||||
(k == "Sec-WebSocket-Extensions" && len(d.Extensions) > 0) ||
|
||||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
|
||||
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
|
||||
default:
|
||||
|
@ -328,6 +336,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||
|
||||
if len(resp.Header.Get("Sec-WebSocket-Extensions")) > 0 {
|
||||
conn.compressionNegotiated = true
|
||||
}
|
||||
|
||||
netConn.SetDeadline(time.Time{})
|
||||
netConn = nil // to avoid close in defer.
|
||||
return conn, resp, nil
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
//"bytes"
|
||||
"compress/flate"
|
||||
//"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
// Supported compression algorithm and parameters.
|
||||
CompressPermessageDeflate = "permessage-deflate; server_no_context_takeover; client_no_context_takeover"
|
||||
|
||||
// Deflate compression level
|
||||
compressDeflateLevel int = 3
|
||||
)
|
||||
|
||||
// Sits between a flate writer and the underlying writer i.e. messageWriter
|
||||
// Truncates last bytes of flate compresses message
|
||||
type FlateAdaptor struct {
|
||||
last5bytes []byte
|
||||
msgWriter io.WriteCloser
|
||||
}
|
||||
|
||||
func NewFlateAdaptor(w io.WriteCloser) *FlateAdaptor {
|
||||
return &FlateAdaptor{
|
||||
msgWriter: w,
|
||||
last5bytes: []byte{},
|
||||
}
|
||||
}
|
||||
|
||||
func (aw *FlateAdaptor) Write(p []byte) (n int, err error) {
|
||||
|
||||
t := append(aw.last5bytes, p...)
|
||||
|
||||
if len(t) > 4 {
|
||||
aw.last5bytes = make([]byte, 5)
|
||||
copy(aw.last5bytes, t[len(t)-5:])
|
||||
_, err = aw.msgWriter.Write(t[:len(t)-5])
|
||||
} else {
|
||||
aw.last5bytes = make([]byte, len(t))
|
||||
aw.last5bytes = t
|
||||
}
|
||||
|
||||
n = len(p)
|
||||
return
|
||||
}
|
||||
|
||||
func (aw *FlateAdaptor) writeEndBlock() (int, error) {
|
||||
var t []byte
|
||||
if aw.last5bytes[4] != 0x00 {
|
||||
t = append(aw.last5bytes, 0x00)
|
||||
}
|
||||
|
||||
return aw.msgWriter.Write(t[:len(t)-5])
|
||||
}
|
||||
|
||||
func (aw *FlateAdaptor) Close() (err error) {
|
||||
if _, err = aw.writeEndBlock(); err == nil {
|
||||
err = aw.msgWriter.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FlateAdaptorWriter --> FlateAdaptor --> messageWriter
|
||||
type FlateAdaptorWriter struct {
|
||||
flWriter *flate.Writer
|
||||
flAdaptor *FlateAdaptor
|
||||
}
|
||||
|
||||
func NewFlateAdaptorWriter(msgWriter io.WriteCloser, level int) (faw *FlateAdaptorWriter, err error) {
|
||||
faw = &FlateAdaptorWriter{
|
||||
flAdaptor: NewFlateAdaptor(msgWriter),
|
||||
}
|
||||
faw.flWriter, err = flate.NewWriter(faw.flAdaptor, level)
|
||||
return
|
||||
}
|
||||
|
||||
func (faw *FlateAdaptorWriter) Write(p []byte) (c int, err error) {
|
||||
if c, err = faw.flWriter.Write(p); err == nil {
|
||||
err = faw.flWriter.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (faw *FlateAdaptorWriter) Close() (err error) {
|
||||
if err = faw.flWriter.Close(); err == nil {
|
||||
err = faw.flAdaptor.Close()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_NewAdaptorWriter(t *testing.T) {
|
||||
backendBuff := new(bytes.Buffer)
|
||||
aw := NewAdaptorWriter(backendBuff)
|
||||
|
||||
fw, err := flate.NewWriter(aw, -1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var n int
|
||||
n, err = fw.Write([]byte("test"))
|
||||
t.Log(n, err)
|
||||
|
||||
if err = fw.Flush(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
92
conn.go
92
conn.go
|
@ -6,6 +6,7 @@ package websocket
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"compress/flate"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
|
@ -13,6 +14,7 @@ import (
|
|||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -21,6 +23,7 @@ const (
|
|||
maxControlFramePayloadSize = 125
|
||||
finalBit = 1 << 7
|
||||
maskBit = 1 << 7
|
||||
compressionBit = 1 << 6 // used in flushFrame on writes
|
||||
writeWait = time.Second
|
||||
|
||||
defaultReadBufferSize = 4096
|
||||
|
@ -144,11 +147,15 @@ type Conn struct {
|
|||
isServer bool
|
||||
subprotocol string
|
||||
|
||||
compressionNegotiated bool // negotiated compression based on handshake
|
||||
|
||||
// Write fields
|
||||
mu chan bool // used as mutex to protect write to conn and closeSent
|
||||
closeSent bool // true if close message was sent
|
||||
|
||||
// Message writer fields.
|
||||
writeCompressionEnabled bool
|
||||
|
||||
writeErr error
|
||||
writeBuf []byte // frame is constructed in this buffer.
|
||||
writePos int // end of data in writeBuf.
|
||||
|
@ -157,6 +164,8 @@ type Conn struct {
|
|||
writeDeadline time.Time
|
||||
|
||||
// Read fields
|
||||
readMessageCompressed bool
|
||||
|
||||
readErr error
|
||||
br *bufio.Reader
|
||||
readRemaining int64 // bytes remaining in current frame.
|
||||
|
@ -218,6 +227,12 @@ func (c *Conn) RemoteAddr() net.Addr {
|
|||
|
||||
// Write methods
|
||||
|
||||
// EnableWriteCompression enables and disables write compression of subsequent text and
|
||||
// binary messages. This function is a noop if compression was not negotiated with the peer.
|
||||
func (c *Conn) EnableWriteCompression(enable bool) {
|
||||
c.writeCompressionEnabled = enable
|
||||
}
|
||||
|
||||
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
|
||||
<-c.mu
|
||||
defer func() { c.mu <- true }()
|
||||
|
@ -327,7 +342,15 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
|||
}
|
||||
|
||||
c.writeFrameType = messageType
|
||||
return messageWriter{c, c.writeSeq}, nil
|
||||
|
||||
var wc io.WriteCloser = messageWriter{c, c.writeSeq}
|
||||
|
||||
// Return compression writer on data frame
|
||||
if c.compressionNegotiated && c.writeCompressionEnabled && isData(messageType) {
|
||||
return NewFlateAdaptorWriter(wc, compressDeflateLevel)
|
||||
}
|
||||
|
||||
return wc, nil
|
||||
}
|
||||
|
||||
func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||
|
@ -346,6 +369,13 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|||
if final {
|
||||
b0 |= finalBit
|
||||
}
|
||||
|
||||
// Check compression and that it is not a continuation frame
|
||||
// as those should not have compression bit set per RFC
|
||||
if c.compressionNegotiated && c.writeCompressionEnabled && c.writeFrameType != continuationFrame {
|
||||
b0 |= compressionBit
|
||||
}
|
||||
|
||||
b1 := byte(0)
|
||||
if !c.isServer {
|
||||
b1 |= maskBit
|
||||
|
@ -515,15 +545,30 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w := wr.(messageWriter)
|
||||
if _, err := w.write(true, data); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.writeSeq == w.seq {
|
||||
if err := c.flushFrame(true, nil); err != nil {
|
||||
|
||||
if c.compressionNegotiated && c.writeCompressionEnabled {
|
||||
|
||||
fw := wr.(*FlateAdaptorWriter)
|
||||
if _, err = fw.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
return fw.Close()
|
||||
|
||||
} else {
|
||||
|
||||
w := wr.(messageWriter)
|
||||
if _, err = w.write(true, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// final flush
|
||||
if c.writeSeq == w.seq {
|
||||
if err = c.flushFrame(true, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -577,7 +622,17 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
mask := b[1]&maskBit != 0
|
||||
c.readRemaining = int64(b[1] & 0x7f)
|
||||
|
||||
if reserved != 0 {
|
||||
switch reserved {
|
||||
case 4:
|
||||
if !c.compressionNegotiated {
|
||||
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
||||
}
|
||||
// Only the first frame of a compressed message has the reserved bit set.
|
||||
c.readMessageCompressed = true
|
||||
break
|
||||
case 0:
|
||||
break
|
||||
default:
|
||||
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
||||
}
|
||||
|
||||
|
@ -633,7 +688,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
|
||||
// 5. For text and binary messages, enforce read limit and return.
|
||||
|
||||
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
|
||||
if frameType == continuationFrame || isData(frameType) {
|
||||
|
||||
c.readLength += c.readRemaining
|
||||
if c.readLimit > 0 && c.readLength > c.readLimit {
|
||||
|
@ -696,7 +751,7 @@ func (c *Conn) handleProtocolError(message string) error {
|
|||
//
|
||||
// The NextReader method and the readers returned from the method cannot be
|
||||
// accessed by more than one goroutine at a time.
|
||||
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||
func (c *Conn) NextReader() (int, io.Reader, error) {
|
||||
|
||||
c.readSeq++
|
||||
c.readLength = 0
|
||||
|
@ -707,8 +762,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|||
c.readErr = hideTempErr(err)
|
||||
break
|
||||
}
|
||||
if frameType == TextMessage || frameType == BinaryMessage {
|
||||
return frameType, messageReader{c, c.readSeq}, nil
|
||||
|
||||
if isData(frameType) {
|
||||
var r io.Reader = messageReader{c, c.readSeq}
|
||||
if c.compressionNegotiated && c.readMessageCompressed {
|
||||
// Append compression bytes to output on the final read
|
||||
r = flate.NewReader(io.MultiReader(r, strings.NewReader("\x00\x00\xff\xff\x01\x00\x00\xff\xff")))
|
||||
}
|
||||
return frameType, r, nil
|
||||
}
|
||||
}
|
||||
return noFrame, nil, c.readErr
|
||||
|
@ -742,6 +803,11 @@ func (r messageReader) Read(b []byte) (int, error) {
|
|||
|
||||
if r.c.readFinal {
|
||||
r.c.readSeq++
|
||||
// Reset compression for the next frame
|
||||
if r.c.compressionNegotiated && r.c.readMessageCompressed {
|
||||
r.c.readMessageCompressed = false
|
||||
}
|
||||
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
|
@ -749,7 +815,7 @@ func (r messageReader) Read(b []byte) (int, error) {
|
|||
switch {
|
||||
case err != nil:
|
||||
r.c.readErr = hideTempErr(err)
|
||||
case frameType == TextMessage || frameType == BinaryMessage:
|
||||
case isData(frameType):
|
||||
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,3 +11,10 @@ and start the client test driver
|
|||
wstest -m fuzzingclient -s fuzzingclient.json
|
||||
|
||||
When the client completes, it writes a report to reports/clients/index.html.
|
||||
|
||||
|
||||
# Install client test driver
|
||||
|
||||
pip install autobahntestsuite
|
||||
|
||||
This will install the test suite containing the `wstest` command.
|
|
@ -8,7 +8,7 @@ package main
|
|||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/euforia/websocket"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
@ -22,6 +22,9 @@ var upgrader = websocket.Upgrader{
|
|||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
Extensions: []string{
|
||||
"permessage-deflate; server_no_context_takeover; client_no_context_takeover",
|
||||
},
|
||||
}
|
||||
|
||||
// echoCopy echoes messages from the client using io.Copy.
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
# Compression example
|
||||
This example covers enabling compression on the server. It starts a websocket server with permessage-deflate enabled for compression. You can then visit the page to send/recieve messages through the browser.
|
||||
|
||||
Start the server by running the following in this directory:
|
||||
|
||||
go run server.go
|
||||
|
||||
You can now navigate to the displayed address in you browser:
|
||||
|
||||
http://localhost:12345/
|
|
@ -0,0 +1,59 @@
|
|||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/euforia/websocket"
|
||||
)
|
||||
|
||||
func main() {
|
||||
dialer := websocket.Dialer{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
Extensions: []string{"permessage-deflate"},
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}
|
||||
|
||||
c, respHdr, err := dialer.Dial("ws://localhost:9001/f", nil)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal("dial:", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
log.Printf("Extensions: %s\n", respHdr.Header.Get("Sec-Websocket-Extensions"))
|
||||
|
||||
compressEnabled := true
|
||||
|
||||
go func() {
|
||||
defer c.Close()
|
||||
for {
|
||||
_, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("read:", err)
|
||||
break
|
||||
}
|
||||
log.Printf("Received: %s", message)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for t := range ticker.C {
|
||||
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
|
||||
if err != nil {
|
||||
log.Println("write:", err)
|
||||
break
|
||||
}
|
||||
|
||||
log.Printf("Wrote: compressed=%v; value=%s\n", compressEnabled, t.String())
|
||||
|
||||
compressEnabled = !compressEnabled
|
||||
c.EnableWriteCompression(compressEnabled)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>WebSocket Example</title>
|
||||
<style type="text/css">
|
||||
body { font-family: helvetica; color: #333;}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="conn-details" style="padding: 10px"></div>
|
||||
<table style="width:100%"><tr>
|
||||
<td style="width:50%;vertical-align:top">
|
||||
<div style="text-align:center">
|
||||
<textarea id="inputArea" style="width:100%;min-height:250px"></textarea>
|
||||
</div>
|
||||
<div style="padding:10px"><button onclick="sendData()">send</button></div>
|
||||
</td>
|
||||
<td style="width:50%;padding:5px;">
|
||||
<div style="padding:10px">Echo Response:</div>
|
||||
<pre id="fileData" style="padding:10px;border:none;min-height:250px;margin:0;white-space: pre-wrap;"></pre>
|
||||
</td>
|
||||
</tr></table>
|
||||
|
||||
<script type="text/javascript">
|
||||
|
||||
var inputElem = document.getElementById("inputArea");
|
||||
var conn = new WebSocket("ws://127.0.0.1:9001/f");
|
||||
|
||||
(function() {
|
||||
var data = document.getElementById("fileData");
|
||||
conn.onopen = function(evt) {
|
||||
var connDetails = document.getElementById('conn-details');
|
||||
connDetails.textContent = 'Connection Extensions: [' + conn.extensions + ']';
|
||||
}
|
||||
conn.onclose = function(evt) {
|
||||
console.log(evt);
|
||||
data.textContent = 'Connection closed';
|
||||
}
|
||||
conn.onmessage = function(evt) {
|
||||
console.log('Message:', evt.data);
|
||||
data.textContent += '[ '+ (new Date()).toString() + ' ] ' + evt.data + '\n';
|
||||
}
|
||||
})();
|
||||
|
||||
function sendData() {
|
||||
if (inputElem.value.length > 0) {
|
||||
conn.send(inputElem.value);
|
||||
}
|
||||
}
|
||||
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
|
@ -0,0 +1,81 @@
|
|||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/euforia/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Extensions: []string{websocket.CompressPermessageDeflate},
|
||||
}
|
||||
webroot, _ = filepath.Abs("./")
|
||||
listenAddr = "0.0.0.0:9001"
|
||||
)
|
||||
|
||||
func ServeWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
ws, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
log.Printf("Client connected: %s\n", r.RemoteAddr)
|
||||
|
||||
//if err = ws.WriteMessage(websocket.TextMessage, []byte("Hello!")); err != nil {
|
||||
// log.Println(err)
|
||||
// err = nil
|
||||
//}
|
||||
|
||||
for {
|
||||
/*
|
||||
if msgType, msgBytes, err = ws.ReadMessage(); err != nil {
|
||||
log.Printf("Client disconnected %s: %s\n", r.RemoteAddr, err)
|
||||
break
|
||||
}
|
||||
log.Printf("type: %d; payload: %d bytes;\n", msgType, len(msgBytes))
|
||||
*/
|
||||
|
||||
msgType, rd, err := ws.NextReader()
|
||||
if err != nil {
|
||||
log.Printf("Client disconnected (%s): %s\n", r.RemoteAddr, err)
|
||||
break
|
||||
}
|
||||
|
||||
wr, err := ws.NextWriter(msgType)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err = io.Copy(wr, rd); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
||||
if err = wr.Close(); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Serve index.html
|
||||
http.Handle("/", http.FileServer(http.Dir(webroot)))
|
||||
// Websocket endpoint
|
||||
http.HandleFunc("/f", ServeWebSocket)
|
||||
|
||||
log.Printf("Starting server on: %s\n", listenAddr)
|
||||
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
}
|
43
server.go
43
server.go
|
@ -7,6 +7,7 @@ package websocket
|
|||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -38,6 +39,11 @@ type Upgrader struct {
|
|||
// requested by the client.
|
||||
Subprotocols []string
|
||||
|
||||
// Server supported extensions (e.g. permessage-deflate)
|
||||
// Currently only 'permessage-deflate; server_no_context_takeover; client_no_context_takeover'
|
||||
// is supported. i.e. no deflate options at this time.
|
||||
Extensions []string
|
||||
|
||||
// Error specifies the function for generating HTTP error responses. If Error
|
||||
// is nil, then http.Error is used to generate the HTTP response.
|
||||
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
|
||||
|
@ -87,6 +93,32 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
|
|||
return ""
|
||||
}
|
||||
|
||||
// Check for compression support, select compression type and
|
||||
// return extension with server supported options along with
|
||||
// whether valid compression headers were found.
|
||||
func (u *Upgrader) selectCompressionExtension(r *http.Request) (string, bool, error) {
|
||||
extensions := r.Header.Get("Sec-WebSocket-Extensions")
|
||||
|
||||
if u.Extensions != nil && len(extensions) > 0 {
|
||||
|
||||
extOpts := strings.Split(extensions, " ")
|
||||
if len(extOpts) > 0 && strings.HasPrefix(extOpts[0], "permessage-deflate") {
|
||||
ext := strings.TrimSuffix(extOpts[0], ";")
|
||||
// Find and return extension with supported options.
|
||||
for _, e := range u.Extensions {
|
||||
// Check if server supports supplied extension
|
||||
if strings.HasPrefix(e, ext) {
|
||||
if e != CompressPermessageDeflate {
|
||||
return "", false, fmt.Errorf("Compression options not supported: %s", e)
|
||||
}
|
||||
return e, true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
|
||||
//
|
||||
// The responseHeader is included in the response to the client's upgrade
|
||||
|
@ -147,6 +179,10 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|||
|
||||
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
|
||||
c.subprotocol = subprotocol
|
||||
_, c.compressionNegotiated, err = u.selectCompressionExtension(r)
|
||||
if err != nil {
|
||||
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
p := c.writeBuf[:0]
|
||||
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
|
||||
|
@ -175,6 +211,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|||
p = append(p, "\r\n"...)
|
||||
}
|
||||
}
|
||||
// Set the selected compression header if enabled.
|
||||
if c.compressionNegotiated {
|
||||
// Turn compression on by default
|
||||
c.writeCompressionEnabled = true
|
||||
p = append(p, "Sec-WebSocket-Extensions: "+CompressPermessageDeflate+"\r\n"...)
|
||||
}
|
||||
|
||||
p = append(p, "\r\n"...)
|
||||
|
||||
// Clear deadlines set by HTTP server.
|
||||
|
|
Loading…
Reference in New Issue