mirror of https://github.com/gorilla/websocket.git
Added permessage-deflate support
This commit is contained in:
commit
3eabc85eac
12
client.go
12
client.go
|
@ -69,6 +69,9 @@ type Dialer struct {
|
||||||
|
|
||||||
// Subprotocols specifies the client's requested subprotocols.
|
// Subprotocols specifies the client's requested subprotocols.
|
||||||
Subprotocols []string
|
Subprotocols []string
|
||||||
|
|
||||||
|
// Extensions specifies the client requested extensions
|
||||||
|
Extensions []string
|
||||||
}
|
}
|
||||||
|
|
||||||
var errMalformedURL = errors.New("malformed ws or wss URL")
|
var errMalformedURL = errors.New("malformed ws or wss URL")
|
||||||
|
@ -196,6 +199,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
if len(d.Subprotocols) > 0 {
|
if len(d.Subprotocols) > 0 {
|
||||||
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
|
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
|
||||||
}
|
}
|
||||||
|
if len(d.Extensions) > 0 {
|
||||||
|
req.Header["Sec-WebSocket-Extensions"] = d.Extensions
|
||||||
|
}
|
||||||
|
|
||||||
for k, vs := range requestHeader {
|
for k, vs := range requestHeader {
|
||||||
switch {
|
switch {
|
||||||
case k == "Host":
|
case k == "Host":
|
||||||
|
@ -206,6 +213,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
k == "Connection" ||
|
k == "Connection" ||
|
||||||
k == "Sec-Websocket-Key" ||
|
k == "Sec-Websocket-Key" ||
|
||||||
k == "Sec-Websocket-Version" ||
|
k == "Sec-Websocket-Version" ||
|
||||||
|
(k == "Sec-WebSocket-Extensions" && len(d.Extensions) > 0) ||
|
||||||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
|
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
|
||||||
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
|
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
|
||||||
default:
|
default:
|
||||||
|
@ -328,6 +336,10 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||||
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||||
|
|
||||||
|
if len(resp.Header.Get("Sec-WebSocket-Extensions")) > 0 {
|
||||||
|
conn.compressionNegotiated = true
|
||||||
|
}
|
||||||
|
|
||||||
netConn.SetDeadline(time.Time{})
|
netConn.SetDeadline(time.Time{})
|
||||||
netConn = nil // to avoid close in defer.
|
netConn = nil // to avoid close in defer.
|
||||||
return conn, resp, nil
|
return conn, resp, nil
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
//"bytes"
|
||||||
|
"compress/flate"
|
||||||
|
//"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Supported compression algorithm and parameters.
|
||||||
|
CompressPermessageDeflate = "permessage-deflate; server_no_context_takeover; client_no_context_takeover"
|
||||||
|
|
||||||
|
// Deflate compression level
|
||||||
|
compressDeflateLevel int = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
// Sits between a flate writer and the underlying writer i.e. messageWriter
|
||||||
|
// Truncates last bytes of flate compresses message
|
||||||
|
type FlateAdaptor struct {
|
||||||
|
last5bytes []byte
|
||||||
|
msgWriter io.WriteCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFlateAdaptor(w io.WriteCloser) *FlateAdaptor {
|
||||||
|
return &FlateAdaptor{
|
||||||
|
msgWriter: w,
|
||||||
|
last5bytes: []byte{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aw *FlateAdaptor) Write(p []byte) (n int, err error) {
|
||||||
|
|
||||||
|
t := append(aw.last5bytes, p...)
|
||||||
|
|
||||||
|
if len(t) > 4 {
|
||||||
|
aw.last5bytes = make([]byte, 5)
|
||||||
|
copy(aw.last5bytes, t[len(t)-5:])
|
||||||
|
_, err = aw.msgWriter.Write(t[:len(t)-5])
|
||||||
|
} else {
|
||||||
|
aw.last5bytes = make([]byte, len(t))
|
||||||
|
aw.last5bytes = t
|
||||||
|
}
|
||||||
|
|
||||||
|
n = len(p)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aw *FlateAdaptor) writeEndBlock() (int, error) {
|
||||||
|
var t []byte
|
||||||
|
if aw.last5bytes[4] != 0x00 {
|
||||||
|
t = append(aw.last5bytes, 0x00)
|
||||||
|
}
|
||||||
|
|
||||||
|
return aw.msgWriter.Write(t[:len(t)-5])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (aw *FlateAdaptor) Close() (err error) {
|
||||||
|
if _, err = aw.writeEndBlock(); err == nil {
|
||||||
|
err = aw.msgWriter.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlateAdaptorWriter --> FlateAdaptor --> messageWriter
|
||||||
|
type FlateAdaptorWriter struct {
|
||||||
|
flWriter *flate.Writer
|
||||||
|
flAdaptor *FlateAdaptor
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFlateAdaptorWriter(msgWriter io.WriteCloser, level int) (faw *FlateAdaptorWriter, err error) {
|
||||||
|
faw = &FlateAdaptorWriter{
|
||||||
|
flAdaptor: NewFlateAdaptor(msgWriter),
|
||||||
|
}
|
||||||
|
faw.flWriter, err = flate.NewWriter(faw.flAdaptor, level)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (faw *FlateAdaptorWriter) Write(p []byte) (c int, err error) {
|
||||||
|
if c, err = faw.flWriter.Write(p); err == nil {
|
||||||
|
err = faw.flWriter.Flush()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (faw *FlateAdaptorWriter) Close() (err error) {
|
||||||
|
if err = faw.flWriter.Close(); err == nil {
|
||||||
|
err = faw.flAdaptor.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,26 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/flate"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_NewAdaptorWriter(t *testing.T) {
|
||||||
|
backendBuff := new(bytes.Buffer)
|
||||||
|
aw := NewAdaptorWriter(backendBuff)
|
||||||
|
|
||||||
|
fw, err := flate.NewWriter(aw, -1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var n int
|
||||||
|
n, err = fw.Write([]byte("test"))
|
||||||
|
t.Log(n, err)
|
||||||
|
|
||||||
|
if err = fw.Flush(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
92
conn.go
92
conn.go
|
@ -6,6 +6,7 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"compress/flate"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
@ -13,6 +14,7 @@ import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,6 +23,7 @@ const (
|
||||||
maxControlFramePayloadSize = 125
|
maxControlFramePayloadSize = 125
|
||||||
finalBit = 1 << 7
|
finalBit = 1 << 7
|
||||||
maskBit = 1 << 7
|
maskBit = 1 << 7
|
||||||
|
compressionBit = 1 << 6 // used in flushFrame on writes
|
||||||
writeWait = time.Second
|
writeWait = time.Second
|
||||||
|
|
||||||
defaultReadBufferSize = 4096
|
defaultReadBufferSize = 4096
|
||||||
|
@ -144,11 +147,15 @@ type Conn struct {
|
||||||
isServer bool
|
isServer bool
|
||||||
subprotocol string
|
subprotocol string
|
||||||
|
|
||||||
|
compressionNegotiated bool // negotiated compression based on handshake
|
||||||
|
|
||||||
// Write fields
|
// Write fields
|
||||||
mu chan bool // used as mutex to protect write to conn and closeSent
|
mu chan bool // used as mutex to protect write to conn and closeSent
|
||||||
closeSent bool // true if close message was sent
|
closeSent bool // true if close message was sent
|
||||||
|
|
||||||
// Message writer fields.
|
// Message writer fields.
|
||||||
|
writeCompressionEnabled bool
|
||||||
|
|
||||||
writeErr error
|
writeErr error
|
||||||
writeBuf []byte // frame is constructed in this buffer.
|
writeBuf []byte // frame is constructed in this buffer.
|
||||||
writePos int // end of data in writeBuf.
|
writePos int // end of data in writeBuf.
|
||||||
|
@ -157,6 +164,8 @@ type Conn struct {
|
||||||
writeDeadline time.Time
|
writeDeadline time.Time
|
||||||
|
|
||||||
// Read fields
|
// Read fields
|
||||||
|
readMessageCompressed bool
|
||||||
|
|
||||||
readErr error
|
readErr error
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
readRemaining int64 // bytes remaining in current frame.
|
readRemaining int64 // bytes remaining in current frame.
|
||||||
|
@ -218,6 +227,12 @@ func (c *Conn) RemoteAddr() net.Addr {
|
||||||
|
|
||||||
// Write methods
|
// Write methods
|
||||||
|
|
||||||
|
// EnableWriteCompression enables and disables write compression of subsequent text and
|
||||||
|
// binary messages. This function is a noop if compression was not negotiated with the peer.
|
||||||
|
func (c *Conn) EnableWriteCompression(enable bool) {
|
||||||
|
c.writeCompressionEnabled = enable
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
|
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
|
||||||
<-c.mu
|
<-c.mu
|
||||||
defer func() { c.mu <- true }()
|
defer func() { c.mu <- true }()
|
||||||
|
@ -327,7 +342,15 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.writeFrameType = messageType
|
c.writeFrameType = messageType
|
||||||
return messageWriter{c, c.writeSeq}, nil
|
|
||||||
|
var wc io.WriteCloser = messageWriter{c, c.writeSeq}
|
||||||
|
|
||||||
|
// Return compression writer on data frame
|
||||||
|
if c.compressionNegotiated && c.writeCompressionEnabled && isData(messageType) {
|
||||||
|
return NewFlateAdaptorWriter(wc, compressDeflateLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return wc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) flushFrame(final bool, extra []byte) error {
|
func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||||
|
@ -346,6 +369,13 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||||
if final {
|
if final {
|
||||||
b0 |= finalBit
|
b0 |= finalBit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check compression and that it is not a continuation frame
|
||||||
|
// as those should not have compression bit set per RFC
|
||||||
|
if c.compressionNegotiated && c.writeCompressionEnabled && c.writeFrameType != continuationFrame {
|
||||||
|
b0 |= compressionBit
|
||||||
|
}
|
||||||
|
|
||||||
b1 := byte(0)
|
b1 := byte(0)
|
||||||
if !c.isServer {
|
if !c.isServer {
|
||||||
b1 |= maskBit
|
b1 |= maskBit
|
||||||
|
@ -515,15 +545,30 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w := wr.(messageWriter)
|
|
||||||
if _, err := w.write(true, data); err != nil {
|
if c.compressionNegotiated && c.writeCompressionEnabled {
|
||||||
return err
|
|
||||||
}
|
fw := wr.(*FlateAdaptorWriter)
|
||||||
if c.writeSeq == w.seq {
|
if _, err = fw.Write(data); err != nil {
|
||||||
if err := c.flushFrame(true, nil); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return fw.Close()
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
w := wr.(messageWriter)
|
||||||
|
if _, err = w.write(true, data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// final flush
|
||||||
|
if c.writeSeq == w.seq {
|
||||||
|
if err = c.flushFrame(true, nil); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -577,7 +622,17 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
mask := b[1]&maskBit != 0
|
mask := b[1]&maskBit != 0
|
||||||
c.readRemaining = int64(b[1] & 0x7f)
|
c.readRemaining = int64(b[1] & 0x7f)
|
||||||
|
|
||||||
if reserved != 0 {
|
switch reserved {
|
||||||
|
case 4:
|
||||||
|
if !c.compressionNegotiated {
|
||||||
|
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
||||||
|
}
|
||||||
|
// Only the first frame of a compressed message has the reserved bit set.
|
||||||
|
c.readMessageCompressed = true
|
||||||
|
break
|
||||||
|
case 0:
|
||||||
|
break
|
||||||
|
default:
|
||||||
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -633,7 +688,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
// 5. For text and binary messages, enforce read limit and return.
|
// 5. For text and binary messages, enforce read limit and return.
|
||||||
|
|
||||||
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
|
if frameType == continuationFrame || isData(frameType) {
|
||||||
|
|
||||||
c.readLength += c.readRemaining
|
c.readLength += c.readRemaining
|
||||||
if c.readLimit > 0 && c.readLength > c.readLimit {
|
if c.readLimit > 0 && c.readLength > c.readLimit {
|
||||||
|
@ -696,7 +751,7 @@ func (c *Conn) handleProtocolError(message string) error {
|
||||||
//
|
//
|
||||||
// The NextReader method and the readers returned from the method cannot be
|
// The NextReader method and the readers returned from the method cannot be
|
||||||
// accessed by more than one goroutine at a time.
|
// accessed by more than one goroutine at a time.
|
||||||
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
func (c *Conn) NextReader() (int, io.Reader, error) {
|
||||||
|
|
||||||
c.readSeq++
|
c.readSeq++
|
||||||
c.readLength = 0
|
c.readLength = 0
|
||||||
|
@ -707,8 +762,14 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
c.readErr = hideTempErr(err)
|
c.readErr = hideTempErr(err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if frameType == TextMessage || frameType == BinaryMessage {
|
|
||||||
return frameType, messageReader{c, c.readSeq}, nil
|
if isData(frameType) {
|
||||||
|
var r io.Reader = messageReader{c, c.readSeq}
|
||||||
|
if c.compressionNegotiated && c.readMessageCompressed {
|
||||||
|
// Append compression bytes to output on the final read
|
||||||
|
r = flate.NewReader(io.MultiReader(r, strings.NewReader("\x00\x00\xff\xff\x01\x00\x00\xff\xff")))
|
||||||
|
}
|
||||||
|
return frameType, r, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return noFrame, nil, c.readErr
|
return noFrame, nil, c.readErr
|
||||||
|
@ -742,6 +803,11 @@ func (r messageReader) Read(b []byte) (int, error) {
|
||||||
|
|
||||||
if r.c.readFinal {
|
if r.c.readFinal {
|
||||||
r.c.readSeq++
|
r.c.readSeq++
|
||||||
|
// Reset compression for the next frame
|
||||||
|
if r.c.compressionNegotiated && r.c.readMessageCompressed {
|
||||||
|
r.c.readMessageCompressed = false
|
||||||
|
}
|
||||||
|
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -749,7 +815,7 @@ func (r messageReader) Read(b []byte) (int, error) {
|
||||||
switch {
|
switch {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
r.c.readErr = hideTempErr(err)
|
r.c.readErr = hideTempErr(err)
|
||||||
case frameType == TextMessage || frameType == BinaryMessage:
|
case isData(frameType):
|
||||||
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,3 +11,10 @@ and start the client test driver
|
||||||
wstest -m fuzzingclient -s fuzzingclient.json
|
wstest -m fuzzingclient -s fuzzingclient.json
|
||||||
|
|
||||||
When the client completes, it writes a report to reports/clients/index.html.
|
When the client completes, it writes a report to reports/clients/index.html.
|
||||||
|
|
||||||
|
|
||||||
|
# Install client test driver
|
||||||
|
|
||||||
|
pip install autobahntestsuite
|
||||||
|
|
||||||
|
This will install the test suite containing the `wstest` command.
|
|
@ -8,7 +8,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/euforia/websocket"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -22,6 +22,9 @@ var upgrader = websocket.Upgrader{
|
||||||
CheckOrigin: func(r *http.Request) bool {
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
return true
|
return true
|
||||||
},
|
},
|
||||||
|
Extensions: []string{
|
||||||
|
"permessage-deflate; server_no_context_takeover; client_no_context_takeover",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// echoCopy echoes messages from the client using io.Copy.
|
// echoCopy echoes messages from the client using io.Copy.
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
# Compression example
|
||||||
|
This example covers enabling compression on the server. It starts a websocket server with permessage-deflate enabled for compression. You can then visit the page to send/recieve messages through the browser.
|
||||||
|
|
||||||
|
Start the server by running the following in this directory:
|
||||||
|
|
||||||
|
go run server.go
|
||||||
|
|
||||||
|
You can now navigate to the displayed address in you browser:
|
||||||
|
|
||||||
|
http://localhost:12345/
|
|
@ -0,0 +1,59 @@
|
||||||
|
// +build ignore
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/euforia/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
dialer := websocket.Dialer{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
Extensions: []string{"permessage-deflate"},
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
}
|
||||||
|
|
||||||
|
c, respHdr, err := dialer.Dial("ws://localhost:9001/f", nil)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("dial:", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
log.Printf("Extensions: %s\n", respHdr.Header.Get("Sec-Websocket-Extensions"))
|
||||||
|
|
||||||
|
compressEnabled := true
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer c.Close()
|
||||||
|
for {
|
||||||
|
_, message, err := c.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
log.Println("read:", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Printf("Received: %s", message)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for t := range ticker.C {
|
||||||
|
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("write:", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Wrote: compressed=%v; value=%s\n", compressEnabled, t.String())
|
||||||
|
|
||||||
|
compressEnabled = !compressEnabled
|
||||||
|
c.EnableWriteCompression(compressEnabled)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<title>WebSocket Example</title>
|
||||||
|
<style type="text/css">
|
||||||
|
body { font-family: helvetica; color: #333;}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="conn-details" style="padding: 10px"></div>
|
||||||
|
<table style="width:100%"><tr>
|
||||||
|
<td style="width:50%;vertical-align:top">
|
||||||
|
<div style="text-align:center">
|
||||||
|
<textarea id="inputArea" style="width:100%;min-height:250px"></textarea>
|
||||||
|
</div>
|
||||||
|
<div style="padding:10px"><button onclick="sendData()">send</button></div>
|
||||||
|
</td>
|
||||||
|
<td style="width:50%;padding:5px;">
|
||||||
|
<div style="padding:10px">Echo Response:</div>
|
||||||
|
<pre id="fileData" style="padding:10px;border:none;min-height:250px;margin:0;white-space: pre-wrap;"></pre>
|
||||||
|
</td>
|
||||||
|
</tr></table>
|
||||||
|
|
||||||
|
<script type="text/javascript">
|
||||||
|
|
||||||
|
var inputElem = document.getElementById("inputArea");
|
||||||
|
var conn = new WebSocket("ws://127.0.0.1:9001/f");
|
||||||
|
|
||||||
|
(function() {
|
||||||
|
var data = document.getElementById("fileData");
|
||||||
|
conn.onopen = function(evt) {
|
||||||
|
var connDetails = document.getElementById('conn-details');
|
||||||
|
connDetails.textContent = 'Connection Extensions: [' + conn.extensions + ']';
|
||||||
|
}
|
||||||
|
conn.onclose = function(evt) {
|
||||||
|
console.log(evt);
|
||||||
|
data.textContent = 'Connection closed';
|
||||||
|
}
|
||||||
|
conn.onmessage = function(evt) {
|
||||||
|
console.log('Message:', evt.data);
|
||||||
|
data.textContent += '[ '+ (new Date()).toString() + ' ] ' + evt.data + '\n';
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
|
||||||
|
function sendData() {
|
||||||
|
if (inputElem.value.length > 0) {
|
||||||
|
conn.send(inputElem.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
|
@ -0,0 +1,81 @@
|
||||||
|
// +build ignore
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/euforia/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
upgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
Extensions: []string{websocket.CompressPermessageDeflate},
|
||||||
|
}
|
||||||
|
webroot, _ = filepath.Abs("./")
|
||||||
|
listenAddr = "0.0.0.0:9001"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ServeWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ws, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
log.Printf("Client connected: %s\n", r.RemoteAddr)
|
||||||
|
|
||||||
|
//if err = ws.WriteMessage(websocket.TextMessage, []byte("Hello!")); err != nil {
|
||||||
|
// log.Println(err)
|
||||||
|
// err = nil
|
||||||
|
//}
|
||||||
|
|
||||||
|
for {
|
||||||
|
/*
|
||||||
|
if msgType, msgBytes, err = ws.ReadMessage(); err != nil {
|
||||||
|
log.Printf("Client disconnected %s: %s\n", r.RemoteAddr, err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Printf("type: %d; payload: %d bytes;\n", msgType, len(msgBytes))
|
||||||
|
*/
|
||||||
|
|
||||||
|
msgType, rd, err := ws.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Client disconnected (%s): %s\n", r.RemoteAddr, err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
wr, err := ws.NextWriter(msgType)
|
||||||
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = io.Copy(wr, rd); err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = wr.Close(); err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Serve index.html
|
||||||
|
http.Handle("/", http.FileServer(http.Dir(webroot)))
|
||||||
|
// Websocket endpoint
|
||||||
|
http.HandleFunc("/f", ServeWebSocket)
|
||||||
|
|
||||||
|
log.Printf("Starting server on: %s\n", listenAddr)
|
||||||
|
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
}
|
43
server.go
43
server.go
|
@ -7,6 +7,7 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -38,6 +39,11 @@ type Upgrader struct {
|
||||||
// requested by the client.
|
// requested by the client.
|
||||||
Subprotocols []string
|
Subprotocols []string
|
||||||
|
|
||||||
|
// Server supported extensions (e.g. permessage-deflate)
|
||||||
|
// Currently only 'permessage-deflate; server_no_context_takeover; client_no_context_takeover'
|
||||||
|
// is supported. i.e. no deflate options at this time.
|
||||||
|
Extensions []string
|
||||||
|
|
||||||
// Error specifies the function for generating HTTP error responses. If Error
|
// Error specifies the function for generating HTTP error responses. If Error
|
||||||
// is nil, then http.Error is used to generate the HTTP response.
|
// is nil, then http.Error is used to generate the HTTP response.
|
||||||
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
|
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
|
||||||
|
@ -87,6 +93,32 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for compression support, select compression type and
|
||||||
|
// return extension with server supported options along with
|
||||||
|
// whether valid compression headers were found.
|
||||||
|
func (u *Upgrader) selectCompressionExtension(r *http.Request) (string, bool, error) {
|
||||||
|
extensions := r.Header.Get("Sec-WebSocket-Extensions")
|
||||||
|
|
||||||
|
if u.Extensions != nil && len(extensions) > 0 {
|
||||||
|
|
||||||
|
extOpts := strings.Split(extensions, " ")
|
||||||
|
if len(extOpts) > 0 && strings.HasPrefix(extOpts[0], "permessage-deflate") {
|
||||||
|
ext := strings.TrimSuffix(extOpts[0], ";")
|
||||||
|
// Find and return extension with supported options.
|
||||||
|
for _, e := range u.Extensions {
|
||||||
|
// Check if server supports supplied extension
|
||||||
|
if strings.HasPrefix(e, ext) {
|
||||||
|
if e != CompressPermessageDeflate {
|
||||||
|
return "", false, fmt.Errorf("Compression options not supported: %s", e)
|
||||||
|
}
|
||||||
|
return e, true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
|
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
|
||||||
//
|
//
|
||||||
// The responseHeader is included in the response to the client's upgrade
|
// The responseHeader is included in the response to the client's upgrade
|
||||||
|
@ -147,6 +179,10 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
|
|
||||||
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
|
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
|
||||||
c.subprotocol = subprotocol
|
c.subprotocol = subprotocol
|
||||||
|
_, c.compressionNegotiated, err = u.selectCompressionExtension(r)
|
||||||
|
if err != nil {
|
||||||
|
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
p := c.writeBuf[:0]
|
p := c.writeBuf[:0]
|
||||||
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
|
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
|
||||||
|
@ -175,6 +211,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
p = append(p, "\r\n"...)
|
p = append(p, "\r\n"...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Set the selected compression header if enabled.
|
||||||
|
if c.compressionNegotiated {
|
||||||
|
// Turn compression on by default
|
||||||
|
c.writeCompressionEnabled = true
|
||||||
|
p = append(p, "Sec-WebSocket-Extensions: "+CompressPermessageDeflate+"\r\n"...)
|
||||||
|
}
|
||||||
|
|
||||||
p = append(p, "\r\n"...)
|
p = append(p, "\r\n"...)
|
||||||
|
|
||||||
// Clear deadlines set by HTTP server.
|
// Clear deadlines set by HTTP server.
|
||||||
|
|
Loading…
Reference in New Issue