mirror of https://github.com/gorilla/websocket.git
Compare commits
1 Commits
0e6933cf76
...
1532d57c38
Author | SHA1 | Date |
---|---|---|
andig | 1532d57c38 |
|
@ -1,70 +0,0 @@
|
|||
version: 2.1
|
||||
|
||||
jobs:
|
||||
"test":
|
||||
parameters:
|
||||
version:
|
||||
type: string
|
||||
default: "latest"
|
||||
golint:
|
||||
type: boolean
|
||||
default: true
|
||||
modules:
|
||||
type: boolean
|
||||
default: true
|
||||
goproxy:
|
||||
type: string
|
||||
default: ""
|
||||
docker:
|
||||
- image: "cimg/go:<< parameters.version >>"
|
||||
working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket
|
||||
environment:
|
||||
GO111MODULE: "on"
|
||||
GOPROXY: "<< parameters.goproxy >>"
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: "Print the Go version"
|
||||
command: >
|
||||
go version
|
||||
- run:
|
||||
name: "Fetch dependencies"
|
||||
command: >
|
||||
if [[ << parameters.modules >> = true ]]; then
|
||||
go mod download
|
||||
export GO111MODULE=on
|
||||
else
|
||||
go get -v ./...
|
||||
fi
|
||||
# Only run gofmt, vet & lint against the latest Go version
|
||||
- run:
|
||||
name: "Run golint"
|
||||
command: >
|
||||
if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
|
||||
go get -u golang.org/x/lint/golint
|
||||
golint ./...
|
||||
fi
|
||||
- run:
|
||||
name: "Run gofmt"
|
||||
command: >
|
||||
if [[ << parameters.version >> = "latest" ]]; then
|
||||
diff -u <(echo -n) <(gofmt -d -e .)
|
||||
fi
|
||||
- run:
|
||||
name: "Run go vet"
|
||||
command: >
|
||||
if [[ << parameters.version >> = "latest" ]]; then
|
||||
go vet -v ./...
|
||||
fi
|
||||
- run:
|
||||
name: "Run go test (+ race detector)"
|
||||
command: >
|
||||
go test -v -race ./...
|
||||
|
||||
workflows:
|
||||
tests:
|
||||
jobs:
|
||||
- test:
|
||||
matrix:
|
||||
parameters:
|
||||
version: ["1.18", "1.17", "1.16"]
|
|
@ -0,0 +1,20 @@
|
|||
; https://editorconfig.org/
|
||||
|
||||
root = true
|
||||
|
||||
[*]
|
||||
insert_final_newline = true
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
[{Makefile,go.mod,go.sum,*.go,.gitmodules}]
|
||||
indent_style = tab
|
||||
indent_size = 4
|
||||
|
||||
[*.md]
|
||||
indent_size = 4
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
eclint_indent_style = unset
|
|
@ -1,7 +0,0 @@
|
|||
# Config for https://github.com/apps/release-drafter
|
||||
template: |
|
||||
|
||||
<summary of changes here>
|
||||
|
||||
## CHANGELOG
|
||||
$CHANGES
|
|
@ -0,0 +1,21 @@
|
|||
# Add all the issues created to the project.
|
||||
name: Add issue or pull request to Project
|
||||
|
||||
on:
|
||||
issues:
|
||||
types:
|
||||
- opened
|
||||
pull_request_target:
|
||||
types:
|
||||
- opened
|
||||
- reopened
|
||||
|
||||
jobs:
|
||||
add-to-project:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Add issue to project
|
||||
uses: actions/add-to-project@v0.5.0
|
||||
with:
|
||||
project-url: https://github.com/orgs/gorilla/projects/4
|
||||
github-token: ${{ secrets.ADD_TO_PROJECT_TOKEN }}
|
|
@ -0,0 +1,37 @@
|
|||
name: Security
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
scan:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.20','1.21']
|
||||
fail-fast: true
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Go ${{ matrix.go }}
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
cache: false
|
||||
|
||||
- name: Run GoSec
|
||||
uses: securego/gosec@master
|
||||
with:
|
||||
args: -exclude-dir examples ./...
|
||||
|
||||
- name: Run GoVulnCheck
|
||||
uses: golang/govulncheck-action@v1
|
||||
with:
|
||||
go-version-input: ${{ matrix.go }}
|
||||
go-package: ./...
|
|
@ -0,0 +1,35 @@
|
|||
name: Test
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
unit:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.20','1.21']
|
||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||
fail-fast: true
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Go ${{ matrix.go }}
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
cache: false
|
||||
|
||||
- name: Run Tests
|
||||
run: go test -race -cover -coverprofile=coverage -covermode=atomic -v ./...
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
files: ./coverage
|
|
@ -0,0 +1,31 @@
|
|||
name: Verify
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
lint:
|
||||
strategy:
|
||||
matrix:
|
||||
go: ['1.20','1.21']
|
||||
fail-fast: true
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Go ${{ matrix.go }}
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
cache: false
|
||||
|
||||
- name: Run GolangCI-Lint
|
||||
uses: golangci/golangci-lint-action@v3
|
||||
with:
|
||||
version: v1.53
|
|
@ -1,25 +1 @@
|
|||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
|
||||
.idea/
|
||||
*.iml
|
||||
coverage.coverprofile
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
run:
|
||||
timeout: "5m"
|
||||
# will not run golangci-lint against *_test.go
|
||||
tests: false
|
||||
issues:
|
||||
exclude-dirs:
|
||||
- examples/*.go
|
||||
exclude-rules:
|
||||
# excluding error checks from all the .go files
|
||||
- path: ./*.go
|
||||
linters:
|
||||
- errcheck
|
||||
|
9
AUTHORS
9
AUTHORS
|
@ -1,9 +0,0 @@
|
|||
# This is the official list of Gorilla WebSocket authors for copyright
|
||||
# purposes.
|
||||
#
|
||||
# Please keep the list sorted.
|
||||
|
||||
Gary Burd <gary@beagledreams.com>
|
||||
Google LLC (https://opensource.google.com/)
|
||||
Joachim Bauch <mail@joachim-bauch.de>
|
||||
|
39
LICENSE
39
LICENSE
|
@ -1,22 +1,27 @@
|
|||
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
|
||||
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,34 @@
|
|||
GO_LINT=$(shell which golangci-lint 2> /dev/null || echo '')
|
||||
GO_LINT_URI=github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
GO_SEC=$(shell which gosec 2> /dev/null || echo '')
|
||||
GO_SEC_URI=github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
GO_VULNCHECK=$(shell which govulncheck 2> /dev/null || echo '')
|
||||
GO_VULNCHECK_URI=golang.org/x/vuln/cmd/govulncheck@latest
|
||||
|
||||
.PHONY: golangci-lint
|
||||
golangci-lint:
|
||||
$(if $(GO_LINT), ,go install $(GO_LINT_URI))
|
||||
@echo "##### Running golangci-lint"
|
||||
golangci-lint run -v
|
||||
|
||||
.PHONY: gosec
|
||||
gosec:
|
||||
$(if $(GO_SEC), ,go install $(GO_SEC_URI))
|
||||
@echo "##### Running gosec"
|
||||
gosec -exclude-dir examples ./...
|
||||
|
||||
.PHONY: govulncheck
|
||||
govulncheck:
|
||||
$(if $(GO_VULNCHECK), ,go install $(GO_VULNCHECK_URI))
|
||||
@echo "##### Running govulncheck"
|
||||
govulncheck ./...
|
||||
|
||||
.PHONY: verify
|
||||
verify: golangci-lint gosec govulncheck
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
@echo "##### Running tests"
|
||||
go test -race -cover -coverprofile=coverage.coverprofile -covermode=atomic -v ./...
|
25
README.md
25
README.md
|
@ -1,19 +1,23 @@
|
|||
# Gorilla WebSocket
|
||||
# gorilla/websocket
|
||||
|
||||
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
|
||||
[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket)
|
||||
![testing](https://github.com/gorilla/websocket/actions/workflows/test.yml/badge.svg)
|
||||
[![codecov](https://codecov.io/github/gorilla/websocket/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/websocket)
|
||||
[![godoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
|
||||
[![sourcegraph](https://sourcegraph.com/github.com/gorilla/websocket/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/websocket?badge)
|
||||
|
||||
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
|
||||
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
|
||||
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
|
||||
|
||||
![Gorilla Logo](https://github.com/gorilla/.github/assets/53367916/d92caabf-98e0-473e-bfbf-ab554ba435e5)
|
||||
|
||||
|
||||
### Documentation
|
||||
|
||||
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
|
||||
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
|
||||
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
|
||||
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
|
||||
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
|
||||
* [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat)
|
||||
* [Command example](https://github.com/gorilla/websocket/tree/main/examples/command)
|
||||
* [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo)
|
||||
* [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch)
|
||||
* [Write buffer pool example](https://github.com/gorilla/websocket/tree/main/examples/bufferpool)
|
||||
|
||||
### Status
|
||||
|
||||
|
@ -29,5 +33,4 @@ package API is stable.
|
|||
|
||||
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
||||
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
||||
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
|
||||
|
||||
subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn).
|
||||
|
|
13
client.go
13
client.go
|
@ -11,13 +11,14 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// ErrBadHandshake is returned when the server response to opening handshake is
|
||||
|
@ -304,7 +305,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
return nil, nil, err
|
||||
}
|
||||
if proxyURL != nil {
|
||||
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
|
||||
dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -391,7 +392,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != 101 ||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols ||
|
||||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
|
||||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
|
||||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
|
||||
|
@ -400,7 +401,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
// debugging.
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := io.ReadFull(resp.Body, buf)
|
||||
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
||||
resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
|
||||
return nil, resp, ErrBadHandshake
|
||||
}
|
||||
|
||||
|
@ -418,7 +419,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
break
|
||||
}
|
||||
|
||||
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||
|
||||
netConn.SetDeadline(time.Time{})
|
||||
|
@ -428,7 +429,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
|
||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||
if cfg == nil {
|
||||
return &tls.Config{}
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
return cfg.Clone()
|
||||
}
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -149,6 +148,7 @@ func sendRecv(t *testing.T, ws *Conn) {
|
|||
}
|
||||
|
||||
func TestProxyDial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
@ -187,6 +187,7 @@ func TestProxyDial(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestProxyAuthorizationDial(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -227,6 +228,7 @@ func TestProxyAuthorizationDial(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDial(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -239,6 +241,7 @@ func TestDial(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDialCookieJar(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -301,6 +304,7 @@ func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
|
|||
}
|
||||
|
||||
func TestDialTLS(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newTLSServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -315,6 +319,7 @@ func TestDialTLS(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDialTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -371,6 +376,7 @@ func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr()
|
|||
func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
|
||||
|
||||
func TestHandshakeTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -387,6 +393,7 @@ func TestHandshakeTimeout(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandshakeTimeoutInContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -408,6 +415,7 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDialBadScheme(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -419,6 +427,7 @@ func TestDialBadScheme(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDialBadOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -436,6 +445,7 @@ func TestDialBadOrigin(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDialBadHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -455,6 +465,7 @@ func TestDialBadHeader(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBadMethod(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ws, err := cstUpgrader.Upgrade(w, r, nil)
|
||||
if err == nil {
|
||||
|
@ -483,6 +494,7 @@ func TestBadMethod(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDialExtraTokensInRespHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||||
w.Header().Set("Upgrade", "foo, websocket")
|
||||
|
@ -500,6 +512,7 @@ func TestDialExtraTokensInRespHeaders(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -526,6 +539,7 @@ func TestHandshake(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRespOnBadHandshake(t *testing.T) {
|
||||
t.Parallel()
|
||||
const expectedStatus = http.StatusGone
|
||||
const expectedBody = "This is the response body."
|
||||
|
||||
|
@ -549,7 +563,7 @@ func TestRespOnBadHandshake(t *testing.T) {
|
|||
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
||||
}
|
||||
|
||||
p, err := ioutil.ReadAll(resp.Body)
|
||||
p, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
||||
}
|
||||
|
@ -570,6 +584,7 @@ func (w testLogWriter) Write(p []byte) (int, error) {
|
|||
|
||||
// TestHost tests handling of host names and confirms that it matches net/http.
|
||||
func TestHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upgrader := Upgrader{}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -751,6 +766,7 @@ func TestHost(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDialCompression(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -765,6 +781,7 @@ func TestDialCompression(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSocksProxyDial(t *testing.T) {
|
||||
t.Parallel()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -844,6 +861,7 @@ func TestSocksProxyDial(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTracingDialWithContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
|
||||
trace := &httptrace.ClientTrace{
|
||||
|
@ -903,6 +921,7 @@ func TestTracingDialWithContext(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestEmptyTracingDialWithContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
trace := &httptrace.ClientTrace{}
|
||||
ctx := httptrace.WithClientTrace(context.Background(), trace)
|
||||
|
@ -924,6 +943,7 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
|
|||
|
||||
// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
|
||||
func TestNetDialConnect(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upgrader := Upgrader{}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1099,6 +1119,7 @@ func TestNetDialConnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
func TestNextProtos(t *testing.T) {
|
||||
t.Parallel()
|
||||
ts := httptest.NewUnstartedServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
)
|
||||
|
|
|
@ -20,6 +20,7 @@ var hostPortNoPortTests = []struct {
|
|||
}
|
||||
|
||||
func TestHostPortNoPort(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tt := range hostPortNoPortTests {
|
||||
hostPort, hostNoPort := hostPortNoPort(tt.u)
|
||||
if hostPort != tt.hostPort {
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -13,6 +12,7 @@ type nopCloser struct{ io.Writer }
|
|||
func (nopCloser) Close() error { return nil }
|
||||
|
||||
func TestTruncWriter(t *testing.T) {
|
||||
t.Parallel()
|
||||
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
|
||||
for n := 1; n <= 10; n++ {
|
||||
var b bytes.Buffer
|
||||
|
@ -42,7 +42,7 @@ func textMessages(num int) [][]byte {
|
|||
}
|
||||
|
||||
func BenchmarkWriteNoCompression(b *testing.B) {
|
||||
w := ioutil.Discard
|
||||
w := io.Discard
|
||||
c := newTestConn(nil, w, false)
|
||||
messages := textMessages(100)
|
||||
b.ResetTimer()
|
||||
|
@ -53,7 +53,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkWriteWithCompression(b *testing.B) {
|
||||
w := ioutil.Discard
|
||||
w := io.Discard
|
||||
c := newTestConn(nil, w, false)
|
||||
messages := textMessages(100)
|
||||
c.enableWriteCompression = true
|
||||
|
@ -66,6 +66,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
|
|||
}
|
||||
|
||||
func TestValidCompressionLevel(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := newTestConn(nil, nil, false)
|
||||
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
|
||||
if err := c.SetCompressionLevel(level); err == nil {
|
||||
|
|
62
conn.go
62
conn.go
|
@ -6,11 +6,10 @@ package websocket
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -181,16 +180,16 @@ var (
|
|||
errInvalidControlFrame = errors.New("websocket: invalid control frame")
|
||||
)
|
||||
|
||||
func newMaskKey() [4]byte {
|
||||
n := rand.Uint32()
|
||||
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
|
||||
}
|
||||
// maskRand is an io.Reader for generating mask bytes. The reader is initialized
|
||||
// to crypto/rand Reader. Tests swap the reader to a math/rand reader for
|
||||
// reproducible results.
|
||||
var maskRand = rand.Reader
|
||||
|
||||
func hideTempErr(err error) error {
|
||||
if e, ok := err.(net.Error); ok && e.Temporary() {
|
||||
err = &netError{msg: e.Error(), timeout: e.Timeout()}
|
||||
}
|
||||
return err
|
||||
// newMaskKey returns a new 32 bit value for masking client frames.
|
||||
func newMaskKey() [4]byte {
|
||||
var k [4]byte
|
||||
_, _ = io.ReadFull(maskRand, k[:])
|
||||
return k
|
||||
}
|
||||
|
||||
func isControl(frameType int) bool {
|
||||
|
@ -358,7 +357,6 @@ func (c *Conn) RemoteAddr() net.Addr {
|
|||
// Write methods
|
||||
|
||||
func (c *Conn) writeFatal(err error) error {
|
||||
err = hideTempErr(err)
|
||||
c.writeErrMu.Lock()
|
||||
if c.writeErr == nil {
|
||||
c.writeErr = err
|
||||
|
@ -436,14 +434,17 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
|||
maskBytes(key, 0, buf[6:])
|
||||
}
|
||||
|
||||
d := 1000 * time.Hour
|
||||
if !deadline.IsZero() {
|
||||
d = deadline.Sub(time.Now())
|
||||
if deadline.IsZero() {
|
||||
// No timeout for zero time.
|
||||
<-c.mu
|
||||
} else {
|
||||
d := time.Until(deadline)
|
||||
if d < 0 {
|
||||
return errWriteTimeout
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-c.mu:
|
||||
default:
|
||||
timer := time.NewTimer(d)
|
||||
select {
|
||||
case <-c.mu:
|
||||
|
@ -451,6 +452,9 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
|||
case <-timer.C:
|
||||
return errWriteTimeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
defer func() { c.mu <- struct{}{} }()
|
||||
|
||||
c.writeErrMu.Lock()
|
||||
|
@ -795,7 +799,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
// 1. Skip remainder of previous frame.
|
||||
|
||||
if c.readRemaining > 0 {
|
||||
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
|
||||
if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
|
||||
return noFrame, err
|
||||
}
|
||||
}
|
||||
|
@ -1008,7 +1012,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|||
for c.readErr == nil {
|
||||
frameType, err := c.advanceFrame()
|
||||
if err != nil {
|
||||
c.readErr = hideTempErr(err)
|
||||
c.readErr = err
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -1048,7 +1052,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
|||
b = b[:c.readRemaining]
|
||||
}
|
||||
n, err := c.br.Read(b)
|
||||
c.readErr = hideTempErr(err)
|
||||
c.readErr = err
|
||||
if c.isServer {
|
||||
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
|
||||
}
|
||||
|
@ -1069,7 +1073,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
|||
frameType, err := c.advanceFrame()
|
||||
switch {
|
||||
case err != nil:
|
||||
c.readErr = hideTempErr(err)
|
||||
c.readErr = err
|
||||
case frameType == TextMessage || frameType == BinaryMessage:
|
||||
c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
||||
}
|
||||
|
@ -1094,7 +1098,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
|
|||
if err != nil {
|
||||
return messageType, nil, err
|
||||
}
|
||||
p, err = ioutil.ReadAll(r)
|
||||
p, err = io.ReadAll(r)
|
||||
return messageType, p, err
|
||||
}
|
||||
|
||||
|
@ -1161,7 +1165,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {
|
|||
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
|
||||
if err == ErrCloseSent {
|
||||
return nil
|
||||
} else if e, ok := err.(net.Error); ok && e.Temporary() {
|
||||
} else if _, ok := err.(net.Error); ok {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
|
@ -1236,3 +1240,15 @@ func FormatCloseMessage(closeCode int, text string) []byte {
|
|||
copy(buf[2:], text)
|
||||
return buf
|
||||
}
|
||||
|
||||
var messageTypes = map[int]string{
|
||||
TextMessage: "TextMessage",
|
||||
BinaryMessage: "BinaryMessage",
|
||||
CloseMessage: "CloseMessage",
|
||||
PingMessage: "PingMessage",
|
||||
PongMessage: "PongMessage",
|
||||
}
|
||||
|
||||
func FormatMessageType(mt int) string {
|
||||
return messageTypes[mt]
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ package websocket
|
|||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
@ -45,7 +44,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
|
|||
|
||||
func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
|
||||
bench := &broadcastBench{
|
||||
w: ioutil.Discard,
|
||||
w: io.Discard,
|
||||
doneCh: make(chan struct{}),
|
||||
closeCh: make(chan struct{}),
|
||||
usePrepared: usePrepared,
|
||||
|
|
130
conn_test.go
130
conn_test.go
|
@ -10,7 +10,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
@ -55,6 +54,7 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
|
|||
}
|
||||
|
||||
func TestFraming(t *testing.T) {
|
||||
t.Parallel()
|
||||
frameSizes := []int{
|
||||
0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
|
||||
// 65536, 65537
|
||||
|
@ -125,7 +125,7 @@ func TestFraming(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Logf("frame size: %d", n)
|
||||
rbuf, err := ioutil.ReadAll(r)
|
||||
rbuf, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
||||
continue
|
||||
|
@ -149,7 +149,49 @@ func TestFraming(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWriteControlDeadline(t *testing.T) {
|
||||
t.Parallel()
|
||||
message := []byte("hello")
|
||||
var connBuf bytes.Buffer
|
||||
c := newTestConn(nil, &connBuf, true)
|
||||
if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil {
|
||||
t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err)
|
||||
}
|
||||
if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil {
|
||||
t.Errorf("WriteControl(..., future deadline) = %v, want nil", err)
|
||||
}
|
||||
if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil {
|
||||
t.Errorf("WriteControl(..., past deadline) = nil, want timeout error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrencyWriteControl(t *testing.T) {
|
||||
const message = "this is a ping/pong messsage"
|
||||
loop := 10
|
||||
workers := 10
|
||||
for i := 0; i < loop; i++ {
|
||||
var connBuf bytes.Buffer
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wc := newTestConn(nil, &connBuf, true)
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
|
||||
t.Errorf("concurrently wc.WriteControl() returned %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
wc.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestControl(t *testing.T) {
|
||||
t.Parallel()
|
||||
const message = "this is a ping/pong messsage"
|
||||
for _, isServer := range []bool{true, false} {
|
||||
for _, isWriteControl := range []bool{true, false} {
|
||||
|
@ -201,6 +243,7 @@ func (p *simpleBufferPool) Put(v interface{}) {
|
|||
}
|
||||
|
||||
func TestWriteBufferPool(t *testing.T) {
|
||||
t.Parallel()
|
||||
const message = "Now is the time for all good people to come to the aid of the party."
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
@ -279,6 +322,7 @@ func TestWriteBufferPool(t *testing.T) {
|
|||
|
||||
// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
|
||||
func TestWriteBufferPoolSync(t *testing.T) {
|
||||
t.Parallel()
|
||||
var buf bytes.Buffer
|
||||
var pool sync.Pool
|
||||
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
|
||||
|
@ -307,6 +351,7 @@ func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error
|
|||
// TestWriteBufferPoolError ensures that buffer is returned to pool after error
|
||||
// on write.
|
||||
func TestWriteBufferPoolError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Part 1: Test NextWriter/Write/Close
|
||||
|
||||
|
@ -350,6 +395,7 @@ func TestWriteBufferPoolError(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||
t.Parallel()
|
||||
const bufSize = 512
|
||||
|
||||
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
|
||||
|
@ -367,7 +413,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
|||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||
}
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
_, err = io.Copy(io.Discard, r)
|
||||
if !reflect.DeepEqual(err, expectedErr) {
|
||||
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
|
||||
}
|
||||
|
@ -378,6 +424,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestEOFWithinFrame(t *testing.T) {
|
||||
t.Parallel()
|
||||
const bufSize = 64
|
||||
|
||||
for n := 0; ; n++ {
|
||||
|
@ -401,7 +448,7 @@ func TestEOFWithinFrame(t *testing.T) {
|
|||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
||||
}
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
_, err = io.Copy(io.Discard, r)
|
||||
if err != errUnexpectedEOF {
|
||||
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||
}
|
||||
|
@ -413,6 +460,7 @@ func TestEOFWithinFrame(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||
t.Parallel()
|
||||
const bufSize = 512
|
||||
|
||||
var b1, b2 bytes.Buffer
|
||||
|
@ -426,7 +474,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
|||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||
}
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
_, err = io.Copy(io.Discard, r)
|
||||
if err != errUnexpectedEOF {
|
||||
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
||||
}
|
||||
|
@ -437,6 +485,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
||||
w, _ := wc.NextWriter(BinaryMessage)
|
||||
io.WriteString(w, "hello")
|
||||
|
@ -462,7 +511,29 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWriteHandlerDoesNotReturnErrCloseSent(t *testing.T) {
|
||||
t.Parallel()
|
||||
var b1, b2 bytes.Buffer
|
||||
|
||||
client := newTestConn(&b2, &b1, false)
|
||||
server := newTestConn(&b1, &b2, true)
|
||||
|
||||
msg := FormatCloseMessage(CloseNormalClosure, "")
|
||||
if err := client.WriteMessage(CloseMessage, msg); err != nil {
|
||||
t.Fatalf("unexpected error when writing close message, %v", err)
|
||||
}
|
||||
|
||||
if _, _, err := server.NextReader(); !IsCloseError(err, 1000) {
|
||||
t.Fatalf("server expects a close message, %v returned", err)
|
||||
}
|
||||
|
||||
if _, _, err := client.NextReader(); !IsCloseError(err, 1000) {
|
||||
t.Fatalf("client expects a close message, %v returned", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("Test ReadLimit is enforced", func(t *testing.T) {
|
||||
const readLimit = 512
|
||||
message := make([]byte, readLimit+1)
|
||||
|
@ -490,7 +561,7 @@ func TestReadLimit(t *testing.T) {
|
|||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
||||
}
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
_, err = io.Copy(io.Discard, r)
|
||||
if err != ErrReadLimit {
|
||||
t.Fatalf("io.Copy() returned %v", err)
|
||||
}
|
||||
|
@ -553,6 +624,7 @@ func TestReadLimit(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAddrs(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := newTestConn(nil, nil, true)
|
||||
if c.LocalAddr() != localAddr {
|
||||
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
|
||||
|
@ -563,6 +635,7 @@ func TestAddrs(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDeprecatedUnderlyingConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
var b1, b2 bytes.Buffer
|
||||
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
||||
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
||||
|
@ -573,6 +646,7 @@ func TestDeprecatedUnderlyingConn(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNetConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
var b1, b2 bytes.Buffer
|
||||
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
||||
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
||||
|
@ -583,6 +657,7 @@ func TestNetConn(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestBufioReadBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Test calling bufio.ReadBytes for value longer than read buffer size.
|
||||
|
||||
m := make([]byte, 512)
|
||||
|
@ -623,6 +698,7 @@ var closeErrorTests = []struct {
|
|||
}
|
||||
|
||||
func TestCloseError(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tt := range closeErrorTests {
|
||||
ok := IsCloseError(tt.err, tt.codes...)
|
||||
if ok != tt.ok {
|
||||
|
@ -643,6 +719,7 @@ var unexpectedCloseErrorTests = []struct {
|
|||
}
|
||||
|
||||
func TestUnexpectedCloseErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tt := range unexpectedCloseErrorTests {
|
||||
ok := IsUnexpectedCloseError(tt.err, tt.codes...)
|
||||
if ok != tt.ok {
|
||||
|
@ -664,6 +741,7 @@ func (w blockingWriter) Write(p []byte) (int, error) {
|
|||
}
|
||||
|
||||
func TestConcurrentWritePanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
||||
c := newTestConn(nil, w, false)
|
||||
go func() {
|
||||
|
@ -691,6 +769,7 @@ func (r failingReader) Read(p []byte) (int, error) {
|
|||
}
|
||||
|
||||
func TestFailedConnectionReadPanic(t *testing.T) {
|
||||
t.Parallel()
|
||||
c := newTestConn(failingReader{}, nil, false)
|
||||
|
||||
defer func() {
|
||||
|
@ -704,3 +783,42 @@ func TestFailedConnectionReadPanic(t *testing.T) {
|
|||
}
|
||||
t.Fatal("should not get here")
|
||||
}
|
||||
|
||||
func TestFormatMessageType(t *testing.T) {
|
||||
str := FormatMessageType(TextMessage)
|
||||
if str != messageTypes[TextMessage] {
|
||||
t.Error("failed to format message type")
|
||||
}
|
||||
|
||||
str = FormatMessageType(CloseMessage)
|
||||
if str != messageTypes[CloseMessage] {
|
||||
t.Error("failed to format message type")
|
||||
}
|
||||
|
||||
str = FormatMessageType(123)
|
||||
if str != messageTypes[123] {
|
||||
t.Error("failed to format message type")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeNetClosedReader struct {
|
||||
}
|
||||
|
||||
func (r fakeNetClosedReader) Read([]byte) (int, error) {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
func TestConnectionClosed(t *testing.T) {
|
||||
var b1, b2 bytes.Buffer
|
||||
|
||||
client := newTestConn(fakeNetClosedReader{}, &b1, false)
|
||||
server := newTestConn(fakeNetClosedReader{}, &b2, true)
|
||||
|
||||
if _, _, err := server.NextReader(); !errors.Is(err, net.ErrClosed) {
|
||||
t.Fatalf("server expects a net.ErrClosed error, %v returned", err)
|
||||
}
|
||||
|
||||
if _, _, err := client.NextReader(); !errors.Is(err, net.ErrClosed) {
|
||||
t.Fatalf("client expects a net.ErrClosed error, %v returned", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,4 +42,4 @@ func processMessage(mt int, p []byte) {}
|
|||
|
||||
// TestX prevents godoc from showing this entire file in the example. Remove
|
||||
// this function when a second example is added.
|
||||
func TestX(t *testing.T) {}
|
||||
func TestX(t *testing.T) { t.Parallel() }
|
||||
|
|
|
@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// echoReadAll echoes messages from the client by reading the entire message
|
||||
// with ioutil.ReadAll.
|
||||
// with io.ReadAll.
|
||||
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
//go:build ignore
|
||||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var addr = flag.String("addr", "localhost:8080", "http service address")
|
||||
|
||||
func runNewConn(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
interrupt := make(chan os.Signal, 1)
|
||||
signal.Notify(interrupt, os.Interrupt)
|
||||
|
||||
u := url.URL{Scheme: "ws", Host: *addr, Path: "/ws"}
|
||||
log.Printf("connecting to %s", u.String())
|
||||
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
log.Fatal("dial:", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
_, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("read:", err)
|
||||
return
|
||||
}
|
||||
log.Printf("recv: %s", message)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(time.Minute * 5)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case t := <-ticker.C:
|
||||
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
|
||||
if err != nil {
|
||||
log.Println("write:", err)
|
||||
return
|
||||
}
|
||||
case <-interrupt:
|
||||
log.Println("interrupt")
|
||||
|
||||
// Cleanly close the connection by sending a close message and then
|
||||
// waiting (with timeout) for the server to close the connection.
|
||||
err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
if err != nil {
|
||||
log.Println("write close:", err)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
log.SetFlags(0)
|
||||
wg := &sync.WaitGroup{}
|
||||
for i := 0; i < 1000; i++ {
|
||||
wg.Add(1)
|
||||
go runNewConn(wg)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
//go:build ignore
|
||||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
_ "net/http/pprof"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var addr = flag.String("addr", "localhost:8080", "http service address")
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 256,
|
||||
WriteBufferSize: 256,
|
||||
WriteBufferPool: &sync.Pool{},
|
||||
}
|
||||
|
||||
func process(c *websocket.Conn) {
|
||||
defer c.Close()
|
||||
for {
|
||||
_, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("read:", err)
|
||||
break
|
||||
}
|
||||
log.Printf("recv: %s", message)
|
||||
}
|
||||
}
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Print("upgrade:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Process connection in a new goroutine
|
||||
go process(c)
|
||||
|
||||
// Let the http handler return, the 8k buffer created by it will be garbage collected
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
log.SetFlags(0)
|
||||
http.HandleFunc("/ws", handler)
|
||||
log.Fatal(http.ListenAndServe(*addr, nil))
|
||||
}
|
|
@ -38,7 +38,7 @@ sends them to the hub.
|
|||
### Hub
|
||||
|
||||
The code for the `Hub` type is in
|
||||
[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go).
|
||||
[hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go).
|
||||
The application's `main` function starts the hub's `run` method as a goroutine.
|
||||
Clients send requests to the hub using the `register`, `unregister` and
|
||||
`broadcast` channels.
|
||||
|
@ -57,7 +57,7 @@ unregisters the client and closes the websocket.
|
|||
|
||||
### Client
|
||||
|
||||
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go).
|
||||
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go).
|
||||
|
||||
The `serveWs` function is registered by the application's `main` function as
|
||||
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
|
||||
|
@ -85,7 +85,7 @@ network.
|
|||
|
||||
## Frontend
|
||||
|
||||
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html).
|
||||
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html).
|
||||
|
||||
On document load, the script checks for websocket functionality in the browser.
|
||||
If websocket functionality is available, then the script opens a connection to
|
||||
|
|
|
@ -57,8 +57,6 @@ func pumpStdin(ws *websocket.Conn, w io.Writer) {
|
|||
}
|
||||
|
||||
func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) {
|
||||
defer func() {
|
||||
}()
|
||||
s := bufio.NewScanner(r)
|
||||
for s.Scan() {
|
||||
ws.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
|
|
|
@ -41,12 +41,12 @@ func main() {
|
|||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
_, message, err := c.ReadMessage()
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("read:", err)
|
||||
return
|
||||
}
|
||||
log.Printf("recv: %s", message)
|
||||
log.Printf("recv: %s, type: %s", message, websocket.FormatMessageType(mt))
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
|
@ -33,7 +33,8 @@ func echo(w http.ResponseWriter, r *http.Request) {
|
|||
log.Println("read:", err)
|
||||
break
|
||||
}
|
||||
log.Printf("recv: %s", message)
|
||||
|
||||
log.Printf("recv: %s, type: %s", message, websocket.FormatMessageType(mt))
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
log.Println("write:", err)
|
||||
|
|
|
@ -7,10 +7,10 @@ package main
|
|||
import (
|
||||
"flag"
|
||||
"html/template"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
|
@ -49,7 +49,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
|
|||
if !fi.ModTime().After(lastMod) {
|
||||
return nil, lastMod, nil
|
||||
}
|
||||
p, err := ioutil.ReadFile(filename)
|
||||
p, err := os.ReadFile(filepath.Clean(filename))
|
||||
if err != nil {
|
||||
return nil, fi.ModTime(), err
|
||||
}
|
||||
|
|
4
go.mod
4
go.mod
|
@ -1,3 +1,5 @@
|
|||
module github.com/gorilla/websocket
|
||||
|
||||
go 1.12
|
||||
go 1.20
|
||||
|
||||
require golang.org/x/net v0.23.0
|
||||
|
|
2
go.sum
2
go.sum
|
@ -0,0 +1,2 @@
|
|||
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
|
||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
|
@ -12,6 +12,7 @@ import (
|
|||
)
|
||||
|
||||
func TestJoinMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"}
|
||||
for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} {
|
||||
for _, term := range []string{"", ","} {
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
)
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
var buf bytes.Buffer
|
||||
wc := newTestConn(nil, &buf, true)
|
||||
rc := newTestConn(&buf, nil, false)
|
||||
|
@ -38,6 +39,7 @@ func TestJSON(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPartialJSONRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
var buf0, buf1 bytes.Buffer
|
||||
wc := newTestConn(nil, &buf0, true)
|
||||
rc := newTestConn(&buf0, &buf1, false)
|
||||
|
@ -91,6 +93,7 @@ func TestPartialJSONRead(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDeprecatedJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
var buf bytes.Buffer
|
||||
wc := newTestConn(nil, &buf, true)
|
||||
rc := newTestConn(&buf, nil, false)
|
||||
|
|
|
@ -29,6 +29,7 @@ func notzero(b []byte) int {
|
|||
}
|
||||
|
||||
func TestMaskBytes(t *testing.T) {
|
||||
t.Parallel()
|
||||
key := [4]byte{1, 2, 3, 4}
|
||||
for size := 1; size <= 1024; size++ {
|
||||
for align := 0; align < wordSize; align++ {
|
||||
|
|
|
@ -33,6 +33,11 @@ var preparedMessageTests = []struct {
|
|||
}
|
||||
|
||||
func TestPreparedMessage(t *testing.T) {
|
||||
testRand := rand.New(rand.NewSource(99))
|
||||
prevMaskRand := maskRand
|
||||
maskRand = testRand
|
||||
defer func() { maskRand = prevMaskRand }()
|
||||
|
||||
for _, tt := range preparedMessageTests {
|
||||
var data = []byte("this is a test")
|
||||
var buf bytes.Buffer
|
||||
|
@ -43,7 +48,7 @@ func TestPreparedMessage(t *testing.T) {
|
|||
c.SetCompressionLevel(tt.compressionLevel)
|
||||
|
||||
// Seed random number generator for consistent frame mask.
|
||||
rand.Seed(1234)
|
||||
testRand.Seed(1234)
|
||||
|
||||
if err := c.WriteMessage(tt.messageType, data); err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -59,7 +64,7 @@ func TestPreparedMessage(t *testing.T) {
|
|||
copy(data, "hello world")
|
||||
|
||||
// Seed random number generator for consistent frame mask.
|
||||
rand.Seed(1234)
|
||||
testRand.Seed(1234)
|
||||
|
||||
buf.Reset()
|
||||
if err := c.WritePreparedMessage(pm); err != nil {
|
||||
|
|
6
proxy.go
6
proxy.go
|
@ -12,6 +12,8 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
type netDialerFunc func(network, addr string) (net.Conn, error)
|
||||
|
@ -21,7 +23,7 @@ func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
|||
}
|
||||
|
||||
func init() {
|
||||
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
||||
proxy.RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy.Dialer) (proxy.Dialer, error) {
|
||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
||||
})
|
||||
}
|
||||
|
@ -68,7 +70,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
conn.Close()
|
||||
f := strings.SplitN(resp.Status, " ", 2)
|
||||
return nil, errors.New(f[1])
|
||||
|
|
10
server.go
10
server.go
|
@ -33,6 +33,7 @@ type Upgrader struct {
|
|||
// size is zero, then buffers allocated by the HTTP server are used. The
|
||||
// I/O buffer sizes do not limit the size of the messages that can be sent
|
||||
// or received.
|
||||
// The default value is 4096 bytes, 4kb.
|
||||
ReadBufferSize, WriteBufferSize int
|
||||
|
||||
// WriteBufferPool is a pool of buffers for write operations. If the value
|
||||
|
@ -101,8 +102,8 @@ func checkSameOrigin(r *http.Request) bool {
|
|||
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
||||
if u.Subprotocols != nil {
|
||||
clientProtocols := Subprotocols(r)
|
||||
for _, serverProtocol := range u.Subprotocols {
|
||||
for _, clientProtocol := range clientProtocols {
|
||||
for _, serverProtocol := range u.Subprotocols {
|
||||
if clientProtocol == serverProtocol {
|
||||
return clientProtocol
|
||||
}
|
||||
|
@ -172,12 +173,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|||
}
|
||||
}
|
||||
|
||||
h, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
|
||||
}
|
||||
var brw *bufio.ReadWriter
|
||||
netConn, brw, err := h.Hijack()
|
||||
netConn, brw, err := http.NewResponseController(w).Hijack()
|
||||
if err != nil {
|
||||
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
|
|
@ -7,8 +7,10 @@ package websocket
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -27,6 +29,7 @@ var subprotocolTests = []struct {
|
|||
}
|
||||
|
||||
func TestSubprotocols(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, st := range subprotocolTests {
|
||||
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}}
|
||||
protocols := Subprotocols(&r)
|
||||
|
@ -46,6 +49,7 @@ var isWebSocketUpgradeTests = []struct {
|
|||
}
|
||||
|
||||
func TestIsWebSocketUpgrade(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tt := range isWebSocketUpgradeTests {
|
||||
ok := IsWebSocketUpgrade(&http.Request{Header: tt.h})
|
||||
if tt.ok != ok {
|
||||
|
@ -54,6 +58,37 @@ func TestIsWebSocketUpgrade(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSubProtocolSelection(t *testing.T) {
|
||||
t.Parallel()
|
||||
upgrader := Upgrader{
|
||||
Subprotocols: []string{"foo", "bar", "baz"},
|
||||
}
|
||||
|
||||
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
|
||||
s := upgrader.selectSubprotocol(&r, nil)
|
||||
if s != "foo" {
|
||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
|
||||
}
|
||||
|
||||
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
|
||||
s = upgrader.selectSubprotocol(&r, nil)
|
||||
if s != "bar" {
|
||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
|
||||
}
|
||||
|
||||
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
|
||||
s = upgrader.selectSubprotocol(&r, nil)
|
||||
if s != "baz" {
|
||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
|
||||
}
|
||||
|
||||
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
|
||||
s = upgrader.selectSubprotocol(&r, nil)
|
||||
if s != "" {
|
||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
|
||||
}
|
||||
}
|
||||
|
||||
var checkSameOriginTests = []struct {
|
||||
ok bool
|
||||
r *http.Request
|
||||
|
@ -64,6 +99,7 @@ var checkSameOriginTests = []struct {
|
|||
}
|
||||
|
||||
func TestCheckSameOrigin(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tt := range checkSameOriginTests {
|
||||
ok := checkSameOrigin(tt.r)
|
||||
if tt.ok != ok {
|
||||
|
@ -90,6 +126,7 @@ var bufioReuseTests = []struct {
|
|||
}
|
||||
|
||||
func TestBufioReuse(t *testing.T) {
|
||||
t.Parallel()
|
||||
for i, tt := range bufioReuseTests {
|
||||
br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
|
||||
bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
|
||||
|
@ -117,3 +154,23 @@ func TestBufioReuse(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHijack_NotSupported(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "upgrade")
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
req.Header.Set("Sec-Websocket-Version", "13")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
upgrader := Upgrader{}
|
||||
_, err := upgrader.Upgrade(recorder, req, nil)
|
||||
|
||||
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
|
||||
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
|
||||
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
//go:build go1.17
|
||||
// +build go1.17
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
|
|
|
@ -21,6 +21,7 @@ var equalASCIIFoldTests = []struct {
|
|||
}
|
||||
|
||||
func TestEqualASCIIFold(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tt := range equalASCIIFoldTests {
|
||||
eq := equalASCIIFold(tt.s, tt.t)
|
||||
if eq != tt.eq {
|
||||
|
@ -44,6 +45,7 @@ var tokenListContainsValueTests = []struct {
|
|||
}
|
||||
|
||||
func TestTokenListContainsValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tt := range tokenListContainsValueTests {
|
||||
h := http.Header{"Upgrade": {tt.value}}
|
||||
ok := tokenListContainsValue(h, "Upgrade", "websocket")
|
||||
|
@ -64,6 +66,7 @@ var isValidChallengeKeyTests = []struct {
|
|||
}
|
||||
|
||||
func TestIsValidChallengeKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tt := range isValidChallengeKeyTests {
|
||||
ok := isValidChallengeKey(tt.key)
|
||||
if ok != tt.ok {
|
||||
|
@ -105,6 +108,7 @@ var parseExtensionTests = []struct {
|
|||
}
|
||||
|
||||
func TestParseExtensions(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tt := range parseExtensionTests {
|
||||
h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}}
|
||||
extensions := parseExtensions(h)
|
||||
|
|
Loading…
Reference in New Issue