Compare commits

..

1 Commits

Author SHA1 Message Date
Matt Silverlock e365eeef09 test: fix unintended test comment in TestFraming 2019-08-24 22:20:45 -07:00
41 changed files with 364 additions and 628 deletions

View File

@ -1,70 +1,76 @@
version: 2.1 version: 2.0
jobs: jobs:
"test": # Base test configuration for Go library tests Each distinct version should
parameters: # inherit this base, and override (at least) the container image used.
version: "test": &test
type: string
default: "latest"
golint:
type: boolean
default: true
modules:
type: boolean
default: true
goproxy:
type: string
default: ""
docker: docker:
- image: "cimg/go:<< parameters.version >>" - image: circleci/golang:latest
working_directory: /home/circleci/project/go/src/git.internal/re/websocket working_directory: /go/src/github.com/gorilla/websocket
environment: steps: &steps
GO111MODULE: "on"
GOPROXY: "<< parameters.goproxy >>"
steps:
- checkout - checkout
- run: - run: go version
name: "Print the Go version" - run: go get -t -v ./...
command: >
go version
- run:
name: "Fetch dependencies"
command: >
if [[ << parameters.modules >> = true ]]; then
go mod download
export GO111MODULE=on
else
go get -v ./...
fi
# Only run gofmt, vet & lint against the latest Go version # Only run gofmt, vet & lint against the latest Go version
- run: - run: >
name: "Run golint" if [[ "$LATEST" = true ]]; then
command: > go get -u golang.org/x/lint/golint
if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then golint ./...
go get -u golang.org/x/lint/golint fi
golint ./... - run: >
fi if [[ "$LATEST" = true ]]; then
- run: diff -u <(echo -n) <(gofmt -d .)
name: "Run gofmt" fi
command: > - run: >
if [[ << parameters.version >> = "latest" ]]; then if [[ "$LATEST" = true ]]; then
diff -u <(echo -n) <(gofmt -d -e .) go vet -v .
fi fi
- run: - run: if [[ "$LATEST" = true ]]; then go vet -v .; fi
name: "Run go vet" - run: go test -v -race ./...
command: >
if [[ << parameters.version >> = "latest" ]]; then "latest":
go vet -v ./... <<: *test
fi environment:
- run: LATEST: true
name: "Run go test (+ race detector)"
command: > "1.12":
go test -v -race ./... <<: *test
docker:
- image: circleci/golang:1.12
"1.11":
<<: *test
docker:
- image: circleci/golang:1.11
"1.10":
<<: *test
docker:
- image: circleci/golang:1.10
"1.9":
<<: *test
docker:
- image: circleci/golang:1.9
"1.8":
<<: *test
docker:
- image: circleci/golang:1.8
"1.7":
<<: *test
docker:
- image: circleci/golang:1.7
workflows: workflows:
tests: version: 2
build:
jobs: jobs:
- test: - "latest"
matrix: - "1.12"
parameters: - "1.11"
version: ["1.18", "1.17", "1.16"] - "1.10"
- "1.9"
- "1.8"
- "1.7"

View File

@ -1,24 +1,18 @@
# Gorilla WebSocket # Gorilla WebSocket
[![GoDoc](https://godoc.org/git.internal/re/websocket?status.svg)](https://godoc.org/git.internal/re/websocket) [![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) [![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket)
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
---
**The Gorilla project has been archived, and is no longer under active maintainenance. You can read more here: https://github.com/gorilla#gorilla-toolkit**
---
### Documentation ### Documentation
* [API Reference](https://pkg.go.dev/git.internal/re/websocket?tab=doc) * [API Reference](http://godoc.org/github.com/gorilla/websocket)
* [Chat example](https://git.internal/re/websocket/tree/master/examples/chat) * [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
* [Command example](https://git.internal/re/websocket/tree/master/examples/command) * [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
* [Client and server example](https://git.internal/re/websocket/tree/master/examples/echo) * [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
* [File watch example](https://git.internal/re/websocket/tree/master/examples/filewatch) * [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
### Status ### Status
@ -28,11 +22,43 @@ package API is stable.
### Installation ### Installation
go get git.internal/re/websocket go get github.com/gorilla/websocket
### Protocol Compliance ### Protocol Compliance
The Gorilla WebSocket package passes the server tests in the [Autobahn Test The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://git.internal/re/websocket/tree/master/examples/autobahn). subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
### Gorilla WebSocket compared with other packages
<table>
<tr>
<th></th>
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
</tr>
<tr>
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
<tr><td>Passes <a href="https://github.com/crossbario/autobahn-testsuite">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
<tr><td colspan="3">Other Features</tr></td>
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
</table>
Notes:
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
2. The application can get the type of a received data message by implementing
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
function.
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
Read returns when the input buffer is full or a frame boundary is
encountered. Each call to Write sends a single frame message. The Gorilla
io.Reader and io.WriteCloser operate on a single WebSocket message.

View File

@ -9,7 +9,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
@ -49,23 +48,15 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
} }
// A Dialer contains options for connecting to WebSocket server. // A Dialer contains options for connecting to WebSocket server.
//
// It is safe to call Dialer's methods concurrently.
type Dialer struct { type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If // NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used. // NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error) NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If // NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, NetDial is used. // NetDialContext is nil, net.DialContext is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
// NetDialTLSContext is nil, NetDialContext is used.
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given // Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the // Request. If the function returns a non-nil error, the
// request is aborted with the provided error. // request is aborted with the provided error.
@ -74,8 +65,6 @@ type Dialer struct {
// TLSClientConfig specifies the TLS configuration to use with tls.Client. // TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used. // If nil, the default configuration is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete. // HandshakeTimeout specifies the duration for the handshake to complete.
@ -187,7 +176,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
req := &http.Request{ req := &http.Request{
Method: http.MethodGet, Method: "GET",
URL: u, URL: u,
Proto: "HTTP/1.1", Proto: "HTTP/1.1",
ProtoMajor: 1, ProtoMajor: 1,
@ -248,32 +237,13 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// Get network dial function. // Get network dial function.
var netDial func(network, add string) (net.Conn, error) var netDial func(network, add string) (net.Conn, error)
switch u.Scheme { if d.NetDialContext != nil {
case "http": netDial = func(network, addr string) (net.Conn, error) {
if d.NetDialContext != nil { return d.NetDialContext(ctx, network, addr)
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
} }
case "https": } else if d.NetDial != nil {
if d.NetDialTLSContext != nil { netDial = d.NetDial
netDial = func(network, addr string) (net.Conn, error) { } else {
return d.NetDialTLSContext(ctx, network, addr)
}
} else if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
default:
return nil, nil, errMalformedURL
}
if netDial == nil {
netDialer := &net.Dialer{} netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) { netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr) return netDialer.DialContext(ctx, network, addr)
@ -319,14 +289,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
netConn, err := netDial("tcp", hostPort) netConn, err := netDial("tcp", hostPort)
if err != nil {
return nil, nil, err
}
if trace != nil && trace.GotConn != nil { if trace != nil && trace.GotConn != nil {
trace.GotConn(httptrace.GotConnInfo{ trace.GotConn(httptrace.GotConnInfo{
Conn: netConn, Conn: netConn,
}) })
} }
if err != nil {
return nil, nil, err
}
defer func() { defer func() {
if netConn != nil { if netConn != nil {
@ -334,9 +304,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
}() }()
if u.Scheme == "https" && d.NetDialTLSContext == nil { if u.Scheme == "https" {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
cfg := cloneTLSConfig(d.TLSClientConfig) cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" { if cfg.ServerName == "" {
cfg.ServerName = hostNoPort cfg.ServerName = hostNoPort
@ -344,12 +312,11 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
tlsConn := tls.Client(netConn, cfg) tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn netConn = tlsConn
if trace != nil && trace.TLSHandshakeStart != nil { var err error
trace.TLSHandshakeStart() if trace != nil {
} err = doHandshakeWithTrace(trace, tlsConn, cfg)
err := doHandshake(ctx, tlsConn, cfg) } else {
if trace != nil && trace.TLSHandshakeDone != nil { err = doHandshake(tlsConn, cfg)
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
} }
if err != nil { if err != nil {
@ -371,17 +338,6 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
resp, err := http.ReadResponse(conn.br, req) resp, err := http.ReadResponse(conn.br, req)
if err != nil { if err != nil {
if d.TLSClientConfig != nil {
for _, proto := range d.TLSClientConfig.NextProtos {
if proto != "http/1.1" {
return nil, nil, fmt.Errorf(
"websocket: protocol %q was given but is not supported;"+
"sharing tls.Config with net/http Transport can cause this error: %w",
proto, err,
)
}
}
}
return nil, nil, err return nil, nil, err
} }
@ -392,8 +348,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
if resp.StatusCode != 101 || if resp.StatusCode != 101 ||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this // Before closing the network connection on return from this
// function, slurp up some of the response to aid application // function, slurp up some of the response to aid application
@ -426,9 +382,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return conn, resp, nil return conn, resp, nil
} }
func cloneTLSConfig(cfg *tls.Config) *tls.Config { func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
if cfg == nil { if err := tlsConn.Handshake(); err != nil {
return &tls.Config{} return err
} }
return cfg.Clone() if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
} }

16
client_clone.go Normal file
View File

@ -0,0 +1,16 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

38
client_clone_legacy.go Normal file
View File

@ -0,0 +1,38 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
import "crypto/tls"
// cloneTLSConfig clones all public fields except the fields
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
// config in active use.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View File

@ -11,7 +11,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -164,7 +163,7 @@ func TestProxyDial(t *testing.T) {
// Capture the request Host header. // Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc( s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect { if r.Method == "CONNECT" {
connect = true connect = true
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
return return
@ -204,7 +203,7 @@ func TestProxyAuthorizationDial(t *testing.T) {
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
proxyAuth := r.Header.Get("Proxy-Authorization") proxyAuth := r.Header.Get("Proxy-Authorization")
expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth { if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
connect = true connect = true
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
return return
@ -464,7 +463,7 @@ func TestBadMethod(t *testing.T) {
})) }))
defer s.Close() defer s.Close()
req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader("")) req, err := http.NewRequest("POST", s.URL, strings.NewReader(""))
if err != nil { if err != nil {
t.Fatalf("NewRequest returned error %v", err) t.Fatalf("NewRequest returned error %v", err)
} }
@ -482,23 +481,6 @@ func TestBadMethod(t *testing.T) {
} }
} }
func TestDialExtraTokensInRespHeaders(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
challengeKey := r.Header.Get("Sec-Websocket-Key")
w.Header().Set("Upgrade", "foo, websocket")
w.Header().Set("Connection", "upgrade, keep-alive")
w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
w.WriteHeader(101)
}))
defer s.Close()
ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
}
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -605,7 +587,7 @@ func TestHost(t *testing.T) {
server *httptest.Server // server to use server *httptest.Server // server to use
url string // host for request URI url string // host for request URI
header string // optional request host header header string // optional request host header
tls string // optional host for tls ServerName tls string // optiona host for tls ServerName
wantAddr string // expected host for dial wantAddr string // expected host for dial
wantHeader string // expected request header on server wantHeader string // expected request header on server
insecureSkipVerify bool insecureSkipVerify bool
@ -736,7 +718,7 @@ func TestHost(t *testing.T) {
Dial: dialer.NetDial, Dial: dialer.NetDial,
TLSClientConfig: dialer.TLSClientConfig, TLSClientConfig: dialer.TLSClientConfig,
} }
req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil) req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
if tt.header != "" { if tt.header != "" {
req.Host = tt.header req.Host = tt.header
} }
@ -921,215 +903,3 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
defer ws.Close() defer ws.Close()
sendRecv(t, ws) sendRecv(t, ws)
} }
// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
func TestNetDialConnect(t *testing.T) {
upgrader := Upgrader{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if IsWebSocketUpgrade(r) {
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
if err != nil {
t.Fatal(err)
}
c.Close()
} else {
w.Header().Set("X-Test-Host", r.Host)
}
})
server := httptest.NewServer(handler)
defer server.Close()
tlsServer := httptest.NewTLSServer(handler)
defer tlsServer.Close()
testUrls := map[*httptest.Server]string{
server: "ws://" + server.Listener.Addr().String() + "/",
tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
}
cas := rootCAs(t, tlsServer)
tlsConfig := &tls.Config{
RootCAs: cas,
ServerName: "example.com",
InsecureSkipVerify: false,
}
tests := []struct {
name string
server *httptest.Server // server to use
netDial func(network, addr string) (net.Conn, error)
netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
tlsClientConfig *tls.Config
}{
{
name: "HTTP server, all NetDial* defined, shall use NetDialContext",
server: server,
netDial: func(network, addr string) (net.Conn, error) {
return nil, errors.New("NetDial should not be called")
},
netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("NetDialTLSContext should not be called")
},
tlsClientConfig: nil,
},
{
name: "HTTP server, all NetDial* undefined",
server: server,
netDial: nil,
netDialContext: nil,
netDialTLSContext: nil,
tlsClientConfig: nil,
},
{
name: "HTTP server, NetDialContext undefined, shall fallback to NetDial",
server: server,
netDial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialContext: nil,
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("NetDialTLSContext should not be called")
},
tlsClientConfig: nil,
},
{
name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext",
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
return nil, errors.New("NetDial should not be called")
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("NetDialContext should not be called")
},
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
netConn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(netConn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
return nil, err
}
return tlsConn, nil
},
tlsClientConfig: nil,
},
{
name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake",
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
return nil, errors.New("NetDial should not be called")
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialTLSContext: nil,
tlsClientConfig: tlsConfig,
},
{
name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake",
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialContext: nil,
netDialTLSContext: nil,
tlsClientConfig: tlsConfig,
},
{
name: "HTTPS server, all NetDial* undefined",
server: tlsServer,
netDial: nil,
netDialContext: nil,
netDialTLSContext: nil,
tlsClientConfig: tlsConfig,
},
{
name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake",
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
return nil, errors.New("NetDial should not be called")
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("NetDialContext should not be called")
},
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
netConn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(netConn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
return nil, err
}
return tlsConn, nil
},
tlsClientConfig: &tls.Config{
RootCAs: nil,
ServerName: "badserver.com",
InsecureSkipVerify: false,
},
},
}
for _, tc := range tests {
dialer := Dialer{
NetDial: tc.netDial,
NetDialContext: tc.netDialContext,
NetDialTLSContext: tc.netDialTLSContext,
TLSClientConfig: tc.tlsClientConfig,
}
// Test websocket dial
c, _, err := dialer.Dial(testUrls[tc.server], nil)
if err != nil {
t.Errorf("FAILED %s, err: %s", tc.name, err.Error())
} else {
c.Close()
}
}
}
func TestNextProtos(t *testing.T) {
ts := httptest.NewUnstartedServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
)
ts.EnableHTTP2 = true
ts.StartTLS()
defer ts.Close()
d := Dialer{
TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
}
r, err := ts.Client().Get(ts.URL)
if err != nil {
t.Fatalf("Get: %v", err)
}
r.Body.Close()
// Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
// after the Client.Get call from net/http above.
var containsHTTP2 bool = false
for _, proto := range d.TLSClientConfig.NextProtos {
if proto == "h2" {
containsHTTP2 = true
}
}
if !containsHTTP2 {
t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
}
_, _, err = d.Dial(makeWsProto(ts.URL), nil)
if err == nil {
t.Fatalf("Dial succeeded, expect fail ")
}
}

85
conn.go
View File

@ -13,7 +13,6 @@ import (
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
"unicode/utf8" "unicode/utf8"
@ -245,8 +244,8 @@ type Conn struct {
subprotocol string subprotocol string
// Write fields // Write fields
mu chan struct{} // used as mutex to protect write to conn mu chan bool // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer. writeBuf []byte // frame is constructed in this buffer.
writePool BufferPool writePool BufferPool
writeBufSize int writeBufSize int
writeDeadline time.Time writeDeadline time.Time
@ -303,8 +302,8 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
writeBuf = make([]byte, writeBufferSize) writeBuf = make([]byte, writeBufferSize)
} }
mu := make(chan struct{}, 1) mu := make(chan bool, 1)
mu <- struct{}{} mu <- true
c := &Conn{ c := &Conn{
isServer: isServer, isServer: isServer,
br: br, br: br,
@ -378,7 +377,7 @@ func (c *Conn) read(n int) ([]byte, error) {
func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
<-c.mu <-c.mu
defer func() { c.mu <- struct{}{} }() defer func() { c.mu <- true }()
c.writeErrMu.Lock() c.writeErrMu.Lock()
err := c.writeErr err := c.writeErr
@ -402,12 +401,6 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return nil return nil
} }
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}
// WriteControl writes a control message with the given deadline. The allowed // WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage. // message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
@ -436,7 +429,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
maskBytes(key, 0, buf[6:]) maskBytes(key, 0, buf[6:])
} }
d := 1000 * time.Hour d := time.Hour * 1000
if !deadline.IsZero() { if !deadline.IsZero() {
d = deadline.Sub(time.Now()) d = deadline.Sub(time.Now())
if d < 0 { if d < 0 {
@ -451,7 +444,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
case <-timer.C: case <-timer.C:
return errWriteTimeout return errWriteTimeout
} }
defer func() { c.mu <- struct{}{} }() defer func() { c.mu <- true }()
c.writeErrMu.Lock() c.writeErrMu.Lock()
err := c.writeErr err := c.writeErr
@ -801,69 +794,47 @@ func (c *Conn) advanceFrame() (int, error) {
} }
// 2. Read and parse first two bytes of frame header. // 2. Read and parse first two bytes of frame header.
// To aid debugging, collect and report all errors in the first two bytes
// of the header.
var errors []string
p, err := c.read(2) p, err := c.read(2)
if err != nil { if err != nil {
return noFrame, err return noFrame, err
} }
frameType := int(p[0] & 0xf)
final := p[0]&finalBit != 0 final := p[0]&finalBit != 0
rsv1 := p[0]&rsv1Bit != 0 frameType := int(p[0] & 0xf)
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0 mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f)) c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false c.readDecompress = false
if rsv1 { if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
if c.newDecompressionReader != nil { c.readDecompress = true
c.readDecompress = true p[0] &^= rsv1Bit
} else {
errors = append(errors, "RSV1 set")
}
} }
if rsv2 { if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
errors = append(errors, "RSV2 set") return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
}
if rsv3 {
errors = append(errors, "RSV3 set")
} }
switch frameType { switch frameType {
case CloseMessage, PingMessage, PongMessage: case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize { if c.readRemaining > maxControlFramePayloadSize {
errors = append(errors, "len > 125 for control") return noFrame, c.handleProtocolError("control frame length > 125")
} }
if !final { if !final {
errors = append(errors, "FIN not set on control") return noFrame, c.handleProtocolError("control frame not final")
} }
case TextMessage, BinaryMessage: case TextMessage, BinaryMessage:
if !c.readFinal { if !c.readFinal {
errors = append(errors, "data before FIN") return noFrame, c.handleProtocolError("message start before final message frame")
} }
c.readFinal = final c.readFinal = final
case continuationFrame: case continuationFrame:
if c.readFinal { if c.readFinal {
errors = append(errors, "continuation after FIN") return noFrame, c.handleProtocolError("continuation after final message frame")
} }
c.readFinal = final c.readFinal = final
default: default:
errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
}
if mask != c.isServer {
errors = append(errors, "bad MASK")
}
if len(errors) > 0 {
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
} }
// 3. Read and parse frame length as per // 3. Read and parse frame length as per
@ -901,6 +872,10 @@ func (c *Conn) advanceFrame() (int, error) {
// 4. Handle frame masking. // 4. Handle frame masking.
if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}
if mask { if mask {
c.readMaskPos = 0 c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey)) p, err := c.read(len(c.readMaskKey))
@ -960,7 +935,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 { if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload)) closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) { if !isValidReceivedCloseCode(closeCode) {
return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) return noFrame, c.handleProtocolError("invalid close code")
} }
closeText = string(payload[2:]) closeText = string(payload[2:])
if !utf8.ValidString(closeText) { if !utf8.ValidString(closeText) {
@ -977,11 +952,7 @@ func (c *Conn) advanceFrame() (int, error) {
} }
func (c *Conn) handleProtocolError(message string) error { func (c *Conn) handleProtocolError(message string) error {
data := FormatCloseMessage(CloseProtocolError, message) c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message) return errors.New("websocket: " + message)
} }
@ -1189,16 +1160,8 @@ func (c *Conn) SetPongHandler(h func(appData string) error) {
c.handlePong = h c.handlePong = h
} }
// NetConn returns the underlying connection that is wrapped by c.
// Note that writing to or reading from this connection directly will corrupt the
// WebSocket connection.
func (c *Conn) NetConn() net.Conn {
return c.conn
}
// UnderlyingConn returns the internal net.Conn. This can be used to further // UnderlyingConn returns the internal net.Conn. This can be used to further
// modifications to connection specific flags. // modifications to connection specific flags.
// Deprecated: Use the NetConn method.
func (c *Conn) UnderlyingConn() net.Conn { func (c *Conn) UnderlyingConn() net.Conn {
return c.conn return c.conn
} }

View File

@ -18,6 +18,7 @@ import (
// scenarios with many subscribers in one channel. // scenarios with many subscribers in one channel.
type broadcastBench struct { type broadcastBench struct {
w io.Writer w io.Writer
message *broadcastMessage
closeCh chan struct{} closeCh chan struct{}
doneCh chan struct{} doneCh chan struct{}
count int32 count int32
@ -51,6 +52,14 @@ func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
usePrepared: usePrepared, usePrepared: usePrepared,
compression: compression, compression: compression,
} }
msg := &broadcastMessage{
payload: textMessages(1)[0],
}
if usePrepared {
pm, _ := NewPreparedMessage(TextMessage, msg.payload)
msg.prepared = pm
}
bench.message = msg
bench.makeConns(10000) bench.makeConns(10000)
return bench return bench
} }
@ -69,7 +78,7 @@ func (b *broadcastBench) makeConns(numConns int) {
for { for {
select { select {
case msg := <-c.msgCh: case msg := <-c.msgCh:
if msg.prepared != nil { if b.usePrepared {
c.conn.WritePreparedMessage(msg.prepared) c.conn.WritePreparedMessage(msg.prepared)
} else { } else {
c.conn.WriteMessage(TextMessage, msg.payload) c.conn.WriteMessage(TextMessage, msg.payload)
@ -91,9 +100,9 @@ func (b *broadcastBench) close() {
close(b.closeCh) close(b.closeCh)
} }
func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { func (b *broadcastBench) runOnce() {
for _, c := range b.conns { for _, c := range b.conns {
c.msgCh <- msg c.msgCh <- b.message
} }
<-b.doneCh <-b.doneCh
} }
@ -105,25 +114,17 @@ func BenchmarkBroadcast(b *testing.B) {
compression bool compression bool
}{ }{
{"NoCompression", false, false}, {"NoCompression", false, false},
{"Compression", false, true}, {"WithCompression", false, true},
{"NoCompressionPrepared", true, false}, {"NoCompressionPrepared", true, false},
{"CompressionPrepared", true, true}, {"WithCompressionPrepared", true, true},
} }
payload := textMessages(1)[0]
for _, bm := range benchmarks { for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) { b.Run(bm.name, func(b *testing.B) {
bench := newBroadcastBench(bm.usePrepared, bm.compression) bench := newBroadcastBench(bm.usePrepared, bm.compression)
defer bench.close() defer bench.close()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
message := &broadcastMessage{ bench.runOnce()
payload: payload,
}
if bench.usePrepared {
pm, _ := NewPreparedMessage(TextMessage, message.payload)
message.prepared = pm
}
bench.broadcastOnce(message)
} }
b.ReportAllocs() b.ReportAllocs()
}) })

View File

@ -56,9 +56,9 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
func TestFraming(t *testing.T) { func TestFraming(t *testing.T) {
frameSizes := []int{ frameSizes := []int{
0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537,
// 65536, 65537
} }
var readChunkers = []struct { var readChunkers = []struct {
name string name string
f func(io.Reader) io.Reader f func(io.Reader) io.Reader
@ -562,7 +562,7 @@ func TestAddrs(t *testing.T) {
} }
} }
func TestDeprecatedUnderlyingConn(t *testing.T) { func TestUnderlyingConn(t *testing.T) {
var b1, b2 bytes.Buffer var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2} fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024, nil, nil, nil) c := newConn(fc, true, 1024, 1024, nil, nil, nil)
@ -572,16 +572,6 @@ func TestDeprecatedUnderlyingConn(t *testing.T) {
} }
} }
func TestNetConn(t *testing.T) {
var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
ul := c.NetConn()
if ul != fc {
t.Fatalf("Underlying conn is not what it should be.")
}
}
func TestBufioReadBytes(t *testing.T) { func TestBufioReadBytes(t *testing.T) {
// Test calling bufio.ReadBytes for value longer than read buffer size. // Test calling bufio.ReadBytes for value longer than read buffer size.

15
conn_write.go Normal file
View File

@ -0,0 +1,15 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "net"
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}

18
conn_write_legacy.go Normal file
View File

@ -0,0 +1,18 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
func (c *Conn) writeBufs(bufs ...[]byte) error {
for _, buf := range bufs {
if len(buf) > 0 {
if _, err := c.conn.Write(buf); err != nil {
return err
}
}
}
return nil
}

6
doc.go
View File

@ -187,9 +187,9 @@
// than the largest message do not provide any benefit. // than the largest message do not provide any benefit.
// //
// Depending on the distribution of message sizes, setting the buffer size to // Depending on the distribution of message sizes, setting the buffer size to
// a value less than the maximum expected message size can greatly reduce memory // to a value less than the maximum expected message size can greatly reduce
// use with a small impact on performance. Here's an example: If 99% of the // memory use with a small impact on performance. Here's an example: If 99% of
// messages are smaller than 256 bytes and the maximum message size is 512 // the messages are smaller than 256 bytes and the maximum message size is 512
// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls // bytes, then a buffer size of 256 bytes will result in 1.01 more system calls
// than a buffer size of 512 bytes. The memory savings is 50%. // than a buffer size of 512 bytes. The memory savings is 50%.
// //

View File

@ -9,7 +9,7 @@ import (
"net/http" "net/http"
"testing" "testing"
"git.internal/re/websocket" "github.com/gorilla/websocket"
) )
var ( var (
@ -23,9 +23,10 @@ var (
// This server application works with a client application running in the // This server application works with a client application running in the
// browser. The client application does not explicitly close the websocket. The // browser. The client application does not explicitly close the websocket. The
// only expected close message from the client has the code // only expected close message from the client has the code
// websocket.CloseGoingAway. All other close messages are likely the // websocket.CloseGoingAway. All other other close messages are likely the
// result of an application or protocol error and are logged to aid debugging. // result of an application or protocol error and are logged to aid debugging.
func ExampleIsUnexpectedCloseError() { func ExampleIsUnexpectedCloseError() {
for { for {
messageType, p, err := c.ReadMessage() messageType, p, err := c.ReadMessage()
if err != nil { if err != nil {
@ -34,11 +35,11 @@ func ExampleIsUnexpectedCloseError() {
} }
return return
} }
processMessage(messageType, p) processMesage(messageType, p)
} }
} }
func processMessage(mt int, p []byte) {} func processMesage(mt int, p []byte) {}
// TestX prevents godoc from showing this entire file in the example. Remove // TestX prevents godoc from showing this entire file in the example. Remove
// this function when a second example is added. // this function when a second example is added.

View File

@ -8,11 +8,6 @@ To test the server, run
and start the client test driver and start the client test driver
mkdir -p reports wstest -m fuzzingclient -s fuzzingclient.json
docker run -it --rm \
-v ${PWD}/config:/config \
-v ${PWD}/reports:/reports \
crossbario/autobahn-testsuite \
wstest -m fuzzingclient -s /config/fuzzingclient.json
When the client completes, it writes a report to reports/index.html. When the client completes, it writes a report to reports/clients/index.html.

View File

@ -1,29 +0,0 @@
{
"cases": ["*"],
"exclude-cases": [],
"exclude-agent-cases": {},
"outdir": "/reports",
"options": {"failByDrop": false},
"servers": [
{
"agent": "ReadAllWriteMessage",
"url": "ws://host.docker.internal:9000/m"
},
{
"agent": "ReadAllWritePreparedMessage",
"url": "ws://host.docker.internal:9000/p"
},
{
"agent": "CopyFull",
"url": "ws://host.docker.internal:9000/f"
},
{
"agent": "ReadAllWrite",
"url": "ws://host.docker.internal:9000/r"
},
{
"agent": "CopyWriterOnly",
"url": "ws://host.docker.internal:9000/c"
}
]
}

View File

@ -0,0 +1,15 @@
{
"options": {"failByDrop": false},
"outdir": "./reports/clients",
"servers": [
{"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}},
{"agent": "ReadAllWritePreparedMessage", "url": "ws://localhost:9000/p", "options": {"version": 18}},
{"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}},
{"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}},
{"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}}
],
"cases": ["*"],
"exclude-cases": [],
"exclude-agent-cases": {}
}

View File

@ -14,7 +14,7 @@ import (
"time" "time"
"unicode/utf8" "unicode/utf8"
"git.internal/re/websocket" "github.com/gorilla/websocket"
) )
var upgrader = websocket.Upgrader{ var upgrader = websocket.Upgrader{
@ -160,7 +160,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found.", http.StatusNotFound) http.Error(w, "Not found.", http.StatusNotFound)
return return
} }
if r.Method != http.MethodGet { if r.Method != "GET" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return return
} }

View File

@ -1,7 +1,7 @@
# Chat Example # Chat Example
This application shows how to use the This application shows how to use the
[websocket](https://git.internal/re/websocket) package to implement a simple [websocket](https://github.com/gorilla/websocket) package to implement a simple
web chat application. web chat application.
## Running the example ## Running the example
@ -13,8 +13,8 @@ development environment.
Once you have Go up and running, you can download, build and run the example Once you have Go up and running, you can download, build and run the example
using the following commands. using the following commands.
$ go get git.internal/re/websocket $ go get github.com/gorilla/websocket
$ cd `go list -f '{{.Dir}}' git.internal/re/websocket/examples/chat` $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/chat`
$ go run *.go $ go run *.go
To use the chat example, open http://localhost:8080/ in your browser. To use the chat example, open http://localhost:8080/ in your browser.
@ -38,7 +38,7 @@ sends them to the hub.
### Hub ### Hub
The code for the `Hub` type is in The code for the `Hub` type is in
[hub.go](https://git.internal/re/websocket/blob/master/examples/chat/hub.go). [hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go).
The application's `main` function starts the hub's `run` method as a goroutine. The application's `main` function starts the hub's `run` method as a goroutine.
Clients send requests to the hub using the `register`, `unregister` and Clients send requests to the hub using the `register`, `unregister` and
`broadcast` channels. `broadcast` channels.
@ -57,7 +57,7 @@ unregisters the client and closes the websocket.
### Client ### Client
The code for the `Client` type is in [client.go](https://git.internal/re/websocket/blob/master/examples/chat/client.go). The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go).
The `serveWs` function is registered by the application's `main` function as The `serveWs` function is registered by the application's `main` function as
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
@ -73,7 +73,7 @@ Finally, the HTTP handler calls the client's `readPump` method. This method
transfers inbound messages from the websocket to the hub. transfers inbound messages from the websocket to the hub.
WebSocket connections [support one concurrent reader and one concurrent WebSocket connections [support one concurrent reader and one concurrent
writer](https://godoc.org/git.internal/re/websocket#hdr-Concurrency). The writer](https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency). The
application ensures that these concurrency requirements are met by executing application ensures that these concurrency requirements are met by executing
all reads from the `readPump` goroutine and all writes from the `writePump` all reads from the `readPump` goroutine and all writes from the `writePump`
goroutine. goroutine.
@ -85,7 +85,7 @@ network.
## Frontend ## Frontend
The frontend code is in [home.html](https://git.internal/re/websocket/blob/master/examples/chat/home.html). The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html).
On document load, the script checks for websocket functionality in the browser. On document load, the script checks for websocket functionality in the browser.
If websocket functionality is available, then the script opens a connection to If websocket functionality is available, then the script opens a connection to

View File

@ -10,7 +10,7 @@ import (
"net/http" "net/http"
"time" "time"
"git.internal/re/websocket" "github.com/gorilla/websocket"
) )
const ( const (

View File

@ -92,7 +92,7 @@ body {
<div id="log"></div> <div id="log"></div>
<form id="form"> <form id="form">
<input type="submit" value="Send" /> <input type="submit" value="Send" />
<input type="text" id="msg" size="64" autofocus /> <input type="text" id="msg" size="64"/>
</form> </form>
</body> </body>
</html> </html>

View File

@ -18,7 +18,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found", http.StatusNotFound) http.Error(w, "Not found", http.StatusNotFound)
return return
} }
if r.Method != http.MethodGet { if r.Method != "GET" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return return
} }

View File

@ -4,8 +4,8 @@ This example connects a websocket connection to stdin and stdout of a command.
Received messages are written to stdin followed by a `\n`. Each line read from Received messages are written to stdin followed by a `\n`. Each line read from
standard out is sent as a message to the client. standard out is sent as a message to the client.
$ go get git.internal/re/websocket $ go get github.com/gorilla/websocket
$ cd `go list -f '{{.Dir}}' git.internal/re/websocket/examples/command` $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/command`
$ go run main.go <command and arguments to run> $ go run main.go <command and arguments to run>
# Open http://localhost:8080/ . # Open http://localhost:8080/ .

View File

@ -14,7 +14,7 @@ import (
"os/exec" "os/exec"
"time" "time"
"git.internal/re/websocket" "github.com/gorilla/websocket"
) )
var ( var (
@ -170,7 +170,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found", http.StatusNotFound) http.Error(w, "Not found", http.StatusNotFound)
return return
} }
if r.Method != http.MethodGet { if r.Method != "GET" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return return
} }

View File

@ -2,7 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build ignore
// +build ignore // +build ignore
package main package main
@ -15,7 +14,7 @@ import (
"os/signal" "os/signal"
"time" "time"
"git.internal/re/websocket" "github.com/gorilla/websocket"
) )
var addr = flag.String("addr", "localhost:8080", "http service address") var addr = flag.String("addr", "localhost:8080", "http service address")

View File

@ -2,7 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build ignore
// +build ignore // +build ignore
package main package main
@ -13,7 +12,7 @@ import (
"log" "log"
"net/http" "net/http"
"git.internal/re/websocket" "github.com/gorilla/websocket"
) )
var addr = flag.String("addr", "localhost:8080", "http service address") var addr = flag.String("addr", "localhost:8080", "http service address")
@ -68,9 +67,8 @@ window.addEventListener("load", function(evt) {
var print = function(message) { var print = function(message) {
var d = document.createElement("div"); var d = document.createElement("div");
d.textContent = message; d.innerHTML = message;
output.appendChild(d); output.appendChild(d);
output.scroll(0, output.scrollHeight);
}; };
document.getElementById("open").onclick = function(evt) { document.getElementById("open").onclick = function(evt) {
@ -128,7 +126,7 @@ You can change the message and send multiple times.
<button id="send">Send</button> <button id="send">Send</button>
</form> </form>
</td><td valign="top" width="50%"> </td><td valign="top" width="50%">
<div id="output" style="max-height: 70vh;overflow-y: scroll;"></div> <div id="output"></div>
</td></tr></table> </td></tr></table>
</body> </body>
</html> </html>

View File

@ -2,8 +2,8 @@
This example sends a file to the browser client for display whenever the file is modified. This example sends a file to the browser client for display whenever the file is modified.
$ go get git.internal/re/websocket $ go get github.com/gorilla/websocket
$ cd `go list -f '{{.Dir}}' git.internal/re/websocket/examples/filewatch` $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/filewatch`
$ go run main.go <name of file to watch> $ go run main.go <name of file to watch>
# Open http://localhost:8080/ . # Open http://localhost:8080/ .
# Modify the file to see it update in the browser. # Modify the file to see it update in the browser.

View File

@ -14,7 +14,7 @@ import (
"strconv" "strconv"
"time" "time"
"git.internal/re/websocket" "github.com/gorilla/websocket"
) )
const ( const (
@ -133,7 +133,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found", http.StatusNotFound) http.Error(w, "Not found", http.StatusNotFound)
return return
} }
if r.Method != http.MethodGet { if r.Method != "GET" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return return
} }
@ -143,7 +143,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
p = []byte(err.Error()) p = []byte(err.Error())
lastMod = time.Unix(0, 0) lastMod = time.Unix(0, 0)
} }
v := struct { var v = struct {
Host string Host string
Data string Data string
LastMod string LastMod string

2
go.mod
View File

@ -1,3 +1,3 @@
module git.internal/re/websocket module github.com/gorilla/websocket
go 1.12 go 1.12

2
go.sum
View File

@ -0,0 +1,2 @@
github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=

View File

@ -2,7 +2,6 @@
// this source code is governed by a BSD-style license that can be found in the // this source code is governed by a BSD-style license that can be found in the
// LICENSE file. // LICENSE file.
//go:build !appengine
// +build !appengine // +build !appengine
package websocket package websocket

View File

@ -2,7 +2,6 @@
// this source code is governed by a BSD-style license that can be found in the // this source code is governed by a BSD-style license that can be found in the
// LICENSE file. // LICENSE file.
//go:build appengine
// +build appengine // +build appengine
package websocket package websocket

View File

@ -73,8 +73,8 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
// Prepare a frame using a 'fake' connection. // Prepare a frame using a 'fake' connection.
// TODO: Refactor code in conn.go to allow more direct construction of // TODO: Refactor code in conn.go to allow more direct construction of
// the frame. // the frame.
mu := make(chan struct{}, 1) mu := make(chan bool, 1)
mu <- struct{}{} mu <- true
var nc prepareConn var nc prepareConn
c := &Conn{ c := &Conn{
conn: &nc, conn: &nc,

View File

@ -48,7 +48,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
} }
connectReq := &http.Request{ connectReq := &http.Request{
Method: http.MethodConnect, Method: "CONNECT",
URL: &url.URL{Opaque: addr}, URL: &url.URL{Opaque: addr},
Host: addr, Host: addr,
Header: connectHeader, Header: connectHeader,

View File

@ -23,8 +23,6 @@ func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a // Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection. // WebSocket connection.
//
// It is safe to call Upgrader's methods concurrently.
type Upgrader struct { type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete. // HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration HandshakeTimeout time.Duration
@ -117,8 +115,8 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
// Upgrade upgrades the HTTP server connection to the WebSocket protocol. // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
// //
// The responseHeader is included in the response to the client's upgrade // The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie). To specify // request. Use the responseHeader to specify cookies (Set-Cookie) and the
// subprotocols supported by the server, set Upgrader.Subprotocols directly. // application negotiated subprotocol (Sec-WebSocket-Protocol).
// //
// If the upgrade fails, then Upgrade replies to the client with an HTTP error // If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response. // response.
@ -133,7 +131,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
} }
if r.Method != http.MethodGet { if r.Method != "GET" {
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
} }
@ -154,8 +152,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
challengeKey := r.Header.Get("Sec-Websocket-Key") challengeKey := r.Header.Get("Sec-Websocket-Key")
if !isValidChallengeKey(challengeKey) { if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length") return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
} }
subprotocol := u.selectSubprotocol(r, responseHeader) subprotocol := u.selectSubprotocol(r, responseHeader)

View File

@ -98,7 +98,7 @@ func TestBufioReuse(t *testing.T) {
} }
upgrader := Upgrader{} upgrader := Upgrader{}
c, err := upgrader.Upgrade(resp, &http.Request{ c, err := upgrader.Upgrade(resp, &http.Request{
Method: http.MethodGet, Method: "GET",
Header: http.Header{ Header: http.Header{
"Upgrade": []string{"websocket"}, "Upgrade": []string{"websocket"},
"Connection": []string{"upgrade"}, "Connection": []string{"upgrade"},
@ -111,7 +111,7 @@ func TestBufioReuse(t *testing.T) {
if reuse := c.br == br; reuse != tt.reuse { if reuse := c.br == br; reuse != tt.reuse {
t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
} }
writeBuf := bufioWriterBuffer(c.NetConn(), bw) writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw)
if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
} }

View File

@ -1,21 +0,0 @@
//go:build go1.17
// +build go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@ -1,21 +0,0 @@
//go:build !go1.17
// +build !go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

19
trace.go Normal file
View File

@ -0,0 +1,19 @@
// +build go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
if trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(tlsConn, cfg)
if trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
return err
}

12
trace_17.go Normal file
View File

@ -0,0 +1,12 @@
// +build !go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
return doHandshake(tlsConn, cfg)
}

15
util.go
View File

@ -281,18 +281,3 @@ headers:
} }
return result return result
} }
// isValidChallengeKey checks if the argument meets RFC6455 specification.
func isValidChallengeKey(s string) bool {
// From RFC6455:
//
// A |Sec-WebSocket-Key| header field with a base64-encoded (see
// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
// length.
if s == "" {
return false
}
decoded, err := base64.StdEncoding.DecodeString(s)
return err == nil && len(decoded) == 16
}

View File

@ -53,25 +53,6 @@ func TestTokenListContainsValue(t *testing.T) {
} }
} }
var isValidChallengeKeyTests = []struct {
key string
ok bool
}{
{"dGhlIHNhbXBsZSBub25jZQ==", true},
{"", false},
{"InvalidKey", false},
{"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
}
func TestIsValidChallengeKey(t *testing.T) {
for _, tt := range isValidChallengeKeyTests {
ok := isValidChallengeKey(tt.key)
if ok != tt.ok {
t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
}
}
}
var parseExtensionTests = []struct { var parseExtensionTests = []struct {
value string value string
extensions []map[string]string extensions []map[string]string