Merge branch 'main' into fix-issue-479

This commit is contained in:
Corey Daley 2023-07-30 14:21:29 -04:00 committed by GitHub
commit a889672aa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 576 additions and 314 deletions

View File

@ -1,76 +1,70 @@
version: 2.0 version: 2.1
jobs: jobs:
# Base test configuration for Go library tests Each distinct version should "test":
# inherit this base, and override (at least) the container image used. parameters:
"test": &test version:
type: string
default: "latest"
golint:
type: boolean
default: true
modules:
type: boolean
default: true
goproxy:
type: string
default: ""
docker: docker:
- image: circleci/golang:latest - image: "cimg/go:<< parameters.version >>"
working_directory: /go/src/github.com/gorilla/websocket working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket
steps: &steps
- checkout
- run: go version
- run: go get -t -v ./...
# Only run gofmt, vet & lint against the latest Go version
- run: >
if [[ "$LATEST" = true ]]; then
go get -u golang.org/x/lint/golint
golint ./...
fi
- run: >
if [[ "$LATEST" = true ]]; then
diff -u <(echo -n) <(gofmt -d .)
fi
- run: >
if [[ "$LATEST" = true ]]; then
go vet -v .
fi
- run: if [[ "$LATEST" = true ]]; then go vet -v .; fi
- run: go test -v -race ./...
"latest":
<<: *test
environment: environment:
LATEST: true GO111MODULE: "on"
GOPROXY: "<< parameters.goproxy >>"
"1.12": steps:
<<: *test - checkout
docker: - run:
- image: circleci/golang:1.12 name: "Print the Go version"
command: >
"1.11": go version
<<: *test - run:
docker: name: "Fetch dependencies"
- image: circleci/golang:1.11 command: >
if [[ << parameters.modules >> = true ]]; then
"1.10": go mod download
<<: *test export GO111MODULE=on
docker: else
- image: circleci/golang:1.10 go get -v ./...
fi
"1.9": # Only run gofmt, vet & lint against the latest Go version
<<: *test - run:
docker: name: "Run golint"
- image: circleci/golang:1.9 command: >
if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
"1.8": go get -u golang.org/x/lint/golint
<<: *test golint ./...
docker: fi
- image: circleci/golang:1.8 - run:
name: "Run gofmt"
"1.7": command: >
<<: *test if [[ << parameters.version >> = "latest" ]]; then
docker: diff -u <(echo -n) <(gofmt -d -e .)
- image: circleci/golang:1.7 fi
- run:
name: "Run go vet"
command: >
if [[ << parameters.version >> = "latest" ]]; then
go vet -v ./...
fi
- run:
name: "Run go test (+ race detector)"
command: >
go test -v -race ./...
workflows: workflows:
version: 2 tests:
build:
jobs: jobs:
- "latest" - test:
- "1.12" matrix:
- "1.11" parameters:
- "1.10" version: ["1.18", "1.17", "1.16"]
- "1.9"
- "1.8"
- "1.7"

View File

@ -6,6 +6,7 @@
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
### Documentation ### Documentation
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
@ -30,35 +31,3 @@ The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
### Gorilla WebSocket compared with other packages
<table>
<tr>
<th></th>
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
</tr>
<tr>
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
<tr><td>Passes <a href="https://github.com/crossbario/autobahn-testsuite">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
<tr><td colspan="3">Other Features</tr></td>
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
</table>
Notes:
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
2. The application can get the type of a received data message by implementing
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
function.
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
Read returns when the input buffer is full or a frame boundary is
encountered. Each call to Write sends a single frame message. The Gorilla
io.Reader and io.WriteCloser operate on a single WebSocket message.

View File

@ -9,6 +9,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
@ -48,15 +49,23 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
} }
// A Dialer contains options for connecting to WebSocket server. // A Dialer contains options for connecting to WebSocket server.
//
// It is safe to call Dialer's methods concurrently.
type Dialer struct { type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If // NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used. // NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error) NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If // NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, net.DialContext is used. // NetDialContext is nil, NetDial is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
// NetDialTLSContext is nil, NetDialContext is used.
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given // Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the // Request. If the function returns a non-nil error, the
// request is aborted with the provided error. // request is aborted with the provided error.
@ -65,6 +74,8 @@ type Dialer struct {
// TLSClientConfig specifies the TLS configuration to use with tls.Client. // TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used. // If nil, the default configuration is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete. // HandshakeTimeout specifies the duration for the handshake to complete.
@ -179,7 +190,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
req := &http.Request{ req := &http.Request{
Method: "GET", Method: http.MethodGet,
URL: u, URL: u,
Proto: "HTTP/1.1", Proto: "HTTP/1.1",
ProtoMajor: 1, ProtoMajor: 1,
@ -240,13 +251,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// Get network dial function. // Get network dial function.
var netDial func(network, add string) (net.Conn, error) var netDial func(network, add string) (net.Conn, error)
if d.NetDialContext != nil { switch u.Scheme {
netDial = func(network, addr string) (net.Conn, error) { case "http":
return d.NetDialContext(ctx, network, addr) if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
} }
} else if d.NetDial != nil { case "https":
netDial = d.NetDial if d.NetDialTLSContext != nil {
} else { netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialTLSContext(ctx, network, addr)
}
} else if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
default:
return nil, nil, errMalformedURL
}
if netDial == nil {
netDialer := &net.Dialer{} netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) { netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr) return netDialer.DialContext(ctx, network, addr)
@ -292,14 +322,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
netConn, err := netDial("tcp", hostPort) netConn, err := netDial("tcp", hostPort)
if err != nil {
return nil, nil, err
}
if trace != nil && trace.GotConn != nil { if trace != nil && trace.GotConn != nil {
trace.GotConn(httptrace.GotConnInfo{ trace.GotConn(httptrace.GotConnInfo{
Conn: netConn, Conn: netConn,
}) })
} }
if err != nil {
return nil, nil, err
}
defer func() { defer func() {
if netConn != nil { if netConn != nil {
@ -307,7 +337,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
}() }()
if u.Scheme == "https" { if u.Scheme == "https" && d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
cfg := cloneTLSConfig(d.TLSClientConfig) cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" { if cfg.ServerName == "" {
cfg.ServerName = hostNoPort cfg.ServerName = hostNoPort
@ -315,11 +347,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
tlsConn := tls.Client(netConn, cfg) tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn netConn = tlsConn
var err error if trace != nil && trace.TLSHandshakeStart != nil {
if trace != nil { trace.TLSHandshakeStart()
err = doHandshakeWithTrace(trace, tlsConn, cfg) }
} else { err := doHandshake(ctx, tlsConn, cfg)
err = doHandshake(tlsConn, cfg) if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
} }
if err != nil { if err != nil {
@ -341,6 +374,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
resp, err := http.ReadResponse(conn.br, req) resp, err := http.ReadResponse(conn.br, req)
if err != nil { if err != nil {
if d.TLSClientConfig != nil {
for _, proto := range d.TLSClientConfig.NextProtos {
if proto != "http/1.1" {
return nil, nil, fmt.Errorf(
"websocket: protocol %q was given but is not supported;"+
"sharing tls.Config with net/http Transport can cause this error: %w",
proto, err,
)
}
}
}
return nil, nil, err return nil, nil, err
} }
@ -351,8 +395,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
if resp.StatusCode != 101 || if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this // Before closing the network connection on return from this
// function, slurp up some of the response to aid application // function, slurp up some of the response to aid application
@ -385,14 +429,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return conn, resp, nil return conn, resp, nil
} }
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if err := tlsConn.Handshake(); err != nil { if cfg == nil {
return err return &tls.Config{}
} }
if !cfg.InsecureSkipVerify { return cfg.Clone()
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
} }

View File

@ -1,16 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

View File

@ -1,38 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
import "crypto/tls"
// cloneTLSConfig clones all public fields except the fields
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
// config in active use.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View File

@ -11,6 +11,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -166,7 +167,7 @@ func TestProxyDial(t *testing.T) {
// Capture the request Host header. // Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc( s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" { if r.Method == http.MethodConnect {
connect = true connect = true
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if r.Header.Get("User-Agents") != "xxx" { if r.Header.Get("User-Agents") != "xxx" {
@ -210,7 +211,7 @@ func TestProxyAuthorizationDial(t *testing.T) {
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
proxyAuth := r.Header.Get("Proxy-Authorization") proxyAuth := r.Header.Get("Proxy-Authorization")
expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth {
connect = true connect = true
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
return return
@ -470,7 +471,7 @@ func TestBadMethod(t *testing.T) {
})) }))
defer s.Close() defer s.Close()
req, err := http.NewRequest("POST", s.URL, strings.NewReader("")) req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader(""))
if err != nil { if err != nil {
t.Fatalf("NewRequest returned error %v", err) t.Fatalf("NewRequest returned error %v", err)
} }
@ -488,6 +489,23 @@ func TestBadMethod(t *testing.T) {
} }
} }
func TestDialExtraTokensInRespHeaders(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
challengeKey := r.Header.Get("Sec-Websocket-Key")
w.Header().Set("Upgrade", "foo, websocket")
w.Header().Set("Connection", "upgrade, keep-alive")
w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
w.WriteHeader(101)
}))
defer s.Close()
ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
}
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -725,7 +743,7 @@ func TestHost(t *testing.T) {
Dial: dialer.NetDial, Dial: dialer.NetDial,
TLSClientConfig: dialer.TLSClientConfig, TLSClientConfig: dialer.TLSClientConfig,
} }
req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil) req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil)
if tt.header != "" { if tt.header != "" {
req.Host = tt.header req.Host = tt.header
} }
@ -910,3 +928,215 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
defer ws.Close() defer ws.Close()
sendRecv(t, ws) sendRecv(t, ws)
} }
// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
func TestNetDialConnect(t *testing.T) {
upgrader := Upgrader{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if IsWebSocketUpgrade(r) {
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
if err != nil {
t.Fatal(err)
}
c.Close()
} else {
w.Header().Set("X-Test-Host", r.Host)
}
})
server := httptest.NewServer(handler)
defer server.Close()
tlsServer := httptest.NewTLSServer(handler)
defer tlsServer.Close()
testUrls := map[*httptest.Server]string{
server: "ws://" + server.Listener.Addr().String() + "/",
tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
}
cas := rootCAs(t, tlsServer)
tlsConfig := &tls.Config{
RootCAs: cas,
ServerName: "example.com",
InsecureSkipVerify: false,
}
tests := []struct {
name string
server *httptest.Server // server to use
netDial func(network, addr string) (net.Conn, error)
netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
tlsClientConfig *tls.Config
}{
{
name: "HTTP server, all NetDial* defined, shall use NetDialContext",
server: server,
netDial: func(network, addr string) (net.Conn, error) {
return nil, errors.New("NetDial should not be called")
},
netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("NetDialTLSContext should not be called")
},
tlsClientConfig: nil,
},
{
name: "HTTP server, all NetDial* undefined",
server: server,
netDial: nil,
netDialContext: nil,
netDialTLSContext: nil,
tlsClientConfig: nil,
},
{
name: "HTTP server, NetDialContext undefined, shall fallback to NetDial",
server: server,
netDial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialContext: nil,
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("NetDialTLSContext should not be called")
},
tlsClientConfig: nil,
},
{
name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext",
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
return nil, errors.New("NetDial should not be called")
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("NetDialContext should not be called")
},
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
netConn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(netConn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
return nil, err
}
return tlsConn, nil
},
tlsClientConfig: nil,
},
{
name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake",
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
return nil, errors.New("NetDial should not be called")
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialTLSContext: nil,
tlsClientConfig: tlsConfig,
},
{
name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake",
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
},
netDialContext: nil,
netDialTLSContext: nil,
tlsClientConfig: tlsConfig,
},
{
name: "HTTPS server, all NetDial* undefined",
server: tlsServer,
netDial: nil,
netDialContext: nil,
netDialTLSContext: nil,
tlsClientConfig: tlsConfig,
},
{
name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake",
server: tlsServer,
netDial: func(network, addr string) (net.Conn, error) {
return nil, errors.New("NetDial should not be called")
},
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, errors.New("NetDialContext should not be called")
},
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
netConn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(netConn, tlsConfig)
err = tlsConn.Handshake()
if err != nil {
return nil, err
}
return tlsConn, nil
},
tlsClientConfig: &tls.Config{
RootCAs: nil,
ServerName: "badserver.com",
InsecureSkipVerify: false,
},
},
}
for _, tc := range tests {
dialer := Dialer{
NetDial: tc.netDial,
NetDialContext: tc.netDialContext,
NetDialTLSContext: tc.netDialTLSContext,
TLSClientConfig: tc.tlsClientConfig,
}
// Test websocket dial
c, _, err := dialer.Dial(testUrls[tc.server], nil)
if err != nil {
t.Errorf("FAILED %s, err: %s", tc.name, err.Error())
} else {
c.Close()
}
}
}
func TestNextProtos(t *testing.T) {
ts := httptest.NewUnstartedServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
)
ts.EnableHTTP2 = true
ts.StartTLS()
defer ts.Close()
d := Dialer{
TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
}
r, err := ts.Client().Get(ts.URL)
if err != nil {
t.Fatalf("Get: %v", err)
}
r.Body.Close()
// Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
// after the Client.Get call from net/http above.
var containsHTTP2 bool = false
for _, proto := range d.TLSClientConfig.NextProtos {
if proto == "h2" {
containsHTTP2 = true
}
}
if !containsHTTP2 {
t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
}
_, _, err = d.Dial(makeWsProto(ts.URL), nil)
if err == nil {
t.Fatalf("Dial succeeded, expect fail ")
}
}

71
conn.go
View File

@ -13,6 +13,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
"unicode/utf8" "unicode/utf8"
@ -401,6 +402,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return nil return nil
} }
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}
// WriteControl writes a control message with the given deadline. The allowed // WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage. // message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
@ -794,47 +801,69 @@ func (c *Conn) advanceFrame() (int, error) {
} }
// 2. Read and parse first two bytes of frame header. // 2. Read and parse first two bytes of frame header.
// To aid debugging, collect and report all errors in the first two bytes
// of the header.
var errors []string
p, err := c.read(2) p, err := c.read(2)
if err != nil { if err != nil {
return noFrame, err return noFrame, err
} }
final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf) frameType := int(p[0] & 0xf)
final := p[0]&finalBit != 0
rsv1 := p[0]&rsv1Bit != 0
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0 mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f)) c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { if rsv1 {
c.readDecompress = true if c.newDecompressionReader != nil {
p[0] &^= rsv1Bit c.readDecompress = true
} else {
errors = append(errors, "RSV1 set")
}
} }
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { if rsv2 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) errors = append(errors, "RSV2 set")
}
if rsv3 {
errors = append(errors, "RSV3 set")
} }
switch frameType { switch frameType {
case CloseMessage, PingMessage, PongMessage: case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize { if c.readRemaining > maxControlFramePayloadSize {
return noFrame, c.handleProtocolError("control frame length > 125") errors = append(errors, "len > 125 for control")
} }
if !final { if !final {
return noFrame, c.handleProtocolError("control frame not final") errors = append(errors, "FIN not set on control")
} }
case TextMessage, BinaryMessage: case TextMessage, BinaryMessage:
if !c.readFinal { if !c.readFinal {
return noFrame, c.handleProtocolError("message start before final message frame") errors = append(errors, "data before FIN")
} }
c.readFinal = final c.readFinal = final
case continuationFrame: case continuationFrame:
if c.readFinal { if c.readFinal {
return noFrame, c.handleProtocolError("continuation after final message frame") errors = append(errors, "continuation after FIN")
} }
c.readFinal = final c.readFinal = final
default: default:
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
}
if mask != c.isServer {
errors = append(errors, "bad MASK")
}
if len(errors) > 0 {
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
} }
// 3. Read and parse frame length as per // 3. Read and parse frame length as per
@ -872,10 +901,6 @@ func (c *Conn) advanceFrame() (int, error) {
// 4. Handle frame masking. // 4. Handle frame masking.
if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}
if mask { if mask {
c.readMaskPos = 0 c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey)) p, err := c.read(len(c.readMaskKey))
@ -935,7 +960,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 { if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload)) closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) { if !isValidReceivedCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code") return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
} }
closeText = string(payload[2:]) closeText = string(payload[2:])
if !utf8.ValidString(closeText) { if !utf8.ValidString(closeText) {
@ -952,7 +977,11 @@ func (c *Conn) advanceFrame() (int, error) {
} }
func (c *Conn) handleProtocolError(message string) error { func (c *Conn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) data := FormatCloseMessage(CloseProtocolError, message)
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message) return errors.New("websocket: " + message)
} }
@ -1160,8 +1189,16 @@ func (c *Conn) SetPongHandler(h func(appData string) error) {
c.handlePong = h c.handlePong = h
} }
// NetConn returns the underlying connection that is wrapped by c.
// Note that writing to or reading from this connection directly will corrupt the
// WebSocket connection.
func (c *Conn) NetConn() net.Conn {
return c.conn
}
// UnderlyingConn returns the internal net.Conn. This can be used to further // UnderlyingConn returns the internal net.Conn. This can be used to further
// modifications to connection specific flags. // modifications to connection specific flags.
// Deprecated: Use the NetConn method.
func (c *Conn) UnderlyingConn() net.Conn { func (c *Conn) UnderlyingConn() net.Conn {
return c.conn return c.conn
} }

View File

@ -18,7 +18,6 @@ import (
// scenarios with many subscribers in one channel. // scenarios with many subscribers in one channel.
type broadcastBench struct { type broadcastBench struct {
w io.Writer w io.Writer
message *broadcastMessage
closeCh chan struct{} closeCh chan struct{}
doneCh chan struct{} doneCh chan struct{}
count int32 count int32
@ -52,14 +51,6 @@ func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
usePrepared: usePrepared, usePrepared: usePrepared,
compression: compression, compression: compression,
} }
msg := &broadcastMessage{
payload: textMessages(1)[0],
}
if usePrepared {
pm, _ := NewPreparedMessage(TextMessage, msg.payload)
msg.prepared = pm
}
bench.message = msg
bench.makeConns(10000) bench.makeConns(10000)
return bench return bench
} }
@ -78,7 +69,7 @@ func (b *broadcastBench) makeConns(numConns int) {
for { for {
select { select {
case msg := <-c.msgCh: case msg := <-c.msgCh:
if b.usePrepared { if msg.prepared != nil {
c.conn.WritePreparedMessage(msg.prepared) c.conn.WritePreparedMessage(msg.prepared)
} else { } else {
c.conn.WriteMessage(TextMessage, msg.payload) c.conn.WriteMessage(TextMessage, msg.payload)
@ -100,9 +91,9 @@ func (b *broadcastBench) close() {
close(b.closeCh) close(b.closeCh)
} }
func (b *broadcastBench) runOnce() { func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) {
for _, c := range b.conns { for _, c := range b.conns {
c.msgCh <- b.message c.msgCh <- msg
} }
<-b.doneCh <-b.doneCh
} }
@ -114,17 +105,25 @@ func BenchmarkBroadcast(b *testing.B) {
compression bool compression bool
}{ }{
{"NoCompression", false, false}, {"NoCompression", false, false},
{"WithCompression", false, true}, {"Compression", false, true},
{"NoCompressionPrepared", true, false}, {"NoCompressionPrepared", true, false},
{"WithCompressionPrepared", true, true}, {"CompressionPrepared", true, true},
} }
payload := textMessages(1)[0]
for _, bm := range benchmarks { for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) { b.Run(bm.name, func(b *testing.B) {
bench := newBroadcastBench(bm.usePrepared, bm.compression) bench := newBroadcastBench(bm.usePrepared, bm.compression)
defer bench.close() defer bench.close()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
bench.runOnce() message := &broadcastMessage{
payload: payload,
}
if bench.usePrepared {
pm, _ := NewPreparedMessage(TextMessage, message.payload)
message.prepared = pm
}
bench.broadcastOnce(message)
} }
b.ReportAllocs() b.ReportAllocs()
}) })

View File

@ -562,7 +562,7 @@ func TestAddrs(t *testing.T) {
} }
} }
func TestUnderlyingConn(t *testing.T) { func TestDeprecatedUnderlyingConn(t *testing.T) {
var b1, b2 bytes.Buffer var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2} fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024, nil, nil, nil) c := newConn(fc, true, 1024, 1024, nil, nil, nil)
@ -572,6 +572,16 @@ func TestUnderlyingConn(t *testing.T) {
} }
} }
func TestNetConn(t *testing.T) {
var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
ul := c.NetConn()
if ul != fc {
t.Fatalf("Underlying conn is not what it should be.")
}
}
func TestBufioReadBytes(t *testing.T) { func TestBufioReadBytes(t *testing.T) {
// Test calling bufio.ReadBytes for value longer than read buffer size. // Test calling bufio.ReadBytes for value longer than read buffer size.

View File

@ -1,15 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "net"
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}

View File

@ -1,18 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
func (c *Conn) writeBufs(bufs ...[]byte) error {
for _, buf := range bufs {
if len(buf) > 0 {
if _, err := c.conn.Write(buf); err != nil {
return err
}
}
}
return nil
}

View File

@ -8,6 +8,11 @@ To test the server, run
and start the client test driver and start the client test driver
wstest -m fuzzingclient -s fuzzingclient.json mkdir -p reports
docker run -it --rm \
-v ${PWD}/config:/config \
-v ${PWD}/reports:/reports \
crossbario/autobahn-testsuite \
wstest -m fuzzingclient -s /config/fuzzingclient.json
When the client completes, it writes a report to reports/clients/index.html. When the client completes, it writes a report to reports/index.html.

View File

@ -0,0 +1,29 @@
{
"cases": ["*"],
"exclude-cases": [],
"exclude-agent-cases": {},
"outdir": "/reports",
"options": {"failByDrop": false},
"servers": [
{
"agent": "ReadAllWriteMessage",
"url": "ws://host.docker.internal:9000/m"
},
{
"agent": "ReadAllWritePreparedMessage",
"url": "ws://host.docker.internal:9000/p"
},
{
"agent": "CopyFull",
"url": "ws://host.docker.internal:9000/f"
},
{
"agent": "ReadAllWrite",
"url": "ws://host.docker.internal:9000/r"
},
{
"agent": "CopyWriterOnly",
"url": "ws://host.docker.internal:9000/c"
}
]
}

View File

@ -1,15 +0,0 @@
{
"options": {"failByDrop": false},
"outdir": "./reports/clients",
"servers": [
{"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}},
{"agent": "ReadAllWritePreparedMessage", "url": "ws://localhost:9000/p", "options": {"version": 18}},
{"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}},
{"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}},
{"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}}
],
"cases": ["*"],
"exclude-cases": [],
"exclude-agent-cases": {}
}

View File

@ -160,7 +160,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found.", http.StatusNotFound) http.Error(w, "Not found.", http.StatusNotFound)
return return
} }
if r.Method != "GET" { if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return return
} }

View File

@ -18,7 +18,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found", http.StatusNotFound) http.Error(w, "Not found", http.StatusNotFound)
return return
} }
if r.Method != "GET" { if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return return
} }

View File

@ -170,7 +170,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found", http.StatusNotFound) http.Error(w, "Not found", http.StatusNotFound)
return return
} }
if r.Method != "GET" { if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return return
} }

View File

@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build ignore
// +build ignore // +build ignore
package main package main

View File

@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build ignore
// +build ignore // +build ignore
package main package main
@ -69,6 +70,7 @@ window.addEventListener("load", function(evt) {
var d = document.createElement("div"); var d = document.createElement("div");
d.textContent = message; d.textContent = message;
output.appendChild(d); output.appendChild(d);
output.scroll(0, output.scrollHeight);
}; };
document.getElementById("open").onclick = function(evt) { document.getElementById("open").onclick = function(evt) {
@ -126,7 +128,7 @@ You can change the message and send multiple times.
<button id="send">Send</button> <button id="send">Send</button>
</form> </form>
</td><td valign="top" width="50%"> </td><td valign="top" width="50%">
<div id="output"></div> <div id="output" style="max-height: 70vh;overflow-y: scroll;"></div>
</td></tr></table> </td></tr></table>
</body> </body>
</html> </html>

View File

@ -133,7 +133,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Not found", http.StatusNotFound) http.Error(w, "Not found", http.StatusNotFound)
return return
} }
if r.Method != "GET" { if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return return
} }

View File

@ -2,6 +2,7 @@
// this source code is governed by a BSD-style license that can be found in the // this source code is governed by a BSD-style license that can be found in the
// LICENSE file. // LICENSE file.
//go:build !appengine
// +build !appengine // +build !appengine
package websocket package websocket

View File

@ -2,6 +2,7 @@
// this source code is governed by a BSD-style license that can be found in the // this source code is governed by a BSD-style license that can be found in the
// LICENSE file. // LICENSE file.
//go:build appengine
// +build appengine // +build appengine
package websocket package websocket

View File

@ -60,7 +60,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
} }
connectReq := &http.Request{ connectReq := &http.Request{
Method: "CONNECT", Method: http.MethodConnect,
URL: &url.URL{Opaque: addr}, URL: &url.URL{Opaque: addr},
Host: addr, Host: addr,
Header: connectHeader, Header: connectHeader,

View File

@ -23,6 +23,8 @@ func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a // Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection. // WebSocket connection.
//
// It is safe to call Upgrader's methods concurrently.
type Upgrader struct { type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete. // HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration HandshakeTimeout time.Duration
@ -115,8 +117,8 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
// Upgrade upgrades the HTTP server connection to the WebSocket protocol. // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
// //
// The responseHeader is included in the response to the client's upgrade // The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the // request. Use the responseHeader to specify cookies (Set-Cookie). To specify
// application negotiated subprotocol (Sec-WebSocket-Protocol). // subprotocols supported by the server, set Upgrader.Subprotocols directly.
// //
// If the upgrade fails, then Upgrade replies to the client with an HTTP error // If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response. // response.
@ -131,7 +133,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
} }
if r.Method != "GET" { if r.Method != http.MethodGet {
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
} }
@ -152,8 +154,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
challengeKey := r.Header.Get("Sec-Websocket-Key") challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" { if !isValidChallengeKey(challengeKey) {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank") return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
} }
subprotocol := u.selectSubprotocol(r, responseHeader) subprotocol := u.selectSubprotocol(r, responseHeader)

View File

@ -98,7 +98,7 @@ func TestBufioReuse(t *testing.T) {
} }
upgrader := Upgrader{} upgrader := Upgrader{}
c, err := upgrader.Upgrade(resp, &http.Request{ c, err := upgrader.Upgrade(resp, &http.Request{
Method: "GET", Method: http.MethodGet,
Header: http.Header{ Header: http.Header{
"Upgrade": []string{"websocket"}, "Upgrade": []string{"websocket"},
"Connection": []string{"upgrade"}, "Connection": []string{"upgrade"},
@ -111,7 +111,7 @@ func TestBufioReuse(t *testing.T) {
if reuse := c.br == br; reuse != tt.reuse { if reuse := c.br == br; reuse != tt.reuse {
t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
} }
writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw) writeBuf := bufioWriterBuffer(c.NetConn(), bw)
if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
} }

21
tls_handshake.go Normal file
View File

@ -0,0 +1,21 @@
//go:build go1.17
// +build go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

21
tls_handshake_116.go Normal file
View File

@ -0,0 +1,21 @@
//go:build !go1.17
// +build !go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@ -1,19 +0,0 @@
// +build go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
if trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(tlsConn, cfg)
if trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
return err
}

View File

@ -1,12 +0,0 @@
// +build !go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
return doHandshake(tlsConn, cfg)
}

15
util.go
View File

@ -281,3 +281,18 @@ headers:
} }
return result return result
} }
// isValidChallengeKey checks if the argument meets RFC6455 specification.
func isValidChallengeKey(s string) bool {
// From RFC6455:
//
// A |Sec-WebSocket-Key| header field with a base64-encoded (see
// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
// length.
if s == "" {
return false
}
decoded, err := base64.StdEncoding.DecodeString(s)
return err == nil && len(decoded) == 16
}

View File

@ -53,6 +53,25 @@ func TestTokenListContainsValue(t *testing.T) {
} }
} }
var isValidChallengeKeyTests = []struct {
key string
ok bool
}{
{"dGhlIHNhbXBsZSBub25jZQ==", true},
{"", false},
{"InvalidKey", false},
{"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
}
func TestIsValidChallengeKey(t *testing.T) {
for _, tt := range isValidChallengeKeyTests {
ok := isValidChallengeKey(tt.key)
if ok != tt.ok {
t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
}
}
}
var parseExtensionTests = []struct { var parseExtensionTests = []struct {
value string value string
extensions []map[string]string extensions []map[string]string