mirror of https://github.com/gorilla/websocket.git
Merge branch 'main' into feature/context-takeover
This commit is contained in:
commit
4f8acdcb7f
|
@ -0,0 +1,70 @@
|
||||||
|
version: 2.1
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
"test":
|
||||||
|
parameters:
|
||||||
|
version:
|
||||||
|
type: string
|
||||||
|
default: "latest"
|
||||||
|
golint:
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
modules:
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
goproxy:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
|
docker:
|
||||||
|
- image: "cimg/go:<< parameters.version >>"
|
||||||
|
working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket
|
||||||
|
environment:
|
||||||
|
GO111MODULE: "on"
|
||||||
|
GOPROXY: "<< parameters.goproxy >>"
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: "Print the Go version"
|
||||||
|
command: >
|
||||||
|
go version
|
||||||
|
- run:
|
||||||
|
name: "Fetch dependencies"
|
||||||
|
command: >
|
||||||
|
if [[ << parameters.modules >> = true ]]; then
|
||||||
|
go mod download
|
||||||
|
export GO111MODULE=on
|
||||||
|
else
|
||||||
|
go get -v ./...
|
||||||
|
fi
|
||||||
|
# Only run gofmt, vet & lint against the latest Go version
|
||||||
|
- run:
|
||||||
|
name: "Run golint"
|
||||||
|
command: >
|
||||||
|
if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
|
||||||
|
go get -u golang.org/x/lint/golint
|
||||||
|
golint ./...
|
||||||
|
fi
|
||||||
|
- run:
|
||||||
|
name: "Run gofmt"
|
||||||
|
command: >
|
||||||
|
if [[ << parameters.version >> = "latest" ]]; then
|
||||||
|
diff -u <(echo -n) <(gofmt -d -e .)
|
||||||
|
fi
|
||||||
|
- run:
|
||||||
|
name: "Run go vet"
|
||||||
|
command: >
|
||||||
|
if [[ << parameters.version >> = "latest" ]]; then
|
||||||
|
go vet -v ./...
|
||||||
|
fi
|
||||||
|
- run:
|
||||||
|
name: "Run go test (+ race detector)"
|
||||||
|
command: >
|
||||||
|
go test -v -race ./...
|
||||||
|
|
||||||
|
workflows:
|
||||||
|
tests:
|
||||||
|
jobs:
|
||||||
|
- test:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
version: ["1.18", "1.17", "1.16"]
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Config for https://github.com/apps/release-drafter
|
||||||
|
template: |
|
||||||
|
|
||||||
|
<summary of changes here>
|
||||||
|
|
||||||
|
## CHANGELOG
|
||||||
|
$CHANGES
|
19
.travis.yml
19
.travis.yml
|
@ -1,19 +0,0 @@
|
||||||
language: go
|
|
||||||
sudo: false
|
|
||||||
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- go: 1.7.x
|
|
||||||
- go: 1.8.x
|
|
||||||
- go: 1.9.x
|
|
||||||
- go: 1.10.x
|
|
||||||
- go: 1.11.x
|
|
||||||
- go: tip
|
|
||||||
allow_failures:
|
|
||||||
- go: tip
|
|
||||||
|
|
||||||
script:
|
|
||||||
- go get -t -v ./...
|
|
||||||
- diff -u <(echo -n) <(gofmt -d .)
|
|
||||||
- go vet $(go list ./... | grep -v /vendor/)
|
|
||||||
- go test -v -race ./...
|
|
40
README.md
40
README.md
|
@ -1,18 +1,20 @@
|
||||||
# Gorilla WebSocket
|
# Gorilla WebSocket
|
||||||
|
|
||||||
|
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
|
||||||
|
[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket)
|
||||||
|
|
||||||
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
|
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
|
||||||
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
|
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
|
||||||
|
|
||||||
[![Build Status](https://travis-ci.org/gorilla/websocket.svg?branch=master)](https://travis-ci.org/gorilla/websocket)
|
|
||||||
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
|
|
||||||
|
|
||||||
### Documentation
|
### Documentation
|
||||||
|
|
||||||
* [API Reference](http://godoc.org/github.com/gorilla/websocket)
|
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
|
||||||
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
|
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
|
||||||
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
|
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
|
||||||
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
|
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
|
||||||
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
|
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
|
||||||
|
* [Write buffer pool example](https://github.com/gorilla/websocket/tree/master/examples/bufferpool)
|
||||||
|
|
||||||
### Status
|
### Status
|
||||||
|
|
||||||
|
@ -30,35 +32,3 @@ The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
||||||
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
||||||
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
|
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
|
||||||
|
|
||||||
### Gorilla WebSocket compared with other packages
|
|
||||||
|
|
||||||
<table>
|
|
||||||
<tr>
|
|
||||||
<th></th>
|
|
||||||
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
|
|
||||||
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
|
|
||||||
<tr><td>Passes <a href="http://autobahn.ws/testsuite/">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
|
|
||||||
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
|
|
||||||
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
|
|
||||||
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
|
|
||||||
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
|
|
||||||
<tr><td colspan="3">Other Features</tr></td>
|
|
||||||
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
|
|
||||||
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
|
|
||||||
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
|
|
||||||
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
|
|
||||||
2. The application can get the type of a received data message by implementing
|
|
||||||
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
|
|
||||||
function.
|
|
||||||
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
|
|
||||||
Read returns when the input buffer is full or a frame boundary is
|
|
||||||
encountered. Each call to Write sends a single frame message. The Gorilla
|
|
||||||
io.Reader and io.WriteCloser operate on a single WebSocket message.
|
|
||||||
|
|
||||||
|
|
95
client.go
95
client.go
|
@ -10,6 +10,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
|
@ -49,15 +50,23 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
|
||||||
}
|
}
|
||||||
|
|
||||||
// A Dialer contains options for connecting to WebSocket server.
|
// A Dialer contains options for connecting to WebSocket server.
|
||||||
|
//
|
||||||
|
// It is safe to call Dialer's methods concurrently.
|
||||||
type Dialer struct {
|
type Dialer struct {
|
||||||
// NetDial specifies the dial function for creating TCP connections. If
|
// NetDial specifies the dial function for creating TCP connections. If
|
||||||
// NetDial is nil, net.Dial is used.
|
// NetDial is nil, net.Dial is used.
|
||||||
NetDial func(network, addr string) (net.Conn, error)
|
NetDial func(network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
// NetDialContext specifies the dial function for creating TCP connections. If
|
// NetDialContext specifies the dial function for creating TCP connections. If
|
||||||
// NetDialContext is nil, net.DialContext is used.
|
// NetDialContext is nil, NetDial is used.
|
||||||
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
|
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
|
||||||
|
// NetDialTLSContext is nil, NetDialContext is used.
|
||||||
|
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
|
||||||
|
// TLSClientConfig is ignored.
|
||||||
|
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
// Proxy specifies a function to return a proxy for a given
|
// Proxy specifies a function to return a proxy for a given
|
||||||
// Request. If the function returns a non-nil error, the
|
// Request. If the function returns a non-nil error, the
|
||||||
// request is aborted with the provided error.
|
// request is aborted with the provided error.
|
||||||
|
@ -66,6 +75,8 @@ type Dialer struct {
|
||||||
|
|
||||||
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
|
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
|
||||||
// If nil, the default configuration is used.
|
// If nil, the default configuration is used.
|
||||||
|
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
|
||||||
|
// is done there and TLSClientConfig is ignored.
|
||||||
TLSClientConfig *tls.Config
|
TLSClientConfig *tls.Config
|
||||||
|
|
||||||
// HandshakeTimeout specifies the duration for the handshake to complete.
|
// HandshakeTimeout specifies the duration for the handshake to complete.
|
||||||
|
@ -180,7 +191,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &http.Request{
|
req := &http.Request{
|
||||||
Method: "GET",
|
Method: http.MethodGet,
|
||||||
URL: u,
|
URL: u,
|
||||||
Proto: "HTTP/1.1",
|
Proto: "HTTP/1.1",
|
||||||
ProtoMajor: 1,
|
ProtoMajor: 1,
|
||||||
|
@ -244,13 +255,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
// Get network dial function.
|
// Get network dial function.
|
||||||
var netDial func(network, add string) (net.Conn, error)
|
var netDial func(network, add string) (net.Conn, error)
|
||||||
|
|
||||||
if d.NetDialContext != nil {
|
switch u.Scheme {
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
case "http":
|
||||||
return d.NetDialContext(ctx, network, addr)
|
if d.NetDialContext != nil {
|
||||||
|
netDial = func(network, addr string) (net.Conn, error) {
|
||||||
|
return d.NetDialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
} else if d.NetDial != nil {
|
||||||
|
netDial = d.NetDial
|
||||||
}
|
}
|
||||||
} else if d.NetDial != nil {
|
case "https":
|
||||||
netDial = d.NetDial
|
if d.NetDialTLSContext != nil {
|
||||||
} else {
|
netDial = func(network, addr string) (net.Conn, error) {
|
||||||
|
return d.NetDialTLSContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
} else if d.NetDialContext != nil {
|
||||||
|
netDial = func(network, addr string) (net.Conn, error) {
|
||||||
|
return d.NetDialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
} else if d.NetDial != nil {
|
||||||
|
netDial = d.NetDial
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, nil, errMalformedURL
|
||||||
|
}
|
||||||
|
|
||||||
|
if netDial == nil {
|
||||||
netDialer := &net.Dialer{}
|
netDialer := &net.Dialer{}
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
netDial = func(network, addr string) (net.Conn, error) {
|
||||||
return netDialer.DialContext(ctx, network, addr)
|
return netDialer.DialContext(ctx, network, addr)
|
||||||
|
@ -296,14 +326,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
}
|
}
|
||||||
|
|
||||||
netConn, err := netDial("tcp", hostPort)
|
netConn, err := netDial("tcp", hostPort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
if trace != nil && trace.GotConn != nil {
|
if trace != nil && trace.GotConn != nil {
|
||||||
trace.GotConn(httptrace.GotConnInfo{
|
trace.GotConn(httptrace.GotConnInfo{
|
||||||
Conn: netConn,
|
Conn: netConn,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if netConn != nil {
|
if netConn != nil {
|
||||||
|
@ -311,7 +341,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if u.Scheme == "https" {
|
if u.Scheme == "https" && d.NetDialTLSContext == nil {
|
||||||
|
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
|
||||||
|
|
||||||
cfg := cloneTLSConfig(d.TLSClientConfig)
|
cfg := cloneTLSConfig(d.TLSClientConfig)
|
||||||
if cfg.ServerName == "" {
|
if cfg.ServerName == "" {
|
||||||
cfg.ServerName = hostNoPort
|
cfg.ServerName = hostNoPort
|
||||||
|
@ -319,11 +351,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
tlsConn := tls.Client(netConn, cfg)
|
tlsConn := tls.Client(netConn, cfg)
|
||||||
netConn = tlsConn
|
netConn = tlsConn
|
||||||
|
|
||||||
var err error
|
if trace != nil && trace.TLSHandshakeStart != nil {
|
||||||
if trace != nil {
|
trace.TLSHandshakeStart()
|
||||||
err = doHandshakeWithTrace(trace, tlsConn, cfg)
|
}
|
||||||
} else {
|
err := doHandshake(ctx, tlsConn, cfg)
|
||||||
err = doHandshake(tlsConn, cfg)
|
if trace != nil && trace.TLSHandshakeDone != nil {
|
||||||
|
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -345,6 +378,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
|
|
||||||
resp, err := http.ReadResponse(conn.br, req)
|
resp, err := http.ReadResponse(conn.br, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if d.TLSClientConfig != nil {
|
||||||
|
for _, proto := range d.TLSClientConfig.NextProtos {
|
||||||
|
if proto != "http/1.1" {
|
||||||
|
return nil, nil, fmt.Errorf(
|
||||||
|
"websocket: protocol %q was given but is not supported;"+
|
||||||
|
"sharing tls.Config with net/http Transport can cause this error: %w",
|
||||||
|
proto, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -355,8 +399,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 101 ||
|
if resp.StatusCode != 101 ||
|
||||||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
|
||||||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
|
||||||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
|
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
|
||||||
// Before closing the network connection on return from this
|
// Before closing the network connection on return from this
|
||||||
// function, slurp up some of the response to aid application
|
// function, slurp up some of the response to aid application
|
||||||
|
@ -400,14 +444,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
return conn, resp, nil
|
return conn, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
|
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if cfg == nil {
|
||||||
return err
|
return &tls.Config{}
|
||||||
}
|
}
|
||||||
if !cfg.InsecureSkipVerify {
|
return cfg.Clone()
|
||||||
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build go1.8
|
|
||||||
|
|
||||||
package websocket
|
|
||||||
|
|
||||||
import "crypto/tls"
|
|
||||||
|
|
||||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
|
||||||
if cfg == nil {
|
|
||||||
return &tls.Config{}
|
|
||||||
}
|
|
||||||
return cfg.Clone()
|
|
||||||
}
|
|
|
@ -1,38 +0,0 @@
|
||||||
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build !go1.8
|
|
||||||
|
|
||||||
package websocket
|
|
||||||
|
|
||||||
import "crypto/tls"
|
|
||||||
|
|
||||||
// cloneTLSConfig clones all public fields except the fields
|
|
||||||
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
|
|
||||||
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
|
|
||||||
// config in active use.
|
|
||||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
|
||||||
if cfg == nil {
|
|
||||||
return &tls.Config{}
|
|
||||||
}
|
|
||||||
return &tls.Config{
|
|
||||||
Rand: cfg.Rand,
|
|
||||||
Time: cfg.Time,
|
|
||||||
Certificates: cfg.Certificates,
|
|
||||||
NameToCertificate: cfg.NameToCertificate,
|
|
||||||
GetCertificate: cfg.GetCertificate,
|
|
||||||
RootCAs: cfg.RootCAs,
|
|
||||||
NextProtos: cfg.NextProtos,
|
|
||||||
ServerName: cfg.ServerName,
|
|
||||||
ClientAuth: cfg.ClientAuth,
|
|
||||||
ClientCAs: cfg.ClientCAs,
|
|
||||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
|
||||||
CipherSuites: cfg.CipherSuites,
|
|
||||||
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
|
||||||
ClientSessionCache: cfg.ClientSessionCache,
|
|
||||||
MinVersion: cfg.MinVersion,
|
|
||||||
MaxVersion: cfg.MaxVersion,
|
|
||||||
CurvePreferences: cfg.CurvePreferences,
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -216,7 +217,7 @@ func TestProxyDial(t *testing.T) {
|
||||||
// Capture the request Host header.
|
// Capture the request Host header.
|
||||||
s.Server.Config.Handler = http.HandlerFunc(
|
s.Server.Config.Handler = http.HandlerFunc(
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == "CONNECT" {
|
if r.Method == http.MethodConnect {
|
||||||
connect = true
|
connect = true
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
return
|
return
|
||||||
|
@ -256,7 +257,7 @@ func TestProxyAuthorizationDial(t *testing.T) {
|
||||||
func(w http.ResponseWriter, r *http.Request) {
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
proxyAuth := r.Header.Get("Proxy-Authorization")
|
proxyAuth := r.Header.Get("Proxy-Authorization")
|
||||||
expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
|
expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
|
||||||
if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
|
if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth {
|
||||||
connect = true
|
connect = true
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
return
|
return
|
||||||
|
@ -516,7 +517,7 @@ func TestBadMethod(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", s.URL, strings.NewReader(""))
|
req, err := http.NewRequest(http.MethodPost, s.URL, strings.NewReader(""))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewRequest returned error %v", err)
|
t.Fatalf("NewRequest returned error %v", err)
|
||||||
}
|
}
|
||||||
|
@ -534,6 +535,23 @@ func TestBadMethod(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialExtraTokensInRespHeaders(t *testing.T) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||||||
|
w.Header().Set("Upgrade", "foo, websocket")
|
||||||
|
w.Header().Set("Connection", "upgrade, keep-alive")
|
||||||
|
w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
|
||||||
|
w.WriteHeader(101)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial: %v", err)
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func TestHandshake(t *testing.T) {
|
func TestHandshake(t *testing.T) {
|
||||||
s := newServer(t, cstHandlerConfig{})
|
s := newServer(t, cstHandlerConfig{})
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
@ -640,7 +658,7 @@ func TestHost(t *testing.T) {
|
||||||
server *httptest.Server // server to use
|
server *httptest.Server // server to use
|
||||||
url string // host for request URI
|
url string // host for request URI
|
||||||
header string // optional request host header
|
header string // optional request host header
|
||||||
tls string // optiona host for tls ServerName
|
tls string // optional host for tls ServerName
|
||||||
wantAddr string // expected host for dial
|
wantAddr string // expected host for dial
|
||||||
wantHeader string // expected request header on server
|
wantHeader string // expected request header on server
|
||||||
insecureSkipVerify bool
|
insecureSkipVerify bool
|
||||||
|
@ -771,7 +789,7 @@ func TestHost(t *testing.T) {
|
||||||
Dial: dialer.NetDial,
|
Dial: dialer.NetDial,
|
||||||
TLSClientConfig: dialer.TLSClientConfig,
|
TLSClientConfig: dialer.TLSClientConfig,
|
||||||
}
|
}
|
||||||
req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
|
req, _ := http.NewRequest(http.MethodGet, httpProtos[tt.server]+tt.url+"/", nil)
|
||||||
if tt.header != "" {
|
if tt.header != "" {
|
||||||
req.Host = tt.header
|
req.Host = tt.header
|
||||||
}
|
}
|
||||||
|
@ -972,3 +990,215 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
|
||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
sendRecv(t, ws)
|
sendRecv(t, ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
|
||||||
|
func TestNetDialConnect(t *testing.T) {
|
||||||
|
|
||||||
|
upgrader := Upgrader{}
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if IsWebSocketUpgrade(r) {
|
||||||
|
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
c.Close()
|
||||||
|
} else {
|
||||||
|
w.Header().Set("X-Test-Host", r.Host)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
server := httptest.NewServer(handler)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tlsServer := httptest.NewTLSServer(handler)
|
||||||
|
defer tlsServer.Close()
|
||||||
|
|
||||||
|
testUrls := map[*httptest.Server]string{
|
||||||
|
server: "ws://" + server.Listener.Addr().String() + "/",
|
||||||
|
tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
|
||||||
|
}
|
||||||
|
|
||||||
|
cas := rootCAs(t, tlsServer)
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
RootCAs: cas,
|
||||||
|
ServerName: "example.com",
|
||||||
|
InsecureSkipVerify: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
server *httptest.Server // server to use
|
||||||
|
netDial func(network, addr string) (net.Conn, error)
|
||||||
|
netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
tlsClientConfig *tls.Config
|
||||||
|
}{
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "HTTP server, all NetDial* defined, shall use NetDialContext",
|
||||||
|
server: server,
|
||||||
|
netDial: func(network, addr string) (net.Conn, error) {
|
||||||
|
return nil, errors.New("NetDial should not be called")
|
||||||
|
},
|
||||||
|
netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return net.Dial(network, addr)
|
||||||
|
},
|
||||||
|
netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return nil, errors.New("NetDialTLSContext should not be called")
|
||||||
|
},
|
||||||
|
tlsClientConfig: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTP server, all NetDial* undefined",
|
||||||
|
server: server,
|
||||||
|
netDial: nil,
|
||||||
|
netDialContext: nil,
|
||||||
|
netDialTLSContext: nil,
|
||||||
|
tlsClientConfig: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTP server, NetDialContext undefined, shall fallback to NetDial",
|
||||||
|
server: server,
|
||||||
|
netDial: func(network, addr string) (net.Conn, error) {
|
||||||
|
return net.Dial(network, addr)
|
||||||
|
},
|
||||||
|
netDialContext: nil,
|
||||||
|
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return nil, errors.New("NetDialTLSContext should not be called")
|
||||||
|
},
|
||||||
|
tlsClientConfig: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext",
|
||||||
|
server: tlsServer,
|
||||||
|
netDial: func(network, addr string) (net.Conn, error) {
|
||||||
|
return nil, errors.New("NetDial should not be called")
|
||||||
|
},
|
||||||
|
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return nil, errors.New("NetDialContext should not be called")
|
||||||
|
},
|
||||||
|
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
netConn, err := net.Dial(network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConn := tls.Client(netConn, tlsConfig)
|
||||||
|
err = tlsConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tlsConn, nil
|
||||||
|
},
|
||||||
|
tlsClientConfig: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake",
|
||||||
|
server: tlsServer,
|
||||||
|
netDial: func(network, addr string) (net.Conn, error) {
|
||||||
|
return nil, errors.New("NetDial should not be called")
|
||||||
|
},
|
||||||
|
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return net.Dial(network, addr)
|
||||||
|
},
|
||||||
|
netDialTLSContext: nil,
|
||||||
|
tlsClientConfig: tlsConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake",
|
||||||
|
server: tlsServer,
|
||||||
|
netDial: func(network, addr string) (net.Conn, error) {
|
||||||
|
return net.Dial(network, addr)
|
||||||
|
},
|
||||||
|
netDialContext: nil,
|
||||||
|
netDialTLSContext: nil,
|
||||||
|
tlsClientConfig: tlsConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTPS server, all NetDial* undefined",
|
||||||
|
server: tlsServer,
|
||||||
|
netDial: nil,
|
||||||
|
netDialContext: nil,
|
||||||
|
netDialTLSContext: nil,
|
||||||
|
tlsClientConfig: tlsConfig,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake",
|
||||||
|
server: tlsServer,
|
||||||
|
netDial: func(network, addr string) (net.Conn, error) {
|
||||||
|
return nil, errors.New("NetDial should not be called")
|
||||||
|
},
|
||||||
|
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return nil, errors.New("NetDialContext should not be called")
|
||||||
|
},
|
||||||
|
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
netConn, err := net.Dial(network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConn := tls.Client(netConn, tlsConfig)
|
||||||
|
err = tlsConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tlsConn, nil
|
||||||
|
},
|
||||||
|
tlsClientConfig: &tls.Config{
|
||||||
|
RootCAs: nil,
|
||||||
|
ServerName: "badserver.com",
|
||||||
|
InsecureSkipVerify: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
dialer := Dialer{
|
||||||
|
NetDial: tc.netDial,
|
||||||
|
NetDialContext: tc.netDialContext,
|
||||||
|
NetDialTLSContext: tc.netDialTLSContext,
|
||||||
|
TLSClientConfig: tc.tlsClientConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test websocket dial
|
||||||
|
c, _, err := dialer.Dial(testUrls[tc.server], nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("FAILED %s, err: %s", tc.name, err.Error())
|
||||||
|
} else {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestNextProtos(t *testing.T) {
|
||||||
|
ts := httptest.NewUnstartedServer(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||||
|
)
|
||||||
|
ts.EnableHTTP2 = true
|
||||||
|
ts.StartTLS()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
d := Dialer{
|
||||||
|
TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := ts.Client().Get(ts.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get: %v", err)
|
||||||
|
}
|
||||||
|
r.Body.Close()
|
||||||
|
|
||||||
|
// Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
|
||||||
|
// after the Client.Get call from net/http above.
|
||||||
|
var containsHTTP2 bool = false
|
||||||
|
for _, proto := range d.TLSClientConfig.NextProtos {
|
||||||
|
if proto == "h2" {
|
||||||
|
containsHTTP2 = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !containsHTTP2 {
|
||||||
|
t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err = d.Dial(makeWsProto(ts.URL), nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Dial succeeded, expect fail ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
148
conn.go
148
conn.go
|
@ -13,6 +13,7 @@ import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
@ -246,8 +247,8 @@ type Conn struct {
|
||||||
subprotocol string
|
subprotocol string
|
||||||
|
|
||||||
// Write fields
|
// Write fields
|
||||||
mu chan bool // used as mutex to protect write to conn
|
mu chan struct{} // used as mutex to protect write to conn
|
||||||
writeBuf []byte // frame is constructed in this buffer.
|
writeBuf []byte // frame is constructed in this buffer.
|
||||||
writePool BufferPool
|
writePool BufferPool
|
||||||
writeBufSize int
|
writeBufSize int
|
||||||
writeDeadline time.Time
|
writeDeadline time.Time
|
||||||
|
@ -262,10 +263,12 @@ type Conn struct {
|
||||||
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
|
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
|
||||||
|
|
||||||
// Read fields
|
// Read fields
|
||||||
reader io.ReadCloser // the current reader returned to the application
|
reader io.ReadCloser // the current reader returned to the application
|
||||||
readErr error
|
readErr error
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
readRemaining int64 // bytes remaining in current frame.
|
// bytes remaining in current frame.
|
||||||
|
// set setReadRemaining to safely update this value and prevent overflow
|
||||||
|
readRemaining int64
|
||||||
readFinal bool // true the current message has more frames.
|
readFinal bool // true the current message has more frames.
|
||||||
readLength int64 // Message size.
|
readLength int64 // Message size.
|
||||||
readLimit int64 // Maximum message size.
|
readLimit int64 // Maximum message size.
|
||||||
|
@ -302,8 +305,8 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
|
||||||
writeBuf = make([]byte, writeBufferSize)
|
writeBuf = make([]byte, writeBufferSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
mu := make(chan bool, 1)
|
mu := make(chan struct{}, 1)
|
||||||
mu <- true
|
mu <- struct{}{}
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
isServer: isServer,
|
isServer: isServer,
|
||||||
br: br,
|
br: br,
|
||||||
|
@ -322,6 +325,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setReadRemaining tracks the number of bytes remaining on the connection. If n
|
||||||
|
// overflows, an ErrReadLimit is returned.
|
||||||
|
func (c *Conn) setReadRemaining(n int64) error {
|
||||||
|
if n < 0 {
|
||||||
|
return ErrReadLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
c.readRemaining = n
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Subprotocol returns the negotiated protocol for the connection.
|
// Subprotocol returns the negotiated protocol for the connection.
|
||||||
func (c *Conn) Subprotocol() string {
|
func (c *Conn) Subprotocol() string {
|
||||||
return c.subprotocol
|
return c.subprotocol
|
||||||
|
@ -366,7 +380,7 @@ func (c *Conn) read(n int) ([]byte, error) {
|
||||||
|
|
||||||
func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
|
func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
|
||||||
<-c.mu
|
<-c.mu
|
||||||
defer func() { c.mu <- true }()
|
defer func() { c.mu <- struct{}{} }()
|
||||||
|
|
||||||
c.writeErrMu.Lock()
|
c.writeErrMu.Lock()
|
||||||
err := c.writeErr
|
err := c.writeErr
|
||||||
|
@ -390,6 +404,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) writeBufs(bufs ...[]byte) error {
|
||||||
|
b := net.Buffers(bufs)
|
||||||
|
_, err := b.WriteTo(c.conn)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// WriteControl writes a control message with the given deadline. The allowed
|
// WriteControl writes a control message with the given deadline. The allowed
|
||||||
// message types are CloseMessage, PingMessage and PongMessage.
|
// message types are CloseMessage, PingMessage and PongMessage.
|
||||||
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
|
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
|
||||||
|
@ -418,7 +438,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
||||||
maskBytes(key, 0, buf[6:])
|
maskBytes(key, 0, buf[6:])
|
||||||
}
|
}
|
||||||
|
|
||||||
d := time.Hour * 1000
|
d := 1000 * time.Hour
|
||||||
if !deadline.IsZero() {
|
if !deadline.IsZero() {
|
||||||
d = deadline.Sub(time.Now())
|
d = deadline.Sub(time.Now())
|
||||||
if d < 0 {
|
if d < 0 {
|
||||||
|
@ -433,7 +453,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
return errWriteTimeout
|
return errWriteTimeout
|
||||||
}
|
}
|
||||||
defer func() { c.mu <- true }()
|
defer func() { c.mu <- struct{}{} }()
|
||||||
|
|
||||||
c.writeErrMu.Lock()
|
c.writeErrMu.Lock()
|
||||||
err := c.writeErr
|
err := c.writeErr
|
||||||
|
@ -710,10 +730,7 @@ func (w *messageWriter) Close() error {
|
||||||
if w.err != nil {
|
if w.err != nil {
|
||||||
return w.err
|
return w.err
|
||||||
}
|
}
|
||||||
if err := w.flushFrame(true, nil); err != nil {
|
return w.flushFrame(true, nil)
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// WritePreparedMessage writes prepared message into connection.
|
// WritePreparedMessage writes prepared message into connection.
|
||||||
|
@ -786,50 +803,82 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Read and parse first two bytes of frame header.
|
// 2. Read and parse first two bytes of frame header.
|
||||||
|
// To aid debugging, collect and report all errors in the first two bytes
|
||||||
|
// of the header.
|
||||||
|
|
||||||
|
var errors []string
|
||||||
|
|
||||||
p, err := c.read(2)
|
p, err := c.read(2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
|
|
||||||
final := p[0]&finalBit != 0
|
|
||||||
frameType := int(p[0] & 0xf)
|
frameType := int(p[0] & 0xf)
|
||||||
|
final := p[0]&finalBit != 0
|
||||||
|
rsv1 := p[0]&rsv1Bit != 0
|
||||||
|
rsv2 := p[0]&rsv2Bit != 0
|
||||||
|
rsv3 := p[0]&rsv3Bit != 0
|
||||||
mask := p[1]&maskBit != 0
|
mask := p[1]&maskBit != 0
|
||||||
c.readRemaining = int64(p[1] & 0x7f)
|
c.setReadRemaining(int64(p[1] & 0x7f))
|
||||||
|
|
||||||
c.readDecompress = false
|
c.readDecompress = false
|
||||||
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
|
if rsv1 {
|
||||||
c.readDecompress = true
|
if c.newDecompressionReader != nil {
|
||||||
p[0] &^= rsv1Bit
|
c.readDecompress = true
|
||||||
|
} else {
|
||||||
|
errors = append(errors, "RSV1 set")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
|
if rsv2 {
|
||||||
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
|
errors = append(errors, "RSV2 set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rsv3 {
|
||||||
|
errors = append(errors, "RSV3 set")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch frameType {
|
switch frameType {
|
||||||
case CloseMessage, PingMessage, PongMessage:
|
case CloseMessage, PingMessage, PongMessage:
|
||||||
if c.readRemaining > maxControlFramePayloadSize {
|
if c.readRemaining > maxControlFramePayloadSize {
|
||||||
return noFrame, c.handleProtocolError("control frame length > 125")
|
errors = append(errors, "len > 125 for control")
|
||||||
}
|
}
|
||||||
if !final {
|
if !final {
|
||||||
return noFrame, c.handleProtocolError("control frame not final")
|
errors = append(errors, "FIN not set on control")
|
||||||
}
|
}
|
||||||
case TextMessage, BinaryMessage:
|
case TextMessage, BinaryMessage:
|
||||||
if !c.readFinal {
|
if !c.readFinal {
|
||||||
return noFrame, c.handleProtocolError("message start before final message frame")
|
errors = append(errors, "data before FIN")
|
||||||
}
|
}
|
||||||
c.readFinal = final
|
c.readFinal = final
|
||||||
case continuationFrame:
|
case continuationFrame:
|
||||||
if c.readFinal {
|
if c.readFinal {
|
||||||
return noFrame, c.handleProtocolError("continuation after final message frame")
|
errors = append(errors, "continuation after FIN")
|
||||||
}
|
}
|
||||||
c.readFinal = final
|
c.readFinal = final
|
||||||
default:
|
default:
|
||||||
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
|
errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Read and parse frame length.
|
if mask != c.isServer {
|
||||||
|
errors = append(errors, "bad MASK")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Read and parse frame length as per
|
||||||
|
// https://tools.ietf.org/html/rfc6455#section-5.2
|
||||||
|
//
|
||||||
|
// The length of the "Payload data", in bytes: if 0-125, that is the payload
|
||||||
|
// length.
|
||||||
|
// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
|
||||||
|
// integer are the payload length.
|
||||||
|
// - If 127, the following 8 bytes interpreted as
|
||||||
|
// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
|
||||||
|
// payload length. Multibyte length quantities are expressed in network byte
|
||||||
|
// order.
|
||||||
|
|
||||||
switch c.readRemaining {
|
switch c.readRemaining {
|
||||||
case 126:
|
case 126:
|
||||||
|
@ -837,21 +886,23 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
c.readRemaining = int64(binary.BigEndian.Uint16(p))
|
|
||||||
|
if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
|
||||||
|
return noFrame, err
|
||||||
|
}
|
||||||
case 127:
|
case 127:
|
||||||
p, err := c.read(8)
|
p, err := c.read(8)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
c.readRemaining = int64(binary.BigEndian.Uint64(p))
|
|
||||||
|
if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
|
||||||
|
return noFrame, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Handle frame masking.
|
// 4. Handle frame masking.
|
||||||
|
|
||||||
if mask != c.isServer {
|
|
||||||
return noFrame, c.handleProtocolError("incorrect mask flag")
|
|
||||||
}
|
|
||||||
|
|
||||||
if mask {
|
if mask {
|
||||||
c.readMaskPos = 0
|
c.readMaskPos = 0
|
||||||
p, err := c.read(len(c.readMaskKey))
|
p, err := c.read(len(c.readMaskKey))
|
||||||
|
@ -866,6 +917,12 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
|
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
|
||||||
|
|
||||||
c.readLength += c.readRemaining
|
c.readLength += c.readRemaining
|
||||||
|
// Don't allow readLength to overflow in the presence of a large readRemaining
|
||||||
|
// counter.
|
||||||
|
if c.readLength < 0 {
|
||||||
|
return noFrame, ErrReadLimit
|
||||||
|
}
|
||||||
|
|
||||||
if c.readLimit > 0 && c.readLength > c.readLimit {
|
if c.readLimit > 0 && c.readLength > c.readLimit {
|
||||||
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
|
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
|
||||||
return noFrame, ErrReadLimit
|
return noFrame, ErrReadLimit
|
||||||
|
@ -879,7 +936,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
var payload []byte
|
var payload []byte
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
payload, err = c.read(int(c.readRemaining))
|
payload, err = c.read(int(c.readRemaining))
|
||||||
c.readRemaining = 0
|
c.setReadRemaining(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
|
@ -905,7 +962,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
if len(payload) >= 2 {
|
if len(payload) >= 2 {
|
||||||
closeCode = int(binary.BigEndian.Uint16(payload))
|
closeCode = int(binary.BigEndian.Uint16(payload))
|
||||||
if !isValidReceivedCloseCode(closeCode) {
|
if !isValidReceivedCloseCode(closeCode) {
|
||||||
return noFrame, c.handleProtocolError("invalid close code")
|
return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
|
||||||
}
|
}
|
||||||
closeText = string(payload[2:])
|
closeText = string(payload[2:])
|
||||||
if !utf8.ValidString(closeText) {
|
if !utf8.ValidString(closeText) {
|
||||||
|
@ -922,7 +979,11 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) handleProtocolError(message string) error {
|
func (c *Conn) handleProtocolError(message string) error {
|
||||||
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
|
data := FormatCloseMessage(CloseProtocolError, message)
|
||||||
|
if len(data) > maxControlFramePayloadSize {
|
||||||
|
data = data[:maxControlFramePayloadSize]
|
||||||
|
}
|
||||||
|
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
|
||||||
return errors.New("websocket: " + message)
|
return errors.New("websocket: " + message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -952,6 +1013,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
c.readErr = hideTempErr(err)
|
c.readErr = hideTempErr(err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if frameType == TextMessage || frameType == BinaryMessage {
|
if frameType == TextMessage || frameType == BinaryMessage {
|
||||||
c.messageReader = &messageReader{c}
|
c.messageReader = &messageReader{c}
|
||||||
c.reader = c.messageReader
|
c.reader = c.messageReader
|
||||||
|
@ -992,7 +1054,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
if c.isServer {
|
if c.isServer {
|
||||||
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
|
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
|
||||||
}
|
}
|
||||||
c.readRemaining -= int64(n)
|
rem := c.readRemaining
|
||||||
|
rem -= int64(n)
|
||||||
|
c.setReadRemaining(rem)
|
||||||
if c.readRemaining > 0 && c.readErr == io.EOF {
|
if c.readRemaining > 0 && c.readErr == io.EOF {
|
||||||
c.readErr = errUnexpectedEOF
|
c.readErr = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
@ -1127,8 +1191,16 @@ func (c *Conn) SetPongHandler(h func(appData string) error) {
|
||||||
c.handlePong = h
|
c.handlePong = h
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NetConn returns the underlying connection that is wrapped by c.
|
||||||
|
// Note that writing to or reading from this connection directly will corrupt the
|
||||||
|
// WebSocket connection.
|
||||||
|
func (c *Conn) NetConn() net.Conn {
|
||||||
|
return c.conn
|
||||||
|
}
|
||||||
|
|
||||||
// UnderlyingConn returns the internal net.Conn. This can be used to further
|
// UnderlyingConn returns the internal net.Conn. This can be used to further
|
||||||
// modifications to connection specific flags.
|
// modifications to connection specific flags.
|
||||||
|
// Deprecated: Use the NetConn method.
|
||||||
func (c *Conn) UnderlyingConn() net.Conn {
|
func (c *Conn) UnderlyingConn() net.Conn {
|
||||||
return c.conn
|
return c.conn
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
// scenarios with many subscribers in one channel.
|
// scenarios with many subscribers in one channel.
|
||||||
type broadcastBench struct {
|
type broadcastBench struct {
|
||||||
w io.Writer
|
w io.Writer
|
||||||
message *broadcastMessage
|
|
||||||
closeCh chan struct{}
|
closeCh chan struct{}
|
||||||
doneCh chan struct{}
|
doneCh chan struct{}
|
||||||
count int32
|
count int32
|
||||||
|
@ -52,14 +51,6 @@ func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
|
||||||
usePrepared: usePrepared,
|
usePrepared: usePrepared,
|
||||||
compression: compression,
|
compression: compression,
|
||||||
}
|
}
|
||||||
msg := &broadcastMessage{
|
|
||||||
payload: textMessages(1)[0],
|
|
||||||
}
|
|
||||||
if usePrepared {
|
|
||||||
pm, _ := NewPreparedMessage(TextMessage, msg.payload)
|
|
||||||
msg.prepared = pm
|
|
||||||
}
|
|
||||||
bench.message = msg
|
|
||||||
bench.makeConns(10000)
|
bench.makeConns(10000)
|
||||||
return bench
|
return bench
|
||||||
}
|
}
|
||||||
|
@ -78,7 +69,7 @@ func (b *broadcastBench) makeConns(numConns int) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-c.msgCh:
|
case msg := <-c.msgCh:
|
||||||
if b.usePrepared {
|
if msg.prepared != nil {
|
||||||
c.conn.WritePreparedMessage(msg.prepared)
|
c.conn.WritePreparedMessage(msg.prepared)
|
||||||
} else {
|
} else {
|
||||||
c.conn.WriteMessage(TextMessage, msg.payload)
|
c.conn.WriteMessage(TextMessage, msg.payload)
|
||||||
|
@ -100,9 +91,9 @@ func (b *broadcastBench) close() {
|
||||||
close(b.closeCh)
|
close(b.closeCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *broadcastBench) runOnce() {
|
func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) {
|
||||||
for _, c := range b.conns {
|
for _, c := range b.conns {
|
||||||
c.msgCh <- b.message
|
c.msgCh <- msg
|
||||||
}
|
}
|
||||||
<-b.doneCh
|
<-b.doneCh
|
||||||
}
|
}
|
||||||
|
@ -114,17 +105,25 @@ func BenchmarkBroadcast(b *testing.B) {
|
||||||
compression bool
|
compression bool
|
||||||
}{
|
}{
|
||||||
{"NoCompression", false, false},
|
{"NoCompression", false, false},
|
||||||
{"WithCompression", false, true},
|
{"Compression", false, true},
|
||||||
{"NoCompressionPrepared", true, false},
|
{"NoCompressionPrepared", true, false},
|
||||||
{"WithCompressionPrepared", true, true},
|
{"CompressionPrepared", true, true},
|
||||||
}
|
}
|
||||||
|
payload := textMessages(1)[0]
|
||||||
for _, bm := range benchmarks {
|
for _, bm := range benchmarks {
|
||||||
b.Run(bm.name, func(b *testing.B) {
|
b.Run(bm.name, func(b *testing.B) {
|
||||||
bench := newBroadcastBench(bm.usePrepared, bm.compression)
|
bench := newBroadcastBench(bm.usePrepared, bm.compression)
|
||||||
defer bench.close()
|
defer bench.close()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
bench.runOnce()
|
message := &broadcastMessage{
|
||||||
|
payload: payload,
|
||||||
|
}
|
||||||
|
if bench.usePrepared {
|
||||||
|
pm, _ := NewPreparedMessage(TextMessage, message.payload)
|
||||||
|
message.prepared = pm
|
||||||
|
}
|
||||||
|
bench.broadcastOnce(message)
|
||||||
}
|
}
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
})
|
})
|
||||||
|
|
129
conn_test.go
129
conn_test.go
|
@ -56,7 +56,10 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFraming(t *testing.T) {
|
func TestFraming(t *testing.T) {
|
||||||
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
|
frameSizes := []int{
|
||||||
|
0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
|
||||||
|
// 65536, 65537
|
||||||
|
}
|
||||||
var readChunkers = []struct {
|
var readChunkers = []struct {
|
||||||
name string
|
name string
|
||||||
f func(io.Reader) io.Reader
|
f func(io.Reader) io.Reader
|
||||||
|
@ -151,6 +154,8 @@ func TestFraming(t *testing.T) {
|
||||||
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
|
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Logf("frame size: %d", n)
|
||||||
rbuf, err := ioutil.ReadAll(r)
|
rbuf, err := ioutil.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
||||||
|
@ -328,7 +333,7 @@ func TestWriteBufferPoolSync(t *testing.T) {
|
||||||
// errorWriter is an io.Writer than returns an error on all writes.
|
// errorWriter is an io.Writer than returns an error on all writes.
|
||||||
type errorWriter struct{}
|
type errorWriter struct{}
|
||||||
|
|
||||||
func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("Error!") }
|
func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") }
|
||||||
|
|
||||||
// TestWriteBufferPoolError ensures that buffer is returned to pool after error
|
// TestWriteBufferPoolError ensures that buffer is returned to pool after error
|
||||||
// on write.
|
// on write.
|
||||||
|
@ -489,37 +494,93 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadLimit(t *testing.T) {
|
func TestReadLimit(t *testing.T) {
|
||||||
|
t.Run("Test ReadLimit is enforced", func(t *testing.T) {
|
||||||
|
const readLimit = 512
|
||||||
|
message := make([]byte, readLimit+1)
|
||||||
|
|
||||||
const readLimit = 512
|
var b1, b2 bytes.Buffer
|
||||||
message := make([]byte, readLimit+1)
|
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
|
||||||
|
rc := newTestConn(&b1, &b2, true)
|
||||||
|
rc.SetReadLimit(readLimit)
|
||||||
|
|
||||||
var b1, b2 bytes.Buffer
|
// Send message at the limit with interleaved pong.
|
||||||
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
rc := newTestConn(&b1, &b2, true)
|
w.Write(message[:readLimit-1])
|
||||||
rc.SetReadLimit(readLimit)
|
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
|
||||||
|
w.Write(message[:1])
|
||||||
|
w.Close()
|
||||||
|
|
||||||
// Send message at the limit with interleaved pong.
|
// Send message larger than the limit.
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
|
||||||
w.Write(message[:readLimit-1])
|
|
||||||
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
|
|
||||||
w.Write(message[:1])
|
|
||||||
w.Close()
|
|
||||||
|
|
||||||
// Send message larger than the limit.
|
op, _, err := rc.NextReader()
|
||||||
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
|
if op != BinaryMessage || err != nil {
|
||||||
|
t.Fatalf("1: NextReader() returned %d, %v", op, err)
|
||||||
|
}
|
||||||
|
op, r, err := rc.NextReader()
|
||||||
|
if op != BinaryMessage || err != nil {
|
||||||
|
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
||||||
|
}
|
||||||
|
_, err = io.Copy(ioutil.Discard, r)
|
||||||
|
if err != ErrReadLimit {
|
||||||
|
t.Fatalf("io.Copy() returned %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
op, _, err := rc.NextReader()
|
t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
const readLimit = 1
|
||||||
t.Fatalf("1: NextReader() returned %d, %v", op, err)
|
|
||||||
}
|
var b1, b2 bytes.Buffer
|
||||||
op, r, err := rc.NextReader()
|
rc := newTestConn(&b1, &b2, true)
|
||||||
if op != BinaryMessage || err != nil {
|
rc.SetReadLimit(readLimit)
|
||||||
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
|
||||||
}
|
// First, send a non-final binary message
|
||||||
_, err = io.Copy(ioutil.Discard, r)
|
b1.Write([]byte("\x02\x81"))
|
||||||
if err != ErrReadLimit {
|
|
||||||
t.Fatalf("io.Copy() returned %v", err)
|
// Mask key
|
||||||
}
|
b1.Write([]byte("\x00\x00\x00\x00"))
|
||||||
|
|
||||||
|
// First payload
|
||||||
|
b1.Write([]byte("A"))
|
||||||
|
|
||||||
|
// Next, send a negative-length, non-final continuation frame
|
||||||
|
b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
|
||||||
|
|
||||||
|
// Mask key
|
||||||
|
b1.Write([]byte("\x00\x00\x00\x00"))
|
||||||
|
|
||||||
|
// Next, send a too long, final continuation frame
|
||||||
|
b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
|
||||||
|
|
||||||
|
// Mask key
|
||||||
|
b1.Write([]byte("\x00\x00\x00\x00"))
|
||||||
|
|
||||||
|
// Too-long payload
|
||||||
|
b1.Write([]byte("BCDEF"))
|
||||||
|
|
||||||
|
op, r, err := rc.NextReader()
|
||||||
|
if op != BinaryMessage || err != nil {
|
||||||
|
t.Fatalf("1: NextReader() returned %d, %v", op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf [10]byte
|
||||||
|
var read int
|
||||||
|
n, err := r.Read(buf[:])
|
||||||
|
if err != nil && err != ErrReadLimit {
|
||||||
|
t.Fatalf("unexpected error testing read limit: %v", err)
|
||||||
|
}
|
||||||
|
read += n
|
||||||
|
|
||||||
|
n, err = r.Read(buf[:])
|
||||||
|
if err != nil && err != ErrReadLimit {
|
||||||
|
t.Fatalf("unexpected error testing read limit: %v", err)
|
||||||
|
}
|
||||||
|
read += n
|
||||||
|
|
||||||
|
if err == nil && read > readLimit {
|
||||||
|
t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddrs(t *testing.T) {
|
func TestAddrs(t *testing.T) {
|
||||||
|
@ -532,7 +593,7 @@ func TestAddrs(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnderlyingConn(t *testing.T) {
|
func TestDeprecatedUnderlyingConn(t *testing.T) {
|
||||||
var b1, b2 bytes.Buffer
|
var b1, b2 bytes.Buffer
|
||||||
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
||||||
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
||||||
|
@ -542,6 +603,16 @@ func TestUnderlyingConn(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNetConn(t *testing.T) {
|
||||||
|
var b1, b2 bytes.Buffer
|
||||||
|
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
||||||
|
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
||||||
|
ul := c.NetConn()
|
||||||
|
if ul != fc {
|
||||||
|
t.Fatalf("Underlying conn is not what it should be.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBufioReadBytes(t *testing.T) {
|
func TestBufioReadBytes(t *testing.T) {
|
||||||
// Test calling bufio.ReadBytes for value longer than read buffer size.
|
// Test calling bufio.ReadBytes for value longer than read buffer size.
|
||||||
|
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build go1.8
|
|
||||||
|
|
||||||
package websocket
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
func (c *Conn) writeBufs(bufs ...[]byte) error {
|
|
||||||
b := net.Buffers(bufs)
|
|
||||||
_, err := b.WriteTo(c.conn)
|
|
||||||
return err
|
|
||||||
}
|
|
|
@ -1,18 +0,0 @@
|
||||||
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build !go1.8
|
|
||||||
|
|
||||||
package websocket
|
|
||||||
|
|
||||||
func (c *Conn) writeBufs(bufs ...[]byte) error {
|
|
||||||
for _, buf := range bufs {
|
|
||||||
if len(buf) > 0 {
|
|
||||||
if _, err := c.conn.Write(buf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
47
doc.go
47
doc.go
|
@ -151,6 +151,53 @@
|
||||||
// checking. The application is responsible for checking the Origin header
|
// checking. The application is responsible for checking the Origin header
|
||||||
// before calling the Upgrade function.
|
// before calling the Upgrade function.
|
||||||
//
|
//
|
||||||
|
// Buffers
|
||||||
|
//
|
||||||
|
// Connections buffer network input and output to reduce the number
|
||||||
|
// of system calls when reading or writing messages.
|
||||||
|
//
|
||||||
|
// Write buffers are also used for constructing WebSocket frames. See RFC 6455,
|
||||||
|
// Section 5 for a discussion of message framing. A WebSocket frame header is
|
||||||
|
// written to the network each time a write buffer is flushed to the network.
|
||||||
|
// Decreasing the size of the write buffer can increase the amount of framing
|
||||||
|
// overhead on the connection.
|
||||||
|
//
|
||||||
|
// The buffer sizes in bytes are specified by the ReadBufferSize and
|
||||||
|
// WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default
|
||||||
|
// size of 4096 when a buffer size field is set to zero. The Upgrader reuses
|
||||||
|
// buffers created by the HTTP server when a buffer size field is set to zero.
|
||||||
|
// The HTTP server buffers have a size of 4096 at the time of this writing.
|
||||||
|
//
|
||||||
|
// The buffer sizes do not limit the size of a message that can be read or
|
||||||
|
// written by a connection.
|
||||||
|
//
|
||||||
|
// Buffers are held for the lifetime of the connection by default. If the
|
||||||
|
// Dialer or Upgrader WriteBufferPool field is set, then a connection holds the
|
||||||
|
// write buffer only when writing a message.
|
||||||
|
//
|
||||||
|
// Applications should tune the buffer sizes to balance memory use and
|
||||||
|
// performance. Increasing the buffer size uses more memory, but can reduce the
|
||||||
|
// number of system calls to read or write the network. In the case of writing,
|
||||||
|
// increasing the buffer size can reduce the number of frame headers written to
|
||||||
|
// the network.
|
||||||
|
//
|
||||||
|
// Some guidelines for setting buffer parameters are:
|
||||||
|
//
|
||||||
|
// Limit the buffer sizes to the maximum expected message size. Buffers larger
|
||||||
|
// than the largest message do not provide any benefit.
|
||||||
|
//
|
||||||
|
// Depending on the distribution of message sizes, setting the buffer size to
|
||||||
|
// a value less than the maximum expected message size can greatly reduce memory
|
||||||
|
// use with a small impact on performance. Here's an example: If 99% of the
|
||||||
|
// messages are smaller than 256 bytes and the maximum message size is 512
|
||||||
|
// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls
|
||||||
|
// than a buffer size of 512 bytes. The memory savings is 50%.
|
||||||
|
//
|
||||||
|
// A write buffer pool is useful when the application has a modest number
|
||||||
|
// writes over a large number of connections. when buffers are pooled, a larger
|
||||||
|
// buffer size has a reduced impact on total memory use and has the benefit of
|
||||||
|
// reducing system calls and frame overhead.
|
||||||
|
//
|
||||||
// Compression EXPERIMENTAL
|
// Compression EXPERIMENTAL
|
||||||
//
|
//
|
||||||
// Per message compression extensions (RFC 7692) are experimentally supported
|
// Per message compression extensions (RFC 7692) are experimentally supported
|
||||||
|
|
|
@ -23,10 +23,9 @@ var (
|
||||||
// This server application works with a client application running in the
|
// This server application works with a client application running in the
|
||||||
// browser. The client application does not explicitly close the websocket. The
|
// browser. The client application does not explicitly close the websocket. The
|
||||||
// only expected close message from the client has the code
|
// only expected close message from the client has the code
|
||||||
// websocket.CloseGoingAway. All other other close messages are likely the
|
// websocket.CloseGoingAway. All other close messages are likely the
|
||||||
// result of an application or protocol error and are logged to aid debugging.
|
// result of an application or protocol error and are logged to aid debugging.
|
||||||
func ExampleIsUnexpectedCloseError() {
|
func ExampleIsUnexpectedCloseError() {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
messageType, p, err := c.ReadMessage()
|
messageType, p, err := c.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -35,11 +34,11 @@ func ExampleIsUnexpectedCloseError() {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
processMesage(messageType, p)
|
processMessage(messageType, p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func processMesage(mt int, p []byte) {}
|
func processMessage(mt int, p []byte) {}
|
||||||
|
|
||||||
// TestX prevents godoc from showing this entire file in the example. Remove
|
// TestX prevents godoc from showing this entire file in the example. Remove
|
||||||
// this function when a second example is added.
|
// this function when a second example is added.
|
||||||
|
|
|
@ -8,6 +8,11 @@ To test the server, run
|
||||||
|
|
||||||
and start the client test driver
|
and start the client test driver
|
||||||
|
|
||||||
wstest -m fuzzingclient -s fuzzingclient.json
|
mkdir -p reports
|
||||||
|
docker run -it --rm \
|
||||||
|
-v ${PWD}/config:/config \
|
||||||
|
-v ${PWD}/reports:/reports \
|
||||||
|
crossbario/autobahn-testsuite \
|
||||||
|
wstest -m fuzzingclient -s /config/fuzzingclient.json
|
||||||
|
|
||||||
When the client completes, it writes a report to reports/clients/index.html.
|
When the client completes, it writes a report to reports/index.html.
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
{
|
||||||
|
"cases": ["*"],
|
||||||
|
"exclude-cases": [],
|
||||||
|
"exclude-agent-cases": {},
|
||||||
|
"outdir": "/reports",
|
||||||
|
"options": {"failByDrop": false},
|
||||||
|
"servers": [
|
||||||
|
{
|
||||||
|
"agent": "ReadAllWriteMessage",
|
||||||
|
"url": "ws://host.docker.internal:9000/m"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"agent": "ReadAllWritePreparedMessage",
|
||||||
|
"url": "ws://host.docker.internal:9000/p"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"agent": "CopyFull",
|
||||||
|
"url": "ws://host.docker.internal:9000/f"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"agent": "ReadAllWrite",
|
||||||
|
"url": "ws://host.docker.internal:9000/r"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"agent": "CopyWriterOnly",
|
||||||
|
"url": "ws://host.docker.internal:9000/c"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
|
@ -1,15 +0,0 @@
|
||||||
|
|
||||||
{
|
|
||||||
"options": {"failByDrop": false},
|
|
||||||
"outdir": "./reports/clients",
|
|
||||||
"servers": [
|
|
||||||
{"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}},
|
|
||||||
{"agent": "ReadAllWritePreparedMessage", "url": "ws://localhost:9000/p", "options": {"version": 18}},
|
|
||||||
{"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}},
|
|
||||||
{"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}},
|
|
||||||
{"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}}
|
|
||||||
],
|
|
||||||
"cases": ["*"],
|
|
||||||
"exclude-cases": [],
|
|
||||||
"exclude-agent-cases": {}
|
|
||||||
}
|
|
|
@ -160,7 +160,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(w, "Not found.", http.StatusNotFound)
|
http.Error(w, "Not found.", http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != "GET" {
|
if r.Method != http.MethodGet {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
//go:build ignore
|
||||||
|
// +build ignore
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var addr = flag.String("addr", "localhost:8080", "http service address")
|
||||||
|
|
||||||
|
func runNewConn(wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
interrupt := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(interrupt, os.Interrupt)
|
||||||
|
|
||||||
|
u := url.URL{Scheme: "ws", Host: *addr, Path: "/ws"}
|
||||||
|
log.Printf("connecting to %s", u.String())
|
||||||
|
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("dial:", err)
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for {
|
||||||
|
_, message, err := c.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
log.Println("read:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("recv: %s", message)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(time.Minute * 5)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
case t := <-ticker.C:
|
||||||
|
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("write:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-interrupt:
|
||||||
|
log.Println("interrupt")
|
||||||
|
|
||||||
|
// Cleanly close the connection by sending a close message and then
|
||||||
|
// waiting (with timeout) for the server to close the connection.
|
||||||
|
err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("write close:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
log.SetFlags(0)
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go runNewConn(wg)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
//go:build ignore
|
||||||
|
// +build ignore
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
_ "net/http/pprof"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var addr = flag.String("addr", "localhost:8080", "http service address")
|
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 256,
|
||||||
|
WriteBufferSize: 256,
|
||||||
|
WriteBufferPool: &sync.Pool{},
|
||||||
|
}
|
||||||
|
|
||||||
|
func process(c *websocket.Conn) {
|
||||||
|
defer c.Close()
|
||||||
|
for {
|
||||||
|
_, message, err := c.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
log.Println("read:", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Printf("recv: %s", message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
c, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Print("upgrade:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process connection in a new goroutine
|
||||||
|
go process(c)
|
||||||
|
|
||||||
|
// Let the http handler return, the 8k buffer created by it will be garbage collected
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
log.SetFlags(0)
|
||||||
|
http.HandleFunc("/ws", handler)
|
||||||
|
log.Fatal(http.ListenAndServe(*addr, nil))
|
||||||
|
}
|
|
@ -92,7 +92,7 @@ body {
|
||||||
<div id="log"></div>
|
<div id="log"></div>
|
||||||
<form id="form">
|
<form id="form">
|
||||||
<input type="submit" value="Send" />
|
<input type="submit" value="Send" />
|
||||||
<input type="text" id="msg" size="64"/>
|
<input type="text" id="msg" size="64" autofocus />
|
||||||
</form>
|
</form>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|
|
@ -18,7 +18,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(w, "Not found", http.StatusNotFound)
|
http.Error(w, "Not found", http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != "GET" {
|
if r.Method != http.MethodGet {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,7 +170,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(w, "Not found", http.StatusNotFound)
|
http.Error(w, "Not found", http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != "GET" {
|
if r.Method != http.MethodGet {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
//go:build ignore
|
||||||
// +build ignore
|
// +build ignore
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
//go:build ignore
|
||||||
// +build ignore
|
// +build ignore
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
@ -67,8 +68,9 @@ window.addEventListener("load", function(evt) {
|
||||||
|
|
||||||
var print = function(message) {
|
var print = function(message) {
|
||||||
var d = document.createElement("div");
|
var d = document.createElement("div");
|
||||||
d.innerHTML = message;
|
d.textContent = message;
|
||||||
output.appendChild(d);
|
output.appendChild(d);
|
||||||
|
output.scroll(0, output.scrollHeight);
|
||||||
};
|
};
|
||||||
|
|
||||||
document.getElementById("open").onclick = function(evt) {
|
document.getElementById("open").onclick = function(evt) {
|
||||||
|
@ -126,7 +128,7 @@ You can change the message and send multiple times.
|
||||||
<button id="send">Send</button>
|
<button id="send">Send</button>
|
||||||
</form>
|
</form>
|
||||||
</td><td valign="top" width="50%">
|
</td><td valign="top" width="50%">
|
||||||
<div id="output"></div>
|
<div id="output" style="max-height: 70vh;overflow-y: scroll;"></div>
|
||||||
</td></tr></table>
|
</td></tr></table>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|
|
@ -133,7 +133,7 @@ func serveHome(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(w, "Not found", http.StatusNotFound)
|
http.Error(w, "Not found", http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Method != "GET" {
|
if r.Method != http.MethodGet {
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JoinMessages concatenates received messages to create a single io.Reader.
|
||||||
|
// The string term is appended to each message. The returned reader does not
|
||||||
|
// support concurrent calls to the Read method.
|
||||||
|
func JoinMessages(c *Conn, term string) io.Reader {
|
||||||
|
return &joinReader{c: c, term: term}
|
||||||
|
}
|
||||||
|
|
||||||
|
type joinReader struct {
|
||||||
|
c *Conn
|
||||||
|
term string
|
||||||
|
r io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *joinReader) Read(p []byte) (int, error) {
|
||||||
|
if r.r == nil {
|
||||||
|
var err error
|
||||||
|
_, r.r, err = r.c.NextReader()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if r.term != "" {
|
||||||
|
r.r = io.MultiReader(r.r, strings.NewReader(r.term))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n, err := r.r.Read(p)
|
||||||
|
if err == io.EOF {
|
||||||
|
err = nil
|
||||||
|
r.r = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
|
@ -0,0 +1,36 @@
|
||||||
|
// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJoinMessages(t *testing.T) {
|
||||||
|
messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"}
|
||||||
|
for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} {
|
||||||
|
for _, term := range []string{"", ","} {
|
||||||
|
var connBuf bytes.Buffer
|
||||||
|
wc := newTestConn(nil, &connBuf, true)
|
||||||
|
rc := newTestConn(&connBuf, nil, false)
|
||||||
|
for _, m := range messages {
|
||||||
|
wc.WriteMessage(BinaryMessage, []byte(m))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result bytes.Buffer
|
||||||
|
_, err := io.CopyBuffer(&result, JoinMessages(rc, term), make([]byte, readChunk))
|
||||||
|
if IsUnexpectedCloseError(err, CloseAbnormalClosure) {
|
||||||
|
t.Errorf("readChunk=%d, term=%q: unexpected error %v", readChunk, term, err)
|
||||||
|
}
|
||||||
|
want := strings.Join(messages, term) + term
|
||||||
|
if result.String() != want {
|
||||||
|
t.Errorf("readChunk=%d, term=%q, got %q, want %q", readChunk, term, result.String(), want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
1
mask.go
1
mask.go
|
@ -2,6 +2,7 @@
|
||||||
// this source code is governed by a BSD-style license that can be found in the
|
// this source code is governed by a BSD-style license that can be found in the
|
||||||
// LICENSE file.
|
// LICENSE file.
|
||||||
|
|
||||||
|
//go:build !appengine
|
||||||
// +build !appengine
|
// +build !appengine
|
||||||
|
|
||||||
package websocket
|
package websocket
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
// this source code is governed by a BSD-style license that can be found in the
|
// this source code is governed by a BSD-style license that can be found in the
|
||||||
// LICENSE file.
|
// LICENSE file.
|
||||||
|
|
||||||
|
//go:build appengine
|
||||||
// +build appengine
|
// +build appengine
|
||||||
|
|
||||||
package websocket
|
package websocket
|
||||||
|
|
|
@ -73,8 +73,8 @@ func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
|
||||||
// Prepare a frame using a 'fake' connection.
|
// Prepare a frame using a 'fake' connection.
|
||||||
// TODO: Refactor code in conn.go to allow more direct construction of
|
// TODO: Refactor code in conn.go to allow more direct construction of
|
||||||
// the frame.
|
// the frame.
|
||||||
mu := make(chan bool, 1)
|
mu := make(chan struct{}, 1)
|
||||||
mu <- true
|
mu <- struct{}{}
|
||||||
var nc prepareConn
|
var nc prepareConn
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
conn: &nc,
|
conn: &nc,
|
||||||
|
|
10
proxy.go
10
proxy.go
|
@ -22,18 +22,18 @@ func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
||||||
return &httpProxyDialer{proxyURL: proxyURL, fowardDial: forwardDialer.Dial}, nil
|
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpProxyDialer struct {
|
type httpProxyDialer struct {
|
||||||
proxyURL *url.URL
|
proxyURL *url.URL
|
||||||
fowardDial func(network, addr string) (net.Conn, error)
|
forwardDial func(network, addr string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
|
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
|
||||||
hostPort, _ := hostPortNoPort(hpd.proxyURL)
|
hostPort, _ := hostPortNoPort(hpd.proxyURL)
|
||||||
conn, err := hpd.fowardDial(network, hostPort)
|
conn, err := hpd.forwardDial(network, hostPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -48,7 +48,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
connectReq := &http.Request{
|
connectReq := &http.Request{
|
||||||
Method: "CONNECT",
|
Method: http.MethodConnect,
|
||||||
URL: &url.URL{Opaque: addr},
|
URL: &url.URL{Opaque: addr},
|
||||||
Host: addr,
|
Host: addr,
|
||||||
Header: connectHeader,
|
Header: connectHeader,
|
||||||
|
|
12
server.go
12
server.go
|
@ -24,6 +24,8 @@ func (e HandshakeError) Error() string { return e.message }
|
||||||
|
|
||||||
// Upgrader specifies parameters for upgrading an HTTP connection to a
|
// Upgrader specifies parameters for upgrading an HTTP connection to a
|
||||||
// WebSocket connection.
|
// WebSocket connection.
|
||||||
|
//
|
||||||
|
// It is safe to call Upgrader's methods concurrently.
|
||||||
type Upgrader struct {
|
type Upgrader struct {
|
||||||
// HandshakeTimeout specifies the duration for the handshake to complete.
|
// HandshakeTimeout specifies the duration for the handshake to complete.
|
||||||
HandshakeTimeout time.Duration
|
HandshakeTimeout time.Duration
|
||||||
|
@ -119,8 +121,8 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
|
||||||
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
|
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
|
||||||
//
|
//
|
||||||
// The responseHeader is included in the response to the client's upgrade
|
// The responseHeader is included in the response to the client's upgrade
|
||||||
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
|
// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
|
||||||
// application negotiated subprotocol (Sec-WebSocket-Protocol).
|
// subprotocols supported by the server, set Upgrader.Subprotocols directly.
|
||||||
//
|
//
|
||||||
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
|
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
|
||||||
// response.
|
// response.
|
||||||
|
@ -135,7 +137,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
|
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method != "GET" {
|
if r.Method != http.MethodGet {
|
||||||
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
|
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,8 +158,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
}
|
}
|
||||||
|
|
||||||
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||||||
if challengeKey == "" {
|
if !isValidChallengeKey(challengeKey) {
|
||||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank")
|
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
|
||||||
}
|
}
|
||||||
|
|
||||||
subprotocol := u.selectSubprotocol(r, responseHeader)
|
subprotocol := u.selectSubprotocol(r, responseHeader)
|
||||||
|
|
|
@ -98,7 +98,7 @@ func TestBufioReuse(t *testing.T) {
|
||||||
}
|
}
|
||||||
upgrader := Upgrader{}
|
upgrader := Upgrader{}
|
||||||
c, err := upgrader.Upgrade(resp, &http.Request{
|
c, err := upgrader.Upgrade(resp, &http.Request{
|
||||||
Method: "GET",
|
Method: http.MethodGet,
|
||||||
Header: http.Header{
|
Header: http.Header{
|
||||||
"Upgrade": []string{"websocket"},
|
"Upgrade": []string{"websocket"},
|
||||||
"Connection": []string{"upgrade"},
|
"Connection": []string{"upgrade"},
|
||||||
|
@ -111,7 +111,7 @@ func TestBufioReuse(t *testing.T) {
|
||||||
if reuse := c.br == br; reuse != tt.reuse {
|
if reuse := c.br == br; reuse != tt.reuse {
|
||||||
t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
|
t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
|
||||||
}
|
}
|
||||||
writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw)
|
writeBuf := bufioWriterBuffer(c.NetConn(), bw)
|
||||||
if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
|
if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
|
||||||
t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
|
t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
//go:build go1.17
|
||||||
|
// +build go1.17
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
)
|
||||||
|
|
||||||
|
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
|
||||||
|
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !cfg.InsecureSkipVerify {
|
||||||
|
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
//go:build !go1.17
|
||||||
|
// +build !go1.17
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
)
|
||||||
|
|
||||||
|
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !cfg.InsecureSkipVerify {
|
||||||
|
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
19
trace.go
19
trace.go
|
@ -1,19 +0,0 @@
|
||||||
// +build go1.8
|
|
||||||
|
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"net/http/httptrace"
|
|
||||||
)
|
|
||||||
|
|
||||||
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
|
|
||||||
if trace.TLSHandshakeStart != nil {
|
|
||||||
trace.TLSHandshakeStart()
|
|
||||||
}
|
|
||||||
err := doHandshake(tlsConn, cfg)
|
|
||||||
if trace.TLSHandshakeDone != nil {
|
|
||||||
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
12
trace_17.go
12
trace_17.go
|
@ -1,12 +0,0 @@
|
||||||
// +build !go1.8
|
|
||||||
|
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"net/http/httptrace"
|
|
||||||
)
|
|
||||||
|
|
||||||
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
|
|
||||||
return doHandshake(tlsConn, cfg)
|
|
||||||
}
|
|
15
util.go
15
util.go
|
@ -281,3 +281,18 @@ headers:
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isValidChallengeKey checks if the argument meets RFC6455 specification.
|
||||||
|
func isValidChallengeKey(s string) bool {
|
||||||
|
// From RFC6455:
|
||||||
|
//
|
||||||
|
// A |Sec-WebSocket-Key| header field with a base64-encoded (see
|
||||||
|
// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
|
||||||
|
// length.
|
||||||
|
|
||||||
|
if s == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(s)
|
||||||
|
return err == nil && len(decoded) == 16
|
||||||
|
}
|
||||||
|
|
19
util_test.go
19
util_test.go
|
@ -53,6 +53,25 @@ func TestTokenListContainsValue(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var isValidChallengeKeyTests = []struct {
|
||||||
|
key string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{"dGhlIHNhbXBsZSBub25jZQ==", true},
|
||||||
|
{"", false},
|
||||||
|
{"InvalidKey", false},
|
||||||
|
{"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidChallengeKey(t *testing.T) {
|
||||||
|
for _, tt := range isValidChallengeKeyTests {
|
||||||
|
ok := isValidChallengeKey(tt.key)
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var parseExtensionTests = []struct {
|
var parseExtensionTests = []struct {
|
||||||
value string
|
value string
|
||||||
extensions []map[string]string
|
extensions []map[string]string
|
||||||
|
|
Loading…
Reference in New Issue