mirror of https://github.com/gorilla/websocket.git
Update go version & add verification/testing tools (#840)
Fixes # **Summary of Changes** 1. 2. 3. > PS: Make sure your PR includes/updates tests! If you need help with this part, just ask!
This commit is contained in:
parent
80393295c1
commit
666c197fc9
|
@ -1,70 +0,0 @@
|
||||||
version: 2.1
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
"test":
|
|
||||||
parameters:
|
|
||||||
version:
|
|
||||||
type: string
|
|
||||||
default: "latest"
|
|
||||||
golint:
|
|
||||||
type: boolean
|
|
||||||
default: true
|
|
||||||
modules:
|
|
||||||
type: boolean
|
|
||||||
default: true
|
|
||||||
goproxy:
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
docker:
|
|
||||||
- image: "cimg/go:<< parameters.version >>"
|
|
||||||
working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket
|
|
||||||
environment:
|
|
||||||
GO111MODULE: "on"
|
|
||||||
GOPROXY: "<< parameters.goproxy >>"
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: "Print the Go version"
|
|
||||||
command: >
|
|
||||||
go version
|
|
||||||
- run:
|
|
||||||
name: "Fetch dependencies"
|
|
||||||
command: >
|
|
||||||
if [[ << parameters.modules >> = true ]]; then
|
|
||||||
go mod download
|
|
||||||
export GO111MODULE=on
|
|
||||||
else
|
|
||||||
go get -v ./...
|
|
||||||
fi
|
|
||||||
# Only run gofmt, vet & lint against the latest Go version
|
|
||||||
- run:
|
|
||||||
name: "Run golint"
|
|
||||||
command: >
|
|
||||||
if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
|
|
||||||
go get -u golang.org/x/lint/golint
|
|
||||||
golint ./...
|
|
||||||
fi
|
|
||||||
- run:
|
|
||||||
name: "Run gofmt"
|
|
||||||
command: >
|
|
||||||
if [[ << parameters.version >> = "latest" ]]; then
|
|
||||||
diff -u <(echo -n) <(gofmt -d -e .)
|
|
||||||
fi
|
|
||||||
- run:
|
|
||||||
name: "Run go vet"
|
|
||||||
command: >
|
|
||||||
if [[ << parameters.version >> = "latest" ]]; then
|
|
||||||
go vet -v ./...
|
|
||||||
fi
|
|
||||||
- run:
|
|
||||||
name: "Run go test (+ race detector)"
|
|
||||||
command: >
|
|
||||||
go test -v -race ./...
|
|
||||||
|
|
||||||
workflows:
|
|
||||||
tests:
|
|
||||||
jobs:
|
|
||||||
- test:
|
|
||||||
matrix:
|
|
||||||
parameters:
|
|
||||||
version: ["1.18", "1.17", "1.16"]
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
; https://editorconfig.org/
|
||||||
|
|
||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 2
|
||||||
|
|
||||||
|
[{Makefile,go.mod,go.sum,*.go,.gitmodules}]
|
||||||
|
indent_style = tab
|
||||||
|
indent_size = 4
|
||||||
|
|
||||||
|
[*.md]
|
||||||
|
indent_size = 4
|
||||||
|
trim_trailing_whitespace = false
|
||||||
|
|
||||||
|
eclint_indent_style = unset
|
|
@ -1,7 +0,0 @@
|
||||||
# Config for https://github.com/apps/release-drafter
|
|
||||||
template: |
|
|
||||||
|
|
||||||
<summary of changes here>
|
|
||||||
|
|
||||||
## CHANGELOG
|
|
||||||
$CHANGES
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Add issues or pull-requests created to the project.
|
||||||
|
name: Add issue or pull request to Project
|
||||||
|
|
||||||
|
on:
|
||||||
|
issues:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
pull_request_target:
|
||||||
|
types:
|
||||||
|
- opened
|
||||||
|
- reopened
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
add-to-project:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Add issue to project
|
||||||
|
uses: actions/add-to-project@v0.5.0
|
||||||
|
with:
|
||||||
|
project-url: https://github.com/orgs/gorilla/projects/4
|
||||||
|
github-token: ${{ secrets.ADD_TO_PROJECT_TOKEN }}
|
|
@ -0,0 +1,55 @@
|
||||||
|
name: CI
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
verify-and-test:
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
go: ['1.19','1.20']
|
||||||
|
os: [ubuntu-latest, macos-latest, windows-latest]
|
||||||
|
fail-fast: true
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout Code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Setup Go ${{ matrix.go }}
|
||||||
|
uses: actions/setup-go@v4
|
||||||
|
with:
|
||||||
|
go-version: ${{ matrix.go }}
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Run GolangCI-Lint
|
||||||
|
uses: golangci/golangci-lint-action@v3
|
||||||
|
with:
|
||||||
|
version: v1.53
|
||||||
|
args: --timeout=5m
|
||||||
|
|
||||||
|
- name: Run GoSec
|
||||||
|
if: matrix.os == 'ubuntu-latest'
|
||||||
|
uses: securego/gosec@master
|
||||||
|
with:
|
||||||
|
args: -exclude-dir examples ./...
|
||||||
|
|
||||||
|
- name: Run GoVulnCheck
|
||||||
|
uses: golang/govulncheck-action@v1
|
||||||
|
with:
|
||||||
|
go-version-input: ${{ matrix.go }}
|
||||||
|
go-package: ./...
|
||||||
|
|
||||||
|
- name: Run Tests
|
||||||
|
run: go test -race -cover -coverprofile=coverage -covermode=atomic -v ./...
|
||||||
|
|
||||||
|
- name: Upload coverage to Codecov
|
||||||
|
uses: codecov/codecov-action@v3
|
||||||
|
with:
|
||||||
|
files: ./coverage
|
|
@ -1,25 +1 @@
|
||||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
coverage.coverprofile
|
||||||
*.o
|
|
||||||
*.a
|
|
||||||
*.so
|
|
||||||
|
|
||||||
# Folders
|
|
||||||
_obj
|
|
||||||
_test
|
|
||||||
|
|
||||||
# Architecture specific extensions/prefixes
|
|
||||||
*.[568vq]
|
|
||||||
[568vq].out
|
|
||||||
|
|
||||||
*.cgo1.go
|
|
||||||
*.cgo2.c
|
|
||||||
_cgo_defun.c
|
|
||||||
_cgo_gotypes.go
|
|
||||||
_cgo_export.*
|
|
||||||
|
|
||||||
_testmain.go
|
|
||||||
|
|
||||||
*.exe
|
|
||||||
|
|
||||||
.idea/
|
|
||||||
*.iml
|
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
run:
|
||||||
|
skip-dirs:
|
||||||
|
- examples/*.go
|
9
AUTHORS
9
AUTHORS
|
@ -1,9 +0,0 @@
|
||||||
# This is the official list of Gorilla WebSocket authors for copyright
|
|
||||||
# purposes.
|
|
||||||
#
|
|
||||||
# Please keep the list sorted.
|
|
||||||
|
|
||||||
Gary Burd <gary@beagledreams.com>
|
|
||||||
Google LLC (https://opensource.google.com/)
|
|
||||||
Joachim Bauch <mail@joachim-bauch.de>
|
|
||||||
|
|
39
LICENSE
39
LICENSE
|
@ -1,22 +1,27 @@
|
||||||
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
|
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
Redistribution and use in source and binary forms, with or without
|
||||||
modification, are permitted provided that the following conditions are met:
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
Redistributions of source code must retain the above copyright notice, this
|
* Redistributions of source code must retain the above copyright
|
||||||
list of conditions and the following disclaimer.
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
Redistributions in binary form must reproduce the above copyright notice,
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
this list of conditions and the following disclaimer in the documentation
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
and/or other materials provided with the distribution.
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
||||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
||||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
GO_LINT=$(shell which golangci-lint 2> /dev/null || echo '')
|
||||||
|
GO_LINT_URI=github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||||
|
|
||||||
|
GO_SEC=$(shell which gosec 2> /dev/null || echo '')
|
||||||
|
GO_SEC_URI=github.com/securego/gosec/v2/cmd/gosec@latest
|
||||||
|
|
||||||
|
GO_VULNCHECK=$(shell which govulncheck 2> /dev/null || echo '')
|
||||||
|
GO_VULNCHECK_URI=golang.org/x/vuln/cmd/govulncheck@latest
|
||||||
|
|
||||||
|
.PHONY: golangci-lint
|
||||||
|
golangci-lint:
|
||||||
|
$(if $(GO_LINT), ,go install $(GO_LINT_URI))
|
||||||
|
@echo "##### Running golangci-lint"
|
||||||
|
golangci-lint run -v
|
||||||
|
|
||||||
|
.PHONY: gosec
|
||||||
|
gosec:
|
||||||
|
$(if $(GO_SEC), ,go install $(GO_SEC_URI))
|
||||||
|
@echo "##### Running gosec"
|
||||||
|
gosec -exclude-dir examples ./...
|
||||||
|
|
||||||
|
.PHONY: govulncheck
|
||||||
|
govulncheck:
|
||||||
|
$(if $(GO_VULNCHECK), ,go install $(GO_VULNCHECK_URI))
|
||||||
|
@echo "##### Running govulncheck"
|
||||||
|
govulncheck ./...
|
||||||
|
|
||||||
|
.PHONY: verify
|
||||||
|
verify: golangci-lint gosec govulncheck
|
||||||
|
|
||||||
|
.PHONY: test
|
||||||
|
test:
|
||||||
|
@echo "##### Running tests"
|
||||||
|
go test -race -cover -coverprofile=coverage.coverprofile -covermode=atomic -v ./...
|
14
README.md
14
README.md
|
@ -1,10 +1,13 @@
|
||||||
# Gorilla WebSocket
|
# gorilla/websocket
|
||||||
|
|
||||||
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
|
![testing](https://github.com/gorilla/websocket/actions/workflows/test.yml/badge.svg)
|
||||||
[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket)
|
[![codecov](https://codecov.io/github/gorilla/websocket/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/websocket)
|
||||||
|
[![godoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
|
||||||
|
[![sourcegraph](https://sourcegraph.com/github.com/gorilla/websocket/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/websocket?badge)
|
||||||
|
|
||||||
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
|
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
|
||||||
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
|
|
||||||
|
![Gorilla Logo](https://github.com/gorilla/.github/assets/53367916/d92caabf-98e0-473e-bfbf-ab554ba435e5)
|
||||||
|
|
||||||
|
|
||||||
### Documentation
|
### Documentation
|
||||||
|
@ -31,4 +34,3 @@ package API is stable.
|
||||||
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
||||||
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
||||||
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
|
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
|
||||||
|
|
||||||
|
|
26
client.go
26
client.go
|
@ -11,13 +11,16 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"log"
|
||||||
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrBadHandshake is returned when the server response to opening handshake is
|
// ErrBadHandshake is returned when the server response to opening handshake is
|
||||||
|
@ -225,6 +228,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
k == "Connection" ||
|
k == "Connection" ||
|
||||||
k == "Sec-Websocket-Key" ||
|
k == "Sec-Websocket-Key" ||
|
||||||
k == "Sec-Websocket-Version" ||
|
k == "Sec-Websocket-Version" ||
|
||||||
|
//#nosec G101 (CWE-798): Potential HTTP request smuggling via parameter pollution
|
||||||
k == "Sec-Websocket-Extensions" ||
|
k == "Sec-Websocket-Extensions" ||
|
||||||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
|
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
|
||||||
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
|
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
|
||||||
|
@ -290,7 +294,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
}
|
}
|
||||||
err = c.SetDeadline(deadline)
|
err = c.SetDeadline(deadline)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Close()
|
if err := c.Close(); err != nil {
|
||||||
|
log.Printf("websocket: failed to close network connection: %v", err)
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
|
@ -304,7 +310,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if proxyURL != nil {
|
if proxyURL != nil {
|
||||||
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
|
dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -330,7 +336,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if netConn != nil {
|
if netConn != nil {
|
||||||
netConn.Close()
|
if err := netConn.Close(); err != nil {
|
||||||
|
log.Printf("websocket: failed to close network connection: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -400,7 +408,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
// debugging.
|
// debugging.
|
||||||
buf := make([]byte, 1024)
|
buf := make([]byte, 1024)
|
||||||
n, _ := io.ReadFull(resp.Body, buf)
|
n, _ := io.ReadFull(resp.Body, buf)
|
||||||
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
|
||||||
return nil, resp, ErrBadHandshake
|
return nil, resp, ErrBadHandshake
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -418,17 +426,19 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||||
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||||
|
|
||||||
netConn.SetDeadline(time.Time{})
|
if err := netConn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
netConn = nil // to avoid close in defer.
|
netConn = nil // to avoid close in defer.
|
||||||
return conn, resp, nil
|
return conn, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return &tls.Config{}
|
return &tls.Config{MinVersion: tls.VersionTLS12}
|
||||||
}
|
}
|
||||||
return cfg.Clone()
|
return cfg.Clone()
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -50,7 +49,6 @@ type cstHandler struct{ *testing.T }
|
||||||
type cstServer struct {
|
type cstServer struct {
|
||||||
*httptest.Server
|
*httptest.Server
|
||||||
URL string
|
URL string
|
||||||
t *testing.T
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -94,7 +92,6 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
|
ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("Upgrade: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
|
@ -106,20 +103,16 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
op, rd, err := ws.NextReader()
|
op, rd, err := ws.NextReader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("NextReader: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
wr, err := ws.NextWriter(op)
|
wr, err := ws.NextWriter(op)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("NextWriter: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err = io.Copy(wr, rd); err != nil {
|
if _, err = io.Copy(wr, rd); err != nil {
|
||||||
t.Logf("NextWriter: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := wr.Close(); err != nil {
|
if err := wr.Close(); err != nil {
|
||||||
t.Logf("Close: %v", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -531,7 +524,9 @@ func TestRespOnBadHandshake(t *testing.T) {
|
||||||
|
|
||||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(expectedStatus)
|
w.WriteHeader(expectedStatus)
|
||||||
io.WriteString(w, expectedBody)
|
if _, err := io.WriteString(w, expectedBody); err != nil {
|
||||||
|
t.Fatalf("WriteString: %v", err)
|
||||||
|
}
|
||||||
}))
|
}))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -549,7 +544,7 @@ func TestRespOnBadHandshake(t *testing.T) {
|
||||||
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := ioutil.ReadAll(resp.Body)
|
p, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
||||||
}
|
}
|
||||||
|
@ -564,7 +559,6 @@ type testLogWriter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w testLogWriter) Write(p []byte) (int, error) {
|
func (w testLogWriter) Write(p []byte) (int, error) {
|
||||||
w.t.Logf("%s", p)
|
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -781,7 +775,10 @@ func TestSocksProxyDial(t *testing.T) {
|
||||||
}
|
}
|
||||||
defer c1.Close()
|
defer c1.Close()
|
||||||
|
|
||||||
c1.SetDeadline(time.Now().Add(30 * time.Second))
|
if err := c1.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||||
|
t.Errorf("set deadline failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
buf := make([]byte, 32)
|
buf := make([]byte, 32)
|
||||||
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
|
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
|
||||||
|
@ -820,10 +817,15 @@ func TestSocksProxyDial(t *testing.T) {
|
||||||
defer c2.Close()
|
defer c2.Close()
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
io.Copy(c1, c2)
|
if _, err := io.Copy(c1, c2); err != nil {
|
||||||
|
t.Errorf("copy failed: %v", err)
|
||||||
|
}
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
io.Copy(c2, c1)
|
if _, err := io.Copy(c2, c1); err != nil {
|
||||||
|
t.Errorf("copy failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
<-done
|
<-done
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"compress/flate"
|
"compress/flate"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
@ -33,7 +34,9 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
|
||||||
"\x01\x00\x00\xff\xff"
|
"\x01\x00\x00\xff\xff"
|
||||||
|
|
||||||
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
fr, _ := flateReaderPool.Get().(io.ReadCloser)
|
||||||
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
|
if err := fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
return &flateReadWrapper{fr}
|
return &flateReadWrapper{fr}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,7 +135,9 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) {
|
||||||
// Preemptively place the reader back in the pool. This helps with
|
// Preemptively place the reader back in the pool. This helps with
|
||||||
// scenarios where the application does not call NextReader() soon after
|
// scenarios where the application does not call NextReader() soon after
|
||||||
// this final read.
|
// this final read.
|
||||||
r.Close()
|
if err := r.Close(); err != nil {
|
||||||
|
log.Printf("websocket: flateReadWrapper.Close() returned error: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,7 +22,9 @@ func TestTruncWriter(t *testing.T) {
|
||||||
if m > n {
|
if m > n {
|
||||||
m = n
|
m = n
|
||||||
}
|
}
|
||||||
w.Write(p[:m])
|
if _, err := w.Write(p[:m]); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
p = p[m:]
|
p = p[m:]
|
||||||
}
|
}
|
||||||
if b.String() != data[:len(data)-len(w.p)] {
|
if b.String() != data[:len(data)-len(w.p)] {
|
||||||
|
@ -42,25 +43,29 @@ func textMessages(num int) [][]byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWriteNoCompression(b *testing.B) {
|
func BenchmarkWriteNoCompression(b *testing.B) {
|
||||||
w := ioutil.Discard
|
w := io.Discard
|
||||||
c := newTestConn(nil, w, false)
|
c := newTestConn(nil, w, false)
|
||||||
messages := textMessages(100)
|
messages := textMessages(100)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
c.WriteMessage(TextMessage, messages[i%len(messages)])
|
if err := c.WriteMessage(TextMessage, messages[i%len(messages)]); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWriteWithCompression(b *testing.B) {
|
func BenchmarkWriteWithCompression(b *testing.B) {
|
||||||
w := ioutil.Discard
|
w := io.Discard
|
||||||
c := newTestConn(nil, w, false)
|
c := newTestConn(nil, w, false)
|
||||||
messages := textMessages(100)
|
messages := textMessages(100)
|
||||||
c.enableWriteCompression = true
|
c.enableWriteCompression = true
|
||||||
c.newCompressionWriter = compressNoContextTakeover
|
c.newCompressionWriter = compressNoContextTakeover
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
c.WriteMessage(TextMessage, messages[i%len(messages)])
|
if err := c.WriteMessage(TextMessage, messages[i%len(messages)]); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
}
|
}
|
||||||
|
|
75
conn.go
75
conn.go
|
@ -6,11 +6,11 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"log"
|
||||||
"math/rand"
|
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -181,13 +181,20 @@ var (
|
||||||
errInvalidControlFrame = errors.New("websocket: invalid control frame")
|
errInvalidControlFrame = errors.New("websocket: invalid control frame")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// maskRand is an io.Reader for generating mask bytes. The reader is initialized
|
||||||
|
// to crypto/rand Reader. Tests swap the reader to a math/rand reader for
|
||||||
|
// reproducible results.
|
||||||
|
var maskRand = rand.Reader
|
||||||
|
|
||||||
|
// newMaskKey returns a new 32 bit value for masking client frames.
|
||||||
func newMaskKey() [4]byte {
|
func newMaskKey() [4]byte {
|
||||||
n := rand.Uint32()
|
var k [4]byte
|
||||||
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
|
_, _ = io.ReadFull(maskRand, k[:])
|
||||||
|
return k
|
||||||
}
|
}
|
||||||
|
|
||||||
func hideTempErr(err error) error {
|
func hideTempErr(err error) error {
|
||||||
if e, ok := err.(net.Error); ok && e.Temporary() {
|
if e, ok := err.(net.Error); ok {
|
||||||
err = &netError{msg: e.Error(), timeout: e.Timeout()}
|
err = &netError{msg: e.Error(), timeout: e.Timeout()}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
@ -372,7 +379,9 @@ func (c *Conn) read(n int) ([]byte, error) {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
err = errUnexpectedEOF
|
err = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
c.br.Discard(len(p))
|
if _, err := c.br.Discard(len(p)); err != nil {
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
return p, err
|
return p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -387,7 +396,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn.SetWriteDeadline(deadline)
|
if err := c.conn.SetWriteDeadline(deadline); err != nil {
|
||||||
|
return c.writeFatal(err)
|
||||||
|
}
|
||||||
if len(buf1) == 0 {
|
if len(buf1) == 0 {
|
||||||
_, err = c.conn.Write(buf0)
|
_, err = c.conn.Write(buf0)
|
||||||
} else {
|
} else {
|
||||||
|
@ -397,7 +408,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
|
||||||
return c.writeFatal(err)
|
return c.writeFatal(err)
|
||||||
}
|
}
|
||||||
if frameType == CloseMessage {
|
if frameType == CloseMessage {
|
||||||
c.writeFatal(ErrCloseSent)
|
_ = c.writeFatal(ErrCloseSent)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -438,7 +449,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
||||||
|
|
||||||
d := 1000 * time.Hour
|
d := 1000 * time.Hour
|
||||||
if !deadline.IsZero() {
|
if !deadline.IsZero() {
|
||||||
d = deadline.Sub(time.Now())
|
d = time.Until(deadline)
|
||||||
if d < 0 {
|
if d < 0 {
|
||||||
return errWriteTimeout
|
return errWriteTimeout
|
||||||
}
|
}
|
||||||
|
@ -460,13 +471,15 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn.SetWriteDeadline(deadline)
|
if err := c.conn.SetWriteDeadline(deadline); err != nil {
|
||||||
|
return c.writeFatal(err)
|
||||||
|
}
|
||||||
_, err = c.conn.Write(buf)
|
_, err = c.conn.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.writeFatal(err)
|
return c.writeFatal(err)
|
||||||
}
|
}
|
||||||
if messageType == CloseMessage {
|
if messageType == CloseMessage {
|
||||||
c.writeFatal(ErrCloseSent)
|
_ = c.writeFatal(ErrCloseSent)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -477,7 +490,9 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
|
||||||
// probably better to return an error in this situation, but we cannot
|
// probably better to return an error in this situation, but we cannot
|
||||||
// change this without breaking existing applications.
|
// change this without breaking existing applications.
|
||||||
if c.writer != nil {
|
if c.writer != nil {
|
||||||
c.writer.Close()
|
if err := c.writer.Close(); err != nil {
|
||||||
|
log.Printf("websocket: discarding writer close error: %v", err)
|
||||||
|
}
|
||||||
c.writer = nil
|
c.writer = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -630,7 +645,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if final {
|
if final {
|
||||||
w.endMessage(errWriteClosed)
|
_ = w.endMessage(errWriteClosed)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -795,7 +810,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
// 1. Skip remainder of previous frame.
|
// 1. Skip remainder of previous frame.
|
||||||
|
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
|
if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -817,7 +832,9 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
rsv2 := p[0]&rsv2Bit != 0
|
rsv2 := p[0]&rsv2Bit != 0
|
||||||
rsv3 := p[0]&rsv3Bit != 0
|
rsv3 := p[0]&rsv3Bit != 0
|
||||||
mask := p[1]&maskBit != 0
|
mask := p[1]&maskBit != 0
|
||||||
c.setReadRemaining(int64(p[1] & 0x7f))
|
if err := c.setReadRemaining(int64(p[1] & 0x7f)); err != nil {
|
||||||
|
return noFrame, err
|
||||||
|
}
|
||||||
|
|
||||||
c.readDecompress = false
|
c.readDecompress = false
|
||||||
if rsv1 {
|
if rsv1 {
|
||||||
|
@ -922,7 +939,9 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.readLimit > 0 && c.readLength > c.readLimit {
|
if c.readLimit > 0 && c.readLength > c.readLimit {
|
||||||
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
|
if err := c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)); err != nil {
|
||||||
|
return noFrame, err
|
||||||
|
}
|
||||||
return noFrame, ErrReadLimit
|
return noFrame, ErrReadLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -934,7 +953,9 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
var payload []byte
|
var payload []byte
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
payload, err = c.read(int(c.readRemaining))
|
payload, err = c.read(int(c.readRemaining))
|
||||||
c.setReadRemaining(0)
|
if err := c.setReadRemaining(0); err != nil {
|
||||||
|
return noFrame, err
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
|
@ -981,7 +1002,9 @@ func (c *Conn) handleProtocolError(message string) error {
|
||||||
if len(data) > maxControlFramePayloadSize {
|
if len(data) > maxControlFramePayloadSize {
|
||||||
data = data[:maxControlFramePayloadSize]
|
data = data[:maxControlFramePayloadSize]
|
||||||
}
|
}
|
||||||
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
|
if err := c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return errors.New("websocket: " + message)
|
return errors.New("websocket: " + message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -998,7 +1021,9 @@ func (c *Conn) handleProtocolError(message string) error {
|
||||||
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
// Close previous reader, only relevant for decompression.
|
// Close previous reader, only relevant for decompression.
|
||||||
if c.reader != nil {
|
if c.reader != nil {
|
||||||
c.reader.Close()
|
if err := c.reader.Close(); err != nil {
|
||||||
|
log.Printf("websocket: discarding reader close error: %v", err)
|
||||||
|
}
|
||||||
c.reader = nil
|
c.reader = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1054,7 +1079,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
}
|
}
|
||||||
rem := c.readRemaining
|
rem := c.readRemaining
|
||||||
rem -= int64(n)
|
rem -= int64(n)
|
||||||
c.setReadRemaining(rem)
|
if err := c.setReadRemaining(rem); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
if c.readRemaining > 0 && c.readErr == io.EOF {
|
if c.readRemaining > 0 && c.readErr == io.EOF {
|
||||||
c.readErr = errUnexpectedEOF
|
c.readErr = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
@ -1094,7 +1121,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return messageType, nil, err
|
return messageType, nil, err
|
||||||
}
|
}
|
||||||
p, err = ioutil.ReadAll(r)
|
p, err = io.ReadAll(r)
|
||||||
return messageType, p, err
|
return messageType, p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1136,7 +1163,9 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
|
||||||
if h == nil {
|
if h == nil {
|
||||||
h = func(code int, text string) error {
|
h = func(code int, text string) error {
|
||||||
message := FormatCloseMessage(code, "")
|
message := FormatCloseMessage(code, "")
|
||||||
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
|
if err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1161,7 +1190,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {
|
||||||
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
|
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
|
||||||
if err == ErrCloseSent {
|
if err == ErrCloseSent {
|
||||||
return nil
|
return nil
|
||||||
} else if e, ok := err.(net.Error); ok && e.Temporary() {
|
} else if _, ok := err.(net.Error); ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -6,7 +6,6 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -45,7 +44,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
|
||||||
|
|
||||||
func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
|
func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
|
||||||
bench := &broadcastBench{
|
bench := &broadcastBench{
|
||||||
w: ioutil.Discard,
|
w: io.Discard,
|
||||||
doneCh: make(chan struct{}),
|
doneCh: make(chan struct{}),
|
||||||
closeCh: make(chan struct{}),
|
closeCh: make(chan struct{}),
|
||||||
usePrepared: usePrepared,
|
usePrepared: usePrepared,
|
||||||
|
@ -70,9 +69,13 @@ func (b *broadcastBench) makeConns(numConns int) {
|
||||||
select {
|
select {
|
||||||
case msg := <-c.msgCh:
|
case msg := <-c.msgCh:
|
||||||
if msg.prepared != nil {
|
if msg.prepared != nil {
|
||||||
c.conn.WritePreparedMessage(msg.prepared)
|
if err := c.conn.WritePreparedMessage(msg.prepared); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
c.conn.WriteMessage(TextMessage, msg.payload)
|
if err := c.conn.WriteMessage(TextMessage, msg.payload); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
val := atomic.AddInt32(&b.count, 1)
|
val := atomic.AddInt32(&b.count, 1)
|
||||||
if val%int32(numConns) == 0 {
|
if val%int32(numConns) == 0 {
|
||||||
|
|
75
conn_test.go
75
conn_test.go
|
@ -10,7 +10,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -124,8 +123,7 @@ func TestFraming(t *testing.T) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("frame size: %d", n)
|
rbuf, err := io.ReadAll(r)
|
||||||
rbuf, err := ioutil.ReadAll(r)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
||||||
continue
|
continue
|
||||||
|
@ -158,7 +156,10 @@ func TestControl(t *testing.T) {
|
||||||
wc := newTestConn(nil, &connBuf, isServer)
|
wc := newTestConn(nil, &connBuf, isServer)
|
||||||
rc := newTestConn(&connBuf, nil, !isServer)
|
rc := newTestConn(&connBuf, nil, !isServer)
|
||||||
if isWriteControl {
|
if isWriteControl {
|
||||||
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
|
if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
|
||||||
|
t.Errorf("%s: wc.WriteControl() returned %v", name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
w, err := wc.NextWriter(PongMessage)
|
w, err := wc.NextWriter(PongMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -175,7 +176,9 @@ func TestControl(t *testing.T) {
|
||||||
}
|
}
|
||||||
var actualMessage string
|
var actualMessage string
|
||||||
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
|
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
|
||||||
rc.NextReader()
|
if _, _, err := rc.NextReader(); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if actualMessage != message {
|
if actualMessage != message {
|
||||||
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
|
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
|
||||||
continue
|
continue
|
||||||
|
@ -359,15 +362,19 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||||
rc := newTestConn(&b1, &b2, true)
|
rc := newTestConn(&b1, &b2, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize+bufSize/2))
|
if _, err := w.Write(make([]byte, bufSize+bufSize/2)); err != nil {
|
||||||
wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
|
t.Fatalf("w.Write() returned %v", err)
|
||||||
|
}
|
||||||
|
if err := wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)); err != nil {
|
||||||
|
t.Fatalf("wc.WriteControl() returned %v", err)
|
||||||
|
}
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
op, r, err := rc.NextReader()
|
op, r, err := rc.NextReader()
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(ioutil.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if !reflect.DeepEqual(err, expectedErr) {
|
if !reflect.DeepEqual(err, expectedErr) {
|
||||||
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
|
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
|
||||||
}
|
}
|
||||||
|
@ -386,7 +393,9 @@ func TestEOFWithinFrame(t *testing.T) {
|
||||||
rc := newTestConn(&b, nil, true)
|
rc := newTestConn(&b, nil, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize))
|
if _, err := w.Write(make([]byte, bufSize)); err != nil {
|
||||||
|
t.Fatalf("%d: w.Write() returned %v", n, err)
|
||||||
|
}
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
if n >= b.Len() {
|
if n >= b.Len() {
|
||||||
|
@ -401,7 +410,7 @@ func TestEOFWithinFrame(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(ioutil.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if err != errUnexpectedEOF {
|
if err != errUnexpectedEOF {
|
||||||
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||||
}
|
}
|
||||||
|
@ -420,13 +429,15 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
rc := newTestConn(&b1, &b2, true)
|
rc := newTestConn(&b1, &b2, true)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(make([]byte, bufSize+bufSize/2))
|
if _, err := w.Write(make([]byte, bufSize+bufSize/2)); err != nil {
|
||||||
|
t.Fatalf("w.Write() returned %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
op, r, err := rc.NextReader()
|
op, r, err := rc.NextReader()
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(ioutil.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if err != errUnexpectedEOF {
|
if err != errUnexpectedEOF {
|
||||||
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
||||||
}
|
}
|
||||||
|
@ -439,7 +450,9 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
wc := newTestConn(nil, &bytes.Buffer{}, false)
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
io.WriteString(w, "hello")
|
if _, err := io.WriteString(w, "hello"); err != nil {
|
||||||
|
t.Fatalf("unexpected error writing, %v", err)
|
||||||
|
}
|
||||||
if err := w.Close(); err != nil {
|
if err := w.Close(); err != nil {
|
||||||
t.Fatalf("unxpected error closing message writer, %v", err)
|
t.Fatalf("unxpected error closing message writer, %v", err)
|
||||||
}
|
}
|
||||||
|
@ -449,7 +462,9 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
w, _ = wc.NextWriter(BinaryMessage)
|
w, _ = wc.NextWriter(BinaryMessage)
|
||||||
io.WriteString(w, "hello")
|
if _, err := io.WriteString(w, "hello"); err != nil {
|
||||||
|
t.Fatalf("unexpected error writing after getting new writer, %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// close w by getting next writer
|
// close w by getting next writer
|
||||||
_, err := wc.NextWriter(BinaryMessage)
|
_, err := wc.NextWriter(BinaryMessage)
|
||||||
|
@ -474,13 +489,21 @@ func TestReadLimit(t *testing.T) {
|
||||||
|
|
||||||
// Send message at the limit with interleaved pong.
|
// Send message at the limit with interleaved pong.
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(message[:readLimit-1])
|
if _, err := w.Write(message[:readLimit-1]); err != nil {
|
||||||
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
|
t.Fatalf("w.WriteMessage() returned %v", err)
|
||||||
w.Write(message[:1])
|
}
|
||||||
|
if err := wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)); err != nil {
|
||||||
|
t.Fatalf("wc.WriteControl() returned %v", err)
|
||||||
|
}
|
||||||
|
if _, err := w.Write(message[:1]); err != nil {
|
||||||
|
t.Fatalf("w.Write() returned %v", err)
|
||||||
|
}
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
// Send message larger than the limit.
|
// Send message larger than the limit.
|
||||||
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
|
if err := wc.WriteMessage(BinaryMessage, message[:readLimit+1]); err != nil {
|
||||||
|
t.Fatalf("wc.WriteMessage() returned %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
op, _, err := rc.NextReader()
|
op, _, err := rc.NextReader()
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
|
@ -490,7 +513,7 @@ func TestReadLimit(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(ioutil.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if err != ErrReadLimit {
|
if err != ErrReadLimit {
|
||||||
t.Fatalf("io.Copy() returned %v", err)
|
t.Fatalf("io.Copy() returned %v", err)
|
||||||
}
|
}
|
||||||
|
@ -593,7 +616,9 @@ func TestBufioReadBytes(t *testing.T) {
|
||||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
|
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
|
||||||
|
|
||||||
w, _ := wc.NextWriter(BinaryMessage)
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
w.Write(m)
|
if _, err := w.Write(m); err != nil {
|
||||||
|
t.Fatalf("w.Write() returned %v", err)
|
||||||
|
}
|
||||||
w.Close()
|
w.Close()
|
||||||
|
|
||||||
op, r, err := rc.NextReader()
|
op, r, err := rc.NextReader()
|
||||||
|
@ -667,7 +692,9 @@ func TestConcurrentWritePanic(t *testing.T) {
|
||||||
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
w := blockingWriter{make(chan struct{}), make(chan struct{})}
|
||||||
c := newTestConn(nil, w, false)
|
c := newTestConn(nil, w, false)
|
||||||
go func() {
|
go func() {
|
||||||
c.WriteMessage(TextMessage, []byte{})
|
if err := c.WriteMessage(TextMessage, []byte{}); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// wait for goroutine to block in write.
|
// wait for goroutine to block in write.
|
||||||
|
@ -680,7 +707,9 @@ func TestConcurrentWritePanic(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.WriteMessage(TextMessage, []byte{})
|
if err := c.WriteMessage(TextMessage, []byte{}); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
t.Fatal("should not get here")
|
t.Fatal("should not get here")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -700,7 +729,7 @@ func TestFailedConnectionReadPanic(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for i := 0; i < 20000; i++ {
|
for i := 0; i < 20000; i++ {
|
||||||
c.ReadMessage()
|
_, _, _ = c.ReadMessage()
|
||||||
}
|
}
|
||||||
t.Fatal("should not get here")
|
t.Fatal("should not get here")
|
||||||
}
|
}
|
||||||
|
|
|
@ -178,7 +178,11 @@ func main() {
|
||||||
http.HandleFunc("/r", echoReadAllWriter)
|
http.HandleFunc("/r", echoReadAllWriter)
|
||||||
http.HandleFunc("/m", echoReadAllWriteMessage)
|
http.HandleFunc("/m", echoReadAllWriteMessage)
|
||||||
http.HandleFunc("/p", echoReadAllWritePreparedMessage)
|
http.HandleFunc("/p", echoReadAllWritePreparedMessage)
|
||||||
err := http.ListenAndServe(*addr, nil)
|
server := &http.Server{
|
||||||
|
Addr: *addr,
|
||||||
|
ReadHeaderTimeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
err := server.ListenAndServe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("ListenAndServe: ", err)
|
log.Fatal("ListenAndServe: ", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"flag"
|
"flag"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var addr = flag.String("addr", ":8080", "http service address")
|
var addr = flag.String("addr", ":8080", "http service address")
|
||||||
|
@ -33,7 +34,11 @@ func main() {
|
||||||
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
|
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
|
||||||
serveWs(hub, w, r)
|
serveWs(hub, w, r)
|
||||||
})
|
})
|
||||||
err := http.ListenAndServe(*addr, nil)
|
server := &http.Server{
|
||||||
|
Addr: *addr,
|
||||||
|
ReadHeaderTimeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
err := server.ListenAndServe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("ListenAndServe: ", err)
|
log.Fatal("ListenAndServe: ", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -189,5 +189,9 @@ func main() {
|
||||||
}
|
}
|
||||||
http.HandleFunc("/", serveHome)
|
http.HandleFunc("/", serveHome)
|
||||||
http.HandleFunc("/ws", serveWs)
|
http.HandleFunc("/ws", serveWs)
|
||||||
log.Fatal(http.ListenAndServe(*addr, nil))
|
server := &http.Server{
|
||||||
|
Addr: *addr,
|
||||||
|
ReadHeaderTimeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
log.Fatal(server.ListenAndServe())
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -49,7 +50,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
|
||||||
if !fi.ModTime().After(lastMod) {
|
if !fi.ModTime().After(lastMod) {
|
||||||
return nil, lastMod, nil
|
return nil, lastMod, nil
|
||||||
}
|
}
|
||||||
p, err := ioutil.ReadFile(filename)
|
p, err := ioutil.ReadFile(filepath.Clean(filename))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fi.ModTime(), err
|
return nil, fi.ModTime(), err
|
||||||
}
|
}
|
||||||
|
@ -163,7 +164,11 @@ func main() {
|
||||||
filename = flag.Args()[0]
|
filename = flag.Args()[0]
|
||||||
http.HandleFunc("/", serveHome)
|
http.HandleFunc("/", serveHome)
|
||||||
http.HandleFunc("/ws", serveWs)
|
http.HandleFunc("/ws", serveWs)
|
||||||
if err := http.ListenAndServe(*addr, nil); err != nil {
|
server := &http.Server{
|
||||||
|
Addr: *addr,
|
||||||
|
ReadHeaderTimeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
if err := server.ListenAndServe(); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
4
go.mod
4
go.mod
|
@ -1,3 +1,5 @@
|
||||||
module github.com/gorilla/websocket
|
module github.com/gorilla/websocket
|
||||||
|
|
||||||
go 1.12
|
go 1.19
|
||||||
|
|
||||||
|
require golang.org/x/net v0.14.0
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -0,0 +1,2 @@
|
||||||
|
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
|
||||||
|
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
|
|
@ -19,7 +19,9 @@ func TestJoinMessages(t *testing.T) {
|
||||||
wc := newTestConn(nil, &connBuf, true)
|
wc := newTestConn(nil, &connBuf, true)
|
||||||
rc := newTestConn(&connBuf, nil, false)
|
rc := newTestConn(&connBuf, nil, false)
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
wc.WriteMessage(BinaryMessage, []byte(m))
|
if err := wc.WriteMessage(BinaryMessage, []byte(m)); err != nil {
|
||||||
|
t.Fatalf("write %q: %v", m, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var result bytes.Buffer
|
var result bytes.Buffer
|
||||||
|
|
4
mask.go
4
mask.go
|
@ -9,6 +9,7 @@ package websocket
|
||||||
|
|
||||||
import "unsafe"
|
import "unsafe"
|
||||||
|
|
||||||
|
// #nosec G103 -- (CWE-242) Has been audited
|
||||||
const wordSize = int(unsafe.Sizeof(uintptr(0)))
|
const wordSize = int(unsafe.Sizeof(uintptr(0)))
|
||||||
|
|
||||||
func maskBytes(key [4]byte, pos int, b []byte) int {
|
func maskBytes(key [4]byte, pos int, b []byte) int {
|
||||||
|
@ -22,6 +23,7 @@ func maskBytes(key [4]byte, pos int, b []byte) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mask one byte at a time to word boundary.
|
// Mask one byte at a time to word boundary.
|
||||||
|
//#nosec G103 -- (CWE-242) Has been audited
|
||||||
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
|
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
|
||||||
n = wordSize - n
|
n = wordSize - n
|
||||||
for i := range b[:n] {
|
for i := range b[:n] {
|
||||||
|
@ -36,11 +38,13 @@ func maskBytes(key [4]byte, pos int, b []byte) int {
|
||||||
for i := range k {
|
for i := range k {
|
||||||
k[i] = key[(pos+i)&3]
|
k[i] = key[(pos+i)&3]
|
||||||
}
|
}
|
||||||
|
//#nosec G103 -- (CWE-242) Has been audited
|
||||||
kw := *(*uintptr)(unsafe.Pointer(&k))
|
kw := *(*uintptr)(unsafe.Pointer(&k))
|
||||||
|
|
||||||
// Mask one word at a time.
|
// Mask one word at a time.
|
||||||
n := (len(b) / wordSize) * wordSize
|
n := (len(b) / wordSize) * wordSize
|
||||||
for i := 0; i < n; i += wordSize {
|
for i := 0; i < n; i += wordSize {
|
||||||
|
//#nosec G103 -- (CWE-242) Has been audited
|
||||||
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
|
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,11 @@ var preparedMessageTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPreparedMessage(t *testing.T) {
|
func TestPreparedMessage(t *testing.T) {
|
||||||
|
testRand := rand.New(rand.NewSource(99))
|
||||||
|
prevMaskRand := maskRand
|
||||||
|
maskRand = testRand
|
||||||
|
defer func() { maskRand = prevMaskRand }()
|
||||||
|
|
||||||
for _, tt := range preparedMessageTests {
|
for _, tt := range preparedMessageTests {
|
||||||
var data = []byte("this is a test")
|
var data = []byte("this is a test")
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
@ -40,10 +45,13 @@ func TestPreparedMessage(t *testing.T) {
|
||||||
if tt.enableWriteCompression {
|
if tt.enableWriteCompression {
|
||||||
c.newCompressionWriter = compressNoContextTakeover
|
c.newCompressionWriter = compressNoContextTakeover
|
||||||
}
|
}
|
||||||
c.SetCompressionLevel(tt.compressionLevel)
|
|
||||||
|
if err := c.SetCompressionLevel(tt.compressionLevel); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
// Seed random number generator for consistent frame mask.
|
// Seed random number generator for consistent frame mask.
|
||||||
rand.Seed(1234)
|
testRand.Seed(1234)
|
||||||
|
|
||||||
if err := c.WriteMessage(tt.messageType, data); err != nil {
|
if err := c.WriteMessage(tt.messageType, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -59,7 +67,7 @@ func TestPreparedMessage(t *testing.T) {
|
||||||
copy(data, "hello world")
|
copy(data, "hello world")
|
||||||
|
|
||||||
// Seed random number generator for consistent frame mask.
|
// Seed random number generator for consistent frame mask.
|
||||||
rand.Seed(1234)
|
testRand.Seed(1234)
|
||||||
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
if err := c.WritePreparedMessage(pm); err != nil {
|
if err := c.WritePreparedMessage(pm); err != nil {
|
||||||
|
@ -68,7 +76,7 @@ func TestPreparedMessage(t *testing.T) {
|
||||||
got := buf.String()
|
got := buf.String()
|
||||||
|
|
||||||
if got != want {
|
if got != want {
|
||||||
t.Errorf("write message != prepared message for %+v", tt)
|
t.Errorf("write message != prepared message, got %#v, want %#v", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
17
proxy.go
17
proxy.go
|
@ -8,10 +8,13 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type netDialerFunc func(network, addr string) (net.Conn, error)
|
type netDialerFunc func(network, addr string) (net.Conn, error)
|
||||||
|
@ -21,7 +24,7 @@ func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
proxy.RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy.Dialer) (proxy.Dialer, error) {
|
||||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -55,7 +58,9 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := connectReq.Write(conn); err != nil {
|
if err := connectReq.Write(conn); err != nil {
|
||||||
conn.Close()
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Printf("httpProxyDialer: failed to close connection: %v", err)
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,12 +69,16 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
||||||
br := bufio.NewReader(conn)
|
br := bufio.NewReader(conn)
|
||||||
resp, err := http.ReadResponse(br, connectReq)
|
resp, err := http.ReadResponse(br, connectReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Printf("httpProxyDialer: failed to close connection: %v", err)
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
conn.Close()
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Printf("httpProxyDialer: failed to close connection: %v", err)
|
||||||
|
}
|
||||||
f := strings.SplitN(resp.Status, " ", 2)
|
f := strings.SplitN(resp.Status, " ", 2)
|
||||||
return nil, errors.New(f[1])
|
return nil, errors.New(f[1])
|
||||||
}
|
}
|
||||||
|
|
38
server.go
38
server.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -183,7 +184,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
}
|
}
|
||||||
|
|
||||||
if brw.Reader.Buffered() > 0 {
|
if brw.Reader.Buffered() > 0 {
|
||||||
netConn.Close()
|
if err := netConn.Close(); err != nil {
|
||||||
|
log.Printf("websocket: failed to close network connection: %v", err)
|
||||||
|
}
|
||||||
return nil, errors.New("websocket: client sent data before handshake is complete")
|
return nil, errors.New("websocket: client sent data before handshake is complete")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -248,17 +251,34 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
p = append(p, "\r\n"...)
|
p = append(p, "\r\n"...)
|
||||||
|
|
||||||
// Clear deadlines set by HTTP server.
|
// Clear deadlines set by HTTP server.
|
||||||
netConn.SetDeadline(time.Time{})
|
if err := netConn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
if err := netConn.Close(); err != nil {
|
||||||
|
log.Printf("websocket: failed to close network connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if u.HandshakeTimeout > 0 {
|
if u.HandshakeTimeout > 0 {
|
||||||
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
|
if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
|
||||||
|
if err := netConn.Close(); err != nil {
|
||||||
|
log.Printf("websocket: failed to close network connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, err = netConn.Write(p); err != nil {
|
if _, err = netConn.Write(p); err != nil {
|
||||||
netConn.Close()
|
if err := netConn.Close(); err != nil {
|
||||||
|
log.Printf("websocket: failed to close network connection: %v", err)
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if u.HandshakeTimeout > 0 {
|
if u.HandshakeTimeout > 0 {
|
||||||
netConn.SetWriteDeadline(time.Time{})
|
if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
|
||||||
|
if err := netConn.Close(); err != nil {
|
||||||
|
log.Printf("websocket: failed to close network connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
|
@ -356,8 +376,12 @@ func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
|
||||||
// bufio.Writer's underlying writer.
|
// bufio.Writer's underlying writer.
|
||||||
var wh writeHook
|
var wh writeHook
|
||||||
bw.Reset(&wh)
|
bw.Reset(&wh)
|
||||||
bw.WriteByte(0)
|
if err := bw.WriteByte(0); err != nil {
|
||||||
bw.Flush()
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := bw.Flush(); err != nil {
|
||||||
|
log.Printf("websocket: bufioWriterBuffer: Flush: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
bw.Reset(originalWriter)
|
bw.Reset(originalWriter)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
//go:build go1.17
|
|
||||||
// +build go1.17
|
|
||||||
|
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
//go:build !go1.17
|
|
||||||
// +build !go1.17
|
|
||||||
|
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
)
|
|
||||||
|
|
||||||
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
|
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !cfg.InsecureSkipVerify {
|
|
||||||
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
4
util.go
4
util.go
|
@ -6,7 +6,7 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha1"
|
"crypto/sha1" //#nosec G505 -- (CWE-327) https://datatracker.ietf.org/doc/html/rfc6455#page-54
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -17,7 +17,7 @@ import (
|
||||||
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||||
|
|
||||||
func computeAcceptKey(challengeKey string) string {
|
func computeAcceptKey(challengeKey string) string {
|
||||||
h := sha1.New()
|
h := sha1.New() //#nosec G401 -- (CWE-326) https://datatracker.ietf.org/doc/html/rfc6455#page-54
|
||||||
h.Write([]byte(challengeKey))
|
h.Write([]byte(challengeKey))
|
||||||
h.Write(keyGUID)
|
h.Write(keyGUID)
|
||||||
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,22 @@
|
||||||
|
Additional IP Rights Grant (Patents)
|
||||||
|
|
||||||
|
"This implementation" means the copyrightable works distributed by
|
||||||
|
Google as part of the Go project.
|
||||||
|
|
||||||
|
Google hereby grants to You a perpetual, worldwide, non-exclusive,
|
||||||
|
no-charge, royalty-free, irrevocable (except as stated in this section)
|
||||||
|
patent license to make, have made, use, offer to sell, sell, import,
|
||||||
|
transfer and otherwise run, modify and propagate the contents of this
|
||||||
|
implementation of Go, where such license applies only to those patent
|
||||||
|
claims, both currently owned or controlled by Google and acquired in
|
||||||
|
the future, licensable by Google that are necessarily infringed by this
|
||||||
|
implementation of Go. This grant does not include claims that would be
|
||||||
|
infringed only as a consequence of further modification of this
|
||||||
|
implementation. If you or your agent or exclusive licensee institute or
|
||||||
|
order or agree to the institution of patent litigation against any
|
||||||
|
entity (including a cross-claim or counterclaim in a lawsuit) alleging
|
||||||
|
that this implementation of Go or any code incorporated within this
|
||||||
|
implementation of Go constitutes direct or contributory patent
|
||||||
|
infringement, or inducement of patent infringement, then any patent
|
||||||
|
rights granted to you under this License for this implementation of Go
|
||||||
|
shall terminate as of the date such litigation is filed.
|
|
@ -0,0 +1,168 @@
|
||||||
|
// Copyright 2018 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package socks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
noDeadline = time.Time{}
|
||||||
|
aLongTimeAgo = time.Unix(1, 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
|
||||||
|
host, port, err := splitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
|
||||||
|
c.SetDeadline(deadline)
|
||||||
|
defer c.SetDeadline(noDeadline)
|
||||||
|
}
|
||||||
|
if ctx != context.Background() {
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer func() {
|
||||||
|
close(done)
|
||||||
|
if ctxErr == nil {
|
||||||
|
ctxErr = <-errCh
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.SetDeadline(aLongTimeAgo)
|
||||||
|
errCh <- ctx.Err()
|
||||||
|
case <-done:
|
||||||
|
errCh <- nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
|
||||||
|
b = append(b, Version5)
|
||||||
|
if len(d.AuthMethods) == 0 || d.Authenticate == nil {
|
||||||
|
b = append(b, 1, byte(AuthMethodNotRequired))
|
||||||
|
} else {
|
||||||
|
ams := d.AuthMethods
|
||||||
|
if len(ams) > 255 {
|
||||||
|
return nil, errors.New("too many authentication methods")
|
||||||
|
}
|
||||||
|
b = append(b, byte(len(ams)))
|
||||||
|
for _, am := range ams {
|
||||||
|
b = append(b, byte(am))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, ctxErr = c.Write(b); ctxErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if b[0] != Version5 {
|
||||||
|
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
|
||||||
|
}
|
||||||
|
am := AuthMethod(b[1])
|
||||||
|
if am == AuthMethodNoAcceptableMethods {
|
||||||
|
return nil, errors.New("no acceptable authentication methods")
|
||||||
|
}
|
||||||
|
if d.Authenticate != nil {
|
||||||
|
if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b = b[:0]
|
||||||
|
b = append(b, Version5, byte(d.cmd), 0)
|
||||||
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
b = append(b, AddrTypeIPv4)
|
||||||
|
b = append(b, ip4...)
|
||||||
|
} else if ip6 := ip.To16(); ip6 != nil {
|
||||||
|
b = append(b, AddrTypeIPv6)
|
||||||
|
b = append(b, ip6...)
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("unknown address type")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if len(host) > 255 {
|
||||||
|
return nil, errors.New("FQDN too long")
|
||||||
|
}
|
||||||
|
b = append(b, AddrTypeFQDN)
|
||||||
|
b = append(b, byte(len(host)))
|
||||||
|
b = append(b, host...)
|
||||||
|
}
|
||||||
|
b = append(b, byte(port>>8), byte(port))
|
||||||
|
if _, ctxErr = c.Write(b); ctxErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if b[0] != Version5 {
|
||||||
|
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
|
||||||
|
}
|
||||||
|
if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
|
||||||
|
return nil, errors.New("unknown error " + cmdErr.String())
|
||||||
|
}
|
||||||
|
if b[2] != 0 {
|
||||||
|
return nil, errors.New("non-zero reserved field")
|
||||||
|
}
|
||||||
|
l := 2
|
||||||
|
var a Addr
|
||||||
|
switch b[3] {
|
||||||
|
case AddrTypeIPv4:
|
||||||
|
l += net.IPv4len
|
||||||
|
a.IP = make(net.IP, net.IPv4len)
|
||||||
|
case AddrTypeIPv6:
|
||||||
|
l += net.IPv6len
|
||||||
|
a.IP = make(net.IP, net.IPv6len)
|
||||||
|
case AddrTypeFQDN:
|
||||||
|
if _, err := io.ReadFull(c, b[:1]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l += int(b[0])
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
|
||||||
|
}
|
||||||
|
if cap(b) < l {
|
||||||
|
b = make([]byte, l)
|
||||||
|
} else {
|
||||||
|
b = b[:l]
|
||||||
|
}
|
||||||
|
if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if a.IP != nil {
|
||||||
|
copy(a.IP, b)
|
||||||
|
} else {
|
||||||
|
a.Name = string(b[:len(b)-2])
|
||||||
|
}
|
||||||
|
a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
|
||||||
|
return &a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitHostPort(address string) (string, int, error) {
|
||||||
|
host, port, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
portnum, err := strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
if 1 > portnum || portnum > 0xffff {
|
||||||
|
return "", 0, errors.New("port number out of range " + port)
|
||||||
|
}
|
||||||
|
return host, portnum, nil
|
||||||
|
}
|
|
@ -0,0 +1,317 @@
|
||||||
|
// Copyright 2018 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package socks provides a SOCKS version 5 client implementation.
|
||||||
|
//
|
||||||
|
// SOCKS protocol version 5 is defined in RFC 1928.
|
||||||
|
// Username/Password authentication for SOCKS version 5 is defined in
|
||||||
|
// RFC 1929.
|
||||||
|
package socks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Command represents a SOCKS command.
|
||||||
|
type Command int
|
||||||
|
|
||||||
|
func (cmd Command) String() string {
|
||||||
|
switch cmd {
|
||||||
|
case CmdConnect:
|
||||||
|
return "socks connect"
|
||||||
|
case cmdBind:
|
||||||
|
return "socks bind"
|
||||||
|
default:
|
||||||
|
return "socks " + strconv.Itoa(int(cmd))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// An AuthMethod represents a SOCKS authentication method.
|
||||||
|
type AuthMethod int
|
||||||
|
|
||||||
|
// A Reply represents a SOCKS command reply code.
|
||||||
|
type Reply int
|
||||||
|
|
||||||
|
func (code Reply) String() string {
|
||||||
|
switch code {
|
||||||
|
case StatusSucceeded:
|
||||||
|
return "succeeded"
|
||||||
|
case 0x01:
|
||||||
|
return "general SOCKS server failure"
|
||||||
|
case 0x02:
|
||||||
|
return "connection not allowed by ruleset"
|
||||||
|
case 0x03:
|
||||||
|
return "network unreachable"
|
||||||
|
case 0x04:
|
||||||
|
return "host unreachable"
|
||||||
|
case 0x05:
|
||||||
|
return "connection refused"
|
||||||
|
case 0x06:
|
||||||
|
return "TTL expired"
|
||||||
|
case 0x07:
|
||||||
|
return "command not supported"
|
||||||
|
case 0x08:
|
||||||
|
return "address type not supported"
|
||||||
|
default:
|
||||||
|
return "unknown code: " + strconv.Itoa(int(code))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wire protocol constants.
|
||||||
|
const (
|
||||||
|
Version5 = 0x05
|
||||||
|
|
||||||
|
AddrTypeIPv4 = 0x01
|
||||||
|
AddrTypeFQDN = 0x03
|
||||||
|
AddrTypeIPv6 = 0x04
|
||||||
|
|
||||||
|
CmdConnect Command = 0x01 // establishes an active-open forward proxy connection
|
||||||
|
cmdBind Command = 0x02 // establishes a passive-open forward proxy connection
|
||||||
|
|
||||||
|
AuthMethodNotRequired AuthMethod = 0x00 // no authentication required
|
||||||
|
AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password
|
||||||
|
AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authentication methods
|
||||||
|
|
||||||
|
StatusSucceeded Reply = 0x00
|
||||||
|
)
|
||||||
|
|
||||||
|
// An Addr represents a SOCKS-specific address.
|
||||||
|
// Either Name or IP is used exclusively.
|
||||||
|
type Addr struct {
|
||||||
|
Name string // fully-qualified domain name
|
||||||
|
IP net.IP
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Addr) Network() string { return "socks" }
|
||||||
|
|
||||||
|
func (a *Addr) String() string {
|
||||||
|
if a == nil {
|
||||||
|
return "<nil>"
|
||||||
|
}
|
||||||
|
port := strconv.Itoa(a.Port)
|
||||||
|
if a.IP == nil {
|
||||||
|
return net.JoinHostPort(a.Name, port)
|
||||||
|
}
|
||||||
|
return net.JoinHostPort(a.IP.String(), port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Conn represents a forward proxy connection.
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
|
||||||
|
boundAddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// BoundAddr returns the address assigned by the proxy server for
|
||||||
|
// connecting to the command target address from the proxy server.
|
||||||
|
func (c *Conn) BoundAddr() net.Addr {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.boundAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Dialer holds SOCKS-specific options.
|
||||||
|
type Dialer struct {
|
||||||
|
cmd Command // either CmdConnect or cmdBind
|
||||||
|
proxyNetwork string // network between a proxy server and a client
|
||||||
|
proxyAddress string // proxy server address
|
||||||
|
|
||||||
|
// ProxyDial specifies the optional dial function for
|
||||||
|
// establishing the transport connection.
|
||||||
|
ProxyDial func(context.Context, string, string) (net.Conn, error)
|
||||||
|
|
||||||
|
// AuthMethods specifies the list of request authentication
|
||||||
|
// methods.
|
||||||
|
// If empty, SOCKS client requests only AuthMethodNotRequired.
|
||||||
|
AuthMethods []AuthMethod
|
||||||
|
|
||||||
|
// Authenticate specifies the optional authentication
|
||||||
|
// function. It must be non-nil when AuthMethods is not empty.
|
||||||
|
// It must return an error when the authentication is failed.
|
||||||
|
Authenticate func(context.Context, io.ReadWriter, AuthMethod) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext connects to the provided address on the provided
|
||||||
|
// network.
|
||||||
|
//
|
||||||
|
// The returned error value may be a net.OpError. When the Op field of
|
||||||
|
// net.OpError contains "socks", the Source field contains a proxy
|
||||||
|
// server address and the Addr field contains a command target
|
||||||
|
// address.
|
||||||
|
//
|
||||||
|
// See func Dial of the net package of standard library for a
|
||||||
|
// description of the network and address parameters.
|
||||||
|
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
if err := d.validateTarget(network, address); err != nil {
|
||||||
|
proxy, dst, _ := d.pathAddrs(address)
|
||||||
|
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
proxy, dst, _ := d.pathAddrs(address)
|
||||||
|
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
var c net.Conn
|
||||||
|
if d.ProxyDial != nil {
|
||||||
|
c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
|
||||||
|
} else {
|
||||||
|
var dd net.Dialer
|
||||||
|
c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
proxy, dst, _ := d.pathAddrs(address)
|
||||||
|
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
||||||
|
}
|
||||||
|
a, err := d.connect(ctx, c, address)
|
||||||
|
if err != nil {
|
||||||
|
c.Close()
|
||||||
|
proxy, dst, _ := d.pathAddrs(address)
|
||||||
|
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
||||||
|
}
|
||||||
|
return &Conn{Conn: c, boundAddr: a}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialWithConn initiates a connection from SOCKS server to the target
|
||||||
|
// network and address using the connection c that is already
|
||||||
|
// connected to the SOCKS server.
|
||||||
|
//
|
||||||
|
// It returns the connection's local address assigned by the SOCKS
|
||||||
|
// server.
|
||||||
|
func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
|
||||||
|
if err := d.validateTarget(network, address); err != nil {
|
||||||
|
proxy, dst, _ := d.pathAddrs(address)
|
||||||
|
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
proxy, dst, _ := d.pathAddrs(address)
|
||||||
|
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
|
||||||
|
}
|
||||||
|
a, err := d.connect(ctx, c, address)
|
||||||
|
if err != nil {
|
||||||
|
proxy, dst, _ := d.pathAddrs(address)
|
||||||
|
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
||||||
|
}
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the provided address on the provided network.
|
||||||
|
//
|
||||||
|
// Unlike DialContext, it returns a raw transport connection instead
|
||||||
|
// of a forward proxy connection.
|
||||||
|
//
|
||||||
|
// Deprecated: Use DialContext or DialWithConn instead.
|
||||||
|
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||||
|
if err := d.validateTarget(network, address); err != nil {
|
||||||
|
proxy, dst, _ := d.pathAddrs(address)
|
||||||
|
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
var c net.Conn
|
||||||
|
if d.ProxyDial != nil {
|
||||||
|
c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
|
||||||
|
} else {
|
||||||
|
c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
proxy, dst, _ := d.pathAddrs(address)
|
||||||
|
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
|
||||||
|
}
|
||||||
|
if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
|
||||||
|
c.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialer) validateTarget(network, address string) error {
|
||||||
|
switch network {
|
||||||
|
case "tcp", "tcp6", "tcp4":
|
||||||
|
default:
|
||||||
|
return errors.New("network not implemented")
|
||||||
|
}
|
||||||
|
switch d.cmd {
|
||||||
|
case CmdConnect, cmdBind:
|
||||||
|
default:
|
||||||
|
return errors.New("command not implemented")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
|
||||||
|
for i, s := range []string{d.proxyAddress, address} {
|
||||||
|
host, port, err := splitHostPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
a := &Addr{Port: port}
|
||||||
|
a.IP = net.ParseIP(host)
|
||||||
|
if a.IP == nil {
|
||||||
|
a.Name = host
|
||||||
|
}
|
||||||
|
if i == 0 {
|
||||||
|
proxy = a
|
||||||
|
} else {
|
||||||
|
dst = a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDialer returns a new Dialer that dials through the provided
|
||||||
|
// proxy server's network and address.
|
||||||
|
func NewDialer(network, address string) *Dialer {
|
||||||
|
return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
authUsernamePasswordVersion = 0x01
|
||||||
|
authStatusSucceeded = 0x00
|
||||||
|
)
|
||||||
|
|
||||||
|
// UsernamePassword are the credentials for the username/password
|
||||||
|
// authentication method.
|
||||||
|
type UsernamePassword struct {
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate authenticates a pair of username and password with the
|
||||||
|
// proxy server.
|
||||||
|
func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error {
|
||||||
|
switch auth {
|
||||||
|
case AuthMethodNotRequired:
|
||||||
|
return nil
|
||||||
|
case AuthMethodUsernamePassword:
|
||||||
|
if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) > 255 {
|
||||||
|
return errors.New("invalid username/password")
|
||||||
|
}
|
||||||
|
b := []byte{authUsernamePasswordVersion}
|
||||||
|
b = append(b, byte(len(up.Username)))
|
||||||
|
b = append(b, up.Username...)
|
||||||
|
b = append(b, byte(len(up.Password)))
|
||||||
|
b = append(b, up.Password...)
|
||||||
|
// TODO(mikio): handle IO deadlines and cancelation if
|
||||||
|
// necessary
|
||||||
|
if _, err := rw.Write(b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := io.ReadFull(rw, b[:2]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if b[0] != authUsernamePasswordVersion {
|
||||||
|
return errors.New("invalid username/password version")
|
||||||
|
}
|
||||||
|
if b[1] != authStatusSucceeded {
|
||||||
|
return errors.New("username/password authentication failed")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
|
||||||
|
}
|
|
@ -0,0 +1,54 @@
|
||||||
|
// Copyright 2019 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A ContextDialer dials using a context.
|
||||||
|
type ContextDialer interface {
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment.
|
||||||
|
//
|
||||||
|
// The passed ctx is only used for returning the Conn, not the lifetime of the Conn.
|
||||||
|
//
|
||||||
|
// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
|
||||||
|
// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
|
||||||
|
//
|
||||||
|
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
|
||||||
|
func Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
d := FromEnvironment()
|
||||||
|
if xd, ok := d.(ContextDialer); ok {
|
||||||
|
return xd.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
return dialContext(ctx, d, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout
|
||||||
|
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
|
||||||
|
func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) {
|
||||||
|
var (
|
||||||
|
conn net.Conn
|
||||||
|
done = make(chan struct{}, 1)
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
go func() {
|
||||||
|
conn, err = d.Dial(network, address)
|
||||||
|
close(done)
|
||||||
|
if conn != nil && ctx.Err() != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
err = ctx.Err()
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type direct struct{}
|
||||||
|
|
||||||
|
// Direct implements Dialer by making network connections directly using net.Dial or net.DialContext.
|
||||||
|
var Direct = direct{}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Dialer = Direct
|
||||||
|
_ ContextDialer = Direct
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dial directly invokes net.Dial with the supplied parameters.
|
||||||
|
func (direct) Dial(network, addr string) (net.Conn, error) {
|
||||||
|
return net.Dial(network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters.
|
||||||
|
func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
var d net.Dialer
|
||||||
|
return d.DialContext(ctx, network, addr)
|
||||||
|
}
|
|
@ -0,0 +1,155 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A PerHost directs connections to a default Dialer unless the host name
|
||||||
|
// requested matches one of a number of exceptions.
|
||||||
|
type PerHost struct {
|
||||||
|
def, bypass Dialer
|
||||||
|
|
||||||
|
bypassNetworks []*net.IPNet
|
||||||
|
bypassIPs []net.IP
|
||||||
|
bypassZones []string
|
||||||
|
bypassHosts []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPerHost returns a PerHost Dialer that directs connections to either
|
||||||
|
// defaultDialer or bypass, depending on whether the connection matches one of
|
||||||
|
// the configured rules.
|
||||||
|
func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
|
||||||
|
return &PerHost{
|
||||||
|
def: defaultDialer,
|
||||||
|
bypass: bypass,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial connects to the address addr on the given network through either
|
||||||
|
// defaultDialer or bypass.
|
||||||
|
func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.dialerForRequest(host).Dial(network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext connects to the address addr on the given network through either
|
||||||
|
// defaultDialer or bypass.
|
||||||
|
func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
|
||||||
|
host, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
d := p.dialerForRequest(host)
|
||||||
|
if x, ok := d.(ContextDialer); ok {
|
||||||
|
return x.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
return dialContext(ctx, d, network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PerHost) dialerForRequest(host string) Dialer {
|
||||||
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
|
for _, net := range p.bypassNetworks {
|
||||||
|
if net.Contains(ip) {
|
||||||
|
return p.bypass
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, bypassIP := range p.bypassIPs {
|
||||||
|
if bypassIP.Equal(ip) {
|
||||||
|
return p.bypass
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.def
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, zone := range p.bypassZones {
|
||||||
|
if strings.HasSuffix(host, zone) {
|
||||||
|
return p.bypass
|
||||||
|
}
|
||||||
|
if host == zone[1:] {
|
||||||
|
// For a zone ".example.com", we match "example.com"
|
||||||
|
// too.
|
||||||
|
return p.bypass
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, bypassHost := range p.bypassHosts {
|
||||||
|
if bypassHost == host {
|
||||||
|
return p.bypass
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.def
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFromString parses a string that contains comma-separated values
|
||||||
|
// specifying hosts that should use the bypass proxy. Each value is either an
|
||||||
|
// IP address, a CIDR range, a zone (*.example.com) or a host name
|
||||||
|
// (localhost). A best effort is made to parse the string and errors are
|
||||||
|
// ignored.
|
||||||
|
func (p *PerHost) AddFromString(s string) {
|
||||||
|
hosts := strings.Split(s, ",")
|
||||||
|
for _, host := range hosts {
|
||||||
|
host = strings.TrimSpace(host)
|
||||||
|
if len(host) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(host, "/") {
|
||||||
|
// We assume that it's a CIDR address like 127.0.0.0/8
|
||||||
|
if _, net, err := net.ParseCIDR(host); err == nil {
|
||||||
|
p.AddNetwork(net)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
|
p.AddIP(ip)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(host, "*.") {
|
||||||
|
p.AddZone(host[1:])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p.AddHost(host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddIP specifies an IP address that will use the bypass proxy. Note that
|
||||||
|
// this will only take effect if a literal IP address is dialed. A connection
|
||||||
|
// to a named host will never match an IP.
|
||||||
|
func (p *PerHost) AddIP(ip net.IP) {
|
||||||
|
p.bypassIPs = append(p.bypassIPs, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
|
||||||
|
// this will only take effect if a literal IP address is dialed. A connection
|
||||||
|
// to a named host will never match.
|
||||||
|
func (p *PerHost) AddNetwork(net *net.IPNet) {
|
||||||
|
p.bypassNetworks = append(p.bypassNetworks, net)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
|
||||||
|
// "example.com" matches "example.com" and all of its subdomains.
|
||||||
|
func (p *PerHost) AddZone(zone string) {
|
||||||
|
if strings.HasSuffix(zone, ".") {
|
||||||
|
zone = zone[:len(zone)-1]
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(zone, ".") {
|
||||||
|
zone = "." + zone
|
||||||
|
}
|
||||||
|
p.bypassZones = append(p.bypassZones, zone)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHost specifies a host name that will use the bypass proxy.
|
||||||
|
func (p *PerHost) AddHost(host string) {
|
||||||
|
if strings.HasSuffix(host, ".") {
|
||||||
|
host = host[:len(host)-1]
|
||||||
|
}
|
||||||
|
p.bypassHosts = append(p.bypassHosts, host)
|
||||||
|
}
|
|
@ -0,0 +1,149 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package proxy provides support for a variety of protocols to proxy network
|
||||||
|
// data.
|
||||||
|
package proxy // import "golang.org/x/net/proxy"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Dialer is a means to establish a connection.
|
||||||
|
// Custom dialers should also implement ContextDialer.
|
||||||
|
type Dialer interface {
|
||||||
|
// Dial connects to the given address via the proxy.
|
||||||
|
Dial(network, addr string) (c net.Conn, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth contains authentication parameters that specific Dialers may require.
|
||||||
|
type Auth struct {
|
||||||
|
User, Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromEnvironment returns the dialer specified by the proxy-related
|
||||||
|
// variables in the environment and makes underlying connections
|
||||||
|
// directly.
|
||||||
|
func FromEnvironment() Dialer {
|
||||||
|
return FromEnvironmentUsing(Direct)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromEnvironmentUsing returns the dialer specify by the proxy-related
|
||||||
|
// variables in the environment and makes underlying connections
|
||||||
|
// using the provided forwarding Dialer (for instance, a *net.Dialer
|
||||||
|
// with desired configuration).
|
||||||
|
func FromEnvironmentUsing(forward Dialer) Dialer {
|
||||||
|
allProxy := allProxyEnv.Get()
|
||||||
|
if len(allProxy) == 0 {
|
||||||
|
return forward
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := url.Parse(allProxy)
|
||||||
|
if err != nil {
|
||||||
|
return forward
|
||||||
|
}
|
||||||
|
proxy, err := FromURL(proxyURL, forward)
|
||||||
|
if err != nil {
|
||||||
|
return forward
|
||||||
|
}
|
||||||
|
|
||||||
|
noProxy := noProxyEnv.Get()
|
||||||
|
if len(noProxy) == 0 {
|
||||||
|
return proxy
|
||||||
|
}
|
||||||
|
|
||||||
|
perHost := NewPerHost(proxy, forward)
|
||||||
|
perHost.AddFromString(noProxy)
|
||||||
|
return perHost
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxySchemes is a map from URL schemes to a function that creates a Dialer
|
||||||
|
// from a URL with such a scheme.
|
||||||
|
var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error)
|
||||||
|
|
||||||
|
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
|
||||||
|
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
|
||||||
|
// by FromURL.
|
||||||
|
func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) {
|
||||||
|
if proxySchemes == nil {
|
||||||
|
proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error))
|
||||||
|
}
|
||||||
|
proxySchemes[scheme] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromURL returns a Dialer given a URL specification and an underlying
|
||||||
|
// Dialer for it to make network requests.
|
||||||
|
func FromURL(u *url.URL, forward Dialer) (Dialer, error) {
|
||||||
|
var auth *Auth
|
||||||
|
if u.User != nil {
|
||||||
|
auth = new(Auth)
|
||||||
|
auth.User = u.User.Username()
|
||||||
|
if p, ok := u.User.Password(); ok {
|
||||||
|
auth.Password = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch u.Scheme {
|
||||||
|
case "socks5", "socks5h":
|
||||||
|
addr := u.Hostname()
|
||||||
|
port := u.Port()
|
||||||
|
if port == "" {
|
||||||
|
port = "1080"
|
||||||
|
}
|
||||||
|
return SOCKS5("tcp", net.JoinHostPort(addr, port), auth, forward)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the scheme doesn't match any of the built-in schemes, see if it
|
||||||
|
// was registered by another package.
|
||||||
|
if proxySchemes != nil {
|
||||||
|
if f, ok := proxySchemes[u.Scheme]; ok {
|
||||||
|
return f(u, forward)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
allProxyEnv = &envOnce{
|
||||||
|
names: []string{"ALL_PROXY", "all_proxy"},
|
||||||
|
}
|
||||||
|
noProxyEnv = &envOnce{
|
||||||
|
names: []string{"NO_PROXY", "no_proxy"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// envOnce looks up an environment variable (optionally by multiple
|
||||||
|
// names) once. It mitigates expensive lookups on some platforms
|
||||||
|
// (e.g. Windows).
|
||||||
|
// (Borrowed from net/http/transport.go)
|
||||||
|
type envOnce struct {
|
||||||
|
names []string
|
||||||
|
once sync.Once
|
||||||
|
val string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *envOnce) Get() string {
|
||||||
|
e.once.Do(e.init)
|
||||||
|
return e.val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *envOnce) init() {
|
||||||
|
for _, n := range e.names {
|
||||||
|
e.val = os.Getenv(n)
|
||||||
|
if e.val != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset is used by tests
|
||||||
|
func (e *envOnce) reset() {
|
||||||
|
e.once = sync.Once{}
|
||||||
|
e.val = ""
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
// Copyright 2011 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/net/internal/socks"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given
|
||||||
|
// address with an optional username and password.
|
||||||
|
// See RFC 1928 and RFC 1929.
|
||||||
|
func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) {
|
||||||
|
d := socks.NewDialer(network, address)
|
||||||
|
if forward != nil {
|
||||||
|
if f, ok := forward.(ContextDialer); ok {
|
||||||
|
d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
|
return f.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
|
||||||
|
return dialContext(ctx, forward, network, address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if auth != nil {
|
||||||
|
up := socks.UsernamePassword{
|
||||||
|
Username: auth.User,
|
||||||
|
Password: auth.Password,
|
||||||
|
}
|
||||||
|
d.AuthMethods = []socks.AuthMethod{
|
||||||
|
socks.AuthMethodNotRequired,
|
||||||
|
socks.AuthMethodUsernamePassword,
|
||||||
|
}
|
||||||
|
d.Authenticate = up.Authenticate
|
||||||
|
}
|
||||||
|
return d, nil
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
# golang.org/x/net v0.14.0
|
||||||
|
## explicit; go 1.17
|
||||||
|
golang.org/x/net/internal/socks
|
||||||
|
golang.org/x/net/proxy
|
473
x_net_proxy.go
473
x_net_proxy.go
|
@ -1,473 +0,0 @@
|
||||||
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
|
|
||||||
//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy
|
|
||||||
|
|
||||||
// Package proxy provides support for a variety of protocols to proxy network
|
|
||||||
// data.
|
|
||||||
//
|
|
||||||
|
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type proxy_direct struct{}
|
|
||||||
|
|
||||||
// Direct is a direct proxy: one that makes network connections directly.
|
|
||||||
var proxy_Direct = proxy_direct{}
|
|
||||||
|
|
||||||
func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
|
|
||||||
return net.Dial(network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// A PerHost directs connections to a default Dialer unless the host name
|
|
||||||
// requested matches one of a number of exceptions.
|
|
||||||
type proxy_PerHost struct {
|
|
||||||
def, bypass proxy_Dialer
|
|
||||||
|
|
||||||
bypassNetworks []*net.IPNet
|
|
||||||
bypassIPs []net.IP
|
|
||||||
bypassZones []string
|
|
||||||
bypassHosts []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPerHost returns a PerHost Dialer that directs connections to either
|
|
||||||
// defaultDialer or bypass, depending on whether the connection matches one of
|
|
||||||
// the configured rules.
|
|
||||||
func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost {
|
|
||||||
return &proxy_PerHost{
|
|
||||||
def: defaultDialer,
|
|
||||||
bypass: bypass,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to the address addr on the given network through either
|
|
||||||
// defaultDialer or bypass.
|
|
||||||
func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
|
|
||||||
host, _, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.dialerForRequest(host).Dial(network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
for _, net := range p.bypassNetworks {
|
|
||||||
if net.Contains(ip) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, bypassIP := range p.bypassIPs {
|
|
||||||
if bypassIP.Equal(ip) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.def
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, zone := range p.bypassZones {
|
|
||||||
if strings.HasSuffix(host, zone) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
if host == zone[1:] {
|
|
||||||
// For a zone ".example.com", we match "example.com"
|
|
||||||
// too.
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, bypassHost := range p.bypassHosts {
|
|
||||||
if bypassHost == host {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.def
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFromString parses a string that contains comma-separated values
|
|
||||||
// specifying hosts that should use the bypass proxy. Each value is either an
|
|
||||||
// IP address, a CIDR range, a zone (*.example.com) or a host name
|
|
||||||
// (localhost). A best effort is made to parse the string and errors are
|
|
||||||
// ignored.
|
|
||||||
func (p *proxy_PerHost) AddFromString(s string) {
|
|
||||||
hosts := strings.Split(s, ",")
|
|
||||||
for _, host := range hosts {
|
|
||||||
host = strings.TrimSpace(host)
|
|
||||||
if len(host) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.Contains(host, "/") {
|
|
||||||
// We assume that it's a CIDR address like 127.0.0.0/8
|
|
||||||
if _, net, err := net.ParseCIDR(host); err == nil {
|
|
||||||
p.AddNetwork(net)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
p.AddIP(ip)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(host, "*.") {
|
|
||||||
p.AddZone(host[1:])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.AddHost(host)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddIP specifies an IP address that will use the bypass proxy. Note that
|
|
||||||
// this will only take effect if a literal IP address is dialed. A connection
|
|
||||||
// to a named host will never match an IP.
|
|
||||||
func (p *proxy_PerHost) AddIP(ip net.IP) {
|
|
||||||
p.bypassIPs = append(p.bypassIPs, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
|
|
||||||
// this will only take effect if a literal IP address is dialed. A connection
|
|
||||||
// to a named host will never match.
|
|
||||||
func (p *proxy_PerHost) AddNetwork(net *net.IPNet) {
|
|
||||||
p.bypassNetworks = append(p.bypassNetworks, net)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
|
|
||||||
// "example.com" matches "example.com" and all of its subdomains.
|
|
||||||
func (p *proxy_PerHost) AddZone(zone string) {
|
|
||||||
if strings.HasSuffix(zone, ".") {
|
|
||||||
zone = zone[:len(zone)-1]
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(zone, ".") {
|
|
||||||
zone = "." + zone
|
|
||||||
}
|
|
||||||
p.bypassZones = append(p.bypassZones, zone)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddHost specifies a host name that will use the bypass proxy.
|
|
||||||
func (p *proxy_PerHost) AddHost(host string) {
|
|
||||||
if strings.HasSuffix(host, ".") {
|
|
||||||
host = host[:len(host)-1]
|
|
||||||
}
|
|
||||||
p.bypassHosts = append(p.bypassHosts, host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Dialer is a means to establish a connection.
|
|
||||||
type proxy_Dialer interface {
|
|
||||||
// Dial connects to the given address via the proxy.
|
|
||||||
Dial(network, addr string) (c net.Conn, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth contains authentication parameters that specific Dialers may require.
|
|
||||||
type proxy_Auth struct {
|
|
||||||
User, Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromEnvironment returns the dialer specified by the proxy related variables in
|
|
||||||
// the environment.
|
|
||||||
func proxy_FromEnvironment() proxy_Dialer {
|
|
||||||
allProxy := proxy_allProxyEnv.Get()
|
|
||||||
if len(allProxy) == 0 {
|
|
||||||
return proxy_Direct
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyURL, err := url.Parse(allProxy)
|
|
||||||
if err != nil {
|
|
||||||
return proxy_Direct
|
|
||||||
}
|
|
||||||
proxy, err := proxy_FromURL(proxyURL, proxy_Direct)
|
|
||||||
if err != nil {
|
|
||||||
return proxy_Direct
|
|
||||||
}
|
|
||||||
|
|
||||||
noProxy := proxy_noProxyEnv.Get()
|
|
||||||
if len(noProxy) == 0 {
|
|
||||||
return proxy
|
|
||||||
}
|
|
||||||
|
|
||||||
perHost := proxy_NewPerHost(proxy, proxy_Direct)
|
|
||||||
perHost.AddFromString(noProxy)
|
|
||||||
return perHost
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxySchemes is a map from URL schemes to a function that creates a Dialer
|
|
||||||
// from a URL with such a scheme.
|
|
||||||
var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)
|
|
||||||
|
|
||||||
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
|
|
||||||
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
|
|
||||||
// by FromURL.
|
|
||||||
func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) {
|
|
||||||
if proxy_proxySchemes == nil {
|
|
||||||
proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error))
|
|
||||||
}
|
|
||||||
proxy_proxySchemes[scheme] = f
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromURL returns a Dialer given a URL specification and an underlying
|
|
||||||
// Dialer for it to make network requests.
|
|
||||||
func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) {
|
|
||||||
var auth *proxy_Auth
|
|
||||||
if u.User != nil {
|
|
||||||
auth = new(proxy_Auth)
|
|
||||||
auth.User = u.User.Username()
|
|
||||||
if p, ok := u.User.Password(); ok {
|
|
||||||
auth.Password = p
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch u.Scheme {
|
|
||||||
case "socks5":
|
|
||||||
return proxy_SOCKS5("tcp", u.Host, auth, forward)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the scheme doesn't match any of the built-in schemes, see if it
|
|
||||||
// was registered by another package.
|
|
||||||
if proxy_proxySchemes != nil {
|
|
||||||
if f, ok := proxy_proxySchemes[u.Scheme]; ok {
|
|
||||||
return f(u, forward)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
proxy_allProxyEnv = &proxy_envOnce{
|
|
||||||
names: []string{"ALL_PROXY", "all_proxy"},
|
|
||||||
}
|
|
||||||
proxy_noProxyEnv = &proxy_envOnce{
|
|
||||||
names: []string{"NO_PROXY", "no_proxy"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// envOnce looks up an environment variable (optionally by multiple
|
|
||||||
// names) once. It mitigates expensive lookups on some platforms
|
|
||||||
// (e.g. Windows).
|
|
||||||
// (Borrowed from net/http/transport.go)
|
|
||||||
type proxy_envOnce struct {
|
|
||||||
names []string
|
|
||||||
once sync.Once
|
|
||||||
val string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *proxy_envOnce) Get() string {
|
|
||||||
e.once.Do(e.init)
|
|
||||||
return e.val
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *proxy_envOnce) init() {
|
|
||||||
for _, n := range e.names {
|
|
||||||
e.val = os.Getenv(n)
|
|
||||||
if e.val != "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
|
|
||||||
// with an optional username and password. See RFC 1928 and RFC 1929.
|
|
||||||
func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) {
|
|
||||||
s := &proxy_socks5{
|
|
||||||
network: network,
|
|
||||||
addr: addr,
|
|
||||||
forward: forward,
|
|
||||||
}
|
|
||||||
if auth != nil {
|
|
||||||
s.user = auth.User
|
|
||||||
s.password = auth.Password
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type proxy_socks5 struct {
|
|
||||||
user, password string
|
|
||||||
network, addr string
|
|
||||||
forward proxy_Dialer
|
|
||||||
}
|
|
||||||
|
|
||||||
const proxy_socks5Version = 5
|
|
||||||
|
|
||||||
const (
|
|
||||||
proxy_socks5AuthNone = 0
|
|
||||||
proxy_socks5AuthPassword = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
const proxy_socks5Connect = 1
|
|
||||||
|
|
||||||
const (
|
|
||||||
proxy_socks5IP4 = 1
|
|
||||||
proxy_socks5Domain = 3
|
|
||||||
proxy_socks5IP6 = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
var proxy_socks5Errors = []string{
|
|
||||||
"",
|
|
||||||
"general failure",
|
|
||||||
"connection forbidden",
|
|
||||||
"network unreachable",
|
|
||||||
"host unreachable",
|
|
||||||
"connection refused",
|
|
||||||
"TTL expired",
|
|
||||||
"command not supported",
|
|
||||||
"address type not supported",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to the address addr on the given network via the SOCKS5 proxy.
|
|
||||||
func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
|
|
||||||
switch network {
|
|
||||||
case "tcp", "tcp6", "tcp4":
|
|
||||||
default:
|
|
||||||
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := s.forward.Dial(s.network, s.addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := s.connect(conn, addr); err != nil {
|
|
||||||
conn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect takes an existing connection to a socks5 proxy server,
|
|
||||||
// and commands the server to extend that connection to target,
|
|
||||||
// which must be a canonical address with a host and port.
|
|
||||||
func (s *proxy_socks5) connect(conn net.Conn, target string) error {
|
|
||||||
host, portStr, err := net.SplitHostPort(target)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.New("proxy: failed to parse port number: " + portStr)
|
|
||||||
}
|
|
||||||
if port < 1 || port > 0xffff {
|
|
||||||
return errors.New("proxy: port number out of range: " + portStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// the size here is just an estimate
|
|
||||||
buf := make([]byte, 0, 6+len(host))
|
|
||||||
|
|
||||||
buf = append(buf, proxy_socks5Version)
|
|
||||||
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
|
|
||||||
buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword)
|
|
||||||
} else {
|
|
||||||
buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := conn.Write(buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
if buf[0] != 5 {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
|
|
||||||
}
|
|
||||||
if buf[1] == 0xff {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 1929
|
|
||||||
if buf[1] == proxy_socks5AuthPassword {
|
|
||||||
buf = buf[:0]
|
|
||||||
buf = append(buf, 1 /* password protocol version */)
|
|
||||||
buf = append(buf, uint8(len(s.user)))
|
|
||||||
buf = append(buf, s.user...)
|
|
||||||
buf = append(buf, uint8(len(s.password)))
|
|
||||||
buf = append(buf, s.password...)
|
|
||||||
|
|
||||||
if _, err := conn.Write(buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if buf[1] != 0 {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
buf = buf[:0]
|
|
||||||
buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */)
|
|
||||||
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
buf = append(buf, proxy_socks5IP4)
|
|
||||||
ip = ip4
|
|
||||||
} else {
|
|
||||||
buf = append(buf, proxy_socks5IP6)
|
|
||||||
}
|
|
||||||
buf = append(buf, ip...)
|
|
||||||
} else {
|
|
||||||
if len(host) > 255 {
|
|
||||||
return errors.New("proxy: destination host name too long: " + host)
|
|
||||||
}
|
|
||||||
buf = append(buf, proxy_socks5Domain)
|
|
||||||
buf = append(buf, byte(len(host)))
|
|
||||||
buf = append(buf, host...)
|
|
||||||
}
|
|
||||||
buf = append(buf, byte(port>>8), byte(port))
|
|
||||||
|
|
||||||
if _, err := conn.Write(buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
failure := "unknown error"
|
|
||||||
if int(buf[1]) < len(proxy_socks5Errors) {
|
|
||||||
failure = proxy_socks5Errors[buf[1]]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(failure) > 0 {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
|
|
||||||
}
|
|
||||||
|
|
||||||
bytesToDiscard := 0
|
|
||||||
switch buf[3] {
|
|
||||||
case proxy_socks5IP4:
|
|
||||||
bytesToDiscard = net.IPv4len
|
|
||||||
case proxy_socks5IP6:
|
|
||||||
bytesToDiscard = net.IPv6len
|
|
||||||
case proxy_socks5Domain:
|
|
||||||
_, err := io.ReadFull(conn, buf[:1])
|
|
||||||
if err != nil {
|
|
||||||
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
bytesToDiscard = int(buf[0])
|
|
||||||
default:
|
|
||||||
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cap(buf) < bytesToDiscard {
|
|
||||||
buf = make([]byte, bytesToDiscard)
|
|
||||||
} else {
|
|
||||||
buf = buf[:bytesToDiscard]
|
|
||||||
}
|
|
||||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Also need to discard the port number
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
Loading…
Reference in New Issue