diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..ecb33f6 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,70 @@ +version: 2.1 + +jobs: + "test": + parameters: + version: + type: string + default: "latest" + golint: + type: boolean + default: true + modules: + type: boolean + default: true + goproxy: + type: string + default: "" + docker: + - image: "cimg/go:<< parameters.version >>" + working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket + environment: + GO111MODULE: "on" + GOPROXY: "<< parameters.goproxy >>" + steps: + - checkout + - run: + name: "Print the Go version" + command: > + go version + - run: + name: "Fetch dependencies" + command: > + if [[ << parameters.modules >> = true ]]; then + go mod download + export GO111MODULE=on + else + go get -v ./... + fi + # Only run gofmt, vet & lint against the latest Go version + - run: + name: "Run golint" + command: > + if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then + go get -u golang.org/x/lint/golint + golint ./... + fi + - run: + name: "Run gofmt" + command: > + if [[ << parameters.version >> = "latest" ]]; then + diff -u <(echo -n) <(gofmt -d -e .) + fi + - run: + name: "Run go vet" + command: > + if [[ << parameters.version >> = "latest" ]]; then + go vet -v ./... + fi + - run: + name: "Run go test (+ race detector)" + command: > + go test -v -race ./... + +workflows: + tests: + jobs: + - test: + matrix: + parameters: + version: ["1.18", "1.17", "1.16"] diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml new file mode 100644 index 0000000..0986b3e --- /dev/null +++ b/.github/release-drafter.yml @@ -0,0 +1,7 @@ +# Config for https://github.com/apps/release-drafter +template: | + + + + ## CHANGELOG + $CHANGES diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index a49db51..0000000 --- a/.travis.yml +++ /dev/null @@ -1,19 +0,0 @@ -language: go -sudo: false - -matrix: - include: - - go: 1.7.x - - go: 1.8.x - - go: 1.9.x - - go: 1.10.x - - go: 1.11.x - - go: tip - allow_failures: - - go: tip - -script: - - go get -t -v ./... - - diff -u <(echo -n) <(gofmt -d .) - - go vet $(go list ./... | grep -v /vendor/) - - go test -v -race ./... diff --git a/README.md b/README.md index 2469694..9b6956f 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,20 @@ # Gorilla WebSocket +[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) +[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) + Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. -[![Build Status](https://travis-ci.org/gorilla/websocket.svg?branch=master)](https://travis-ci.org/gorilla/websocket) -[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) ### Documentation -* [API Reference](http://godoc.org/github.com/gorilla/websocket) +* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) * [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) * [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) * [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) * [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) +* [Write buffer pool example](https://github.com/gorilla/websocket/tree/master/examples/bufferpool) ### Status @@ -30,35 +32,3 @@ The Gorilla WebSocket package passes the server tests in the [Autobahn Test Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). -### Gorilla WebSocket compared with other packages - - - - - - - - - - - - - - - - - - -
github.com/gorillagolang.org/x/net
RFC 6455 Features
Passes Autobahn Test SuiteYesNo
Receive fragmented messageYesNo, see note 1
Send close messageYesNo
Send pings and receive pongsYesNo
Get the type of a received data messageYesYes, see note 2
Other Features
Compression ExtensionsExperimentalNo
Read message using io.ReaderYesNo, see note 3
Write message using io.WriteCloserYesNo, see note 3
- -Notes: - -1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). -2. The application can get the type of a received data message by implementing - a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal) - function. -3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries. - Read returns when the input buffer is full or a frame boundary is - encountered. Each call to Write sends a single frame message. The Gorilla - io.Reader and io.WriteCloser operate on a single WebSocket message. - diff --git a/client.go b/client.go index 87c71ac..4d07aa8 100644 --- a/client.go +++ b/client.go @@ -10,6 +10,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "io/ioutil" "net" @@ -49,15 +50,23 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS } // A Dialer contains options for connecting to WebSocket server. +// +// It is safe to call Dialer's methods concurrently. type Dialer struct { // NetDial specifies the dial function for creating TCP connections. If // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) // NetDialContext specifies the dial function for creating TCP connections. If - // NetDialContext is nil, net.DialContext is used. + // NetDialContext is nil, NetDial is used. NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If + // NetDialTLSContext is nil, NetDialContext is used. + // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and + // TLSClientConfig is ignored. + NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -66,6 +75,8 @@ type Dialer struct { // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. TLSClientConfig *tls.Config // HandshakeTimeout specifies the duration for the handshake to complete. @@ -180,7 +191,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } req := &http.Request{ - Method: "GET", + Method: http.MethodGet, URL: u, Proto: "HTTP/1.1", ProtoMajor: 1, @@ -244,13 +255,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // Get network dial function. var netDial func(network, add string) (net.Conn, error) - if d.NetDialContext != nil { - netDial = func(network, addr string) (net.Conn, error) { - return d.NetDialContext(ctx, network, addr) + switch u.Scheme { + case "http": + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial } - } else if d.NetDial != nil { - netDial = d.NetDial - } else { + case "https": + if d.NetDialTLSContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialTLSContext(ctx, network, addr) + } + } else if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } + default: + return nil, nil, errMalformedURL + } + + if netDial == nil { netDialer := &net.Dialer{} netDial = func(network, addr string) (net.Conn, error) { return netDialer.DialContext(ctx, network, addr) @@ -296,14 +326,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } netConn, err := netDial("tcp", hostPort) + if err != nil { + return nil, nil, err + } if trace != nil && trace.GotConn != nil { trace.GotConn(httptrace.GotConnInfo{ Conn: netConn, }) } - if err != nil { - return nil, nil, err - } defer func() { if netConn != nil { @@ -311,7 +341,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } }() - if u.Scheme == "https" { + if u.Scheme == "https" && d.NetDialTLSContext == nil { + // If NetDialTLSContext is set, assume that the TLS handshake has already been done + cfg := cloneTLSConfig(d.TLSClientConfig) if cfg.ServerName == "" { cfg.ServerName = hostNoPort @@ -319,11 +351,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h tlsConn := tls.Client(netConn, cfg) netConn = tlsConn - var err error - if trace != nil { - err = doHandshakeWithTrace(trace, tlsConn, cfg) - } else { - err = doHandshake(tlsConn, cfg) + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) } if err != nil { @@ -345,6 +378,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h resp, err := http.ReadResponse(conn.br, req) if err != nil { + if d.TLSClientConfig != nil { + for _, proto := range d.TLSClientConfig.NextProtos { + if proto != "http/1.1" { + return nil, nil, fmt.Errorf( + "websocket: protocol %q was given but is not supported;"+ + "sharing tls.Config with net/http Transport can cause this error: %w", + proto, err, + ) + } + } + } return nil, nil, err } @@ -355,8 +399,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application @@ -400,14 +444,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h return conn, resp, nil } -func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { - if err := tlsConn.Handshake(); err != nil { - return err +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - return err - } - } - return nil + return cfg.Clone() } diff --git a/client_clone.go b/client_clone.go deleted file mode 100644 index 4f0d943..0000000 --- a/client_clone.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.8 - -package websocket - -import "crypto/tls" - -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return cfg.Clone() -} diff --git a/client_clone_legacy.go b/client_clone_legacy.go deleted file mode 100644 index babb007..0000000 --- a/client_clone_legacy.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !go1.8 - -package websocket - -import "crypto/tls" - -// cloneTLSConfig clones all public fields except the fields -// SessionTicketsDisabled and SessionTicketKey. This avoids copying the -// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a -// config in active use. -func cloneTLSConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, - } -} diff --git a/client_server_test.go b/client_server_test.go index e4079ae..73f0e33 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -11,6 +11,7 @@ import ( "crypto/x509" "encoding/base64" "encoding/binary" + "errors" "fmt" "io" "io/ioutil" @@ -216,7 +217,7 @@ func TestProxyDial(t *testing.T) { // Capture the request Host header. s.Server.Config.Handler = http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - if r.Method == "CONNECT" { + if r.Method == http.MethodConnect { connect = true w.WriteHeader(http.StatusOK) return @@ -256,7 +257,7 @@ func TestProxyAuthorizationDial(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { proxyAuth := r.Header.Get("Proxy-Authorization") expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) - if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { + if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth { connect = true w.WriteHeader(http.StatusOK) return @@ -516,7 +517,7 @@ func TestBadMethod(t *testing.T) { })) defer s.Close() - req, err := http.NewRequest("POST", s.URL, strings.NewReader("")) + req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader("")) if err != nil { t.Fatalf("NewRequest returned error %v", err) } @@ -534,6 +535,23 @@ func TestBadMethod(t *testing.T) { } } +func TestDialExtraTokensInRespHeaders(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + challengeKey := r.Header.Get("Sec-Websocket-Key") + w.Header().Set("Upgrade", "foo, websocket") + w.Header().Set("Connection", "upgrade, keep-alive") + w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey)) + w.WriteHeader(101) + })) + defer s.Close() + + ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() +} + func TestHandshake(t *testing.T) { s := newServer(t, cstHandlerConfig{}) defer s.Close() @@ -640,7 +658,7 @@ func TestHost(t *testing.T) { server *httptest.Server // server to use url string // host for request URI header string // optional request host header - tls string // optiona host for tls ServerName + tls string // optional host for tls ServerName wantAddr string // expected host for dial wantHeader string // expected request header on server insecureSkipVerify bool @@ -771,7 +789,7 @@ func TestHost(t *testing.T) { Dial: dialer.NetDial, TLSClientConfig: dialer.TLSClientConfig, } - req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil) + req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil) if tt.header != "" { req.Host = tt.header } @@ -972,3 +990,215 @@ func TestEmptyTracingDialWithContext(t *testing.T) { defer ws.Close() sendRecv(t, ws) } + +// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext +func TestNetDialConnect(t *testing.T) { + + upgrader := Upgrader{} + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if IsWebSocketUpgrade(r) { + c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}}) + if err != nil { + t.Fatal(err) + } + c.Close() + } else { + w.Header().Set("X-Test-Host", r.Host) + } + }) + + server := httptest.NewServer(handler) + defer server.Close() + + tlsServer := httptest.NewTLSServer(handler) + defer tlsServer.Close() + + testUrls := map[*httptest.Server]string{ + server: "ws://" + server.Listener.Addr().String() + "/", + tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/", + } + + cas := rootCAs(t, tlsServer) + tlsConfig := &tls.Config{ + RootCAs: cas, + ServerName: "example.com", + InsecureSkipVerify: false, + } + + tests := []struct { + name string + server *httptest.Server // server to use + netDial func(network, addr string) (net.Conn, error) + netDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + tlsClientConfig *tls.Config + }{ + + { + name: "HTTP server, all NetDial* defined, shall use NetDialContext", + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialTLSContext should not be called") + }, + tlsClientConfig: nil, + }, + { + name: "HTTP server, all NetDial* undefined", + server: server, + netDial: nil, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: nil, + }, + { + name: "HTTP server, NetDialContext undefined, shall fallback to NetDial", + server: server, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialTLSContext should not be called") + }, + tlsClientConfig: nil, + }, + { + name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialContext should not be called") + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: nil, + }, + { + name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, all NetDial* undefined", + server: tlsServer, + netDial: nil, + netDialContext: nil, + netDialTLSContext: nil, + tlsClientConfig: tlsConfig, + }, + { + name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake", + server: tlsServer, + netDial: func(network, addr string) (net.Conn, error) { + return nil, errors.New("NetDial should not be called") + }, + netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("NetDialContext should not be called") + }, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + netConn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(netConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return nil, err + } + return tlsConn, nil + }, + tlsClientConfig: &tls.Config{ + RootCAs: nil, + ServerName: "badserver.com", + InsecureSkipVerify: false, + }, + }, + } + + for _, tc := range tests { + dialer := Dialer{ + NetDial: tc.netDial, + NetDialContext: tc.netDialContext, + NetDialTLSContext: tc.netDialTLSContext, + TLSClientConfig: tc.tlsClientConfig, + } + + // Test websocket dial + c, _, err := dialer.Dial(testUrls[tc.server], nil) + if err != nil { + t.Errorf("FAILED %s, err: %s", tc.name, err.Error()) + } else { + c.Close() + } + } +} +func TestNextProtos(t *testing.T) { + ts := httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + ) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + d := Dialer{ + TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig, + } + + r, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + r.Body.Close() + + // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2" + // after the Client.Get call from net/http above. + var containsHTTP2 bool = false + for _, proto := range d.TLSClientConfig.NextProtos { + if proto == "h2" { + containsHTTP2 = true + } + } + if !containsHTTP2 { + t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"") + } + + _, _, err = d.Dial(makeWsProto(ts.URL), nil) + if err == nil { + t.Fatalf("Dial succeeded, expect fail ") + } +} diff --git a/conn.go b/conn.go index 62d55b1..ae4e77c 100644 --- a/conn.go +++ b/conn.go @@ -13,6 +13,7 @@ import ( "math/rand" "net" "strconv" + "strings" "sync" "time" "unicode/utf8" @@ -246,8 +247,8 @@ type Conn struct { subprotocol string // Write fields - mu chan bool // used as mutex to protect write to conn - writeBuf []byte // frame is constructed in this buffer. + mu chan struct{} // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. writePool BufferPool writeBufSize int writeDeadline time.Time @@ -262,10 +263,12 @@ type Conn struct { newCompressionWriter func(io.WriteCloser, int) io.WriteCloser // Read fields - reader io.ReadCloser // the current reader returned to the application - readErr error - br *bufio.Reader - readRemaining int64 // bytes remaining in current frame. + reader io.ReadCloser // the current reader returned to the application + readErr error + br *bufio.Reader + // bytes remaining in current frame. + // set setReadRemaining to safely update this value and prevent overflow + readRemaining int64 readFinal bool // true the current message has more frames. readLength int64 // Message size. readLimit int64 // Maximum message size. @@ -302,8 +305,8 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBuf = make([]byte, writeBufferSize) } - mu := make(chan bool, 1) - mu <- true + mu := make(chan struct{}, 1) + mu <- struct{}{} c := &Conn{ isServer: isServer, br: br, @@ -322,6 +325,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, return c } +// setReadRemaining tracks the number of bytes remaining on the connection. If n +// overflows, an ErrReadLimit is returned. +func (c *Conn) setReadRemaining(n int64) error { + if n < 0 { + return ErrReadLimit + } + + c.readRemaining = n + return nil +} + // Subprotocol returns the negotiated protocol for the connection. func (c *Conn) Subprotocol() string { return c.subprotocol @@ -366,7 +380,7 @@ func (c *Conn) read(n int) ([]byte, error) { func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { <-c.mu - defer func() { c.mu <- true }() + defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() err := c.writeErr @@ -390,6 +404,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error return nil } +func (c *Conn) writeBufs(bufs ...[]byte) error { + b := net.Buffers(bufs) + _, err := b.WriteTo(c.conn) + return err +} + // WriteControl writes a control message with the given deadline. The allowed // message types are CloseMessage, PingMessage and PongMessage. func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { @@ -418,7 +438,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er maskBytes(key, 0, buf[6:]) } - d := time.Hour * 1000 + d := 1000 * time.Hour if !deadline.IsZero() { d = deadline.Sub(time.Now()) if d < 0 { @@ -433,7 +453,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er case <-timer.C: return errWriteTimeout } - defer func() { c.mu <- true }() + defer func() { c.mu <- struct{}{} }() c.writeErrMu.Lock() err := c.writeErr @@ -710,10 +730,7 @@ func (w *messageWriter) Close() error { if w.err != nil { return w.err } - if err := w.flushFrame(true, nil); err != nil { - return err - } - return nil + return w.flushFrame(true, nil) } // WritePreparedMessage writes prepared message into connection. @@ -786,50 +803,82 @@ func (c *Conn) advanceFrame() (int, error) { } // 2. Read and parse first two bytes of frame header. + // To aid debugging, collect and report all errors in the first two bytes + // of the header. + + var errors []string p, err := c.read(2) if err != nil { return noFrame, err } - final := p[0]&finalBit != 0 frameType := int(p[0] & 0xf) + final := p[0]&finalBit != 0 + rsv1 := p[0]&rsv1Bit != 0 + rsv2 := p[0]&rsv2Bit != 0 + rsv3 := p[0]&rsv3Bit != 0 mask := p[1]&maskBit != 0 - c.readRemaining = int64(p[1] & 0x7f) + c.setReadRemaining(int64(p[1] & 0x7f)) c.readDecompress = false - if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { - c.readDecompress = true - p[0] &^= rsv1Bit + if rsv1 { + if c.newDecompressionReader != nil { + c.readDecompress = true + } else { + errors = append(errors, "RSV1 set") + } } - if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { - return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) + if rsv2 { + errors = append(errors, "RSV2 set") + } + + if rsv3 { + errors = append(errors, "RSV3 set") } switch frameType { case CloseMessage, PingMessage, PongMessage: if c.readRemaining > maxControlFramePayloadSize { - return noFrame, c.handleProtocolError("control frame length > 125") + errors = append(errors, "len > 125 for control") } if !final { - return noFrame, c.handleProtocolError("control frame not final") + errors = append(errors, "FIN not set on control") } case TextMessage, BinaryMessage: if !c.readFinal { - return noFrame, c.handleProtocolError("message start before final message frame") + errors = append(errors, "data before FIN") } c.readFinal = final case continuationFrame: if c.readFinal { - return noFrame, c.handleProtocolError("continuation after final message frame") + errors = append(errors, "continuation after FIN") } c.readFinal = final default: - return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) + errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) } - // 3. Read and parse frame length. + if mask != c.isServer { + errors = append(errors, "bad MASK") + } + + if len(errors) > 0 { + return noFrame, c.handleProtocolError(strings.Join(errors, ", ")) + } + + // 3. Read and parse frame length as per + // https://tools.ietf.org/html/rfc6455#section-5.2 + // + // The length of the "Payload data", in bytes: if 0-125, that is the payload + // length. + // - If 126, the following 2 bytes interpreted as a 16-bit unsigned + // integer are the payload length. + // - If 127, the following 8 bytes interpreted as + // a 64-bit unsigned integer (the most significant bit MUST be 0) are the + // payload length. Multibyte length quantities are expressed in network byte + // order. switch c.readRemaining { case 126: @@ -837,21 +886,23 @@ func (c *Conn) advanceFrame() (int, error) { if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint16(p)) + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { + return noFrame, err + } case 127: p, err := c.read(8) if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint64(p)) + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { + return noFrame, err + } } // 4. Handle frame masking. - if mask != c.isServer { - return noFrame, c.handleProtocolError("incorrect mask flag") - } - if mask { c.readMaskPos = 0 p, err := c.read(len(c.readMaskKey)) @@ -866,6 +917,12 @@ func (c *Conn) advanceFrame() (int, error) { if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { c.readLength += c.readRemaining + // Don't allow readLength to overflow in the presence of a large readRemaining + // counter. + if c.readLength < 0 { + return noFrame, ErrReadLimit + } + if c.readLimit > 0 && c.readLength > c.readLimit { c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) return noFrame, ErrReadLimit @@ -879,7 +936,7 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { payload, err = c.read(int(c.readRemaining)) - c.readRemaining = 0 + c.setReadRemaining(0) if err != nil { return noFrame, err } @@ -905,7 +962,7 @@ func (c *Conn) advanceFrame() (int, error) { if len(payload) >= 2 { closeCode = int(binary.BigEndian.Uint16(payload)) if !isValidReceivedCloseCode(closeCode) { - return noFrame, c.handleProtocolError("invalid close code") + return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) } closeText = string(payload[2:]) if !utf8.ValidString(closeText) { @@ -922,7 +979,11 @@ func (c *Conn) advanceFrame() (int, error) { } func (c *Conn) handleProtocolError(message string) error { - c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) + data := FormatCloseMessage(CloseProtocolError, message) + if len(data) > maxControlFramePayloadSize { + data = data[:maxControlFramePayloadSize] + } + c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) return errors.New("websocket: " + message) } @@ -952,6 +1013,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { c.readErr = hideTempErr(err) break } + if frameType == TextMessage || frameType == BinaryMessage { c.messageReader = &messageReader{c} c.reader = c.messageReader @@ -992,7 +1054,9 @@ func (r *messageReader) Read(b []byte) (int, error) { if c.isServer { c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) } - c.readRemaining -= int64(n) + rem := c.readRemaining + rem -= int64(n) + c.setReadRemaining(rem) if c.readRemaining > 0 && c.readErr == io.EOF { c.readErr = errUnexpectedEOF } @@ -1127,8 +1191,16 @@ func (c *Conn) SetPongHandler(h func(appData string) error) { c.handlePong = h } +// NetConn returns the underlying connection that is wrapped by c. +// Note that writing to or reading from this connection directly will corrupt the +// WebSocket connection. +func (c *Conn) NetConn() net.Conn { + return c.conn +} + // UnderlyingConn returns the internal net.Conn. This can be used to further // modifications to connection specific flags. +// Deprecated: Use the NetConn method. func (c *Conn) UnderlyingConn() net.Conn { return c.conn } diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go index cb88cbb..6e744fc 100644 --- a/conn_broadcast_test.go +++ b/conn_broadcast_test.go @@ -18,7 +18,6 @@ import ( // scenarios with many subscribers in one channel. type broadcastBench struct { w io.Writer - message *broadcastMessage closeCh chan struct{} doneCh chan struct{} count int32 @@ -52,14 +51,6 @@ func newBroadcastBench(usePrepared, compression bool) *broadcastBench { usePrepared: usePrepared, compression: compression, } - msg := &broadcastMessage{ - payload: textMessages(1)[0], - } - if usePrepared { - pm, _ := NewPreparedMessage(TextMessage, msg.payload) - msg.prepared = pm - } - bench.message = msg bench.makeConns(10000) return bench } @@ -78,7 +69,7 @@ func (b *broadcastBench) makeConns(numConns int) { for { select { case msg := <-c.msgCh: - if b.usePrepared { + if msg.prepared != nil { c.conn.WritePreparedMessage(msg.prepared) } else { c.conn.WriteMessage(TextMessage, msg.payload) @@ -100,9 +91,9 @@ func (b *broadcastBench) close() { close(b.closeCh) } -func (b *broadcastBench) runOnce() { +func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { for _, c := range b.conns { - c.msgCh <- b.message + c.msgCh <- msg } <-b.doneCh } @@ -114,17 +105,25 @@ func BenchmarkBroadcast(b *testing.B) { compression bool }{ {"NoCompression", false, false}, - {"WithCompression", false, true}, + {"Compression", false, true}, {"NoCompressionPrepared", true, false}, - {"WithCompressionPrepared", true, true}, + {"CompressionPrepared", true, true}, } + payload := textMessages(1)[0] for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { bench := newBroadcastBench(bm.usePrepared, bm.compression) defer bench.close() b.ResetTimer() for i := 0; i < b.N; i++ { - bench.runOnce() + message := &broadcastMessage{ + payload: payload, + } + if bench.usePrepared { + pm, _ := NewPreparedMessage(TextMessage, message.payload) + message.prepared = pm + } + bench.broadcastOnce(message) } b.ReportAllocs() }) diff --git a/conn_test.go b/conn_test.go index 7be8a4f..fd4acc4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -56,7 +56,10 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { } func TestFraming(t *testing.T) { - frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537} + frameSizes := []int{ + 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, + // 65536, 65537 + } var readChunkers = []struct { name string f func(io.Reader) io.Reader @@ -151,6 +154,8 @@ func TestFraming(t *testing.T) { t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) continue } + + t.Logf("frame size: %d", n) rbuf, err := ioutil.ReadAll(r) if err != nil { t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) @@ -328,7 +333,7 @@ func TestWriteBufferPoolSync(t *testing.T) { // errorWriter is an io.Writer than returns an error on all writes. type errorWriter struct{} -func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("Error!") } +func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } // TestWriteBufferPoolError ensures that buffer is returned to pool after error // on write. @@ -489,37 +494,93 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } func TestReadLimit(t *testing.T) { + t.Run("Test ReadLimit is enforced", func(t *testing.T) { + const readLimit = 512 + message := make([]byte, readLimit+1) - const readLimit = 512 - message := make([]byte, readLimit+1) + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) - var b1, b2 bytes.Buffer - wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) - rc := newTestConn(&b1, &b2, true) - rc.SetReadLimit(readLimit) + // Send message at the limit with interleaved pong. + w, _ := wc.NextWriter(BinaryMessage) + w.Write(message[:readLimit-1]) + wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + w.Write(message[:1]) + w.Close() - // Send message at the limit with interleaved pong. - w, _ := wc.NextWriter(BinaryMessage) - w.Write(message[:readLimit-1]) - wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) - w.Write(message[:1]) - w.Close() + // Send message larger than the limit. + wc.WriteMessage(BinaryMessage, message[:readLimit+1]) - // Send message larger than the limit. - wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + op, _, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("2: NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != ErrReadLimit { + t.Fatalf("io.Copy() returned %v", err) + } + }) - op, _, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("1: NextReader() returned %d, %v", op, err) - } - op, r, err := rc.NextReader() - if op != BinaryMessage || err != nil { - t.Fatalf("2: NextReader() returned %d, %v", op, err) - } - _, err = io.Copy(ioutil.Discard, r) - if err != ErrReadLimit { - t.Fatalf("io.Copy() returned %v", err) - } + t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { + const readLimit = 1 + + var b1, b2 bytes.Buffer + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // First, send a non-final binary message + b1.Write([]byte("\x02\x81")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // First payload + b1.Write([]byte("A")) + + // Next, send a negative-length, non-final continuation frame + b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Next, send a too long, final continuation frame + b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Too-long payload + b1.Write([]byte("BCDEF")) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + + var buf [10]byte + var read int + n, err := r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + n, err = r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + if err == nil && read > readLimit { + t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) + } + }) } func TestAddrs(t *testing.T) { @@ -532,7 +593,7 @@ func TestAddrs(t *testing.T) { } } -func TestUnderlyingConn(t *testing.T) { +func TestDeprecatedUnderlyingConn(t *testing.T) { var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2} c := newConn(fc, true, 1024, 1024, nil, nil, nil) @@ -542,6 +603,16 @@ func TestUnderlyingConn(t *testing.T) { } } +func TestNetConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024, nil, nil, nil) + ul := c.NetConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} + func TestBufioReadBytes(t *testing.T) { // Test calling bufio.ReadBytes for value longer than read buffer size. diff --git a/conn_write.go b/conn_write.go deleted file mode 100644 index a509a21..0000000 --- a/conn_write.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build go1.8 - -package websocket - -import "net" - -func (c *Conn) writeBufs(bufs ...[]byte) error { - b := net.Buffers(bufs) - _, err := b.WriteTo(c.conn) - return err -} diff --git a/conn_write_legacy.go b/conn_write_legacy.go deleted file mode 100644 index 37edaff..0000000 --- a/conn_write_legacy.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !go1.8 - -package websocket - -func (c *Conn) writeBufs(bufs ...[]byte) error { - for _, buf := range bufs { - if len(buf) > 0 { - if _, err := c.conn.Write(buf); err != nil { - return err - } - } - } - return nil -} diff --git a/doc.go b/doc.go index cc9c05b..9cdc8a0 100644 --- a/doc.go +++ b/doc.go @@ -151,6 +151,53 @@ // checking. The application is responsible for checking the Origin header // before calling the Upgrade function. // +// Buffers +// +// Connections buffer network input and output to reduce the number +// of system calls when reading or writing messages. +// +// Write buffers are also used for constructing WebSocket frames. See RFC 6455, +// Section 5 for a discussion of message framing. A WebSocket frame header is +// written to the network each time a write buffer is flushed to the network. +// Decreasing the size of the write buffer can increase the amount of framing +// overhead on the connection. +// +// The buffer sizes in bytes are specified by the ReadBufferSize and +// WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default +// size of 4096 when a buffer size field is set to zero. The Upgrader reuses +// buffers created by the HTTP server when a buffer size field is set to zero. +// The HTTP server buffers have a size of 4096 at the time of this writing. +// +// The buffer sizes do not limit the size of a message that can be read or +// written by a connection. +// +// Buffers are held for the lifetime of the connection by default. If the +// Dialer or Upgrader WriteBufferPool field is set, then a connection holds the +// write buffer only when writing a message. +// +// Applications should tune the buffer sizes to balance memory use and +// performance. Increasing the buffer size uses more memory, but can reduce the +// number of system calls to read or write the network. In the case of writing, +// increasing the buffer size can reduce the number of frame headers written to +// the network. +// +// Some guidelines for setting buffer parameters are: +// +// Limit the buffer sizes to the maximum expected message size. Buffers larger +// than the largest message do not provide any benefit. +// +// Depending on the distribution of message sizes, setting the buffer size to +// a value less than the maximum expected message size can greatly reduce memory +// use with a small impact on performance. Here's an example: If 99% of the +// messages are smaller than 256 bytes and the maximum message size is 512 +// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls +// than a buffer size of 512 bytes. The memory savings is 50%. +// +// A write buffer pool is useful when the application has a modest number +// writes over a large number of connections. when buffers are pooled, a larger +// buffer size has a reduced impact on total memory use and has the benefit of +// reducing system calls and frame overhead. +// // Compression EXPERIMENTAL // // Per message compression extensions (RFC 7692) are experimentally supported diff --git a/example_test.go b/example_test.go index 96449ea..cd1883b 100644 --- a/example_test.go +++ b/example_test.go @@ -23,10 +23,9 @@ var ( // This server application works with a client application running in the // browser. The client application does not explicitly close the websocket. The // only expected close message from the client has the code -// websocket.CloseGoingAway. All other other close messages are likely the +// websocket.CloseGoingAway. All other close messages are likely the // result of an application or protocol error and are logged to aid debugging. func ExampleIsUnexpectedCloseError() { - for { messageType, p, err := c.ReadMessage() if err != nil { @@ -35,11 +34,11 @@ func ExampleIsUnexpectedCloseError() { } return } - processMesage(messageType, p) + processMessage(messageType, p) } } -func processMesage(mt int, p []byte) {} +func processMessage(mt int, p []byte) {} // TestX prevents godoc from showing this entire file in the example. Remove // this function when a second example is added. diff --git a/examples/autobahn/README.md b/examples/autobahn/README.md index dde8525..cc954fe 100644 --- a/examples/autobahn/README.md +++ b/examples/autobahn/README.md @@ -8,6 +8,11 @@ To test the server, run and start the client test driver - wstest -m fuzzingclient -s fuzzingclient.json + mkdir -p reports + docker run -it --rm \ + -v ${PWD}/config:/config \ + -v ${PWD}/reports:/reports \ + crossbario/autobahn-testsuite \ + wstest -m fuzzingclient -s /config/fuzzingclient.json -When the client completes, it writes a report to reports/clients/index.html. +When the client completes, it writes a report to reports/index.html. diff --git a/examples/autobahn/config/fuzzingclient.json b/examples/autobahn/config/fuzzingclient.json new file mode 100644 index 0000000..eda4e66 --- /dev/null +++ b/examples/autobahn/config/fuzzingclient.json @@ -0,0 +1,29 @@ +{ + "cases": ["*"], + "exclude-cases": [], + "exclude-agent-cases": {}, + "outdir": "/reports", + "options": {"failByDrop": false}, + "servers": [ + { + "agent": "ReadAllWriteMessage", + "url": "ws://host.docker.internal:9000/m" + }, + { + "agent": "ReadAllWritePreparedMessage", + "url": "ws://host.docker.internal:9000/p" + }, + { + "agent": "CopyFull", + "url": "ws://host.docker.internal:9000/f" + }, + { + "agent": "ReadAllWrite", + "url": "ws://host.docker.internal:9000/r" + }, + { + "agent": "CopyWriterOnly", + "url": "ws://host.docker.internal:9000/c" + } + ] +} diff --git a/examples/autobahn/fuzzingclient.json b/examples/autobahn/fuzzingclient.json deleted file mode 100644 index aa3a0bc..0000000 --- a/examples/autobahn/fuzzingclient.json +++ /dev/null @@ -1,15 +0,0 @@ - -{ - "options": {"failByDrop": false}, - "outdir": "./reports/clients", - "servers": [ - {"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}}, - {"agent": "ReadAllWritePreparedMessage", "url": "ws://localhost:9000/p", "options": {"version": 18}}, - {"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}}, - {"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}}, - {"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}} - ], - "cases": ["*"], - "exclude-cases": [], - "exclude-agent-cases": {} -} diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index c2d6ee5..8b17fe3 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -160,7 +160,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) { http.Error(w, "Not found.", http.StatusNotFound) return } - if r.Method != "GET" { + if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } diff --git a/examples/bufferpool/client.go b/examples/bufferpool/client.go new file mode 100644 index 0000000..a3719a9 --- /dev/null +++ b/examples/bufferpool/client.go @@ -0,0 +1,89 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "flag" + "log" + "net/url" + "os" + "os/signal" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +var addr = flag.String("addr", "localhost:8080", "http service address") + +func runNewConn(wg *sync.WaitGroup) { + defer wg.Done() + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + + u := url.URL{Scheme: "ws", Host: *addr, Path: "/ws"} + log.Printf("connecting to %s", u.String()) + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatal("dial:", err) + } + defer c.Close() + + done := make(chan struct{}) + + go func() { + defer close(done) + for { + _, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + return + } + log.Printf("recv: %s", message) + } + }() + + ticker := time.NewTicker(time.Minute * 5) + defer ticker.Stop() + + for { + select { + case <-done: + return + case t := <-ticker.C: + err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) + if err != nil { + log.Println("write:", err) + return + } + case <-interrupt: + log.Println("interrupt") + + // Cleanly close the connection by sending a close message and then + // waiting (with timeout) for the server to close the connection. + err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + log.Println("write close:", err) + return + } + select { + case <-done: + case <-time.After(time.Second): + } + return + } + } +} + +func main() { + flag.Parse() + log.SetFlags(0) + wg := &sync.WaitGroup{} + for i := 0; i < 1000; i++ { + wg.Add(1) + go runNewConn(wg) + } + wg.Wait() +} diff --git a/examples/bufferpool/server.go b/examples/bufferpool/server.go new file mode 100644 index 0000000..25bb20f --- /dev/null +++ b/examples/bufferpool/server.go @@ -0,0 +1,55 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "flag" + "log" + "net/http" + "sync" + + _ "net/http/pprof" + + "github.com/gorilla/websocket" +) + +var addr = flag.String("addr", "localhost:8080", "http service address") + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 256, + WriteBufferSize: 256, + WriteBufferPool: &sync.Pool{}, +} + +func process(c *websocket.Conn) { + defer c.Close() + for { + _, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + break + } + log.Printf("recv: %s", message) + } +} + +func handler(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + + // Process connection in a new goroutine + go process(c) + + // Let the http handler return, the 8k buffer created by it will be garbage collected +} + +func main() { + flag.Parse() + log.SetFlags(0) + http.HandleFunc("/ws", handler) + log.Fatal(http.ListenAndServe(*addr, nil)) +} diff --git a/examples/chat/home.html b/examples/chat/home.html index a39a0c2..bf866af 100644 --- a/examples/chat/home.html +++ b/examples/chat/home.html @@ -92,7 +92,7 @@ body {
- +
diff --git a/examples/chat/main.go b/examples/chat/main.go index 9d4737a..474709f 100644 --- a/examples/chat/main.go +++ b/examples/chat/main.go @@ -18,7 +18,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) { http.Error(w, "Not found", http.StatusNotFound) return } - if r.Method != "GET" { + if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } diff --git a/examples/command/main.go b/examples/command/main.go index 304f1a5..38d9f6c 100644 --- a/examples/command/main.go +++ b/examples/command/main.go @@ -170,7 +170,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) { http.Error(w, "Not found", http.StatusNotFound) return } - if r.Method != "GET" { + if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } diff --git a/examples/echo/client.go b/examples/echo/client.go index bf0e657..7d870bd 100644 --- a/examples/echo/client.go +++ b/examples/echo/client.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build ignore // +build ignore package main diff --git a/examples/echo/server.go b/examples/echo/server.go index ecc680c..f9a0b7b 100644 --- a/examples/echo/server.go +++ b/examples/echo/server.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +//go:build ignore // +build ignore package main @@ -67,8 +68,9 @@ window.addEventListener("load", function(evt) { var print = function(message) { var d = document.createElement("div"); - d.innerHTML = message; + d.textContent = message; output.appendChild(d); + output.scroll(0, output.scrollHeight); }; document.getElementById("open").onclick = function(evt) { @@ -126,7 +128,7 @@ You can change the message and send multiple times. -
+
diff --git a/examples/filewatch/main.go b/examples/filewatch/main.go index b834ed3..d4bf80e 100644 --- a/examples/filewatch/main.go +++ b/examples/filewatch/main.go @@ -133,7 +133,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) { http.Error(w, "Not found", http.StatusNotFound) return } - if r.Method != "GET" { + if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1a7afd5 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/gorilla/websocket + +go 1.12 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/join.go b/join.go new file mode 100644 index 0000000..c64f8c8 --- /dev/null +++ b/join.go @@ -0,0 +1,42 @@ +// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "io" + "strings" +) + +// JoinMessages concatenates received messages to create a single io.Reader. +// The string term is appended to each message. The returned reader does not +// support concurrent calls to the Read method. +func JoinMessages(c *Conn, term string) io.Reader { + return &joinReader{c: c, term: term} +} + +type joinReader struct { + c *Conn + term string + r io.Reader +} + +func (r *joinReader) Read(p []byte) (int, error) { + if r.r == nil { + var err error + _, r.r, err = r.c.NextReader() + if err != nil { + return 0, err + } + if r.term != "" { + r.r = io.MultiReader(r.r, strings.NewReader(r.term)) + } + } + n, err := r.r.Read(p) + if err == io.EOF { + err = nil + r.r = nil + } + return n, err +} diff --git a/join_test.go b/join_test.go new file mode 100644 index 0000000..961ac04 --- /dev/null +++ b/join_test.go @@ -0,0 +1,36 @@ +// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "io" + "strings" + "testing" +) + +func TestJoinMessages(t *testing.T) { + messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"} + for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} { + for _, term := range []string{"", ","} { + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, true) + rc := newTestConn(&connBuf, nil, false) + for _, m := range messages { + wc.WriteMessage(BinaryMessage, []byte(m)) + } + + var result bytes.Buffer + _, err := io.CopyBuffer(&result, JoinMessages(rc, term), make([]byte, readChunk)) + if IsUnexpectedCloseError(err, CloseAbnormalClosure) { + t.Errorf("readChunk=%d, term=%q: unexpected error %v", readChunk, term, err) + } + want := strings.Join(messages, term) + term + if result.String() != want { + t.Errorf("readChunk=%d, term=%q, got %q, want %q", readChunk, term, result.String(), want) + } + } + } +} diff --git a/mask.go b/mask.go index 577fce9..d0742bf 100644 --- a/mask.go +++ b/mask.go @@ -2,6 +2,7 @@ // this source code is governed by a BSD-style license that can be found in the // LICENSE file. +//go:build !appengine // +build !appengine package websocket diff --git a/mask_safe.go b/mask_safe.go index 2aac060..36250ca 100644 --- a/mask_safe.go +++ b/mask_safe.go @@ -2,6 +2,7 @@ // this source code is governed by a BSD-style license that can be found in the // LICENSE file. +//go:build appengine // +build appengine package websocket diff --git a/prepared.go b/prepared.go index 74ec565..c854225 100644 --- a/prepared.go +++ b/prepared.go @@ -73,8 +73,8 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { // Prepare a frame using a 'fake' connection. // TODO: Refactor code in conn.go to allow more direct construction of // the frame. - mu := make(chan bool, 1) - mu <- true + mu := make(chan struct{}, 1) + mu <- struct{}{} var nc prepareConn c := &Conn{ conn: &nc, diff --git a/proxy.go b/proxy.go index bf2478e..e0f466b 100644 --- a/proxy.go +++ b/proxy.go @@ -22,18 +22,18 @@ func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { func init() { proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { - return &httpProxyDialer{proxyURL: proxyURL, fowardDial: forwardDialer.Dial}, nil + return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil }) } type httpProxyDialer struct { - proxyURL *url.URL - fowardDial func(network, addr string) (net.Conn, error) + proxyURL *url.URL + forwardDial func(network, addr string) (net.Conn, error) } func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { hostPort, _ := hostPortNoPort(hpd.proxyURL) - conn, err := hpd.fowardDial(network, hostPort) + conn, err := hpd.forwardDial(network, hostPort) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) } connectReq := &http.Request{ - Method: "CONNECT", + Method: http.MethodConnect, URL: &url.URL{Opaque: addr}, Host: addr, Header: connectHeader, diff --git a/server.go b/server.go index 4bd0539..028f55f 100644 --- a/server.go +++ b/server.go @@ -24,6 +24,8 @@ func (e HandshakeError) Error() string { return e.message } // Upgrader specifies parameters for upgrading an HTTP connection to a // WebSocket connection. +// +// It is safe to call Upgrader's methods concurrently. type Upgrader struct { // HandshakeTimeout specifies the duration for the handshake to complete. HandshakeTimeout time.Duration @@ -119,8 +121,8 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // The responseHeader is included in the response to the client's upgrade -// request. Use the responseHeader to specify cookies (Set-Cookie) and the -// application negotiated subprotocol (Sec-WebSocket-Protocol). +// request. Use the responseHeader to specify cookies (Set-Cookie). To specify +// subprotocols supported by the server, set Upgrader.Subprotocols directly. // // If the upgrade fails, then Upgrade replies to the client with an HTTP error // response. @@ -135,7 +137,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") } - if r.Method != "GET" { + if r.Method != http.MethodGet { return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") } @@ -156,8 +158,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } challengeKey := r.Header.Get("Sec-Websocket-Key") - if challengeKey == "" { - return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank") + if !isValidChallengeKey(challengeKey) { + return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length") } subprotocol := u.selectSubprotocol(r, responseHeader) diff --git a/server_test.go b/server_test.go index 456c1db..5804be1 100644 --- a/server_test.go +++ b/server_test.go @@ -98,7 +98,7 @@ func TestBufioReuse(t *testing.T) { } upgrader := Upgrader{} c, err := upgrader.Upgrade(resp, &http.Request{ - Method: "GET", + Method: http.MethodGet, Header: http.Header{ "Upgrade": []string{"websocket"}, "Connection": []string{"upgrade"}, @@ -111,7 +111,7 @@ func TestBufioReuse(t *testing.T) { if reuse := c.br == br; reuse != tt.reuse { t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) } - writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw) + writeBuf := bufioWriterBuffer(c.NetConn(), bw) if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) } diff --git a/tls_handshake.go b/tls_handshake.go new file mode 100644 index 0000000..a62b68c --- /dev/null +++ b/tls_handshake.go @@ -0,0 +1,21 @@ +//go:build go1.17 +// +build go1.17 + +package websocket + +import ( + "context" + "crypto/tls" +) + +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.HandshakeContext(ctx); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/tls_handshake_116.go b/tls_handshake_116.go new file mode 100644 index 0000000..e1b2b44 --- /dev/null +++ b/tls_handshake_116.go @@ -0,0 +1,21 @@ +//go:build !go1.17 +// +build !go1.17 + +package websocket + +import ( + "context" + "crypto/tls" +) + +func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error { + if err := tlsConn.Handshake(); err != nil { + return err + } + if !cfg.InsecureSkipVerify { + if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { + return err + } + } + return nil +} diff --git a/trace.go b/trace.go deleted file mode 100644 index 834f122..0000000 --- a/trace.go +++ /dev/null @@ -1,19 +0,0 @@ -// +build go1.8 - -package websocket - -import ( - "crypto/tls" - "net/http/httptrace" -) - -func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { - if trace.TLSHandshakeStart != nil { - trace.TLSHandshakeStart() - } - err := doHandshake(tlsConn, cfg) - if trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) - } - return err -} diff --git a/trace_17.go b/trace_17.go deleted file mode 100644 index 77d05a0..0000000 --- a/trace_17.go +++ /dev/null @@ -1,12 +0,0 @@ -// +build !go1.8 - -package websocket - -import ( - "crypto/tls" - "net/http/httptrace" -) - -func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { - return doHandshake(tlsConn, cfg) -} diff --git a/util.go b/util.go index 7bf2f66..31a5dee 100644 --- a/util.go +++ b/util.go @@ -281,3 +281,18 @@ headers: } return result } + +// isValidChallengeKey checks if the argument meets RFC6455 specification. +func isValidChallengeKey(s string) bool { + // From RFC6455: + // + // A |Sec-WebSocket-Key| header field with a base64-encoded (see + // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in + // length. + + if s == "" { + return false + } + decoded, err := base64.StdEncoding.DecodeString(s) + return err == nil && len(decoded) == 16 +} diff --git a/util_test.go b/util_test.go index af710ba..f14d69a 100644 --- a/util_test.go +++ b/util_test.go @@ -53,6 +53,25 @@ func TestTokenListContainsValue(t *testing.T) { } } +var isValidChallengeKeyTests = []struct { + key string + ok bool +}{ + {"dGhlIHNhbXBsZSBub25jZQ==", true}, + {"", false}, + {"InvalidKey", false}, + {"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false}, +} + +func TestIsValidChallengeKey(t *testing.T) { + for _, tt := range isValidChallengeKeyTests { + ok := isValidChallengeKey(tt.key) + if ok != tt.ok { + t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok) + } + } +} + var parseExtensionTests = []struct { value string extensions []map[string]string