Compare commits

..

1 Commits

Author SHA1 Message Date
andig 1532d57c38
Merge c769d0214c into 9ec25ca502 2024-06-12 17:45:45 +08:00
40 changed files with 674 additions and 215 deletions

View File

@ -1,70 +0,0 @@
version: 2.1
jobs:
"test":
parameters:
version:
type: string
default: "latest"
golint:
type: boolean
default: true
modules:
type: boolean
default: true
goproxy:
type: string
default: ""
docker:
- image: "cimg/go:<< parameters.version >>"
working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket
environment:
GO111MODULE: "on"
GOPROXY: "<< parameters.goproxy >>"
steps:
- checkout
- run:
name: "Print the Go version"
command: >
go version
- run:
name: "Fetch dependencies"
command: >
if [[ << parameters.modules >> = true ]]; then
go mod download
export GO111MODULE=on
else
go get -v ./...
fi
# Only run gofmt, vet & lint against the latest Go version
- run:
name: "Run golint"
command: >
if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
go get -u golang.org/x/lint/golint
golint ./...
fi
- run:
name: "Run gofmt"
command: >
if [[ << parameters.version >> = "latest" ]]; then
diff -u <(echo -n) <(gofmt -d -e .)
fi
- run:
name: "Run go vet"
command: >
if [[ << parameters.version >> = "latest" ]]; then
go vet -v ./...
fi
- run:
name: "Run go test (+ race detector)"
command: >
go test -v -race ./...
workflows:
tests:
jobs:
- test:
matrix:
parameters:
version: ["1.18", "1.17", "1.16"]

20
.editorconfig Normal file
View File

@ -0,0 +1,20 @@
; https://editorconfig.org/
root = true
[*]
insert_final_newline = true
charset = utf-8
trim_trailing_whitespace = true
indent_style = space
indent_size = 2
[{Makefile,go.mod,go.sum,*.go,.gitmodules}]
indent_style = tab
indent_size = 4
[*.md]
indent_size = 4
trim_trailing_whitespace = false
eclint_indent_style = unset

View File

@ -1,7 +0,0 @@
# Config for https://github.com/apps/release-drafter
template: |
<summary of changes here>
## CHANGELOG
$CHANGES

21
.github/workflows/issues.yml vendored Normal file
View File

@ -0,0 +1,21 @@
# Add all the issues created to the project.
name: Add issue or pull request to Project
on:
issues:
types:
- opened
pull_request_target:
types:
- opened
- reopened
jobs:
add-to-project:
runs-on: ubuntu-latest
steps:
- name: Add issue to project
uses: actions/add-to-project@v0.5.0
with:
project-url: https://github.com/orgs/gorilla/projects/4
github-token: ${{ secrets.ADD_TO_PROJECT_TOKEN }}

37
.github/workflows/security.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Security
on:
push:
branches:
- main
pull_request:
branches:
- main
permissions:
contents: read
jobs:
scan:
strategy:
matrix:
go: ['1.20','1.21']
fail-fast: true
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Setup Go ${{ matrix.go }}
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
cache: false
- name: Run GoSec
uses: securego/gosec@master
with:
args: -exclude-dir examples ./...
- name: Run GoVulnCheck
uses: golang/govulncheck-action@v1
with:
go-version-input: ${{ matrix.go }}
go-package: ./...

35
.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,35 @@
name: Test
on:
push:
branches:
- main
pull_request:
branches:
- main
permissions:
contents: read
jobs:
unit:
strategy:
matrix:
go: ['1.20','1.21']
os: [ubuntu-latest, macos-latest, windows-latest]
fail-fast: true
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Setup Go ${{ matrix.go }}
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
cache: false
- name: Run Tests
run: go test -race -cover -coverprofile=coverage -covermode=atomic -v ./...
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
files: ./coverage

31
.github/workflows/verify.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Verify
on:
push:
branches:
- main
pull_request:
branches:
- main
permissions:
contents: read
jobs:
lint:
strategy:
matrix:
go: ['1.20','1.21']
fail-fast: true
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Setup Go ${{ matrix.go }}
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
cache: false
- name: Run GolangCI-Lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.53

26
.gitignore vendored
View File

@ -1,25 +1 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects) coverage.coverprofile
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
.idea/
*.iml

13
.golangci.yml Normal file
View File

@ -0,0 +1,13 @@
run:
timeout: "5m"
# will not run golangci-lint against *_test.go
tests: false
issues:
exclude-dirs:
- examples/*.go
exclude-rules:
# excluding error checks from all the .go files
- path: ./*.go
linters:
- errcheck

View File

@ -1,9 +0,0 @@
# This is the official list of Gorilla WebSocket authors for copyright
# purposes.
#
# Please keep the list sorted.
Gary Burd <gary@beagledreams.com>
Google LLC (https://opensource.google.com/)
Joachim Bauch <mail@joachim-bauch.de>

41
LICENSE
View File

@ -1,22 +1,27 @@
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. Copyright (c) 2023 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met: modification, are permitted provided that the following conditions are
met:
Redistributions of source code must retain the above copyright notice, this * Redistributions of source code must retain the above copyright
list of conditions and the following disclaimer. notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
Redistributions in binary form must reproduce the above copyright notice, THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
this list of conditions and the following disclaimer in the documentation "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
and/or other materials provided with the distribution. LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

34
Makefile Normal file
View File

@ -0,0 +1,34 @@
GO_LINT=$(shell which golangci-lint 2> /dev/null || echo '')
GO_LINT_URI=github.com/golangci/golangci-lint/cmd/golangci-lint@latest
GO_SEC=$(shell which gosec 2> /dev/null || echo '')
GO_SEC_URI=github.com/securego/gosec/v2/cmd/gosec@latest
GO_VULNCHECK=$(shell which govulncheck 2> /dev/null || echo '')
GO_VULNCHECK_URI=golang.org/x/vuln/cmd/govulncheck@latest
.PHONY: golangci-lint
golangci-lint:
$(if $(GO_LINT), ,go install $(GO_LINT_URI))
@echo "##### Running golangci-lint"
golangci-lint run -v
.PHONY: gosec
gosec:
$(if $(GO_SEC), ,go install $(GO_SEC_URI))
@echo "##### Running gosec"
gosec -exclude-dir examples ./...
.PHONY: govulncheck
govulncheck:
$(if $(GO_VULNCHECK), ,go install $(GO_VULNCHECK_URI))
@echo "##### Running govulncheck"
govulncheck ./...
.PHONY: verify
verify: golangci-lint gosec govulncheck
.PHONY: test
test:
@echo "##### Running tests"
go test -race -cover -coverprofile=coverage.coverprofile -covermode=atomic -v ./...

View File

@ -1,19 +1,23 @@
# Gorilla WebSocket # gorilla/websocket
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) ![testing](https://github.com/gorilla/websocket/actions/workflows/test.yml/badge.svg)
[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) [![codecov](https://codecov.io/github/gorilla/websocket/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/websocket)
[![godoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
[![sourcegraph](https://sourcegraph.com/github.com/gorilla/websocket/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/websocket?badge)
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
![Gorilla Logo](https://github.com/gorilla/.github/assets/53367916/d92caabf-98e0-473e-bfbf-ab554ba435e5)
### Documentation ### Documentation
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) * [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) * [Command example](https://github.com/gorilla/websocket/tree/main/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) * [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) * [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch)
* [Write buffer pool example](https://github.com/gorilla/websocket/tree/main/examples/bufferpool)
### Status ### Status
@ -29,5 +33,4 @@ package API is stable.
The Gorilla WebSocket package passes the server tests in the [Autobahn Test The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn).

View File

@ -11,13 +11,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"golang.org/x/net/proxy"
) )
// ErrBadHandshake is returned when the server response to opening handshake is // ErrBadHandshake is returned when the server response to opening handshake is
@ -304,7 +305,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err return nil, nil, err
} }
if proxyURL != nil { if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -391,7 +392,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
} }
if resp.StatusCode != 101 || if resp.StatusCode != http.StatusSwitchingProtocols ||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") || !tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") || !tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
@ -400,7 +401,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// debugging. // debugging.
buf := make([]byte, 1024) buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf) n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake return nil, resp, ErrBadHandshake
} }
@ -418,7 +419,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
break break
} }
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{}) netConn.SetDeadline(time.Time{})
@ -428,7 +429,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
func cloneTLSConfig(cfg *tls.Config) *tls.Config { func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil { if cfg == nil {
return &tls.Config{} return &tls.Config{MinVersion: tls.VersionTLS12}
} }
return cfg.Clone() return cfg.Clone()
} }

View File

@ -14,7 +14,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -149,6 +148,7 @@ func sendRecv(t *testing.T, ws *Conn) {
} }
func TestProxyDial(t *testing.T) { func TestProxyDial(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -187,6 +187,7 @@ func TestProxyDial(t *testing.T) {
} }
func TestProxyAuthorizationDial(t *testing.T) { func TestProxyAuthorizationDial(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -227,6 +228,7 @@ func TestProxyAuthorizationDial(t *testing.T) {
} }
func TestDial(t *testing.T) { func TestDial(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -239,6 +241,7 @@ func TestDial(t *testing.T) {
} }
func TestDialCookieJar(t *testing.T) { func TestDialCookieJar(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -301,6 +304,7 @@ func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
} }
func TestDialTLS(t *testing.T) { func TestDialTLS(t *testing.T) {
t.Parallel()
s := newTLSServer(t) s := newTLSServer(t)
defer s.Close() defer s.Close()
@ -315,6 +319,7 @@ func TestDialTLS(t *testing.T) {
} }
func TestDialTimeout(t *testing.T) { func TestDialTimeout(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -371,6 +376,7 @@ func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr()
func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
func TestHandshakeTimeout(t *testing.T) { func TestHandshakeTimeout(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -387,6 +393,7 @@ func TestHandshakeTimeout(t *testing.T) {
} }
func TestHandshakeTimeoutInContext(t *testing.T) { func TestHandshakeTimeoutInContext(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -408,6 +415,7 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
} }
func TestDialBadScheme(t *testing.T) { func TestDialBadScheme(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -419,6 +427,7 @@ func TestDialBadScheme(t *testing.T) {
} }
func TestDialBadOrigin(t *testing.T) { func TestDialBadOrigin(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -436,6 +445,7 @@ func TestDialBadOrigin(t *testing.T) {
} }
func TestDialBadHeader(t *testing.T) { func TestDialBadHeader(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -455,6 +465,7 @@ func TestDialBadHeader(t *testing.T) {
} }
func TestBadMethod(t *testing.T) { func TestBadMethod(t *testing.T) {
t.Parallel()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := cstUpgrader.Upgrade(w, r, nil) ws, err := cstUpgrader.Upgrade(w, r, nil)
if err == nil { if err == nil {
@ -483,6 +494,7 @@ func TestBadMethod(t *testing.T) {
} }
func TestDialExtraTokensInRespHeaders(t *testing.T) { func TestDialExtraTokensInRespHeaders(t *testing.T) {
t.Parallel()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
challengeKey := r.Header.Get("Sec-Websocket-Key") challengeKey := r.Header.Get("Sec-Websocket-Key")
w.Header().Set("Upgrade", "foo, websocket") w.Header().Set("Upgrade", "foo, websocket")
@ -500,6 +512,7 @@ func TestDialExtraTokensInRespHeaders(t *testing.T) {
} }
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -526,6 +539,7 @@ func TestHandshake(t *testing.T) {
} }
func TestRespOnBadHandshake(t *testing.T) { func TestRespOnBadHandshake(t *testing.T) {
t.Parallel()
const expectedStatus = http.StatusGone const expectedStatus = http.StatusGone
const expectedBody = "This is the response body." const expectedBody = "This is the response body."
@ -549,7 +563,7 @@ func TestRespOnBadHandshake(t *testing.T) {
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
} }
p, err := ioutil.ReadAll(resp.Body) p, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatalf("ReadFull(resp.Body) returned error %v", err) t.Fatalf("ReadFull(resp.Body) returned error %v", err)
} }
@ -570,6 +584,7 @@ func (w testLogWriter) Write(p []byte) (int, error) {
// TestHost tests handling of host names and confirms that it matches net/http. // TestHost tests handling of host names and confirms that it matches net/http.
func TestHost(t *testing.T) { func TestHost(t *testing.T) {
t.Parallel()
upgrader := Upgrader{} upgrader := Upgrader{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -751,6 +766,7 @@ func TestHost(t *testing.T) {
} }
func TestDialCompression(t *testing.T) { func TestDialCompression(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -765,6 +781,7 @@ func TestDialCompression(t *testing.T) {
} }
func TestSocksProxyDial(t *testing.T) { func TestSocksProxyDial(t *testing.T) {
t.Parallel()
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
@ -844,6 +861,7 @@ func TestSocksProxyDial(t *testing.T) {
} }
func TestTracingDialWithContext(t *testing.T) { func TestTracingDialWithContext(t *testing.T) {
t.Parallel()
var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
trace := &httptrace.ClientTrace{ trace := &httptrace.ClientTrace{
@ -903,6 +921,7 @@ func TestTracingDialWithContext(t *testing.T) {
} }
func TestEmptyTracingDialWithContext(t *testing.T) { func TestEmptyTracingDialWithContext(t *testing.T) {
t.Parallel()
trace := &httptrace.ClientTrace{} trace := &httptrace.ClientTrace{}
ctx := httptrace.WithClientTrace(context.Background(), trace) ctx := httptrace.WithClientTrace(context.Background(), trace)
@ -924,6 +943,7 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext // TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
func TestNetDialConnect(t *testing.T) { func TestNetDialConnect(t *testing.T) {
t.Parallel()
upgrader := Upgrader{} upgrader := Upgrader{}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -1099,6 +1119,7 @@ func TestNetDialConnect(t *testing.T) {
} }
} }
func TestNextProtos(t *testing.T) { func TestNextProtos(t *testing.T) {
t.Parallel()
ts := httptest.NewUnstartedServer( ts := httptest.NewUnstartedServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
) )

View File

@ -20,6 +20,7 @@ var hostPortNoPortTests = []struct {
} }
func TestHostPortNoPort(t *testing.T) { func TestHostPortNoPort(t *testing.T) {
t.Parallel()
for _, tt := range hostPortNoPortTests { for _, tt := range hostPortNoPortTests {
hostPort, hostNoPort := hostPortNoPort(tt.u) hostPort, hostNoPort := hostPortNoPort(tt.u)
if hostPort != tt.hostPort { if hostPort != tt.hostPort {

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"testing" "testing"
) )
@ -13,6 +12,7 @@ type nopCloser struct{ io.Writer }
func (nopCloser) Close() error { return nil } func (nopCloser) Close() error { return nil }
func TestTruncWriter(t *testing.T) { func TestTruncWriter(t *testing.T) {
t.Parallel()
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
for n := 1; n <= 10; n++ { for n := 1; n <= 10; n++ {
var b bytes.Buffer var b bytes.Buffer
@ -42,7 +42,7 @@ func textMessages(num int) [][]byte {
} }
func BenchmarkWriteNoCompression(b *testing.B) { func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
b.ResetTimer() b.ResetTimer()
@ -53,7 +53,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
} }
func BenchmarkWriteWithCompression(b *testing.B) { func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
c.enableWriteCompression = true c.enableWriteCompression = true
@ -66,6 +66,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
} }
func TestValidCompressionLevel(t *testing.T) { func TestValidCompressionLevel(t *testing.T) {
t.Parallel()
c := newTestConn(nil, nil, false) c := newTestConn(nil, nil, false)
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
if err := c.SetCompressionLevel(level); err == nil { if err := c.SetCompressionLevel(level); err == nil {

72
conn.go
View File

@ -6,11 +6,10 @@ package websocket
import ( import (
"bufio" "bufio"
"crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"io/ioutil"
"math/rand"
"net" "net"
"strconv" "strconv"
"strings" "strings"
@ -181,16 +180,16 @@ var (
errInvalidControlFrame = errors.New("websocket: invalid control frame") errInvalidControlFrame = errors.New("websocket: invalid control frame")
) )
func newMaskKey() [4]byte { // maskRand is an io.Reader for generating mask bytes. The reader is initialized
n := rand.Uint32() // to crypto/rand Reader. Tests swap the reader to a math/rand reader for
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} // reproducible results.
} var maskRand = rand.Reader
func hideTempErr(err error) error { // newMaskKey returns a new 32 bit value for masking client frames.
if e, ok := err.(net.Error); ok && e.Temporary() { func newMaskKey() [4]byte {
err = &netError{msg: e.Error(), timeout: e.Timeout()} var k [4]byte
} _, _ = io.ReadFull(maskRand, k[:])
return err return k
} }
func isControl(frameType int) bool { func isControl(frameType int) bool {
@ -358,7 +357,6 @@ func (c *Conn) RemoteAddr() net.Addr {
// Write methods // Write methods
func (c *Conn) writeFatal(err error) error { func (c *Conn) writeFatal(err error) error {
err = hideTempErr(err)
c.writeErrMu.Lock() c.writeErrMu.Lock()
if c.writeErr == nil { if c.writeErr == nil {
c.writeErr = err c.writeErr = err
@ -436,21 +434,27 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
maskBytes(key, 0, buf[6:]) maskBytes(key, 0, buf[6:])
} }
d := 1000 * time.Hour if deadline.IsZero() {
if !deadline.IsZero() { // No timeout for zero time.
d = deadline.Sub(time.Now()) <-c.mu
} else {
d := time.Until(deadline)
if d < 0 { if d < 0 {
return errWriteTimeout return errWriteTimeout
} }
select {
case <-c.mu:
default:
timer := time.NewTimer(d)
select {
case <-c.mu:
timer.Stop()
case <-timer.C:
return errWriteTimeout
}
}
} }
timer := time.NewTimer(d)
select {
case <-c.mu:
timer.Stop()
case <-timer.C:
return errWriteTimeout
}
defer func() { c.mu <- struct{}{} }() defer func() { c.mu <- struct{}{} }()
c.writeErrMu.Lock() c.writeErrMu.Lock()
@ -795,7 +799,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
if c.readRemaining > 0 { if c.readRemaining > 0 {
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err return noFrame, err
} }
} }
@ -1008,7 +1012,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
for c.readErr == nil { for c.readErr == nil {
frameType, err := c.advanceFrame() frameType, err := c.advanceFrame()
if err != nil { if err != nil {
c.readErr = hideTempErr(err) c.readErr = err
break break
} }
@ -1048,7 +1052,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
b = b[:c.readRemaining] b = b[:c.readRemaining]
} }
n, err := c.br.Read(b) n, err := c.br.Read(b)
c.readErr = hideTempErr(err) c.readErr = err
if c.isServer { if c.isServer {
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
} }
@ -1069,7 +1073,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
frameType, err := c.advanceFrame() frameType, err := c.advanceFrame()
switch { switch {
case err != nil: case err != nil:
c.readErr = hideTempErr(err) c.readErr = err
case frameType == TextMessage || frameType == BinaryMessage: case frameType == TextMessage || frameType == BinaryMessage:
c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
} }
@ -1094,7 +1098,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
if err != nil { if err != nil {
return messageType, nil, err return messageType, nil, err
} }
p, err = ioutil.ReadAll(r) p, err = io.ReadAll(r)
return messageType, p, err return messageType, p, err
} }
@ -1161,7 +1165,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
if err == ErrCloseSent { if err == ErrCloseSent {
return nil return nil
} else if e, ok := err.(net.Error); ok && e.Temporary() { } else if _, ok := err.(net.Error); ok {
return nil return nil
} }
return err return err
@ -1236,3 +1240,15 @@ func FormatCloseMessage(closeCode int, text string) []byte {
copy(buf[2:], text) copy(buf[2:], text)
return buf return buf
} }
var messageTypes = map[int]string{
TextMessage: "TextMessage",
BinaryMessage: "BinaryMessage",
CloseMessage: "CloseMessage",
PingMessage: "PingMessage",
PongMessage: "PongMessage",
}
func FormatMessageType(mt int) string {
return messageTypes[mt]
}

View File

@ -6,7 +6,6 @@ package websocket
import ( import (
"io" "io"
"io/ioutil"
"sync/atomic" "sync/atomic"
"testing" "testing"
) )
@ -45,7 +44,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
func newBroadcastBench(usePrepared, compression bool) *broadcastBench { func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
bench := &broadcastBench{ bench := &broadcastBench{
w: ioutil.Discard, w: io.Discard,
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
usePrepared: usePrepared, usePrepared: usePrepared,

View File

@ -10,7 +10,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"reflect" "reflect"
"sync" "sync"
@ -55,6 +54,7 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
} }
func TestFraming(t *testing.T) { func TestFraming(t *testing.T) {
t.Parallel()
frameSizes := []int{ frameSizes := []int{
0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
// 65536, 65537 // 65536, 65537
@ -125,7 +125,7 @@ func TestFraming(t *testing.T) {
} }
t.Logf("frame size: %d", n) t.Logf("frame size: %d", n)
rbuf, err := ioutil.ReadAll(r) rbuf, err := io.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue continue
@ -149,7 +149,49 @@ func TestFraming(t *testing.T) {
} }
} }
func TestWriteControlDeadline(t *testing.T) {
t.Parallel()
message := []byte("hello")
var connBuf bytes.Buffer
c := newTestConn(nil, &connBuf, true)
if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil {
t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err)
}
if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil {
t.Errorf("WriteControl(..., future deadline) = %v, want nil", err)
}
if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil {
t.Errorf("WriteControl(..., past deadline) = nil, want timeout error")
}
}
func TestConcurrencyWriteControl(t *testing.T) {
const message = "this is a ping/pong messsage"
loop := 10
workers := 10
for i := 0; i < loop; i++ {
var connBuf bytes.Buffer
wg := sync.WaitGroup{}
wc := newTestConn(nil, &connBuf, true)
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
t.Errorf("concurrently wc.WriteControl() returned %v", err)
}
}()
}
wg.Wait()
wc.Close()
}
}
func TestControl(t *testing.T) { func TestControl(t *testing.T) {
t.Parallel()
const message = "this is a ping/pong messsage" const message = "this is a ping/pong messsage"
for _, isServer := range []bool{true, false} { for _, isServer := range []bool{true, false} {
for _, isWriteControl := range []bool{true, false} { for _, isWriteControl := range []bool{true, false} {
@ -201,6 +243,7 @@ func (p *simpleBufferPool) Put(v interface{}) {
} }
func TestWriteBufferPool(t *testing.T) { func TestWriteBufferPool(t *testing.T) {
t.Parallel()
const message = "Now is the time for all good people to come to the aid of the party." const message = "Now is the time for all good people to come to the aid of the party."
var buf bytes.Buffer var buf bytes.Buffer
@ -279,6 +322,7 @@ func TestWriteBufferPool(t *testing.T) {
// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. // TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
func TestWriteBufferPoolSync(t *testing.T) { func TestWriteBufferPoolSync(t *testing.T) {
t.Parallel()
var buf bytes.Buffer var buf bytes.Buffer
var pool sync.Pool var pool sync.Pool
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
@ -307,6 +351,7 @@ func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error
// TestWriteBufferPoolError ensures that buffer is returned to pool after error // TestWriteBufferPoolError ensures that buffer is returned to pool after error
// on write. // on write.
func TestWriteBufferPoolError(t *testing.T) { func TestWriteBufferPoolError(t *testing.T) {
t.Parallel()
// Part 1: Test NextWriter/Write/Close // Part 1: Test NextWriter/Write/Close
@ -350,6 +395,7 @@ func TestWriteBufferPoolError(t *testing.T) {
} }
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
t.Parallel()
const bufSize = 512 const bufSize = 512
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
@ -367,7 +413,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if !reflect.DeepEqual(err, expectedErr) { if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
} }
@ -378,6 +424,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
} }
func TestEOFWithinFrame(t *testing.T) { func TestEOFWithinFrame(t *testing.T) {
t.Parallel()
const bufSize = 64 const bufSize = 64
for n := 0; ; n++ { for n := 0; ; n++ {
@ -401,7 +448,7 @@ func TestEOFWithinFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
} }
@ -413,6 +460,7 @@ func TestEOFWithinFrame(t *testing.T) {
} }
func TestEOFBeforeFinalFrame(t *testing.T) { func TestEOFBeforeFinalFrame(t *testing.T) {
t.Parallel()
const bufSize = 512 const bufSize = 512
var b1, b2 bytes.Buffer var b1, b2 bytes.Buffer
@ -426,7 +474,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
} }
@ -437,6 +485,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
} }
func TestWriteAfterMessageWriterClose(t *testing.T) { func TestWriteAfterMessageWriterClose(t *testing.T) {
t.Parallel()
wc := newTestConn(nil, &bytes.Buffer{}, false) wc := newTestConn(nil, &bytes.Buffer{}, false)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello") io.WriteString(w, "hello")
@ -462,7 +511,29 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
} }
} }
func TestWriteHandlerDoesNotReturnErrCloseSent(t *testing.T) {
t.Parallel()
var b1, b2 bytes.Buffer
client := newTestConn(&b2, &b1, false)
server := newTestConn(&b1, &b2, true)
msg := FormatCloseMessage(CloseNormalClosure, "")
if err := client.WriteMessage(CloseMessage, msg); err != nil {
t.Fatalf("unexpected error when writing close message, %v", err)
}
if _, _, err := server.NextReader(); !IsCloseError(err, 1000) {
t.Fatalf("server expects a close message, %v returned", err)
}
if _, _, err := client.NextReader(); !IsCloseError(err, 1000) {
t.Fatalf("client expects a close message, %v returned", err)
}
}
func TestReadLimit(t *testing.T) { func TestReadLimit(t *testing.T) {
t.Parallel()
t.Run("Test ReadLimit is enforced", func(t *testing.T) { t.Run("Test ReadLimit is enforced", func(t *testing.T) {
const readLimit = 512 const readLimit = 512
message := make([]byte, readLimit+1) message := make([]byte, readLimit+1)
@ -490,7 +561,7 @@ func TestReadLimit(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err) t.Fatalf("2: NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != ErrReadLimit { if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err) t.Fatalf("io.Copy() returned %v", err)
} }
@ -553,6 +624,7 @@ func TestReadLimit(t *testing.T) {
} }
func TestAddrs(t *testing.T) { func TestAddrs(t *testing.T) {
t.Parallel()
c := newTestConn(nil, nil, true) c := newTestConn(nil, nil, true)
if c.LocalAddr() != localAddr { if c.LocalAddr() != localAddr {
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
@ -563,6 +635,7 @@ func TestAddrs(t *testing.T) {
} }
func TestDeprecatedUnderlyingConn(t *testing.T) { func TestDeprecatedUnderlyingConn(t *testing.T) {
t.Parallel()
var b1, b2 bytes.Buffer var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2} fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024, nil, nil, nil) c := newConn(fc, true, 1024, 1024, nil, nil, nil)
@ -573,6 +646,7 @@ func TestDeprecatedUnderlyingConn(t *testing.T) {
} }
func TestNetConn(t *testing.T) { func TestNetConn(t *testing.T) {
t.Parallel()
var b1, b2 bytes.Buffer var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2} fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024, nil, nil, nil) c := newConn(fc, true, 1024, 1024, nil, nil, nil)
@ -583,6 +657,7 @@ func TestNetConn(t *testing.T) {
} }
func TestBufioReadBytes(t *testing.T) { func TestBufioReadBytes(t *testing.T) {
t.Parallel()
// Test calling bufio.ReadBytes for value longer than read buffer size. // Test calling bufio.ReadBytes for value longer than read buffer size.
m := make([]byte, 512) m := make([]byte, 512)
@ -623,6 +698,7 @@ var closeErrorTests = []struct {
} }
func TestCloseError(t *testing.T) { func TestCloseError(t *testing.T) {
t.Parallel()
for _, tt := range closeErrorTests { for _, tt := range closeErrorTests {
ok := IsCloseError(tt.err, tt.codes...) ok := IsCloseError(tt.err, tt.codes...)
if ok != tt.ok { if ok != tt.ok {
@ -643,6 +719,7 @@ var unexpectedCloseErrorTests = []struct {
} }
func TestUnexpectedCloseErrors(t *testing.T) { func TestUnexpectedCloseErrors(t *testing.T) {
t.Parallel()
for _, tt := range unexpectedCloseErrorTests { for _, tt := range unexpectedCloseErrorTests {
ok := IsUnexpectedCloseError(tt.err, tt.codes...) ok := IsUnexpectedCloseError(tt.err, tt.codes...)
if ok != tt.ok { if ok != tt.ok {
@ -664,6 +741,7 @@ func (w blockingWriter) Write(p []byte) (int, error) {
} }
func TestConcurrentWritePanic(t *testing.T) { func TestConcurrentWritePanic(t *testing.T) {
t.Parallel()
w := blockingWriter{make(chan struct{}), make(chan struct{})} w := blockingWriter{make(chan struct{}), make(chan struct{})}
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
go func() { go func() {
@ -691,6 +769,7 @@ func (r failingReader) Read(p []byte) (int, error) {
} }
func TestFailedConnectionReadPanic(t *testing.T) { func TestFailedConnectionReadPanic(t *testing.T) {
t.Parallel()
c := newTestConn(failingReader{}, nil, false) c := newTestConn(failingReader{}, nil, false)
defer func() { defer func() {
@ -704,3 +783,42 @@ func TestFailedConnectionReadPanic(t *testing.T) {
} }
t.Fatal("should not get here") t.Fatal("should not get here")
} }
func TestFormatMessageType(t *testing.T) {
str := FormatMessageType(TextMessage)
if str != messageTypes[TextMessage] {
t.Error("failed to format message type")
}
str = FormatMessageType(CloseMessage)
if str != messageTypes[CloseMessage] {
t.Error("failed to format message type")
}
str = FormatMessageType(123)
if str != messageTypes[123] {
t.Error("failed to format message type")
}
}
type fakeNetClosedReader struct {
}
func (r fakeNetClosedReader) Read([]byte) (int, error) {
return 0, net.ErrClosed
}
func TestConnectionClosed(t *testing.T) {
var b1, b2 bytes.Buffer
client := newTestConn(fakeNetClosedReader{}, &b1, false)
server := newTestConn(fakeNetClosedReader{}, &b2, true)
if _, _, err := server.NextReader(); !errors.Is(err, net.ErrClosed) {
t.Fatalf("server expects a net.ErrClosed error, %v returned", err)
}
if _, _, err := client.NextReader(); !errors.Is(err, net.ErrClosed) {
t.Fatalf("client expects a net.ErrClosed error, %v returned", err)
}
}

View File

@ -42,4 +42,4 @@ func processMessage(mt int, p []byte) {}
// TestX prevents godoc from showing this entire file in the example. Remove // TestX prevents godoc from showing this entire file in the example. Remove
// this function when a second example is added. // this function when a second example is added.
func TestX(t *testing.T) {} func TestX(t *testing.T) { t.Parallel() }

View File

@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
} }
// echoReadAll echoes messages from the client by reading the entire message // echoReadAll echoes messages from the client by reading the entire message
// with ioutil.ReadAll. // with io.ReadAll.
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {

View File

@ -0,0 +1,89 @@
//go:build ignore
// +build ignore
package main
import (
"flag"
"log"
"net/url"
"os"
"os/signal"
"sync"
"time"
"github.com/gorilla/websocket"
)
var addr = flag.String("addr", "localhost:8080", "http service address")
func runNewConn(wg *sync.WaitGroup) {
defer wg.Done()
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
u := url.URL{Scheme: "ws", Host: *addr, Path: "/ws"}
log.Printf("connecting to %s", u.String())
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
log.Fatal("dial:", err)
}
defer c.Close()
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := c.ReadMessage()
if err != nil {
log.Println("read:", err)
return
}
log.Printf("recv: %s", message)
}
}()
ticker := time.NewTicker(time.Minute * 5)
defer ticker.Stop()
for {
select {
case <-done:
return
case t := <-ticker.C:
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
if err != nil {
log.Println("write:", err)
return
}
case <-interrupt:
log.Println("interrupt")
// Cleanly close the connection by sending a close message and then
// waiting (with timeout) for the server to close the connection.
err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
log.Println("write close:", err)
return
}
select {
case <-done:
case <-time.After(time.Second):
}
return
}
}
}
func main() {
flag.Parse()
log.SetFlags(0)
wg := &sync.WaitGroup{}
for i := 0; i < 1000; i++ {
wg.Add(1)
go runNewConn(wg)
}
wg.Wait()
}

View File

@ -0,0 +1,55 @@
//go:build ignore
// +build ignore
package main
import (
"flag"
"log"
"net/http"
"sync"
_ "net/http/pprof"
"github.com/gorilla/websocket"
)
var addr = flag.String("addr", "localhost:8080", "http service address")
var upgrader = websocket.Upgrader{
ReadBufferSize: 256,
WriteBufferSize: 256,
WriteBufferPool: &sync.Pool{},
}
func process(c *websocket.Conn) {
defer c.Close()
for {
_, message, err := c.ReadMessage()
if err != nil {
log.Println("read:", err)
break
}
log.Printf("recv: %s", message)
}
}
func handler(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Print("upgrade:", err)
return
}
// Process connection in a new goroutine
go process(c)
// Let the http handler return, the 8k buffer created by it will be garbage collected
}
func main() {
flag.Parse()
log.SetFlags(0)
http.HandleFunc("/ws", handler)
log.Fatal(http.ListenAndServe(*addr, nil))
}

View File

@ -38,7 +38,7 @@ sends them to the hub.
### Hub ### Hub
The code for the `Hub` type is in The code for the `Hub` type is in
[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go). [hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go).
The application's `main` function starts the hub's `run` method as a goroutine. The application's `main` function starts the hub's `run` method as a goroutine.
Clients send requests to the hub using the `register`, `unregister` and Clients send requests to the hub using the `register`, `unregister` and
`broadcast` channels. `broadcast` channels.
@ -57,7 +57,7 @@ unregisters the client and closes the websocket.
### Client ### Client
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go). The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go).
The `serveWs` function is registered by the application's `main` function as The `serveWs` function is registered by the application's `main` function as
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
@ -85,7 +85,7 @@ network.
## Frontend ## Frontend
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html). The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html).
On document load, the script checks for websocket functionality in the browser. On document load, the script checks for websocket functionality in the browser.
If websocket functionality is available, then the script opens a connection to If websocket functionality is available, then the script opens a connection to

View File

@ -57,8 +57,6 @@ func pumpStdin(ws *websocket.Conn, w io.Writer) {
} }
func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) { func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) {
defer func() {
}()
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
for s.Scan() { for s.Scan() {
ws.SetWriteDeadline(time.Now().Add(writeWait)) ws.SetWriteDeadline(time.Now().Add(writeWait))

View File

@ -41,12 +41,12 @@ func main() {
go func() { go func() {
defer close(done) defer close(done)
for { for {
_, message, err := c.ReadMessage() mt, message, err := c.ReadMessage()
if err != nil { if err != nil {
log.Println("read:", err) log.Println("read:", err)
return return
} }
log.Printf("recv: %s", message) log.Printf("recv: %s, type: %s", message, websocket.FormatMessageType(mt))
} }
}() }()

View File

@ -33,7 +33,8 @@ func echo(w http.ResponseWriter, r *http.Request) {
log.Println("read:", err) log.Println("read:", err)
break break
} }
log.Printf("recv: %s", message)
log.Printf("recv: %s, type: %s", message, websocket.FormatMessageType(mt))
err = c.WriteMessage(mt, message) err = c.WriteMessage(mt, message)
if err != nil { if err != nil {
log.Println("write:", err) log.Println("write:", err)

View File

@ -7,10 +7,10 @@ package main
import ( import (
"flag" "flag"
"html/template" "html/template"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
"path/filepath"
"strconv" "strconv"
"time" "time"
@ -49,7 +49,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
if !fi.ModTime().After(lastMod) { if !fi.ModTime().After(lastMod) {
return nil, lastMod, nil return nil, lastMod, nil
} }
p, err := ioutil.ReadFile(filename) p, err := os.ReadFile(filepath.Clean(filename))
if err != nil { if err != nil {
return nil, fi.ModTime(), err return nil, fi.ModTime(), err
} }

4
go.mod
View File

@ -1,3 +1,5 @@
module github.com/gorilla/websocket module github.com/gorilla/websocket
go 1.12 go 1.20
require golang.org/x/net v0.23.0

2
go.sum
View File

@ -0,0 +1,2 @@
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=

View File

@ -12,6 +12,7 @@ import (
) )
func TestJoinMessages(t *testing.T) { func TestJoinMessages(t *testing.T) {
t.Parallel()
messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"} messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"}
for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} { for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} {
for _, term := range []string{"", ","} { for _, term := range []string{"", ","} {

View File

@ -13,6 +13,7 @@ import (
) )
func TestJSON(t *testing.T) { func TestJSON(t *testing.T) {
t.Parallel()
var buf bytes.Buffer var buf bytes.Buffer
wc := newTestConn(nil, &buf, true) wc := newTestConn(nil, &buf, true)
rc := newTestConn(&buf, nil, false) rc := newTestConn(&buf, nil, false)
@ -38,6 +39,7 @@ func TestJSON(t *testing.T) {
} }
func TestPartialJSONRead(t *testing.T) { func TestPartialJSONRead(t *testing.T) {
t.Parallel()
var buf0, buf1 bytes.Buffer var buf0, buf1 bytes.Buffer
wc := newTestConn(nil, &buf0, true) wc := newTestConn(nil, &buf0, true)
rc := newTestConn(&buf0, &buf1, false) rc := newTestConn(&buf0, &buf1, false)
@ -91,6 +93,7 @@ func TestPartialJSONRead(t *testing.T) {
} }
func TestDeprecatedJSON(t *testing.T) { func TestDeprecatedJSON(t *testing.T) {
t.Parallel()
var buf bytes.Buffer var buf bytes.Buffer
wc := newTestConn(nil, &buf, true) wc := newTestConn(nil, &buf, true)
rc := newTestConn(&buf, nil, false) rc := newTestConn(&buf, nil, false)

View File

@ -29,6 +29,7 @@ func notzero(b []byte) int {
} }
func TestMaskBytes(t *testing.T) { func TestMaskBytes(t *testing.T) {
t.Parallel()
key := [4]byte{1, 2, 3, 4} key := [4]byte{1, 2, 3, 4}
for size := 1; size <= 1024; size++ { for size := 1; size <= 1024; size++ {
for align := 0; align < wordSize; align++ { for align := 0; align < wordSize; align++ {

View File

@ -33,6 +33,11 @@ var preparedMessageTests = []struct {
} }
func TestPreparedMessage(t *testing.T) { func TestPreparedMessage(t *testing.T) {
testRand := rand.New(rand.NewSource(99))
prevMaskRand := maskRand
maskRand = testRand
defer func() { maskRand = prevMaskRand }()
for _, tt := range preparedMessageTests { for _, tt := range preparedMessageTests {
var data = []byte("this is a test") var data = []byte("this is a test")
var buf bytes.Buffer var buf bytes.Buffer
@ -43,7 +48,7 @@ func TestPreparedMessage(t *testing.T) {
c.SetCompressionLevel(tt.compressionLevel) c.SetCompressionLevel(tt.compressionLevel)
// Seed random number generator for consistent frame mask. // Seed random number generator for consistent frame mask.
rand.Seed(1234) testRand.Seed(1234)
if err := c.WriteMessage(tt.messageType, data); err != nil { if err := c.WriteMessage(tt.messageType, data); err != nil {
t.Fatal(err) t.Fatal(err)
@ -59,7 +64,7 @@ func TestPreparedMessage(t *testing.T) {
copy(data, "hello world") copy(data, "hello world")
// Seed random number generator for consistent frame mask. // Seed random number generator for consistent frame mask.
rand.Seed(1234) testRand.Seed(1234)
buf.Reset() buf.Reset()
if err := c.WritePreparedMessage(pm); err != nil { if err := c.WritePreparedMessage(pm); err != nil {

View File

@ -12,6 +12,8 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"golang.org/x/net/proxy"
) )
type netDialerFunc func(network, addr string) (net.Conn, error) type netDialerFunc func(network, addr string) (net.Conn, error)
@ -21,7 +23,7 @@ func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
} }
func init() { func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { proxy.RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy.Dialer) (proxy.Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
}) })
} }
@ -68,7 +70,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
return nil, err return nil, err
} }
if resp.StatusCode != 200 { if resp.StatusCode != http.StatusOK {
conn.Close() conn.Close()
f := strings.SplitN(resp.Status, " ", 2) f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1]) return nil, errors.New(f[1])

View File

@ -33,6 +33,7 @@ type Upgrader struct {
// size is zero, then buffers allocated by the HTTP server are used. The // size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent // I/O buffer sizes do not limit the size of the messages that can be sent
// or received. // or received.
// The default value is 4096 bytes, 4kb.
ReadBufferSize, WriteBufferSize int ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value // WriteBufferPool is a pool of buffers for write operations. If the value
@ -101,8 +102,8 @@ func checkSameOrigin(r *http.Request) bool {
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil { if u.Subprotocols != nil {
clientProtocols := Subprotocols(r) clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols { for _, clientProtocol := range clientProtocols {
for _, clientProtocol := range clientProtocols { for _, serverProtocol := range u.Subprotocols {
if clientProtocol == serverProtocol { if clientProtocol == serverProtocol {
return clientProtocol return clientProtocol
} }
@ -172,12 +173,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
} }
h, ok := w.(http.Hijacker) netConn, brw, err := http.NewResponseController(w).Hijack()
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error()) return u.returnError(w, r, http.StatusInternalServerError, err.Error())
} }

View File

@ -7,8 +7,10 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -27,6 +29,7 @@ var subprotocolTests = []struct {
} }
func TestSubprotocols(t *testing.T) { func TestSubprotocols(t *testing.T) {
t.Parallel()
for _, st := range subprotocolTests { for _, st := range subprotocolTests {
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}} r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}}
protocols := Subprotocols(&r) protocols := Subprotocols(&r)
@ -46,6 +49,7 @@ var isWebSocketUpgradeTests = []struct {
} }
func TestIsWebSocketUpgrade(t *testing.T) { func TestIsWebSocketUpgrade(t *testing.T) {
t.Parallel()
for _, tt := range isWebSocketUpgradeTests { for _, tt := range isWebSocketUpgradeTests {
ok := IsWebSocketUpgrade(&http.Request{Header: tt.h}) ok := IsWebSocketUpgrade(&http.Request{Header: tt.h})
if tt.ok != ok { if tt.ok != ok {
@ -54,6 +58,37 @@ func TestIsWebSocketUpgrade(t *testing.T) {
} }
} }
func TestSubProtocolSelection(t *testing.T) {
t.Parallel()
upgrader := Upgrader{
Subprotocols: []string{"foo", "bar", "baz"},
}
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
s := upgrader.selectSubprotocol(&r, nil)
if s != "foo" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "bar" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "baz" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
}
}
var checkSameOriginTests = []struct { var checkSameOriginTests = []struct {
ok bool ok bool
r *http.Request r *http.Request
@ -64,6 +99,7 @@ var checkSameOriginTests = []struct {
} }
func TestCheckSameOrigin(t *testing.T) { func TestCheckSameOrigin(t *testing.T) {
t.Parallel()
for _, tt := range checkSameOriginTests { for _, tt := range checkSameOriginTests {
ok := checkSameOrigin(tt.r) ok := checkSameOrigin(tt.r)
if tt.ok != ok { if tt.ok != ok {
@ -90,6 +126,7 @@ var bufioReuseTests = []struct {
} }
func TestBufioReuse(t *testing.T) { func TestBufioReuse(t *testing.T) {
t.Parallel()
for i, tt := range bufioReuseTests { for i, tt := range bufioReuseTests {
br := bufio.NewReaderSize(strings.NewReader(""), tt.n) br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n) bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
@ -117,3 +154,23 @@ func TestBufioReuse(t *testing.T) {
} }
} }
} }
func TestHijack_NotSupported(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "upgrade")
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Sec-Websocket-Version", "13")
recorder := httptest.NewRecorder()
upgrader := Upgrader{}
_, err := upgrader.Upgrade(recorder, req, nil)
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
}
}

View File

@ -1,6 +1,3 @@
//go:build go1.17
// +build go1.17
package websocket package websocket
import ( import (

View File

@ -21,6 +21,7 @@ var equalASCIIFoldTests = []struct {
} }
func TestEqualASCIIFold(t *testing.T) { func TestEqualASCIIFold(t *testing.T) {
t.Parallel()
for _, tt := range equalASCIIFoldTests { for _, tt := range equalASCIIFoldTests {
eq := equalASCIIFold(tt.s, tt.t) eq := equalASCIIFold(tt.s, tt.t)
if eq != tt.eq { if eq != tt.eq {
@ -44,6 +45,7 @@ var tokenListContainsValueTests = []struct {
} }
func TestTokenListContainsValue(t *testing.T) { func TestTokenListContainsValue(t *testing.T) {
t.Parallel()
for _, tt := range tokenListContainsValueTests { for _, tt := range tokenListContainsValueTests {
h := http.Header{"Upgrade": {tt.value}} h := http.Header{"Upgrade": {tt.value}}
ok := tokenListContainsValue(h, "Upgrade", "websocket") ok := tokenListContainsValue(h, "Upgrade", "websocket")
@ -64,6 +66,7 @@ var isValidChallengeKeyTests = []struct {
} }
func TestIsValidChallengeKey(t *testing.T) { func TestIsValidChallengeKey(t *testing.T) {
t.Parallel()
for _, tt := range isValidChallengeKeyTests { for _, tt := range isValidChallengeKeyTests {
ok := isValidChallengeKey(tt.key) ok := isValidChallengeKey(tt.key)
if ok != tt.ok { if ok != tt.ok {
@ -105,6 +108,7 @@ var parseExtensionTests = []struct {
} }
func TestParseExtensions(t *testing.T) { func TestParseExtensions(t *testing.T) {
t.Parallel()
for _, tt := range parseExtensionTests { for _, tt := range parseExtensionTests {
h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}} h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}}
extensions := parseExtensions(h) extensions := parseExtensions(h)