diff --git a/.circleci/config.yml b/.circleci/config.yml
index 1240d78..ecb33f6 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -1,76 +1,70 @@
-version: 2.0
+version: 2.1
jobs:
- # Base test configuration for Go library tests Each distinct version should
- # inherit this base, and override (at least) the container image used.
- "test": &test
+ "test":
+ parameters:
+ version:
+ type: string
+ default: "latest"
+ golint:
+ type: boolean
+ default: true
+ modules:
+ type: boolean
+ default: true
+ goproxy:
+ type: string
+ default: ""
docker:
- - image: circleci/golang:latest
- working_directory: /go/src/github.com/gorilla/websocket
- steps: &steps
- - checkout
- - run: go version
- - run: go get -t -v ./...
- # Only run gofmt, vet & lint against the latest Go version
- - run: >
- if [[ "$LATEST" = true ]]; then
- go get -u golang.org/x/lint/golint
- golint ./...
- fi
- - run: >
- if [[ "$LATEST" = true ]]; then
- diff -u <(echo -n) <(gofmt -d .)
- fi
- - run: >
- if [[ "$LATEST" = true ]]; then
- go vet -v .
- fi
- - run: if [[ "$LATEST" = true ]]; then go vet -v .; fi
- - run: go test -v -race ./...
-
- "latest":
- <<: *test
+ - image: "cimg/go:<< parameters.version >>"
+ working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket
environment:
- LATEST: true
-
- "1.12":
- <<: *test
- docker:
- - image: circleci/golang:1.12
-
- "1.11":
- <<: *test
- docker:
- - image: circleci/golang:1.11
-
- "1.10":
- <<: *test
- docker:
- - image: circleci/golang:1.10
-
- "1.9":
- <<: *test
- docker:
- - image: circleci/golang:1.9
-
- "1.8":
- <<: *test
- docker:
- - image: circleci/golang:1.8
-
- "1.7":
- <<: *test
- docker:
- - image: circleci/golang:1.7
+ GO111MODULE: "on"
+ GOPROXY: "<< parameters.goproxy >>"
+ steps:
+ - checkout
+ - run:
+ name: "Print the Go version"
+ command: >
+ go version
+ - run:
+ name: "Fetch dependencies"
+ command: >
+ if [[ << parameters.modules >> = true ]]; then
+ go mod download
+ export GO111MODULE=on
+ else
+ go get -v ./...
+ fi
+ # Only run gofmt, vet & lint against the latest Go version
+ - run:
+ name: "Run golint"
+ command: >
+ if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
+ go get -u golang.org/x/lint/golint
+ golint ./...
+ fi
+ - run:
+ name: "Run gofmt"
+ command: >
+ if [[ << parameters.version >> = "latest" ]]; then
+ diff -u <(echo -n) <(gofmt -d -e .)
+ fi
+ - run:
+ name: "Run go vet"
+ command: >
+ if [[ << parameters.version >> = "latest" ]]; then
+ go vet -v ./...
+ fi
+ - run:
+ name: "Run go test (+ race detector)"
+ command: >
+ go test -v -race ./...
workflows:
- version: 2
- build:
+ tests:
jobs:
- - "latest"
- - "1.12"
- - "1.11"
- - "1.10"
- - "1.9"
- - "1.8"
- - "1.7"
+ - test:
+ matrix:
+ parameters:
+ version: ["1.18", "1.17", "1.16"]
diff --git a/README.md b/README.md
index 19aa2e7..d33ed7f 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,7 @@
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
+
### Documentation
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
@@ -30,35 +31,3 @@ The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
-### Gorilla WebSocket compared with other packages
-
-
-
-Notes:
-
-1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
-2. The application can get the type of a received data message by implementing
- a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
- function.
-3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
- Read returns when the input buffer is full or a frame boundary is
- encountered. Each call to Write sends a single frame message. The Gorilla
- io.Reader and io.WriteCloser operate on a single WebSocket message.
-
diff --git a/client.go b/client.go
index 524e542..64f925c 100644
--- a/client.go
+++ b/client.go
@@ -9,6 +9,7 @@ import (
"context"
"crypto/tls"
"errors"
+ "fmt"
"io"
"io/ioutil"
"net"
@@ -48,15 +49,23 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
}
// A Dialer contains options for connecting to WebSocket server.
+//
+// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
- // NetDialContext is nil, net.DialContext is used.
+ // NetDialContext is nil, NetDial is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
+ // NetDialTLSContext is nil, NetDialContext is used.
+ // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
+ // TLSClientConfig is ignored.
+ NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
@@ -65,6 +74,8 @@ type Dialer struct {
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
+ // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
+ // is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete.
@@ -179,7 +190,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
req := &http.Request{
- Method: "GET",
+ Method: http.MethodGet,
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
@@ -240,13 +251,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// Get network dial function.
var netDial func(network, add string) (net.Conn, error)
- if d.NetDialContext != nil {
- netDial = func(network, addr string) (net.Conn, error) {
- return d.NetDialContext(ctx, network, addr)
+ switch u.Scheme {
+ case "http":
+ if d.NetDialContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialContext(ctx, network, addr)
+ }
+ } else if d.NetDial != nil {
+ netDial = d.NetDial
}
- } else if d.NetDial != nil {
- netDial = d.NetDial
- } else {
+ case "https":
+ if d.NetDialTLSContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialTLSContext(ctx, network, addr)
+ }
+ } else if d.NetDialContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialContext(ctx, network, addr)
+ }
+ } else if d.NetDial != nil {
+ netDial = d.NetDial
+ }
+ default:
+ return nil, nil, errMalformedURL
+ }
+
+ if netDial == nil {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
@@ -292,14 +322,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
netConn, err := netDial("tcp", hostPort)
+ if err != nil {
+ return nil, nil, err
+ }
if trace != nil && trace.GotConn != nil {
trace.GotConn(httptrace.GotConnInfo{
Conn: netConn,
})
}
- if err != nil {
- return nil, nil, err
- }
defer func() {
if netConn != nil {
@@ -307,7 +337,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}()
- if u.Scheme == "https" {
+ if u.Scheme == "https" && d.NetDialTLSContext == nil {
+ // If NetDialTLSContext is set, assume that the TLS handshake has already been done
+
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
@@ -315,11 +347,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
- var err error
- if trace != nil {
- err = doHandshakeWithTrace(trace, tlsConn, cfg)
- } else {
- err = doHandshake(tlsConn, cfg)
+ if trace != nil && trace.TLSHandshakeStart != nil {
+ trace.TLSHandshakeStart()
+ }
+ err := doHandshake(ctx, tlsConn, cfg)
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
if err != nil {
@@ -341,6 +374,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
resp, err := http.ReadResponse(conn.br, req)
if err != nil {
+ if d.TLSClientConfig != nil {
+ for _, proto := range d.TLSClientConfig.NextProtos {
+ if proto != "http/1.1" {
+ return nil, nil, fmt.Errorf(
+ "websocket: protocol %q was given but is not supported;"+
+ "sharing tls.Config with net/http Transport can cause this error: %w",
+ proto, err,
+ )
+ }
+ }
+ }
return nil, nil, err
}
@@ -351,8 +395,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
if resp.StatusCode != 101 ||
- !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
- !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
+ !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
+ !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
@@ -385,14 +429,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return conn, resp, nil
}
-func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
- if err := tlsConn.Handshake(); err != nil {
- return err
+func cloneTLSConfig(cfg *tls.Config) *tls.Config {
+ if cfg == nil {
+ return &tls.Config{}
}
- if !cfg.InsecureSkipVerify {
- if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
- return err
- }
- }
- return nil
+ return cfg.Clone()
}
diff --git a/client_clone.go b/client_clone.go
deleted file mode 100644
index 4f0d943..0000000
--- a/client_clone.go
+++ /dev/null
@@ -1,16 +0,0 @@
-// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build go1.8
-
-package websocket
-
-import "crypto/tls"
-
-func cloneTLSConfig(cfg *tls.Config) *tls.Config {
- if cfg == nil {
- return &tls.Config{}
- }
- return cfg.Clone()
-}
diff --git a/client_clone_legacy.go b/client_clone_legacy.go
deleted file mode 100644
index babb007..0000000
--- a/client_clone_legacy.go
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build !go1.8
-
-package websocket
-
-import "crypto/tls"
-
-// cloneTLSConfig clones all public fields except the fields
-// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
-// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
-// config in active use.
-func cloneTLSConfig(cfg *tls.Config) *tls.Config {
- if cfg == nil {
- return &tls.Config{}
- }
- return &tls.Config{
- Rand: cfg.Rand,
- Time: cfg.Time,
- Certificates: cfg.Certificates,
- NameToCertificate: cfg.NameToCertificate,
- GetCertificate: cfg.GetCertificate,
- RootCAs: cfg.RootCAs,
- NextProtos: cfg.NextProtos,
- ServerName: cfg.ServerName,
- ClientAuth: cfg.ClientAuth,
- ClientCAs: cfg.ClientCAs,
- InsecureSkipVerify: cfg.InsecureSkipVerify,
- CipherSuites: cfg.CipherSuites,
- PreferServerCipherSuites: cfg.PreferServerCipherSuites,
- ClientSessionCache: cfg.ClientSessionCache,
- MinVersion: cfg.MinVersion,
- MaxVersion: cfg.MaxVersion,
- CurvePreferences: cfg.CurvePreferences,
- }
-}
diff --git a/client_server_test.go b/client_server_test.go
index 5c54492..6643dad 100644
--- a/client_server_test.go
+++ b/client_server_test.go
@@ -11,6 +11,7 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/binary"
+ "errors"
"fmt"
"io"
"io/ioutil"
@@ -166,7 +167,7 @@ func TestProxyDial(t *testing.T) {
// Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
- if r.Method == "CONNECT" {
+ if r.Method == http.MethodConnect {
connect = true
w.WriteHeader(http.StatusOK)
if r.Header.Get("User-Agents") != "xxx" {
@@ -210,7 +211,7 @@ func TestProxyAuthorizationDial(t *testing.T) {
func(w http.ResponseWriter, r *http.Request) {
proxyAuth := r.Header.Get("Proxy-Authorization")
expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
- if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
+ if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth {
connect = true
w.WriteHeader(http.StatusOK)
return
@@ -470,7 +471,7 @@ func TestBadMethod(t *testing.T) {
}))
defer s.Close()
- req, err := http.NewRequest("POST", s.URL, strings.NewReader(""))
+ req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest returned error %v", err)
}
@@ -488,6 +489,23 @@ func TestBadMethod(t *testing.T) {
}
}
+func TestDialExtraTokensInRespHeaders(t *testing.T) {
+ s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ challengeKey := r.Header.Get("Sec-Websocket-Key")
+ w.Header().Set("Upgrade", "foo, websocket")
+ w.Header().Set("Connection", "upgrade, keep-alive")
+ w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
+ w.WriteHeader(101)
+ }))
+ defer s.Close()
+
+ ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil)
+ if err != nil {
+ t.Fatalf("Dial: %v", err)
+ }
+ defer ws.Close()
+}
+
func TestHandshake(t *testing.T) {
s := newServer(t)
defer s.Close()
@@ -725,7 +743,7 @@ func TestHost(t *testing.T) {
Dial: dialer.NetDial,
TLSClientConfig: dialer.TLSClientConfig,
}
- req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
+ req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil)
if tt.header != "" {
req.Host = tt.header
}
@@ -910,3 +928,215 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
defer ws.Close()
sendRecv(t, ws)
}
+
+// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
+func TestNetDialConnect(t *testing.T) {
+
+ upgrader := Upgrader{}
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if IsWebSocketUpgrade(r) {
+ c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+ } else {
+ w.Header().Set("X-Test-Host", r.Host)
+ }
+ })
+
+ server := httptest.NewServer(handler)
+ defer server.Close()
+
+ tlsServer := httptest.NewTLSServer(handler)
+ defer tlsServer.Close()
+
+ testUrls := map[*httptest.Server]string{
+ server: "ws://" + server.Listener.Addr().String() + "/",
+ tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
+ }
+
+ cas := rootCAs(t, tlsServer)
+ tlsConfig := &tls.Config{
+ RootCAs: cas,
+ ServerName: "example.com",
+ InsecureSkipVerify: false,
+ }
+
+ tests := []struct {
+ name string
+ server *httptest.Server // server to use
+ netDial func(network, addr string) (net.Conn, error)
+ netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
+ tlsClientConfig *tls.Config
+ }{
+
+ {
+ name: "HTTP server, all NetDial* defined, shall use NetDialContext",
+ server: server,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialTLSContext should not be called")
+ },
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTP server, all NetDial* undefined",
+ server: server,
+ netDial: nil,
+ netDialContext: nil,
+ netDialTLSContext: nil,
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTP server, NetDialContext undefined, shall fallback to NetDial",
+ server: server,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialContext: nil,
+ netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialTLSContext should not be called")
+ },
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialContext should not be called")
+ },
+ netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ netConn, err := net.Dial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ tlsConn := tls.Client(netConn, tlsConfig)
+ err = tlsConn.Handshake()
+ if err != nil {
+ return nil, err
+ }
+ return tlsConn, nil
+ },
+ tlsClientConfig: nil,
+ },
+ {
+ name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialTLSContext: nil,
+ tlsClientConfig: tlsConfig,
+ },
+ {
+ name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return net.Dial(network, addr)
+ },
+ netDialContext: nil,
+ netDialTLSContext: nil,
+ tlsClientConfig: tlsConfig,
+ },
+ {
+ name: "HTTPS server, all NetDial* undefined",
+ server: tlsServer,
+ netDial: nil,
+ netDialContext: nil,
+ netDialTLSContext: nil,
+ tlsClientConfig: tlsConfig,
+ },
+ {
+ name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake",
+ server: tlsServer,
+ netDial: func(network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDial should not be called")
+ },
+ netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return nil, errors.New("NetDialContext should not be called")
+ },
+ netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ netConn, err := net.Dial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ tlsConn := tls.Client(netConn, tlsConfig)
+ err = tlsConn.Handshake()
+ if err != nil {
+ return nil, err
+ }
+ return tlsConn, nil
+ },
+ tlsClientConfig: &tls.Config{
+ RootCAs: nil,
+ ServerName: "badserver.com",
+ InsecureSkipVerify: false,
+ },
+ },
+ }
+
+ for _, tc := range tests {
+ dialer := Dialer{
+ NetDial: tc.netDial,
+ NetDialContext: tc.netDialContext,
+ NetDialTLSContext: tc.netDialTLSContext,
+ TLSClientConfig: tc.tlsClientConfig,
+ }
+
+ // Test websocket dial
+ c, _, err := dialer.Dial(testUrls[tc.server], nil)
+ if err != nil {
+ t.Errorf("FAILED %s, err: %s", tc.name, err.Error())
+ } else {
+ c.Close()
+ }
+ }
+}
+func TestNextProtos(t *testing.T) {
+ ts := httptest.NewUnstartedServer(
+ http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
+ )
+ ts.EnableHTTP2 = true
+ ts.StartTLS()
+ defer ts.Close()
+
+ d := Dialer{
+ TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
+ }
+
+ r, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ r.Body.Close()
+
+ // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
+ // after the Client.Get call from net/http above.
+ var containsHTTP2 bool = false
+ for _, proto := range d.TLSClientConfig.NextProtos {
+ if proto == "h2" {
+ containsHTTP2 = true
+ }
+ }
+ if !containsHTTP2 {
+ t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
+ }
+
+ _, _, err = d.Dial(makeWsProto(ts.URL), nil)
+ if err == nil {
+ t.Fatalf("Dial succeeded, expect fail ")
+ }
+}
diff --git a/conn.go b/conn.go
index ca46d2f..5161ef8 100644
--- a/conn.go
+++ b/conn.go
@@ -13,6 +13,7 @@ import (
"math/rand"
"net"
"strconv"
+ "strings"
"sync"
"time"
"unicode/utf8"
@@ -401,6 +402,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return nil
}
+func (c *Conn) writeBufs(bufs ...[]byte) error {
+ b := net.Buffers(bufs)
+ _, err := b.WriteTo(c.conn)
+ return err
+}
+
// WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
@@ -794,47 +801,69 @@ func (c *Conn) advanceFrame() (int, error) {
}
// 2. Read and parse first two bytes of frame header.
+ // To aid debugging, collect and report all errors in the first two bytes
+ // of the header.
+
+ var errors []string
p, err := c.read(2)
if err != nil {
return noFrame, err
}
- final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
+ final := p[0]&finalBit != 0
+ rsv1 := p[0]&rsv1Bit != 0
+ rsv2 := p[0]&rsv2Bit != 0
+ rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false
- if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
- c.readDecompress = true
- p[0] &^= rsv1Bit
+ if rsv1 {
+ if c.newDecompressionReader != nil {
+ c.readDecompress = true
+ } else {
+ errors = append(errors, "RSV1 set")
+ }
}
- if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
- return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
+ if rsv2 {
+ errors = append(errors, "RSV2 set")
+ }
+
+ if rsv3 {
+ errors = append(errors, "RSV3 set")
}
switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
- return noFrame, c.handleProtocolError("control frame length > 125")
+ errors = append(errors, "len > 125 for control")
}
if !final {
- return noFrame, c.handleProtocolError("control frame not final")
+ errors = append(errors, "FIN not set on control")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
- return noFrame, c.handleProtocolError("message start before final message frame")
+ errors = append(errors, "data before FIN")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
- return noFrame, c.handleProtocolError("continuation after final message frame")
+ errors = append(errors, "continuation after FIN")
}
c.readFinal = final
default:
- return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
+ errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
+ }
+
+ if mask != c.isServer {
+ errors = append(errors, "bad MASK")
+ }
+
+ if len(errors) > 0 {
+ return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
}
// 3. Read and parse frame length as per
@@ -872,10 +901,6 @@ func (c *Conn) advanceFrame() (int, error) {
// 4. Handle frame masking.
- if mask != c.isServer {
- return noFrame, c.handleProtocolError("incorrect mask flag")
- }
-
if mask {
c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey))
@@ -935,7 +960,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
- return noFrame, c.handleProtocolError("invalid close code")
+ return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
}
closeText = string(payload[2:])
if !utf8.ValidString(closeText) {
@@ -952,7 +977,11 @@ func (c *Conn) advanceFrame() (int, error) {
}
func (c *Conn) handleProtocolError(message string) error {
- c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
+ data := FormatCloseMessage(CloseProtocolError, message)
+ if len(data) > maxControlFramePayloadSize {
+ data = data[:maxControlFramePayloadSize]
+ }
+ c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}
@@ -1160,8 +1189,16 @@ func (c *Conn) SetPongHandler(h func(appData string) error) {
c.handlePong = h
}
+// NetConn returns the underlying connection that is wrapped by c.
+// Note that writing to or reading from this connection directly will corrupt the
+// WebSocket connection.
+func (c *Conn) NetConn() net.Conn {
+ return c.conn
+}
+
// UnderlyingConn returns the internal net.Conn. This can be used to further
// modifications to connection specific flags.
+// Deprecated: Use the NetConn method.
func (c *Conn) UnderlyingConn() net.Conn {
return c.conn
}
diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go
index cb88cbb..6e744fc 100644
--- a/conn_broadcast_test.go
+++ b/conn_broadcast_test.go
@@ -18,7 +18,6 @@ import (
// scenarios with many subscribers in one channel.
type broadcastBench struct {
w io.Writer
- message *broadcastMessage
closeCh chan struct{}
doneCh chan struct{}
count int32
@@ -52,14 +51,6 @@ func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
usePrepared: usePrepared,
compression: compression,
}
- msg := &broadcastMessage{
- payload: textMessages(1)[0],
- }
- if usePrepared {
- pm, _ := NewPreparedMessage(TextMessage, msg.payload)
- msg.prepared = pm
- }
- bench.message = msg
bench.makeConns(10000)
return bench
}
@@ -78,7 +69,7 @@ func (b *broadcastBench) makeConns(numConns int) {
for {
select {
case msg := <-c.msgCh:
- if b.usePrepared {
+ if msg.prepared != nil {
c.conn.WritePreparedMessage(msg.prepared)
} else {
c.conn.WriteMessage(TextMessage, msg.payload)
@@ -100,9 +91,9 @@ func (b *broadcastBench) close() {
close(b.closeCh)
}
-func (b *broadcastBench) runOnce() {
+func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) {
for _, c := range b.conns {
- c.msgCh <- b.message
+ c.msgCh <- msg
}
<-b.doneCh
}
@@ -114,17 +105,25 @@ func BenchmarkBroadcast(b *testing.B) {
compression bool
}{
{"NoCompression", false, false},
- {"WithCompression", false, true},
+ {"Compression", false, true},
{"NoCompressionPrepared", true, false},
- {"WithCompressionPrepared", true, true},
+ {"CompressionPrepared", true, true},
}
+ payload := textMessages(1)[0]
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
bench := newBroadcastBench(bm.usePrepared, bm.compression)
defer bench.close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
- bench.runOnce()
+ message := &broadcastMessage{
+ payload: payload,
+ }
+ if bench.usePrepared {
+ pm, _ := NewPreparedMessage(TextMessage, message.payload)
+ message.prepared = pm
+ }
+ bench.broadcastOnce(message)
}
b.ReportAllocs()
})
diff --git a/conn_test.go b/conn_test.go
index bd96e0a..06e5184 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -562,7 +562,7 @@ func TestAddrs(t *testing.T) {
}
}
-func TestUnderlyingConn(t *testing.T) {
+func TestDeprecatedUnderlyingConn(t *testing.T) {
var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
@@ -572,6 +572,16 @@ func TestUnderlyingConn(t *testing.T) {
}
}
+func TestNetConn(t *testing.T) {
+ var b1, b2 bytes.Buffer
+ fc := fakeNetConn{Reader: &b1, Writer: &b2}
+ c := newConn(fc, true, 1024, 1024, nil, nil, nil)
+ ul := c.NetConn()
+ if ul != fc {
+ t.Fatalf("Underlying conn is not what it should be.")
+ }
+}
+
func TestBufioReadBytes(t *testing.T) {
// Test calling bufio.ReadBytes for value longer than read buffer size.
diff --git a/conn_write.go b/conn_write.go
deleted file mode 100644
index a509a21..0000000
--- a/conn_write.go
+++ /dev/null
@@ -1,15 +0,0 @@
-// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build go1.8
-
-package websocket
-
-import "net"
-
-func (c *Conn) writeBufs(bufs ...[]byte) error {
- b := net.Buffers(bufs)
- _, err := b.WriteTo(c.conn)
- return err
-}
diff --git a/conn_write_legacy.go b/conn_write_legacy.go
deleted file mode 100644
index 37edaff..0000000
--- a/conn_write_legacy.go
+++ /dev/null
@@ -1,18 +0,0 @@
-// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build !go1.8
-
-package websocket
-
-func (c *Conn) writeBufs(bufs ...[]byte) error {
- for _, buf := range bufs {
- if len(buf) > 0 {
- if _, err := c.conn.Write(buf); err != nil {
- return err
- }
- }
- }
- return nil
-}
diff --git a/examples/autobahn/README.md b/examples/autobahn/README.md
index dde8525..cc954fe 100644
--- a/examples/autobahn/README.md
+++ b/examples/autobahn/README.md
@@ -8,6 +8,11 @@ To test the server, run
and start the client test driver
- wstest -m fuzzingclient -s fuzzingclient.json
+ mkdir -p reports
+ docker run -it --rm \
+ -v ${PWD}/config:/config \
+ -v ${PWD}/reports:/reports \
+ crossbario/autobahn-testsuite \
+ wstest -m fuzzingclient -s /config/fuzzingclient.json
-When the client completes, it writes a report to reports/clients/index.html.
+When the client completes, it writes a report to reports/index.html.
diff --git a/examples/autobahn/config/fuzzingclient.json b/examples/autobahn/config/fuzzingclient.json
new file mode 100644
index 0000000..eda4e66
--- /dev/null
+++ b/examples/autobahn/config/fuzzingclient.json
@@ -0,0 +1,29 @@
+{
+ "cases": ["*"],
+ "exclude-cases": [],
+ "exclude-agent-cases": {},
+ "outdir": "/reports",
+ "options": {"failByDrop": false},
+ "servers": [
+ {
+ "agent": "ReadAllWriteMessage",
+ "url": "ws://host.docker.internal:9000/m"
+ },
+ {
+ "agent": "ReadAllWritePreparedMessage",
+ "url": "ws://host.docker.internal:9000/p"
+ },
+ {
+ "agent": "CopyFull",
+ "url": "ws://host.docker.internal:9000/f"
+ },
+ {
+ "agent": "ReadAllWrite",
+ "url": "ws://host.docker.internal:9000/r"
+ },
+ {
+ "agent": "CopyWriterOnly",
+ "url": "ws://host.docker.internal:9000/c"
+ }
+ ]
+}
diff --git a/examples/autobahn/fuzzingclient.json b/examples/autobahn/fuzzingclient.json
deleted file mode 100644
index aa3a0bc..0000000
--- a/examples/autobahn/fuzzingclient.json
+++ /dev/null
@@ -1,15 +0,0 @@
-
-{
- "options": {"failByDrop": false},
- "outdir": "./reports/clients",
- "servers": [
- {"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}},
- {"agent": "ReadAllWritePreparedMessage", "url": "ws://localhost:9000/p", "options": {"version": 18}},
- {"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}},
- {"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}},
- {"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}}
- ],
- "cases": ["*"],
- "exclude-cases": [],
- "exclude-agent-cases": {}
-}
diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go
index c2d6ee5..8b17fe3 100644
--- a/examples/autobahn/server.go
+++ b/examples/autobahn/server.go
@@ -160,7 +160,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found.", http.StatusNotFound)
return
}
- if r.Method != "GET" {
+ if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
diff --git a/examples/chat/main.go b/examples/chat/main.go
index 9d4737a..474709f 100644
--- a/examples/chat/main.go
+++ b/examples/chat/main.go
@@ -18,7 +18,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found", http.StatusNotFound)
return
}
- if r.Method != "GET" {
+ if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
diff --git a/examples/command/main.go b/examples/command/main.go
index 304f1a5..38d9f6c 100644
--- a/examples/command/main.go
+++ b/examples/command/main.go
@@ -170,7 +170,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found", http.StatusNotFound)
return
}
- if r.Method != "GET" {
+ if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
diff --git a/examples/echo/client.go b/examples/echo/client.go
index bf0e657..7d870bd 100644
--- a/examples/echo/client.go
+++ b/examples/echo/client.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+//go:build ignore
// +build ignore
package main
diff --git a/examples/echo/server.go b/examples/echo/server.go
index 2f5305f..f9a0b7b 100644
--- a/examples/echo/server.go
+++ b/examples/echo/server.go
@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+//go:build ignore
// +build ignore
package main
@@ -69,6 +70,7 @@ window.addEventListener("load", function(evt) {
var d = document.createElement("div");
d.textContent = message;
output.appendChild(d);
+ output.scroll(0, output.scrollHeight);
};
document.getElementById("open").onclick = function(evt) {
@@ -126,7 +128,7 @@ You can change the message and send multiple times.
-
+
|