mirror of https://github.com/gorilla/websocket.git
Compare commits
2 Commits
7785671495
...
71e79a20b9
Author | SHA1 | Date |
---|---|---|
好为 | 71e79a20b9 | |
apoorvajagtap | ce903f6d1d |
|
@ -0,0 +1,70 @@
|
||||||
|
version: 2.1
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
"test":
|
||||||
|
parameters:
|
||||||
|
version:
|
||||||
|
type: string
|
||||||
|
default: "latest"
|
||||||
|
golint:
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
modules:
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
goproxy:
|
||||||
|
type: string
|
||||||
|
default: ""
|
||||||
|
docker:
|
||||||
|
- image: "cimg/go:<< parameters.version >>"
|
||||||
|
working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket
|
||||||
|
environment:
|
||||||
|
GO111MODULE: "on"
|
||||||
|
GOPROXY: "<< parameters.goproxy >>"
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: "Print the Go version"
|
||||||
|
command: >
|
||||||
|
go version
|
||||||
|
- run:
|
||||||
|
name: "Fetch dependencies"
|
||||||
|
command: >
|
||||||
|
if [[ << parameters.modules >> = true ]]; then
|
||||||
|
go mod download
|
||||||
|
export GO111MODULE=on
|
||||||
|
else
|
||||||
|
go get -v ./...
|
||||||
|
fi
|
||||||
|
# Only run gofmt, vet & lint against the latest Go version
|
||||||
|
- run:
|
||||||
|
name: "Run golint"
|
||||||
|
command: >
|
||||||
|
if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
|
||||||
|
go get -u golang.org/x/lint/golint
|
||||||
|
golint ./...
|
||||||
|
fi
|
||||||
|
- run:
|
||||||
|
name: "Run gofmt"
|
||||||
|
command: >
|
||||||
|
if [[ << parameters.version >> = "latest" ]]; then
|
||||||
|
diff -u <(echo -n) <(gofmt -d -e .)
|
||||||
|
fi
|
||||||
|
- run:
|
||||||
|
name: "Run go vet"
|
||||||
|
command: >
|
||||||
|
if [[ << parameters.version >> = "latest" ]]; then
|
||||||
|
go vet -v ./...
|
||||||
|
fi
|
||||||
|
- run:
|
||||||
|
name: "Run go test (+ race detector)"
|
||||||
|
command: >
|
||||||
|
go test -v -race ./...
|
||||||
|
|
||||||
|
workflows:
|
||||||
|
tests:
|
||||||
|
jobs:
|
||||||
|
- test:
|
||||||
|
matrix:
|
||||||
|
parameters:
|
||||||
|
version: ["1.18", "1.17", "1.16"]
|
|
@ -1,20 +0,0 @@
|
||||||
; https://editorconfig.org/
|
|
||||||
|
|
||||||
root = true
|
|
||||||
|
|
||||||
[*]
|
|
||||||
insert_final_newline = true
|
|
||||||
charset = utf-8
|
|
||||||
trim_trailing_whitespace = true
|
|
||||||
indent_style = space
|
|
||||||
indent_size = 2
|
|
||||||
|
|
||||||
[{Makefile,go.mod,go.sum,*.go,.gitmodules}]
|
|
||||||
indent_style = tab
|
|
||||||
indent_size = 4
|
|
||||||
|
|
||||||
[*.md]
|
|
||||||
indent_size = 4
|
|
||||||
trim_trailing_whitespace = false
|
|
||||||
|
|
||||||
eclint_indent_style = unset
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
# Config for https://github.com/apps/release-drafter
|
||||||
|
template: |
|
||||||
|
|
||||||
|
<summary of changes here>
|
||||||
|
|
||||||
|
## CHANGELOG
|
||||||
|
$CHANGES
|
|
@ -1,21 +0,0 @@
|
||||||
# Add all the issues created to the project.
|
|
||||||
name: Add issue or pull request to Project
|
|
||||||
|
|
||||||
on:
|
|
||||||
issues:
|
|
||||||
types:
|
|
||||||
- opened
|
|
||||||
pull_request_target:
|
|
||||||
types:
|
|
||||||
- opened
|
|
||||||
- reopened
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
add-to-project:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Add issue to project
|
|
||||||
uses: actions/add-to-project@v0.5.0
|
|
||||||
with:
|
|
||||||
project-url: https://github.com/orgs/gorilla/projects/4
|
|
||||||
github-token: ${{ secrets.ADD_TO_PROJECT_TOKEN }}
|
|
|
@ -1,37 +0,0 @@
|
||||||
name: Security
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
jobs:
|
|
||||||
scan:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
go: ['1.20','1.21']
|
|
||||||
fail-fast: true
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout Code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Setup Go ${{ matrix.go }}
|
|
||||||
uses: actions/setup-go@v4
|
|
||||||
with:
|
|
||||||
go-version: ${{ matrix.go }}
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Run GoSec
|
|
||||||
uses: securego/gosec@master
|
|
||||||
with:
|
|
||||||
args: -exclude-dir examples ./...
|
|
||||||
|
|
||||||
- name: Run GoVulnCheck
|
|
||||||
uses: golang/govulncheck-action@v1
|
|
||||||
with:
|
|
||||||
go-version-input: ${{ matrix.go }}
|
|
||||||
go-package: ./...
|
|
|
@ -1,35 +0,0 @@
|
||||||
name: Test
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
jobs:
|
|
||||||
unit:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
go: ['1.20','1.21']
|
|
||||||
os: [ubuntu-latest, macos-latest, windows-latest]
|
|
||||||
fail-fast: true
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
steps:
|
|
||||||
- name: Checkout Code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Setup Go ${{ matrix.go }}
|
|
||||||
uses: actions/setup-go@v4
|
|
||||||
with:
|
|
||||||
go-version: ${{ matrix.go }}
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Run Tests
|
|
||||||
run: go test -race -cover -coverprofile=coverage -covermode=atomic -v ./...
|
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
|
||||||
uses: codecov/codecov-action@v3
|
|
||||||
with:
|
|
||||||
files: ./coverage
|
|
|
@ -1,31 +0,0 @@
|
||||||
name: Verify
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
jobs:
|
|
||||||
lint:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
go: ['1.20','1.21']
|
|
||||||
fail-fast: true
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout Code
|
|
||||||
uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Setup Go ${{ matrix.go }}
|
|
||||||
uses: actions/setup-go@v4
|
|
||||||
with:
|
|
||||||
go-version: ${{ matrix.go }}
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Run GolangCI-Lint
|
|
||||||
uses: golangci/golangci-lint-action@v3
|
|
||||||
with:
|
|
||||||
version: v1.53
|
|
|
@ -1 +1,25 @@
|
||||||
coverage.coverprofile
|
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||||
|
*.o
|
||||||
|
*.a
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Folders
|
||||||
|
_obj
|
||||||
|
_test
|
||||||
|
|
||||||
|
# Architecture specific extensions/prefixes
|
||||||
|
*.[568vq]
|
||||||
|
[568vq].out
|
||||||
|
|
||||||
|
*.cgo1.go
|
||||||
|
*.cgo2.c
|
||||||
|
_cgo_defun.c
|
||||||
|
_cgo_gotypes.go
|
||||||
|
_cgo_export.*
|
||||||
|
|
||||||
|
_testmain.go
|
||||||
|
|
||||||
|
*.exe
|
||||||
|
|
||||||
|
.idea/
|
||||||
|
*.iml
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
run:
|
|
||||||
timeout: "5m"
|
|
||||||
# will not run golangci-lint against *_test.go
|
|
||||||
tests: false
|
|
||||||
issues:
|
|
||||||
exclude-dirs:
|
|
||||||
- examples/*.go
|
|
||||||
exclude-rules:
|
|
||||||
# excluding error checks from all the .go files
|
|
||||||
- path: ./*.go
|
|
||||||
linters:
|
|
||||||
- errcheck
|
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
# This is the official list of Gorilla WebSocket authors for copyright
|
||||||
|
# purposes.
|
||||||
|
#
|
||||||
|
# Please keep the list sorted.
|
||||||
|
|
||||||
|
Gary Burd <gary@beagledreams.com>
|
||||||
|
Google LLC (https://opensource.google.com/)
|
||||||
|
Joachim Bauch <mail@joachim-bauch.de>
|
||||||
|
|
39
LICENSE
39
LICENSE
|
@ -1,27 +1,22 @@
|
||||||
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
|
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
Redistribution and use in source and binary forms, with or without
|
||||||
modification, are permitted provided that the following conditions are
|
modification, are permitted provided that the following conditions are met:
|
||||||
met:
|
|
||||||
|
|
||||||
* Redistributions of source code must retain the above copyright
|
Redistributions of source code must retain the above copyright notice, this
|
||||||
notice, this list of conditions and the following disclaimer.
|
list of conditions and the following disclaimer.
|
||||||
* Redistributions in binary form must reproduce the above
|
|
||||||
copyright notice, this list of conditions and the following disclaimer
|
|
||||||
in the documentation and/or other materials provided with the
|
|
||||||
distribution.
|
|
||||||
* Neither the name of Google Inc. nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived from
|
|
||||||
this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
Redistributions in binary form must reproduce the above copyright notice,
|
||||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
this list of conditions and the following disclaimer in the documentation
|
||||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
and/or other materials provided with the distribution.
|
||||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
||||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
34
Makefile
34
Makefile
|
@ -1,34 +0,0 @@
|
||||||
GO_LINT=$(shell which golangci-lint 2> /dev/null || echo '')
|
|
||||||
GO_LINT_URI=github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
|
||||||
|
|
||||||
GO_SEC=$(shell which gosec 2> /dev/null || echo '')
|
|
||||||
GO_SEC_URI=github.com/securego/gosec/v2/cmd/gosec@latest
|
|
||||||
|
|
||||||
GO_VULNCHECK=$(shell which govulncheck 2> /dev/null || echo '')
|
|
||||||
GO_VULNCHECK_URI=golang.org/x/vuln/cmd/govulncheck@latest
|
|
||||||
|
|
||||||
.PHONY: golangci-lint
|
|
||||||
golangci-lint:
|
|
||||||
$(if $(GO_LINT), ,go install $(GO_LINT_URI))
|
|
||||||
@echo "##### Running golangci-lint"
|
|
||||||
golangci-lint run -v
|
|
||||||
|
|
||||||
.PHONY: gosec
|
|
||||||
gosec:
|
|
||||||
$(if $(GO_SEC), ,go install $(GO_SEC_URI))
|
|
||||||
@echo "##### Running gosec"
|
|
||||||
gosec -exclude-dir examples ./...
|
|
||||||
|
|
||||||
.PHONY: govulncheck
|
|
||||||
govulncheck:
|
|
||||||
$(if $(GO_VULNCHECK), ,go install $(GO_VULNCHECK_URI))
|
|
||||||
@echo "##### Running govulncheck"
|
|
||||||
govulncheck ./...
|
|
||||||
|
|
||||||
.PHONY: verify
|
|
||||||
verify: golangci-lint gosec govulncheck
|
|
||||||
|
|
||||||
.PHONY: test
|
|
||||||
test:
|
|
||||||
@echo "##### Running tests"
|
|
||||||
go test -race -cover -coverprofile=coverage.coverprofile -covermode=atomic -v ./...
|
|
25
README.md
25
README.md
|
@ -1,23 +1,19 @@
|
||||||
# gorilla/websocket
|
# Gorilla WebSocket
|
||||||
|
|
||||||
![testing](https://github.com/gorilla/websocket/actions/workflows/test.yml/badge.svg)
|
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
|
||||||
[![codecov](https://codecov.io/github/gorilla/websocket/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/websocket)
|
[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket)
|
||||||
[![godoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
|
|
||||||
[![sourcegraph](https://sourcegraph.com/github.com/gorilla/websocket/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/websocket?badge)
|
|
||||||
|
|
||||||
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
|
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
|
||||||
|
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
|
||||||
![Gorilla Logo](https://github.com/gorilla/.github/assets/53367916/d92caabf-98e0-473e-bfbf-ab554ba435e5)
|
|
||||||
|
|
||||||
|
|
||||||
### Documentation
|
### Documentation
|
||||||
|
|
||||||
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
|
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
|
||||||
* [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat)
|
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
|
||||||
* [Command example](https://github.com/gorilla/websocket/tree/main/examples/command)
|
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
|
||||||
* [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo)
|
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
|
||||||
* [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch)
|
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
|
||||||
* [Write buffer pool example](https://github.com/gorilla/websocket/tree/main/examples/bufferpool)
|
|
||||||
|
|
||||||
### Status
|
### Status
|
||||||
|
|
||||||
|
@ -33,4 +29,5 @@ package API is stable.
|
||||||
|
|
||||||
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
||||||
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
||||||
subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn).
|
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
|
||||||
|
|
||||||
|
|
13
client.go
13
client.go
|
@ -11,14 +11,13 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrBadHandshake is returned when the server response to opening handshake is
|
// ErrBadHandshake is returned when the server response to opening handshake is
|
||||||
|
@ -309,7 +308,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if proxyURL != nil {
|
if proxyURL != nil {
|
||||||
dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
|
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -400,7 +399,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusSwitchingProtocols ||
|
if resp.StatusCode != 101 ||
|
||||||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
|
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
|
||||||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
|
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
|
||||||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
|
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
|
||||||
|
@ -409,7 +408,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
// debugging.
|
// debugging.
|
||||||
buf := make([]byte, 1024)
|
buf := make([]byte, 1024)
|
||||||
n, _ := io.ReadFull(resp.Body, buf)
|
n, _ := io.ReadFull(resp.Body, buf)
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
|
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
||||||
return nil, resp, ErrBadHandshake
|
return nil, resp, ErrBadHandshake
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -427,7 +426,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||||
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||||
|
|
||||||
netConn.SetDeadline(time.Time{})
|
netConn.SetDeadline(time.Time{})
|
||||||
|
@ -437,7 +436,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
|
|
||||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return &tls.Config{MinVersion: tls.VersionTLS12}
|
return &tls.Config{}
|
||||||
}
|
}
|
||||||
return cfg.Clone()
|
return cfg.Clone()
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -148,7 +149,6 @@ func sendRecv(t *testing.T, ws *Conn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyDial(t *testing.T) {
|
func TestProxyDial(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
@ -187,7 +187,6 @@ func TestProxyDial(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyAuthorizationDial(t *testing.T) {
|
func TestProxyAuthorizationDial(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -228,7 +227,6 @@ func TestProxyAuthorizationDial(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDial(t *testing.T) {
|
func TestDial(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -241,7 +239,6 @@ func TestDial(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialCookieJar(t *testing.T) {
|
func TestDialCookieJar(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -304,7 +301,6 @@ func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialTLS(t *testing.T) {
|
func TestDialTLS(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newTLSServer(t)
|
s := newTLSServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -319,7 +315,6 @@ func TestDialTLS(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialTimeout(t *testing.T) {
|
func TestDialTimeout(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -376,7 +371,6 @@ func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr()
|
||||||
func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
|
func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
|
||||||
|
|
||||||
func TestHandshakeTimeout(t *testing.T) {
|
func TestHandshakeTimeout(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -393,7 +387,6 @@ func TestHandshakeTimeout(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandshakeTimeoutInContext(t *testing.T) {
|
func TestHandshakeTimeoutInContext(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -415,7 +408,6 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialBadScheme(t *testing.T) {
|
func TestDialBadScheme(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -427,7 +419,6 @@ func TestDialBadScheme(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialBadOrigin(t *testing.T) {
|
func TestDialBadOrigin(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -445,7 +436,6 @@ func TestDialBadOrigin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialBadHeader(t *testing.T) {
|
func TestDialBadHeader(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -465,7 +455,6 @@ func TestDialBadHeader(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBadMethod(t *testing.T) {
|
func TestBadMethod(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ws, err := cstUpgrader.Upgrade(w, r, nil)
|
ws, err := cstUpgrader.Upgrade(w, r, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -494,7 +483,6 @@ func TestBadMethod(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialExtraTokensInRespHeaders(t *testing.T) {
|
func TestDialExtraTokensInRespHeaders(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||||||
w.Header().Set("Upgrade", "foo, websocket")
|
w.Header().Set("Upgrade", "foo, websocket")
|
||||||
|
@ -512,7 +500,6 @@ func TestDialExtraTokensInRespHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandshake(t *testing.T) {
|
func TestHandshake(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -539,7 +526,6 @@ func TestHandshake(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRespOnBadHandshake(t *testing.T) {
|
func TestRespOnBadHandshake(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
const expectedStatus = http.StatusGone
|
const expectedStatus = http.StatusGone
|
||||||
const expectedBody = "This is the response body."
|
const expectedBody = "This is the response body."
|
||||||
|
|
||||||
|
@ -563,7 +549,7 @@ func TestRespOnBadHandshake(t *testing.T) {
|
||||||
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := io.ReadAll(resp.Body)
|
p, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
||||||
}
|
}
|
||||||
|
@ -584,7 +570,6 @@ func (w testLogWriter) Write(p []byte) (int, error) {
|
||||||
|
|
||||||
// TestHost tests handling of host names and confirms that it matches net/http.
|
// TestHost tests handling of host names and confirms that it matches net/http.
|
||||||
func TestHost(t *testing.T) {
|
func TestHost(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
upgrader := Upgrader{}
|
upgrader := Upgrader{}
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -766,7 +751,6 @@ func TestHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDialCompression(t *testing.T) {
|
func TestDialCompression(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -781,7 +765,6 @@ func TestDialCompression(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSocksProxyDial(t *testing.T) {
|
func TestSocksProxyDial(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -861,7 +844,6 @@ func TestSocksProxyDial(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTracingDialWithContext(t *testing.T) {
|
func TestTracingDialWithContext(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
|
var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
|
||||||
trace := &httptrace.ClientTrace{
|
trace := &httptrace.ClientTrace{
|
||||||
|
@ -921,7 +903,6 @@ func TestTracingDialWithContext(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEmptyTracingDialWithContext(t *testing.T) {
|
func TestEmptyTracingDialWithContext(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
trace := &httptrace.ClientTrace{}
|
trace := &httptrace.ClientTrace{}
|
||||||
ctx := httptrace.WithClientTrace(context.Background(), trace)
|
ctx := httptrace.WithClientTrace(context.Background(), trace)
|
||||||
|
@ -943,7 +924,6 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
|
||||||
|
|
||||||
// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
|
// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
|
||||||
func TestNetDialConnect(t *testing.T) {
|
func TestNetDialConnect(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
upgrader := Upgrader{}
|
upgrader := Upgrader{}
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -1119,7 +1099,6 @@ func TestNetDialConnect(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func TestNextProtos(t *testing.T) {
|
func TestNextProtos(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
ts := httptest.NewUnstartedServer(
|
ts := httptest.NewUnstartedServer(
|
||||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,7 +20,6 @@ var hostPortNoPortTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHostPortNoPort(t *testing.T) {
|
func TestHostPortNoPort(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, tt := range hostPortNoPortTests {
|
for _, tt := range hostPortNoPortTests {
|
||||||
hostPort, hostNoPort := hostPortNoPort(tt.u)
|
hostPort, hostNoPort := hostPortNoPort(tt.u)
|
||||||
if hostPort != tt.hostPort {
|
if hostPort != tt.hostPort {
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,7 +13,6 @@ type nopCloser struct{ io.Writer }
|
||||||
func (nopCloser) Close() error { return nil }
|
func (nopCloser) Close() error { return nil }
|
||||||
|
|
||||||
func TestTruncWriter(t *testing.T) {
|
func TestTruncWriter(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
|
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
|
||||||
for n := 1; n <= 10; n++ {
|
for n := 1; n <= 10; n++ {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
|
@ -42,7 +42,7 @@ func textMessages(num int) [][]byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWriteNoCompression(b *testing.B) {
|
func BenchmarkWriteNoCompression(b *testing.B) {
|
||||||
w := io.Discard
|
w := ioutil.Discard
|
||||||
c := newTestConn(nil, w, false)
|
c := newTestConn(nil, w, false)
|
||||||
messages := textMessages(100)
|
messages := textMessages(100)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -53,7 +53,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWriteWithCompression(b *testing.B) {
|
func BenchmarkWriteWithCompression(b *testing.B) {
|
||||||
w := io.Discard
|
w := ioutil.Discard
|
||||||
c := newTestConn(nil, w, false)
|
c := newTestConn(nil, w, false)
|
||||||
messages := textMessages(100)
|
messages := textMessages(100)
|
||||||
c.enableWriteCompression = true
|
c.enableWriteCompression = true
|
||||||
|
@ -66,7 +66,6 @@ func BenchmarkWriteWithCompression(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidCompressionLevel(t *testing.T) {
|
func TestValidCompressionLevel(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
c := newTestConn(nil, nil, false)
|
c := newTestConn(nil, nil, false)
|
||||||
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
|
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
|
||||||
if err := c.SetCompressionLevel(level); err == nil {
|
if err := c.SetCompressionLevel(level); err == nil {
|
||||||
|
|
72
conn.go
72
conn.go
|
@ -6,10 +6,11 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/rand"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -180,16 +181,16 @@ var (
|
||||||
errInvalidControlFrame = errors.New("websocket: invalid control frame")
|
errInvalidControlFrame = errors.New("websocket: invalid control frame")
|
||||||
)
|
)
|
||||||
|
|
||||||
// maskRand is an io.Reader for generating mask bytes. The reader is initialized
|
|
||||||
// to crypto/rand Reader. Tests swap the reader to a math/rand reader for
|
|
||||||
// reproducible results.
|
|
||||||
var maskRand = rand.Reader
|
|
||||||
|
|
||||||
// newMaskKey returns a new 32 bit value for masking client frames.
|
|
||||||
func newMaskKey() [4]byte {
|
func newMaskKey() [4]byte {
|
||||||
var k [4]byte
|
n := rand.Uint32()
|
||||||
_, _ = io.ReadFull(maskRand, k[:])
|
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
|
||||||
return k
|
}
|
||||||
|
|
||||||
|
func hideTempErr(err error) error {
|
||||||
|
if e, ok := err.(net.Error); ok && e.Temporary() {
|
||||||
|
err = &netError{msg: e.Error(), timeout: e.Timeout()}
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func isControl(frameType int) bool {
|
func isControl(frameType int) bool {
|
||||||
|
@ -357,6 +358,7 @@ func (c *Conn) RemoteAddr() net.Addr {
|
||||||
// Write methods
|
// Write methods
|
||||||
|
|
||||||
func (c *Conn) writeFatal(err error) error {
|
func (c *Conn) writeFatal(err error) error {
|
||||||
|
err = hideTempErr(err)
|
||||||
c.writeErrMu.Lock()
|
c.writeErrMu.Lock()
|
||||||
if c.writeErr == nil {
|
if c.writeErr == nil {
|
||||||
c.writeErr = err
|
c.writeErr = err
|
||||||
|
@ -434,27 +436,21 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
||||||
maskBytes(key, 0, buf[6:])
|
maskBytes(key, 0, buf[6:])
|
||||||
}
|
}
|
||||||
|
|
||||||
if deadline.IsZero() {
|
d := 1000 * time.Hour
|
||||||
// No timeout for zero time.
|
if !deadline.IsZero() {
|
||||||
<-c.mu
|
d = deadline.Sub(time.Now())
|
||||||
} else {
|
|
||||||
d := time.Until(deadline)
|
|
||||||
if d < 0 {
|
if d < 0 {
|
||||||
return errWriteTimeout
|
return errWriteTimeout
|
||||||
}
|
}
|
||||||
select {
|
|
||||||
case <-c.mu:
|
|
||||||
default:
|
|
||||||
timer := time.NewTimer(d)
|
|
||||||
select {
|
|
||||||
case <-c.mu:
|
|
||||||
timer.Stop()
|
|
||||||
case <-timer.C:
|
|
||||||
return errWriteTimeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
timer := time.NewTimer(d)
|
||||||
|
select {
|
||||||
|
case <-c.mu:
|
||||||
|
timer.Stop()
|
||||||
|
case <-timer.C:
|
||||||
|
return errWriteTimeout
|
||||||
|
}
|
||||||
defer func() { c.mu <- struct{}{} }()
|
defer func() { c.mu <- struct{}{} }()
|
||||||
|
|
||||||
c.writeErrMu.Lock()
|
c.writeErrMu.Lock()
|
||||||
|
@ -799,7 +795,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
// 1. Skip remainder of previous frame.
|
// 1. Skip remainder of previous frame.
|
||||||
|
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
|
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1012,7 +1008,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
for c.readErr == nil {
|
for c.readErr == nil {
|
||||||
frameType, err := c.advanceFrame()
|
frameType, err := c.advanceFrame()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.readErr = err
|
c.readErr = hideTempErr(err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1052,7 +1048,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
b = b[:c.readRemaining]
|
b = b[:c.readRemaining]
|
||||||
}
|
}
|
||||||
n, err := c.br.Read(b)
|
n, err := c.br.Read(b)
|
||||||
c.readErr = err
|
c.readErr = hideTempErr(err)
|
||||||
if c.isServer {
|
if c.isServer {
|
||||||
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
|
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
|
||||||
}
|
}
|
||||||
|
@ -1073,7 +1069,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
frameType, err := c.advanceFrame()
|
frameType, err := c.advanceFrame()
|
||||||
switch {
|
switch {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
c.readErr = err
|
c.readErr = hideTempErr(err)
|
||||||
case frameType == TextMessage || frameType == BinaryMessage:
|
case frameType == TextMessage || frameType == BinaryMessage:
|
||||||
c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
||||||
}
|
}
|
||||||
|
@ -1098,7 +1094,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return messageType, nil, err
|
return messageType, nil, err
|
||||||
}
|
}
|
||||||
p, err = io.ReadAll(r)
|
p, err = ioutil.ReadAll(r)
|
||||||
return messageType, p, err
|
return messageType, p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1165,7 +1161,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {
|
||||||
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
|
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
|
||||||
if err == ErrCloseSent {
|
if err == ErrCloseSent {
|
||||||
return nil
|
return nil
|
||||||
} else if _, ok := err.(net.Error); ok {
|
} else if e, ok := err.(net.Error); ok && e.Temporary() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
@ -1240,15 +1236,3 @@ func FormatCloseMessage(closeCode int, text string) []byte {
|
||||||
copy(buf[2:], text)
|
copy(buf[2:], text)
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
var messageTypes = map[int]string{
|
|
||||||
TextMessage: "TextMessage",
|
|
||||||
BinaryMessage: "BinaryMessage",
|
|
||||||
CloseMessage: "CloseMessage",
|
|
||||||
PingMessage: "PingMessage",
|
|
||||||
PongMessage: "PongMessage",
|
|
||||||
}
|
|
||||||
|
|
||||||
func FormatMessageType(mt int) string {
|
|
||||||
return messageTypes[mt]
|
|
||||||
}
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -44,7 +45,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
|
||||||
|
|
||||||
func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
|
func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
|
||||||
bench := &broadcastBench{
|
bench := &broadcastBench{
|
||||||
w: io.Discard,
|
w: ioutil.Discard,
|
||||||
doneCh: make(chan struct{}),
|
doneCh: make(chan struct{}),
|
||||||
closeCh: make(chan struct{}),
|
closeCh: make(chan struct{}),
|
||||||
usePrepared: usePrepared,
|
usePrepared: usePrepared,
|
||||||
|
|
130
conn_test.go
130
conn_test.go
|
@ -10,6 +10,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -54,7 +55,6 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFraming(t *testing.T) {
|
func TestFraming(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
frameSizes := []int{
|
frameSizes := []int{
|
||||||
0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
|
0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
|
||||||
// 65536, 65537
|
// 65536, 65537
|
||||||
|
@ -125,7 +125,7 @@ func TestFraming(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("frame size: %d", n)
|
t.Logf("frame size: %d", n)
|
||||||
rbuf, err := io.ReadAll(r)
|
rbuf, err := ioutil.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
||||||
continue
|
continue
|
||||||
|
@ -149,49 +149,7 @@ func TestFraming(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteControlDeadline(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
message := []byte("hello")
|
|
||||||
var connBuf bytes.Buffer
|
|
||||||
c := newTestConn(nil, &connBuf, true)
|
|
||||||
if err := c.WriteControl(PongMessage, message, time.Time{}); err != nil {
|
|
||||||
t.Errorf("WriteControl(..., zero deadline) = %v, want nil", err)
|
|
||||||
}
|
|
||||||
if err := c.WriteControl(PongMessage, message, time.Now().Add(time.Second)); err != nil {
|
|
||||||
t.Errorf("WriteControl(..., future deadline) = %v, want nil", err)
|
|
||||||
}
|
|
||||||
if err := c.WriteControl(PongMessage, message, time.Now().Add(-time.Second)); err == nil {
|
|
||||||
t.Errorf("WriteControl(..., past deadline) = nil, want timeout error")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrencyWriteControl(t *testing.T) {
|
|
||||||
const message = "this is a ping/pong messsage"
|
|
||||||
loop := 10
|
|
||||||
workers := 10
|
|
||||||
for i := 0; i < loop; i++ {
|
|
||||||
var connBuf bytes.Buffer
|
|
||||||
|
|
||||||
wg := sync.WaitGroup{}
|
|
||||||
wc := newTestConn(nil, &connBuf, true)
|
|
||||||
|
|
||||||
for i := 0; i < workers; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
|
|
||||||
t.Errorf("concurrently wc.WriteControl() returned %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
wc.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestControl(t *testing.T) {
|
func TestControl(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
const message = "this is a ping/pong messsage"
|
const message = "this is a ping/pong messsage"
|
||||||
for _, isServer := range []bool{true, false} {
|
for _, isServer := range []bool{true, false} {
|
||||||
for _, isWriteControl := range []bool{true, false} {
|
for _, isWriteControl := range []bool{true, false} {
|
||||||
|
@ -243,7 +201,6 @@ func (p *simpleBufferPool) Put(v interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteBufferPool(t *testing.T) {
|
func TestWriteBufferPool(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
const message = "Now is the time for all good people to come to the aid of the party."
|
const message = "Now is the time for all good people to come to the aid of the party."
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
@ -322,7 +279,6 @@ func TestWriteBufferPool(t *testing.T) {
|
||||||
|
|
||||||
// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
|
// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
|
||||||
func TestWriteBufferPoolSync(t *testing.T) {
|
func TestWriteBufferPoolSync(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
var pool sync.Pool
|
var pool sync.Pool
|
||||||
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
|
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
|
||||||
|
@ -351,7 +307,6 @@ func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error
|
||||||
// TestWriteBufferPoolError ensures that buffer is returned to pool after error
|
// TestWriteBufferPoolError ensures that buffer is returned to pool after error
|
||||||
// on write.
|
// on write.
|
||||||
func TestWriteBufferPoolError(t *testing.T) {
|
func TestWriteBufferPoolError(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// Part 1: Test NextWriter/Write/Close
|
// Part 1: Test NextWriter/Write/Close
|
||||||
|
|
||||||
|
@ -395,7 +350,6 @@ func TestWriteBufferPoolError(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
const bufSize = 512
|
const bufSize = 512
|
||||||
|
|
||||||
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
|
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
|
||||||
|
@ -413,7 +367,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(io.Discard, r)
|
_, err = io.Copy(ioutil.Discard, r)
|
||||||
if !reflect.DeepEqual(err, expectedErr) {
|
if !reflect.DeepEqual(err, expectedErr) {
|
||||||
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
|
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
|
||||||
}
|
}
|
||||||
|
@ -424,7 +378,6 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEOFWithinFrame(t *testing.T) {
|
func TestEOFWithinFrame(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
const bufSize = 64
|
const bufSize = 64
|
||||||
|
|
||||||
for n := 0; ; n++ {
|
for n := 0; ; n++ {
|
||||||
|
@ -448,7 +401,7 @@ func TestEOFWithinFrame(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(io.Discard, r)
|
_, err = io.Copy(ioutil.Discard, r)
|
||||||
if err != errUnexpectedEOF {
|
if err != errUnexpectedEOF {
|
||||||
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||||
}
|
}
|
||||||
|
@ -460,7 +413,6 @@ func TestEOFWithinFrame(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEOFBeforeFinalFrame(t *testing.T) {
|
func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
const bufSize = 512
|
const bufSize = 512
|
||||||
|
|
||||||
var b1, b2 bytes.Buffer
|
var b1, b2 bytes.Buffer
|
||||||
|
@ -474,7 +426,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(io.Discard, r)
|
_, err = io.Copy(ioutil.Discard, r)
|
||||||
if err != errUnexpectedEOF {
|
if err != errUnexpectedEOF {
|
||||||
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
||||||
}
|
}
|
||||||
|
@ -485,7 +437,6 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
io.WriteString(w, "hello")
|
io.WriteString(w, "hello")
|
||||||
|
@ -511,29 +462,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteHandlerDoesNotReturnErrCloseSent(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
var b1, b2 bytes.Buffer
|
|
||||||
|
|
||||||
client := newTestConn(&b2, &b1, false)
|
|
||||||
server := newTestConn(&b1, &b2, true)
|
|
||||||
|
|
||||||
msg := FormatCloseMessage(CloseNormalClosure, "")
|
|
||||||
if err := client.WriteMessage(CloseMessage, msg); err != nil {
|
|
||||||
t.Fatalf("unexpected error when writing close message, %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, _, err := server.NextReader(); !IsCloseError(err, 1000) {
|
|
||||||
t.Fatalf("server expects a close message, %v returned", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, _, err := client.NextReader(); !IsCloseError(err, 1000) {
|
|
||||||
t.Fatalf("client expects a close message, %v returned", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadLimit(t *testing.T) {
|
func TestReadLimit(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
t.Run("Test ReadLimit is enforced", func(t *testing.T) {
|
t.Run("Test ReadLimit is enforced", func(t *testing.T) {
|
||||||
const readLimit = 512
|
const readLimit = 512
|
||||||
message := make([]byte, readLimit+1)
|
message := make([]byte, readLimit+1)
|
||||||
|
@ -561,7 +490,7 @@ func TestReadLimit(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(io.Discard, r)
|
_, err = io.Copy(ioutil.Discard, r)
|
||||||
if err != ErrReadLimit {
|
if err != ErrReadLimit {
|
||||||
t.Fatalf("io.Copy() returned %v", err)
|
t.Fatalf("io.Copy() returned %v", err)
|
||||||
}
|
}
|
||||||
|
@ -624,7 +553,6 @@ func TestReadLimit(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddrs(t *testing.T) {
|
func TestAddrs(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
c := newTestConn(nil, nil, true)
|
c := newTestConn(nil, nil, true)
|
||||||
if c.LocalAddr() != localAddr {
|
if c.LocalAddr() != localAddr {
|
||||||
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
|
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
|
||||||
|
@ -635,7 +563,6 @@ func TestAddrs(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeprecatedUnderlyingConn(t *testing.T) {
|
func TestDeprecatedUnderlyingConn(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
var b1, b2 bytes.Buffer
|
var b1, b2 bytes.Buffer
|
||||||
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
||||||
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
||||||
|
@ -646,7 +573,6 @@ func TestDeprecatedUnderlyingConn(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNetConn(t *testing.T) {
|
func TestNetConn(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
var b1, b2 bytes.Buffer
|
var b1, b2 bytes.Buffer
|
||||||
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
||||||
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
c := newConn(fc, true, 1024, 1024, nil, nil, nil)
|
||||||
|
@ -657,7 +583,6 @@ func TestNetConn(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBufioReadBytes(t *testing.T) {
|
func TestBufioReadBytes(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
// Test calling bufio.ReadBytes for value longer than read buffer size.
|
// Test calling bufio.ReadBytes for value longer than read buffer size.
|
||||||
|
|
||||||
m := make([]byte, 512)
|
m := make([]byte, 512)
|
||||||
|
@ -698,7 +623,6 @@ var closeErrorTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCloseError(t *testing.T) {
|
func TestCloseError(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, tt := range closeErrorTests {
|
for _, tt := range closeErrorTests {
|
||||||
ok := IsCloseError(tt.err, tt.codes...)
|
ok := IsCloseError(tt.err, tt.codes...)
|
||||||
if ok != tt.ok {
|
if ok != tt.ok {
|
||||||
|
@ -719,7 +643,6 @@ var unexpectedCloseErrorTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUnexpectedCloseErrors(t *testing.T) {
|
func TestUnexpectedCloseErrors(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, tt := range unexpectedCloseErrorTests {
|
for _, tt := range unexpectedCloseErrorTests {
|
||||||
ok := IsUnexpectedCloseError(tt.err, tt.codes...)
|
ok := IsUnexpectedCloseError(tt.err, tt.codes...)
|
||||||
if ok != tt.ok {
|
if ok != tt.ok {
|
||||||
|
@ -741,7 +664,6 @@ func (w blockingWriter) Write(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConcurrentWritePanic(t *testing.T) {
|
func TestConcurrentWritePanic(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
||||||
c := newTestConn(nil, w, false)
|
c := newTestConn(nil, w, false)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -769,7 +691,6 @@ func (r failingReader) Read(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFailedConnectionReadPanic(t *testing.T) {
|
func TestFailedConnectionReadPanic(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
c := newTestConn(failingReader{}, nil, false)
|
c := newTestConn(failingReader{}, nil, false)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -783,42 +704,3 @@ func TestFailedConnectionReadPanic(t *testing.T) {
|
||||||
}
|
}
|
||||||
t.Fatal("should not get here")
|
t.Fatal("should not get here")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFormatMessageType(t *testing.T) {
|
|
||||||
str := FormatMessageType(TextMessage)
|
|
||||||
if str != messageTypes[TextMessage] {
|
|
||||||
t.Error("failed to format message type")
|
|
||||||
}
|
|
||||||
|
|
||||||
str = FormatMessageType(CloseMessage)
|
|
||||||
if str != messageTypes[CloseMessage] {
|
|
||||||
t.Error("failed to format message type")
|
|
||||||
}
|
|
||||||
|
|
||||||
str = FormatMessageType(123)
|
|
||||||
if str != messageTypes[123] {
|
|
||||||
t.Error("failed to format message type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeNetClosedReader struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r fakeNetClosedReader) Read([]byte) (int, error) {
|
|
||||||
return 0, net.ErrClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectionClosed(t *testing.T) {
|
|
||||||
var b1, b2 bytes.Buffer
|
|
||||||
|
|
||||||
client := newTestConn(fakeNetClosedReader{}, &b1, false)
|
|
||||||
server := newTestConn(fakeNetClosedReader{}, &b2, true)
|
|
||||||
|
|
||||||
if _, _, err := server.NextReader(); !errors.Is(err, net.ErrClosed) {
|
|
||||||
t.Fatalf("server expects a net.ErrClosed error, %v returned", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, _, err := client.NextReader(); !errors.Is(err, net.ErrClosed) {
|
|
||||||
t.Fatalf("client expects a net.ErrClosed error, %v returned", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -42,4 +42,4 @@ func processMessage(mt int, p []byte) {}
|
||||||
|
|
||||||
// TestX prevents godoc from showing this entire file in the example. Remove
|
// TestX prevents godoc from showing this entire file in the example. Remove
|
||||||
// this function when a second example is added.
|
// this function when a second example is added.
|
||||||
func TestX(t *testing.T) { t.Parallel() }
|
func TestX(t *testing.T) {}
|
||||||
|
|
|
@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// echoReadAll echoes messages from the client by reading the entire message
|
// echoReadAll echoes messages from the client by reading the entire message
|
||||||
// with io.ReadAll.
|
// with ioutil.ReadAll.
|
||||||
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
|
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,89 +0,0 @@
|
||||||
//go:build ignore
|
|
||||||
// +build ignore
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
"log"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
var addr = flag.String("addr", "localhost:8080", "http service address")
|
|
||||||
|
|
||||||
func runNewConn(wg *sync.WaitGroup) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
interrupt := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(interrupt, os.Interrupt)
|
|
||||||
|
|
||||||
u := url.URL{Scheme: "ws", Host: *addr, Path: "/ws"}
|
|
||||||
log.Printf("connecting to %s", u.String())
|
|
||||||
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("dial:", err)
|
|
||||||
}
|
|
||||||
defer c.Close()
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer close(done)
|
|
||||||
for {
|
|
||||||
_, message, err := c.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
log.Println("read:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("recv: %s", message)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(time.Minute * 5)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
case t := <-ticker.C:
|
|
||||||
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
|
|
||||||
if err != nil {
|
|
||||||
log.Println("write:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case <-interrupt:
|
|
||||||
log.Println("interrupt")
|
|
||||||
|
|
||||||
// Cleanly close the connection by sending a close message and then
|
|
||||||
// waiting (with timeout) for the server to close the connection.
|
|
||||||
err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
|
||||||
if err != nil {
|
|
||||||
log.Println("write close:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
flag.Parse()
|
|
||||||
log.SetFlags(0)
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go runNewConn(wg)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
|
@ -1,55 +0,0 @@
|
||||||
//go:build ignore
|
|
||||||
// +build ignore
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
_ "net/http/pprof"
|
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
var addr = flag.String("addr", "localhost:8080", "http service address")
|
|
||||||
|
|
||||||
var upgrader = websocket.Upgrader{
|
|
||||||
ReadBufferSize: 256,
|
|
||||||
WriteBufferSize: 256,
|
|
||||||
WriteBufferPool: &sync.Pool{},
|
|
||||||
}
|
|
||||||
|
|
||||||
func process(c *websocket.Conn) {
|
|
||||||
defer c.Close()
|
|
||||||
for {
|
|
||||||
_, message, err := c.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
log.Println("read:", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
log.Printf("recv: %s", message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
c, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Print("upgrade:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process connection in a new goroutine
|
|
||||||
go process(c)
|
|
||||||
|
|
||||||
// Let the http handler return, the 8k buffer created by it will be garbage collected
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
flag.Parse()
|
|
||||||
log.SetFlags(0)
|
|
||||||
http.HandleFunc("/ws", handler)
|
|
||||||
log.Fatal(http.ListenAndServe(*addr, nil))
|
|
||||||
}
|
|
|
@ -38,7 +38,7 @@ sends them to the hub.
|
||||||
### Hub
|
### Hub
|
||||||
|
|
||||||
The code for the `Hub` type is in
|
The code for the `Hub` type is in
|
||||||
[hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go).
|
[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go).
|
||||||
The application's `main` function starts the hub's `run` method as a goroutine.
|
The application's `main` function starts the hub's `run` method as a goroutine.
|
||||||
Clients send requests to the hub using the `register`, `unregister` and
|
Clients send requests to the hub using the `register`, `unregister` and
|
||||||
`broadcast` channels.
|
`broadcast` channels.
|
||||||
|
@ -57,7 +57,7 @@ unregisters the client and closes the websocket.
|
||||||
|
|
||||||
### Client
|
### Client
|
||||||
|
|
||||||
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go).
|
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go).
|
||||||
|
|
||||||
The `serveWs` function is registered by the application's `main` function as
|
The `serveWs` function is registered by the application's `main` function as
|
||||||
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
|
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
|
||||||
|
@ -85,7 +85,7 @@ network.
|
||||||
|
|
||||||
## Frontend
|
## Frontend
|
||||||
|
|
||||||
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html).
|
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html).
|
||||||
|
|
||||||
On document load, the script checks for websocket functionality in the browser.
|
On document load, the script checks for websocket functionality in the browser.
|
||||||
If websocket functionality is available, then the script opens a connection to
|
If websocket functionality is available, then the script opens a connection to
|
||||||
|
|
|
@ -57,6 +57,8 @@ func pumpStdin(ws *websocket.Conn, w io.Writer) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) {
|
func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) {
|
||||||
|
defer func() {
|
||||||
|
}()
|
||||||
s := bufio.NewScanner(r)
|
s := bufio.NewScanner(r)
|
||||||
for s.Scan() {
|
for s.Scan() {
|
||||||
ws.SetWriteDeadline(time.Now().Add(writeWait))
|
ws.SetWriteDeadline(time.Now().Add(writeWait))
|
||||||
|
|
|
@ -41,12 +41,12 @@ func main() {
|
||||||
go func() {
|
go func() {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
for {
|
for {
|
||||||
mt, message, err := c.ReadMessage()
|
_, message, err := c.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("read:", err)
|
log.Println("read:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("recv: %s, type: %s", message, websocket.FormatMessageType(mt))
|
log.Printf("recv: %s", message)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,7 @@ func echo(w http.ResponseWriter, r *http.Request) {
|
||||||
log.Println("read:", err)
|
log.Println("read:", err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
log.Printf("recv: %s", message)
|
||||||
log.Printf("recv: %s, type: %s", message, websocket.FormatMessageType(mt))
|
|
||||||
err = c.WriteMessage(mt, message)
|
err = c.WriteMessage(mt, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("write:", err)
|
log.Println("write:", err)
|
||||||
|
|
|
@ -7,10 +7,10 @@ package main
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
|
||||||
if !fi.ModTime().After(lastMod) {
|
if !fi.ModTime().After(lastMod) {
|
||||||
return nil, lastMod, nil
|
return nil, lastMod, nil
|
||||||
}
|
}
|
||||||
p, err := os.ReadFile(filepath.Clean(filename))
|
p, err := ioutil.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fi.ModTime(), err
|
return nil, fi.ModTime(), err
|
||||||
}
|
}
|
||||||
|
|
4
go.mod
4
go.mod
|
@ -1,5 +1,3 @@
|
||||||
module github.com/gorilla/websocket
|
module github.com/gorilla/websocket
|
||||||
|
|
||||||
go 1.20
|
go 1.12
|
||||||
|
|
||||||
require golang.org/x/net v0.23.0
|
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -1,2 +0,0 @@
|
||||||
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
|
|
||||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJoinMessages(t *testing.T) {
|
func TestJoinMessages(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"}
|
messages := []string{"a", "bc", "def", "ghij", "klmno", "0", "12", "345", "6789"}
|
||||||
for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} {
|
for _, readChunk := range []int{1, 2, 3, 4, 5, 6, 7} {
|
||||||
for _, term := range []string{"", ","} {
|
for _, term := range []string{"", ","} {
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJSON(t *testing.T) {
|
func TestJSON(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
wc := newTestConn(nil, &buf, true)
|
wc := newTestConn(nil, &buf, true)
|
||||||
rc := newTestConn(&buf, nil, false)
|
rc := newTestConn(&buf, nil, false)
|
||||||
|
@ -39,7 +38,6 @@ func TestJSON(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPartialJSONRead(t *testing.T) {
|
func TestPartialJSONRead(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
var buf0, buf1 bytes.Buffer
|
var buf0, buf1 bytes.Buffer
|
||||||
wc := newTestConn(nil, &buf0, true)
|
wc := newTestConn(nil, &buf0, true)
|
||||||
rc := newTestConn(&buf0, &buf1, false)
|
rc := newTestConn(&buf0, &buf1, false)
|
||||||
|
@ -93,7 +91,6 @@ func TestPartialJSONRead(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeprecatedJSON(t *testing.T) {
|
func TestDeprecatedJSON(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
wc := newTestConn(nil, &buf, true)
|
wc := newTestConn(nil, &buf, true)
|
||||||
rc := newTestConn(&buf, nil, false)
|
rc := newTestConn(&buf, nil, false)
|
||||||
|
|
|
@ -29,7 +29,6 @@ func notzero(b []byte) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMaskBytes(t *testing.T) {
|
func TestMaskBytes(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
key := [4]byte{1, 2, 3, 4}
|
key := [4]byte{1, 2, 3, 4}
|
||||||
for size := 1; size <= 1024; size++ {
|
for size := 1; size <= 1024; size++ {
|
||||||
for align := 0; align < wordSize; align++ {
|
for align := 0; align < wordSize; align++ {
|
||||||
|
|
|
@ -33,11 +33,6 @@ var preparedMessageTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPreparedMessage(t *testing.T) {
|
func TestPreparedMessage(t *testing.T) {
|
||||||
testRand := rand.New(rand.NewSource(99))
|
|
||||||
prevMaskRand := maskRand
|
|
||||||
maskRand = testRand
|
|
||||||
defer func() { maskRand = prevMaskRand }()
|
|
||||||
|
|
||||||
for _, tt := range preparedMessageTests {
|
for _, tt := range preparedMessageTests {
|
||||||
var data = []byte("this is a test")
|
var data = []byte("this is a test")
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
@ -48,7 +43,7 @@ func TestPreparedMessage(t *testing.T) {
|
||||||
c.SetCompressionLevel(tt.compressionLevel)
|
c.SetCompressionLevel(tt.compressionLevel)
|
||||||
|
|
||||||
// Seed random number generator for consistent frame mask.
|
// Seed random number generator for consistent frame mask.
|
||||||
testRand.Seed(1234)
|
rand.Seed(1234)
|
||||||
|
|
||||||
if err := c.WriteMessage(tt.messageType, data); err != nil {
|
if err := c.WriteMessage(tt.messageType, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -64,7 +59,7 @@ func TestPreparedMessage(t *testing.T) {
|
||||||
copy(data, "hello world")
|
copy(data, "hello world")
|
||||||
|
|
||||||
// Seed random number generator for consistent frame mask.
|
// Seed random number generator for consistent frame mask.
|
||||||
testRand.Seed(1234)
|
rand.Seed(1234)
|
||||||
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
if err := c.WritePreparedMessage(pm); err != nil {
|
if err := c.WritePreparedMessage(pm); err != nil {
|
||||||
|
|
6
proxy.go
6
proxy.go
|
@ -12,8 +12,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type netDialerFunc func(network, addr string) (net.Conn, error)
|
type netDialerFunc func(network, addr string) (net.Conn, error)
|
||||||
|
@ -23,7 +21,7 @@ func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
proxy.RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy.Dialer) (proxy.Dialer, error) {
|
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
||||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -70,7 +68,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != 200 {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
f := strings.SplitN(resp.Status, " ", 2)
|
f := strings.SplitN(resp.Status, " ", 2)
|
||||||
return nil, errors.New(f[1])
|
return nil, errors.New(f[1])
|
||||||
|
|
12
server.go
12
server.go
|
@ -33,7 +33,6 @@ type Upgrader struct {
|
||||||
// size is zero, then buffers allocated by the HTTP server are used. The
|
// size is zero, then buffers allocated by the HTTP server are used. The
|
||||||
// I/O buffer sizes do not limit the size of the messages that can be sent
|
// I/O buffer sizes do not limit the size of the messages that can be sent
|
||||||
// or received.
|
// or received.
|
||||||
// The default value is 4096 bytes, 4kb.
|
|
||||||
ReadBufferSize, WriteBufferSize int
|
ReadBufferSize, WriteBufferSize int
|
||||||
|
|
||||||
// WriteBufferPool is a pool of buffers for write operations. If the value
|
// WriteBufferPool is a pool of buffers for write operations. If the value
|
||||||
|
@ -102,8 +101,8 @@ func checkSameOrigin(r *http.Request) bool {
|
||||||
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
||||||
if u.Subprotocols != nil {
|
if u.Subprotocols != nil {
|
||||||
clientProtocols := Subprotocols(r)
|
clientProtocols := Subprotocols(r)
|
||||||
for _, clientProtocol := range clientProtocols {
|
for _, serverProtocol := range u.Subprotocols {
|
||||||
for _, serverProtocol := range u.Subprotocols {
|
for _, clientProtocol := range clientProtocols {
|
||||||
if clientProtocol == serverProtocol {
|
if clientProtocol == serverProtocol {
|
||||||
return clientProtocol
|
return clientProtocol
|
||||||
}
|
}
|
||||||
|
@ -173,7 +172,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
netConn, brw, err := http.NewResponseController(w).Hijack()
|
h, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
|
||||||
|
}
|
||||||
|
var brw *bufio.ReadWriter
|
||||||
|
netConn, brw, err := h.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
|
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,10 +7,8 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -29,7 +27,6 @@ var subprotocolTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSubprotocols(t *testing.T) {
|
func TestSubprotocols(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, st := range subprotocolTests {
|
for _, st := range subprotocolTests {
|
||||||
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}}
|
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {st.h}}}
|
||||||
protocols := Subprotocols(&r)
|
protocols := Subprotocols(&r)
|
||||||
|
@ -49,7 +46,6 @@ var isWebSocketUpgradeTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsWebSocketUpgrade(t *testing.T) {
|
func TestIsWebSocketUpgrade(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, tt := range isWebSocketUpgradeTests {
|
for _, tt := range isWebSocketUpgradeTests {
|
||||||
ok := IsWebSocketUpgrade(&http.Request{Header: tt.h})
|
ok := IsWebSocketUpgrade(&http.Request{Header: tt.h})
|
||||||
if tt.ok != ok {
|
if tt.ok != ok {
|
||||||
|
@ -58,37 +54,6 @@ func TestIsWebSocketUpgrade(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSubProtocolSelection(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
upgrader := Upgrader{
|
|
||||||
Subprotocols: []string{"foo", "bar", "baz"},
|
|
||||||
}
|
|
||||||
|
|
||||||
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
|
|
||||||
s := upgrader.selectSubprotocol(&r, nil)
|
|
||||||
if s != "foo" {
|
|
||||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
|
|
||||||
}
|
|
||||||
|
|
||||||
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
|
|
||||||
s = upgrader.selectSubprotocol(&r, nil)
|
|
||||||
if s != "bar" {
|
|
||||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
|
|
||||||
}
|
|
||||||
|
|
||||||
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
|
|
||||||
s = upgrader.selectSubprotocol(&r, nil)
|
|
||||||
if s != "baz" {
|
|
||||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
|
|
||||||
}
|
|
||||||
|
|
||||||
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
|
|
||||||
s = upgrader.selectSubprotocol(&r, nil)
|
|
||||||
if s != "" {
|
|
||||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var checkSameOriginTests = []struct {
|
var checkSameOriginTests = []struct {
|
||||||
ok bool
|
ok bool
|
||||||
r *http.Request
|
r *http.Request
|
||||||
|
@ -99,7 +64,6 @@ var checkSameOriginTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCheckSameOrigin(t *testing.T) {
|
func TestCheckSameOrigin(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, tt := range checkSameOriginTests {
|
for _, tt := range checkSameOriginTests {
|
||||||
ok := checkSameOrigin(tt.r)
|
ok := checkSameOrigin(tt.r)
|
||||||
if tt.ok != ok {
|
if tt.ok != ok {
|
||||||
|
@ -126,7 +90,6 @@ var bufioReuseTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBufioReuse(t *testing.T) {
|
func TestBufioReuse(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for i, tt := range bufioReuseTests {
|
for i, tt := range bufioReuseTests {
|
||||||
br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
|
br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
|
||||||
bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
|
bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
|
||||||
|
@ -154,23 +117,3 @@ func TestBufioReuse(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHijack_NotSupported(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
||||||
req.Header.Set("Upgrade", "websocket")
|
|
||||||
req.Header.Set("Connection", "upgrade")
|
|
||||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
|
||||||
req.Header.Set("Sec-Websocket-Version", "13")
|
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
|
||||||
|
|
||||||
upgrader := Upgrader{}
|
|
||||||
_, err := upgrader.Upgrade(recorder, req, nil)
|
|
||||||
|
|
||||||
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
|
|
||||||
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
|
|
||||||
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
//go:build go1.17
|
||||||
|
// +build go1.17
|
||||||
|
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -21,7 +21,6 @@ var equalASCIIFoldTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEqualASCIIFold(t *testing.T) {
|
func TestEqualASCIIFold(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, tt := range equalASCIIFoldTests {
|
for _, tt := range equalASCIIFoldTests {
|
||||||
eq := equalASCIIFold(tt.s, tt.t)
|
eq := equalASCIIFold(tt.s, tt.t)
|
||||||
if eq != tt.eq {
|
if eq != tt.eq {
|
||||||
|
@ -45,7 +44,6 @@ var tokenListContainsValueTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTokenListContainsValue(t *testing.T) {
|
func TestTokenListContainsValue(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, tt := range tokenListContainsValueTests {
|
for _, tt := range tokenListContainsValueTests {
|
||||||
h := http.Header{"Upgrade": {tt.value}}
|
h := http.Header{"Upgrade": {tt.value}}
|
||||||
ok := tokenListContainsValue(h, "Upgrade", "websocket")
|
ok := tokenListContainsValue(h, "Upgrade", "websocket")
|
||||||
|
@ -66,7 +64,6 @@ var isValidChallengeKeyTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsValidChallengeKey(t *testing.T) {
|
func TestIsValidChallengeKey(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, tt := range isValidChallengeKeyTests {
|
for _, tt := range isValidChallengeKeyTests {
|
||||||
ok := isValidChallengeKey(tt.key)
|
ok := isValidChallengeKey(tt.key)
|
||||||
if ok != tt.ok {
|
if ok != tt.ok {
|
||||||
|
@ -108,7 +105,6 @@ var parseExtensionTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseExtensions(t *testing.T) {
|
func TestParseExtensions(t *testing.T) {
|
||||||
t.Parallel()
|
|
||||||
for _, tt := range parseExtensionTests {
|
for _, tt := range parseExtensionTests {
|
||||||
h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}}
|
h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}}
|
||||||
extensions := parseExtensions(h)
|
extensions := parseExtensions(h)
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions are
|
|
||||||
met:
|
|
||||||
|
|
||||||
* Redistributions of source code must retain the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer.
|
|
||||||
* Redistributions in binary form must reproduce the above
|
|
||||||
copyright notice, this list of conditions and the following disclaimer
|
|
||||||
in the documentation and/or other materials provided with the
|
|
||||||
distribution.
|
|
||||||
* Neither the name of Google Inc. nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived from
|
|
||||||
this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
||||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
||||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
||||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
||||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
||||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
||||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
||||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
@ -1,22 +0,0 @@
|
||||||
Additional IP Rights Grant (Patents)
|
|
||||||
|
|
||||||
"This implementation" means the copyrightable works distributed by
|
|
||||||
Google as part of the Go project.
|
|
||||||
|
|
||||||
Google hereby grants to You a perpetual, worldwide, non-exclusive,
|
|
||||||
no-charge, royalty-free, irrevocable (except as stated in this section)
|
|
||||||
patent license to make, have made, use, offer to sell, sell, import,
|
|
||||||
transfer and otherwise run, modify and propagate the contents of this
|
|
||||||
implementation of Go, where such license applies only to those patent
|
|
||||||
claims, both currently owned or controlled by Google and acquired in
|
|
||||||
the future, licensable by Google that are necessarily infringed by this
|
|
||||||
implementation of Go. This grant does not include claims that would be
|
|
||||||
infringed only as a consequence of further modification of this
|
|
||||||
implementation. If you or your agent or exclusive licensee institute or
|
|
||||||
order or agree to the institution of patent litigation against any
|
|
||||||
entity (including a cross-claim or counterclaim in a lawsuit) alleging
|
|
||||||
that this implementation of Go or any code incorporated within this
|
|
||||||
implementation of Go constitutes direct or contributory patent
|
|
||||||
infringement, or inducement of patent infringement, then any patent
|
|
||||||
rights granted to you under this License for this implementation of Go
|
|
||||||
shall terminate as of the date such litigation is filed.
|
|
|
@ -1,168 +0,0 @@
|
||||||
// Copyright 2018 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package socks
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
noDeadline = time.Time{}
|
|
||||||
aLongTimeAgo = time.Unix(1, 0)
|
|
||||||
)
|
|
||||||
|
|
||||||
func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
|
|
||||||
host, port, err := splitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
|
|
||||||
c.SetDeadline(deadline)
|
|
||||||
defer c.SetDeadline(noDeadline)
|
|
||||||
}
|
|
||||||
if ctx != context.Background() {
|
|
||||||
errCh := make(chan error, 1)
|
|
||||||
done := make(chan struct{})
|
|
||||||
defer func() {
|
|
||||||
close(done)
|
|
||||||
if ctxErr == nil {
|
|
||||||
ctxErr = <-errCh
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
c.SetDeadline(aLongTimeAgo)
|
|
||||||
errCh <- ctx.Err()
|
|
||||||
case <-done:
|
|
||||||
errCh <- nil
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
|
|
||||||
b = append(b, Version5)
|
|
||||||
if len(d.AuthMethods) == 0 || d.Authenticate == nil {
|
|
||||||
b = append(b, 1, byte(AuthMethodNotRequired))
|
|
||||||
} else {
|
|
||||||
ams := d.AuthMethods
|
|
||||||
if len(ams) > 255 {
|
|
||||||
return nil, errors.New("too many authentication methods")
|
|
||||||
}
|
|
||||||
b = append(b, byte(len(ams)))
|
|
||||||
for _, am := range ams {
|
|
||||||
b = append(b, byte(am))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, ctxErr = c.Write(b); ctxErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if b[0] != Version5 {
|
|
||||||
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
|
|
||||||
}
|
|
||||||
am := AuthMethod(b[1])
|
|
||||||
if am == AuthMethodNoAcceptableMethods {
|
|
||||||
return nil, errors.New("no acceptable authentication methods")
|
|
||||||
}
|
|
||||||
if d.Authenticate != nil {
|
|
||||||
if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
b = b[:0]
|
|
||||||
b = append(b, Version5, byte(d.cmd), 0)
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
b = append(b, AddrTypeIPv4)
|
|
||||||
b = append(b, ip4...)
|
|
||||||
} else if ip6 := ip.To16(); ip6 != nil {
|
|
||||||
b = append(b, AddrTypeIPv6)
|
|
||||||
b = append(b, ip6...)
|
|
||||||
} else {
|
|
||||||
return nil, errors.New("unknown address type")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if len(host) > 255 {
|
|
||||||
return nil, errors.New("FQDN too long")
|
|
||||||
}
|
|
||||||
b = append(b, AddrTypeFQDN)
|
|
||||||
b = append(b, byte(len(host)))
|
|
||||||
b = append(b, host...)
|
|
||||||
}
|
|
||||||
b = append(b, byte(port>>8), byte(port))
|
|
||||||
if _, ctxErr = c.Write(b); ctxErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if b[0] != Version5 {
|
|
||||||
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
|
|
||||||
}
|
|
||||||
if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
|
|
||||||
return nil, errors.New("unknown error " + cmdErr.String())
|
|
||||||
}
|
|
||||||
if b[2] != 0 {
|
|
||||||
return nil, errors.New("non-zero reserved field")
|
|
||||||
}
|
|
||||||
l := 2
|
|
||||||
var a Addr
|
|
||||||
switch b[3] {
|
|
||||||
case AddrTypeIPv4:
|
|
||||||
l += net.IPv4len
|
|
||||||
a.IP = make(net.IP, net.IPv4len)
|
|
||||||
case AddrTypeIPv6:
|
|
||||||
l += net.IPv6len
|
|
||||||
a.IP = make(net.IP, net.IPv6len)
|
|
||||||
case AddrTypeFQDN:
|
|
||||||
if _, err := io.ReadFull(c, b[:1]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
l += int(b[0])
|
|
||||||
default:
|
|
||||||
return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
|
|
||||||
}
|
|
||||||
if cap(b) < l {
|
|
||||||
b = make([]byte, l)
|
|
||||||
} else {
|
|
||||||
b = b[:l]
|
|
||||||
}
|
|
||||||
if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if a.IP != nil {
|
|
||||||
copy(a.IP, b)
|
|
||||||
} else {
|
|
||||||
a.Name = string(b[:len(b)-2])
|
|
||||||
}
|
|
||||||
a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
|
|
||||||
return &a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitHostPort(address string) (string, int, error) {
|
|
||||||
host, port, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, err
|
|
||||||
}
|
|
||||||
portnum, err := strconv.Atoi(port)
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, err
|
|
||||||
}
|
|
||||||
if 1 > portnum || portnum > 0xffff {
|
|
||||||
return "", 0, errors.New("port number out of range " + port)
|
|
||||||
}
|
|
||||||
return host, portnum, nil
|
|
||||||
}
|
|
|
@ -1,317 +0,0 @@
|
||||||
// Copyright 2018 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package socks provides a SOCKS version 5 client implementation.
|
|
||||||
//
|
|
||||||
// SOCKS protocol version 5 is defined in RFC 1928.
|
|
||||||
// Username/Password authentication for SOCKS version 5 is defined in
|
|
||||||
// RFC 1929.
|
|
||||||
package socks
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Command represents a SOCKS command.
|
|
||||||
type Command int
|
|
||||||
|
|
||||||
func (cmd Command) String() string {
|
|
||||||
switch cmd {
|
|
||||||
case CmdConnect:
|
|
||||||
return "socks connect"
|
|
||||||
case cmdBind:
|
|
||||||
return "socks bind"
|
|
||||||
default:
|
|
||||||
return "socks " + strconv.Itoa(int(cmd))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// An AuthMethod represents a SOCKS authentication method.
|
|
||||||
type AuthMethod int
|
|
||||||
|
|
||||||
// A Reply represents a SOCKS command reply code.
|
|
||||||
type Reply int
|
|
||||||
|
|
||||||
func (code Reply) String() string {
|
|
||||||
switch code {
|
|
||||||
case StatusSucceeded:
|
|
||||||
return "succeeded"
|
|
||||||
case 0x01:
|
|
||||||
return "general SOCKS server failure"
|
|
||||||
case 0x02:
|
|
||||||
return "connection not allowed by ruleset"
|
|
||||||
case 0x03:
|
|
||||||
return "network unreachable"
|
|
||||||
case 0x04:
|
|
||||||
return "host unreachable"
|
|
||||||
case 0x05:
|
|
||||||
return "connection refused"
|
|
||||||
case 0x06:
|
|
||||||
return "TTL expired"
|
|
||||||
case 0x07:
|
|
||||||
return "command not supported"
|
|
||||||
case 0x08:
|
|
||||||
return "address type not supported"
|
|
||||||
default:
|
|
||||||
return "unknown code: " + strconv.Itoa(int(code))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wire protocol constants.
|
|
||||||
const (
|
|
||||||
Version5 = 0x05
|
|
||||||
|
|
||||||
AddrTypeIPv4 = 0x01
|
|
||||||
AddrTypeFQDN = 0x03
|
|
||||||
AddrTypeIPv6 = 0x04
|
|
||||||
|
|
||||||
CmdConnect Command = 0x01 // establishes an active-open forward proxy connection
|
|
||||||
cmdBind Command = 0x02 // establishes a passive-open forward proxy connection
|
|
||||||
|
|
||||||
AuthMethodNotRequired AuthMethod = 0x00 // no authentication required
|
|
||||||
AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password
|
|
||||||
AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authentication methods
|
|
||||||
|
|
||||||
StatusSucceeded Reply = 0x00
|
|
||||||
)
|
|
||||||
|
|
||||||
// An Addr represents a SOCKS-specific address.
|
|
||||||
// Either Name or IP is used exclusively.
|
|
||||||
type Addr struct {
|
|
||||||
Name string // fully-qualified domain name
|
|
||||||
IP net.IP
|
|
||||||
Port int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Addr) Network() string { return "socks" }
|
|
||||||
|
|
||||||
func (a *Addr) String() string {
|
|
||||||
if a == nil {
|
|
||||||
return "<nil>"
|
|
||||||
}
|
|
||||||
port := strconv.Itoa(a.Port)
|
|
||||||
if a.IP == nil {
|
|
||||||
return net.JoinHostPort(a.Name, port)
|
|
||||||
}
|
|
||||||
return net.JoinHostPort(a.IP.String(), port)
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Conn represents a forward proxy connection.
|
|
||||||
type Conn struct {
|
|
||||||
net.Conn
|
|
||||||
|
|
||||||
boundAddr net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// BoundAddr returns the address assigned by the proxy server for
|
|
||||||
// connecting to the command target address from the proxy server.
|
|
||||||
func (c *Conn) BoundAddr() net.Addr {
|
|
||||||
if c == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return c.boundAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Dialer holds SOCKS-specific options.
|
|
||||||
type Dialer struct {
|
|
||||||
cmd Command // either CmdConnect or cmdBind
|
|
||||||
proxyNetwork string // network between a proxy server and a client
|
|
||||||
proxyAddress string // proxy server address
|
|
||||||
|
|
||||||
// ProxyDial specifies the optional dial function for
|
|
||||||
// establishing the transport connection.
|
|
||||||
ProxyDial func(context.Context, string, string) (net.Conn, error)
|
|
||||||
|
|
||||||
// AuthMethods specifies the list of request authentication
|
|
||||||
// methods.
|
|
||||||
// If empty, SOCKS client requests only AuthMethodNotRequired.
|
|
||||||
AuthMethods []AuthMethod
|
|
||||||
|
|
||||||
// Authenticate specifies the optional authentication
|
|
||||||
// function. It must be non-nil when AuthMethods is not empty.
|
|
||||||
// It must return an error when the authentication is failed.
|
|
||||||
Authenticate func(context.Context, io.ReadWriter, AuthMethod) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext connects to the provided address on the provided
|
|
||||||
// network.
|
|
||||||
//
|
|
||||||
// The returned error value may be a net.OpError. When the Op field of
|
|
||||||
// net.OpError contains "socks", the Source field contains a proxy
|
|
||||||
// server address and the Addr field contains a command target
|
|
||||||
// address.
|
|
||||||
//
|
|
||||||
// See func Dial of the net package of standard library for a
|
|
||||||
// description of the network and address parameters.
|
|
||||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
if err := d.validateTarget(network, address); err != nil {
|
|
||||||
proxy, dst, _ := d.pathAddrs(address)
|
|
||||||
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
|
||||||
}
|
|
||||||
if ctx == nil {
|
|
||||||
proxy, dst, _ := d.pathAddrs(address)
|
|
||||||
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
var c net.Conn
|
|
||||||
if d.ProxyDial != nil {
|
|
||||||
c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
|
|
||||||
} else {
|
|
||||||
var dd net.Dialer
|
|
||||||
c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
proxy, dst, _ := d.pathAddrs(address)
|
|
||||||
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
|
||||||
}
|
|
||||||
a, err := d.connect(ctx, c, address)
|
|
||||||
if err != nil {
|
|
||||||
c.Close()
|
|
||||||
proxy, dst, _ := d.pathAddrs(address)
|
|
||||||
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
|
||||||
}
|
|
||||||
return &Conn{Conn: c, boundAddr: a}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialWithConn initiates a connection from SOCKS server to the target
|
|
||||||
// network and address using the connection c that is already
|
|
||||||
// connected to the SOCKS server.
|
|
||||||
//
|
|
||||||
// It returns the connection's local address assigned by the SOCKS
|
|
||||||
// server.
|
|
||||||
func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
|
|
||||||
if err := d.validateTarget(network, address); err != nil {
|
|
||||||
proxy, dst, _ := d.pathAddrs(address)
|
|
||||||
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
|
||||||
}
|
|
||||||
if ctx == nil {
|
|
||||||
proxy, dst, _ := d.pathAddrs(address)
|
|
||||||
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
|
|
||||||
}
|
|
||||||
a, err := d.connect(ctx, c, address)
|
|
||||||
if err != nil {
|
|
||||||
proxy, dst, _ := d.pathAddrs(address)
|
|
||||||
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
|
||||||
}
|
|
||||||
return a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to the provided address on the provided network.
|
|
||||||
//
|
|
||||||
// Unlike DialContext, it returns a raw transport connection instead
|
|
||||||
// of a forward proxy connection.
|
|
||||||
//
|
|
||||||
// Deprecated: Use DialContext or DialWithConn instead.
|
|
||||||
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
|
||||||
if err := d.validateTarget(network, address); err != nil {
|
|
||||||
proxy, dst, _ := d.pathAddrs(address)
|
|
||||||
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
var c net.Conn
|
|
||||||
if d.ProxyDial != nil {
|
|
||||||
c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
|
|
||||||
} else {
|
|
||||||
c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
proxy, dst, _ := d.pathAddrs(address)
|
|
||||||
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
|
||||||
}
|
|
||||||
if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
|
|
||||||
c.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Dialer) validateTarget(network, address string) error {
|
|
||||||
switch network {
|
|
||||||
case "tcp", "tcp6", "tcp4":
|
|
||||||
default:
|
|
||||||
return errors.New("network not implemented")
|
|
||||||
}
|
|
||||||
switch d.cmd {
|
|
||||||
case CmdConnect, cmdBind:
|
|
||||||
default:
|
|
||||||
return errors.New("command not implemented")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
|
|
||||||
for i, s := range []string{d.proxyAddress, address} {
|
|
||||||
host, port, err := splitHostPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
a := &Addr{Port: port}
|
|
||||||
a.IP = net.ParseIP(host)
|
|
||||||
if a.IP == nil {
|
|
||||||
a.Name = host
|
|
||||||
}
|
|
||||||
if i == 0 {
|
|
||||||
proxy = a
|
|
||||||
} else {
|
|
||||||
dst = a
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDialer returns a new Dialer that dials through the provided
|
|
||||||
// proxy server's network and address.
|
|
||||||
func NewDialer(network, address string) *Dialer {
|
|
||||||
return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
authUsernamePasswordVersion = 0x01
|
|
||||||
authStatusSucceeded = 0x00
|
|
||||||
)
|
|
||||||
|
|
||||||
// UsernamePassword are the credentials for the username/password
|
|
||||||
// authentication method.
|
|
||||||
type UsernamePassword struct {
|
|
||||||
Username string
|
|
||||||
Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate authenticates a pair of username and password with the
|
|
||||||
// proxy server.
|
|
||||||
func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error {
|
|
||||||
switch auth {
|
|
||||||
case AuthMethodNotRequired:
|
|
||||||
return nil
|
|
||||||
case AuthMethodUsernamePassword:
|
|
||||||
if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) > 255 {
|
|
||||||
return errors.New("invalid username/password")
|
|
||||||
}
|
|
||||||
b := []byte{authUsernamePasswordVersion}
|
|
||||||
b = append(b, byte(len(up.Username)))
|
|
||||||
b = append(b, up.Username...)
|
|
||||||
b = append(b, byte(len(up.Password)))
|
|
||||||
b = append(b, up.Password...)
|
|
||||||
// TODO(mikio): handle IO deadlines and cancelation if
|
|
||||||
// necessary
|
|
||||||
if _, err := rw.Write(b); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := io.ReadFull(rw, b[:2]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if b[0] != authUsernamePasswordVersion {
|
|
||||||
return errors.New("invalid username/password version")
|
|
||||||
}
|
|
||||||
if b[1] != authStatusSucceeded {
|
|
||||||
return errors.New("username/password authentication failed")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
|
|
||||||
}
|
|
|
@ -1,54 +0,0 @@
|
||||||
// Copyright 2019 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A ContextDialer dials using a context.
|
|
||||||
type ContextDialer interface {
|
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment.
|
|
||||||
//
|
|
||||||
// The passed ctx is only used for returning the Conn, not the lifetime of the Conn.
|
|
||||||
//
|
|
||||||
// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
|
|
||||||
// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
|
|
||||||
//
|
|
||||||
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
|
|
||||||
func Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
d := FromEnvironment()
|
|
||||||
if xd, ok := d.(ContextDialer); ok {
|
|
||||||
return xd.DialContext(ctx, network, address)
|
|
||||||
}
|
|
||||||
return dialContext(ctx, d, network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout
|
|
||||||
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
|
|
||||||
func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) {
|
|
||||||
var (
|
|
||||||
conn net.Conn
|
|
||||||
done = make(chan struct{}, 1)
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
go func() {
|
|
||||||
conn, err = d.Dial(network, address)
|
|
||||||
close(done)
|
|
||||||
if conn != nil && ctx.Err() != nil {
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
err = ctx.Err()
|
|
||||||
case <-done:
|
|
||||||
}
|
|
||||||
return conn, err
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type direct struct{}
|
|
||||||
|
|
||||||
// Direct implements Dialer by making network connections directly using net.Dial or net.DialContext.
|
|
||||||
var Direct = direct{}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ Dialer = Direct
|
|
||||||
_ ContextDialer = Direct
|
|
||||||
)
|
|
||||||
|
|
||||||
// Dial directly invokes net.Dial with the supplied parameters.
|
|
||||||
func (direct) Dial(network, addr string) (net.Conn, error) {
|
|
||||||
return net.Dial(network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters.
|
|
||||||
func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
var d net.Dialer
|
|
||||||
return d.DialContext(ctx, network, addr)
|
|
||||||
}
|
|
|
@ -1,155 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A PerHost directs connections to a default Dialer unless the host name
|
|
||||||
// requested matches one of a number of exceptions.
|
|
||||||
type PerHost struct {
|
|
||||||
def, bypass Dialer
|
|
||||||
|
|
||||||
bypassNetworks []*net.IPNet
|
|
||||||
bypassIPs []net.IP
|
|
||||||
bypassZones []string
|
|
||||||
bypassHosts []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPerHost returns a PerHost Dialer that directs connections to either
|
|
||||||
// defaultDialer or bypass, depending on whether the connection matches one of
|
|
||||||
// the configured rules.
|
|
||||||
func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
|
|
||||||
return &PerHost{
|
|
||||||
def: defaultDialer,
|
|
||||||
bypass: bypass,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to the address addr on the given network through either
|
|
||||||
// defaultDialer or bypass.
|
|
||||||
func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
|
|
||||||
host, _, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.dialerForRequest(host).Dial(network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext connects to the address addr on the given network through either
|
|
||||||
// defaultDialer or bypass.
|
|
||||||
func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
|
|
||||||
host, _, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
d := p.dialerForRequest(host)
|
|
||||||
if x, ok := d.(ContextDialer); ok {
|
|
||||||
return x.DialContext(ctx, network, addr)
|
|
||||||
}
|
|
||||||
return dialContext(ctx, d, network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *PerHost) dialerForRequest(host string) Dialer {
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
for _, net := range p.bypassNetworks {
|
|
||||||
if net.Contains(ip) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, bypassIP := range p.bypassIPs {
|
|
||||||
if bypassIP.Equal(ip) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.def
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, zone := range p.bypassZones {
|
|
||||||
if strings.HasSuffix(host, zone) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
if host == zone[1:] {
|
|
||||||
// For a zone ".example.com", we match "example.com"
|
|
||||||
// too.
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, bypassHost := range p.bypassHosts {
|
|
||||||
if bypassHost == host {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.def
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFromString parses a string that contains comma-separated values
|
|
||||||
// specifying hosts that should use the bypass proxy. Each value is either an
|
|
||||||
// IP address, a CIDR range, a zone (*.example.com) or a host name
|
|
||||||
// (localhost). A best effort is made to parse the string and errors are
|
|
||||||
// ignored.
|
|
||||||
func (p *PerHost) AddFromString(s string) {
|
|
||||||
hosts := strings.Split(s, ",")
|
|
||||||
for _, host := range hosts {
|
|
||||||
host = strings.TrimSpace(host)
|
|
||||||
if len(host) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.Contains(host, "/") {
|
|
||||||
// We assume that it's a CIDR address like 127.0.0.0/8
|
|
||||||
if _, net, err := net.ParseCIDR(host); err == nil {
|
|
||||||
p.AddNetwork(net)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
p.AddIP(ip)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(host, "*.") {
|
|
||||||
p.AddZone(host[1:])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.AddHost(host)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddIP specifies an IP address that will use the bypass proxy. Note that
|
|
||||||
// this will only take effect if a literal IP address is dialed. A connection
|
|
||||||
// to a named host will never match an IP.
|
|
||||||
func (p *PerHost) AddIP(ip net.IP) {
|
|
||||||
p.bypassIPs = append(p.bypassIPs, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
|
|
||||||
// this will only take effect if a literal IP address is dialed. A connection
|
|
||||||
// to a named host will never match.
|
|
||||||
func (p *PerHost) AddNetwork(net *net.IPNet) {
|
|
||||||
p.bypassNetworks = append(p.bypassNetworks, net)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
|
|
||||||
// "example.com" matches "example.com" and all of its subdomains.
|
|
||||||
func (p *PerHost) AddZone(zone string) {
|
|
||||||
if strings.HasSuffix(zone, ".") {
|
|
||||||
zone = zone[:len(zone)-1]
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(zone, ".") {
|
|
||||||
zone = "." + zone
|
|
||||||
}
|
|
||||||
p.bypassZones = append(p.bypassZones, zone)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddHost specifies a host name that will use the bypass proxy.
|
|
||||||
func (p *PerHost) AddHost(host string) {
|
|
||||||
if strings.HasSuffix(host, ".") {
|
|
||||||
host = host[:len(host)-1]
|
|
||||||
}
|
|
||||||
p.bypassHosts = append(p.bypassHosts, host)
|
|
||||||
}
|
|
|
@ -1,149 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package proxy provides support for a variety of protocols to proxy network
|
|
||||||
// data.
|
|
||||||
package proxy // import "golang.org/x/net/proxy"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Dialer is a means to establish a connection.
|
|
||||||
// Custom dialers should also implement ContextDialer.
|
|
||||||
type Dialer interface {
|
|
||||||
// Dial connects to the given address via the proxy.
|
|
||||||
Dial(network, addr string) (c net.Conn, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth contains authentication parameters that specific Dialers may require.
|
|
||||||
type Auth struct {
|
|
||||||
User, Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromEnvironment returns the dialer specified by the proxy-related
|
|
||||||
// variables in the environment and makes underlying connections
|
|
||||||
// directly.
|
|
||||||
func FromEnvironment() Dialer {
|
|
||||||
return FromEnvironmentUsing(Direct)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromEnvironmentUsing returns the dialer specify by the proxy-related
|
|
||||||
// variables in the environment and makes underlying connections
|
|
||||||
// using the provided forwarding Dialer (for instance, a *net.Dialer
|
|
||||||
// with desired configuration).
|
|
||||||
func FromEnvironmentUsing(forward Dialer) Dialer {
|
|
||||||
allProxy := allProxyEnv.Get()
|
|
||||||
if len(allProxy) == 0 {
|
|
||||||
return forward
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyURL, err := url.Parse(allProxy)
|
|
||||||
if err != nil {
|
|
||||||
return forward
|
|
||||||
}
|
|
||||||
proxy, err := FromURL(proxyURL, forward)
|
|
||||||
if err != nil {
|
|
||||||
return forward
|
|
||||||
}
|
|
||||||
|
|
||||||
noProxy := noProxyEnv.Get()
|
|
||||||
if len(noProxy) == 0 {
|
|
||||||
return proxy
|
|
||||||
}
|
|
||||||
|
|
||||||
perHost := NewPerHost(proxy, forward)
|
|
||||||
perHost.AddFromString(noProxy)
|
|
||||||
return perHost
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxySchemes is a map from URL schemes to a function that creates a Dialer
|
|
||||||
// from a URL with such a scheme.
|
|
||||||
var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error)
|
|
||||||
|
|
||||||
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
|
|
||||||
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
|
|
||||||
// by FromURL.
|
|
||||||
func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) {
|
|
||||||
if proxySchemes == nil {
|
|
||||||
proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error))
|
|
||||||
}
|
|
||||||
proxySchemes[scheme] = f
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromURL returns a Dialer given a URL specification and an underlying
|
|
||||||
// Dialer for it to make network requests.
|
|
||||||
func FromURL(u *url.URL, forward Dialer) (Dialer, error) {
|
|
||||||
var auth *Auth
|
|
||||||
if u.User != nil {
|
|
||||||
auth = new(Auth)
|
|
||||||
auth.User = u.User.Username()
|
|
||||||
if p, ok := u.User.Password(); ok {
|
|
||||||
auth.Password = p
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch u.Scheme {
|
|
||||||
case "socks5", "socks5h":
|
|
||||||
addr := u.Hostname()
|
|
||||||
port := u.Port()
|
|
||||||
if port == "" {
|
|
||||||
port = "1080"
|
|
||||||
}
|
|
||||||
return SOCKS5("tcp", net.JoinHostPort(addr, port), auth, forward)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the scheme doesn't match any of the built-in schemes, see if it
|
|
||||||
// was registered by another package.
|
|
||||||
if proxySchemes != nil {
|
|
||||||
if f, ok := proxySchemes[u.Scheme]; ok {
|
|
||||||
return f(u, forward)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
allProxyEnv = &envOnce{
|
|
||||||
names: []string{"ALL_PROXY", "all_proxy"},
|
|
||||||
}
|
|
||||||
noProxyEnv = &envOnce{
|
|
||||||
names: []string{"NO_PROXY", "no_proxy"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// envOnce looks up an environment variable (optionally by multiple
|
|
||||||
// names) once. It mitigates expensive lookups on some platforms
|
|
||||||
// (e.g. Windows).
|
|
||||||
// (Borrowed from net/http/transport.go)
|
|
||||||
type envOnce struct {
|
|
||||||
names []string
|
|
||||||
once sync.Once
|
|
||||||
val string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *envOnce) Get() string {
|
|
||||||
e.once.Do(e.init)
|
|
||||||
return e.val
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *envOnce) init() {
|
|
||||||
for _, n := range e.names {
|
|
||||||
e.val = os.Getenv(n)
|
|
||||||
if e.val != "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// reset is used by tests
|
|
||||||
func (e *envOnce) reset() {
|
|
||||||
e.once = sync.Once{}
|
|
||||||
e.val = ""
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
// Copyright 2011 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"golang.org/x/net/internal/socks"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given
|
|
||||||
// address with an optional username and password.
|
|
||||||
// See RFC 1928 and RFC 1929.
|
|
||||||
func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) {
|
|
||||||
d := socks.NewDialer(network, address)
|
|
||||||
if forward != nil {
|
|
||||||
if f, ok := forward.(ContextDialer); ok {
|
|
||||||
d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
|
|
||||||
return f.DialContext(ctx, network, address)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
|
|
||||||
return dialContext(ctx, forward, network, address)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if auth != nil {
|
|
||||||
up := socks.UsernamePassword{
|
|
||||||
Username: auth.User,
|
|
||||||
Password: auth.Password,
|
|
||||||
}
|
|
||||||
d.AuthMethods = []socks.AuthMethod{
|
|
||||||
socks.AuthMethodNotRequired,
|
|
||||||
socks.AuthMethodUsernamePassword,
|
|
||||||
}
|
|
||||||
d.Authenticate = up.Authenticate
|
|
||||||
}
|
|
||||||
return d, nil
|
|
||||||
}
|
|
|
@ -1,4 +0,0 @@
|
||||||
# golang.org/x/net v0.23.0
|
|
||||||
## explicit; go 1.18
|
|
||||||
golang.org/x/net/internal/socks
|
|
||||||
golang.org/x/net/proxy
|
|
Loading…
Reference in New Issue