Update go version & add verification/testing tools (#840)

Fixes #

**Summary of Changes**

1.
2. 
3.

> PS: Make sure your PR includes/updates tests! If you need help with
this part, just ask!
This commit is contained in:
Corey Daley 2023-08-26 16:01:45 -04:00 committed by GitHub
parent 80393295c1
commit 666c197fc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 1388 additions and 734 deletions

View File

@ -1,70 +0,0 @@
version: 2.1
jobs:
"test":
parameters:
version:
type: string
default: "latest"
golint:
type: boolean
default: true
modules:
type: boolean
default: true
goproxy:
type: string
default: ""
docker:
- image: "cimg/go:<< parameters.version >>"
working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket
environment:
GO111MODULE: "on"
GOPROXY: "<< parameters.goproxy >>"
steps:
- checkout
- run:
name: "Print the Go version"
command: >
go version
- run:
name: "Fetch dependencies"
command: >
if [[ << parameters.modules >> = true ]]; then
go mod download
export GO111MODULE=on
else
go get -v ./...
fi
# Only run gofmt, vet & lint against the latest Go version
- run:
name: "Run golint"
command: >
if [ << parameters.version >> = "latest" ] && [ << parameters.golint >> = true ]; then
go get -u golang.org/x/lint/golint
golint ./...
fi
- run:
name: "Run gofmt"
command: >
if [[ << parameters.version >> = "latest" ]]; then
diff -u <(echo -n) <(gofmt -d -e .)
fi
- run:
name: "Run go vet"
command: >
if [[ << parameters.version >> = "latest" ]]; then
go vet -v ./...
fi
- run:
name: "Run go test (+ race detector)"
command: >
go test -v -race ./...
workflows:
tests:
jobs:
- test:
matrix:
parameters:
version: ["1.18", "1.17", "1.16"]

20
.editorconfig Normal file
View File

@ -0,0 +1,20 @@
; https://editorconfig.org/
root = true
[*]
insert_final_newline = true
charset = utf-8
trim_trailing_whitespace = true
indent_style = space
indent_size = 2
[{Makefile,go.mod,go.sum,*.go,.gitmodules}]
indent_style = tab
indent_size = 4
[*.md]
indent_size = 4
trim_trailing_whitespace = false
eclint_indent_style = unset

View File

@ -1,7 +0,0 @@
# Config for https://github.com/apps/release-drafter
template: |
<summary of changes here>
## CHANGELOG
$CHANGES

21
.github/workflows/issues.yml vendored Normal file
View File

@ -0,0 +1,21 @@
# Add issues or pull-requests created to the project.
name: Add issue or pull request to Project
on:
issues:
types:
- opened
pull_request_target:
types:
- opened
- reopened
jobs:
add-to-project:
runs-on: ubuntu-latest
steps:
- name: Add issue to project
uses: actions/add-to-project@v0.5.0
with:
project-url: https://github.com/orgs/gorilla/projects/4
github-token: ${{ secrets.ADD_TO_PROJECT_TOKEN }}

55
.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,55 @@
name: CI
on:
push:
branches:
- main
pull_request:
branches:
- main
permissions:
contents: read
jobs:
verify-and-test:
strategy:
matrix:
go: ['1.19','1.20']
os: [ubuntu-latest, macos-latest, windows-latest]
fail-fast: true
runs-on: ${{ matrix.os }}
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Setup Go ${{ matrix.go }}
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go }}
cache: false
- name: Run GolangCI-Lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.53
args: --timeout=5m
- name: Run GoSec
if: matrix.os == 'ubuntu-latest'
uses: securego/gosec@master
with:
args: -exclude-dir examples ./...
- name: Run GoVulnCheck
uses: golang/govulncheck-action@v1
with:
go-version-input: ${{ matrix.go }}
go-package: ./...
- name: Run Tests
run: go test -race -cover -coverprofile=coverage -covermode=atomic -v ./...
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
files: ./coverage

26
.gitignore vendored
View File

@ -1,25 +1 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects) coverage.coverprofile
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
.idea/
*.iml

3
.golangci.yml Normal file
View File

@ -0,0 +1,3 @@
run:
skip-dirs:
- examples/*.go

View File

@ -1,9 +0,0 @@
# This is the official list of Gorilla WebSocket authors for copyright
# purposes.
#
# Please keep the list sorted.
Gary Burd <gary@beagledreams.com>
Google LLC (https://opensource.google.com/)
Joachim Bauch <mail@joachim-bauch.de>

39
LICENSE
View File

@ -1,22 +1,27 @@
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. Copyright (c) 2023 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met: modification, are permitted provided that the following conditions are
met:
Redistributions of source code must retain the above copyright notice, this * Redistributions of source code must retain the above copyright
list of conditions and the following disclaimer. notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
Redistributions in binary form must reproduce the above copyright notice, THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
this list of conditions and the following disclaimer in the documentation "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
and/or other materials provided with the distribution. LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

34
Makefile Normal file
View File

@ -0,0 +1,34 @@
GO_LINT=$(shell which golangci-lint 2> /dev/null || echo '')
GO_LINT_URI=github.com/golangci/golangci-lint/cmd/golangci-lint@latest
GO_SEC=$(shell which gosec 2> /dev/null || echo '')
GO_SEC_URI=github.com/securego/gosec/v2/cmd/gosec@latest
GO_VULNCHECK=$(shell which govulncheck 2> /dev/null || echo '')
GO_VULNCHECK_URI=golang.org/x/vuln/cmd/govulncheck@latest
.PHONY: golangci-lint
golangci-lint:
$(if $(GO_LINT), ,go install $(GO_LINT_URI))
@echo "##### Running golangci-lint"
golangci-lint run -v
.PHONY: gosec
gosec:
$(if $(GO_SEC), ,go install $(GO_SEC_URI))
@echo "##### Running gosec"
gosec -exclude-dir examples ./...
.PHONY: govulncheck
govulncheck:
$(if $(GO_VULNCHECK), ,go install $(GO_VULNCHECK_URI))
@echo "##### Running govulncheck"
govulncheck ./...
.PHONY: verify
verify: golangci-lint gosec govulncheck
.PHONY: test
test:
@echo "##### Running tests"
go test -race -cover -coverprofile=coverage.coverprofile -covermode=atomic -v ./...

View File

@ -1,10 +1,13 @@
# Gorilla WebSocket # gorilla/websocket
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket) ![testing](https://github.com/gorilla/websocket/actions/workflows/test.yml/badge.svg)
[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket) [![codecov](https://codecov.io/github/gorilla/websocket/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/websocket)
[![godoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
[![sourcegraph](https://sourcegraph.com/github.com/gorilla/websocket/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/websocket?badge)
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
![Gorilla Logo](https://github.com/gorilla/.github/assets/53367916/d92caabf-98e0-473e-bfbf-ab554ba435e5)
### Documentation ### Documentation
@ -31,4 +34,3 @@ package API is stable.
The Gorilla WebSocket package passes the server tests in the [Autobahn Test The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).

View File

@ -11,13 +11,16 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "log"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"golang.org/x/net/proxy"
) )
// ErrBadHandshake is returned when the server response to opening handshake is // ErrBadHandshake is returned when the server response to opening handshake is
@ -225,6 +228,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
k == "Connection" || k == "Connection" ||
k == "Sec-Websocket-Key" || k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" || k == "Sec-Websocket-Version" ||
//#nosec G101 (CWE-798): Potential HTTP request smuggling via parameter pollution
k == "Sec-Websocket-Extensions" || k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
@ -290,7 +294,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
err = c.SetDeadline(deadline) err = c.SetDeadline(deadline)
if err != nil { if err != nil {
c.Close() if err := c.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
return nil, err return nil, err
} }
return c, nil return c, nil
@ -304,7 +310,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err return nil, nil, err
} }
if proxyURL != nil { if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) dialer, err := proxy.FromURL(proxyURL, netDialerFunc(netDial))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -330,7 +336,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
defer func() { defer func() {
if netConn != nil { if netConn != nil {
netConn.Close() if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
} }
}() }()
@ -400,7 +408,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// debugging. // debugging.
buf := make([]byte, 1024) buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf) n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake return nil, resp, ErrBadHandshake
} }
@ -418,17 +426,19 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
break break
} }
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{}) if err := netConn.SetDeadline(time.Time{}); err != nil {
return nil, nil, err
}
netConn = nil // to avoid close in defer. netConn = nil // to avoid close in defer.
return conn, resp, nil return conn, resp, nil
} }
func cloneTLSConfig(cfg *tls.Config) *tls.Config { func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil { if cfg == nil {
return &tls.Config{} return &tls.Config{MinVersion: tls.VersionTLS12}
} }
return cfg.Clone() return cfg.Clone()
} }

View File

@ -14,7 +14,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -50,7 +49,6 @@ type cstHandler struct{ *testing.T }
type cstServer struct { type cstServer struct {
*httptest.Server *httptest.Server
URL string URL string
t *testing.T
} }
const ( const (
@ -94,7 +92,6 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
if err != nil { if err != nil {
t.Logf("Upgrade: %v", err)
return return
} }
defer ws.Close() defer ws.Close()
@ -106,20 +103,16 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
op, rd, err := ws.NextReader() op, rd, err := ws.NextReader()
if err != nil { if err != nil {
t.Logf("NextReader: %v", err)
return return
} }
wr, err := ws.NextWriter(op) wr, err := ws.NextWriter(op)
if err != nil { if err != nil {
t.Logf("NextWriter: %v", err)
return return
} }
if _, err = io.Copy(wr, rd); err != nil { if _, err = io.Copy(wr, rd); err != nil {
t.Logf("NextWriter: %v", err)
return return
} }
if err := wr.Close(); err != nil { if err := wr.Close(); err != nil {
t.Logf("Close: %v", err)
return return
} }
} }
@ -531,7 +524,9 @@ func TestRespOnBadHandshake(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(expectedStatus) w.WriteHeader(expectedStatus)
io.WriteString(w, expectedBody) if _, err := io.WriteString(w, expectedBody); err != nil {
t.Fatalf("WriteString: %v", err)
}
})) }))
defer s.Close() defer s.Close()
@ -549,7 +544,7 @@ func TestRespOnBadHandshake(t *testing.T) {
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
} }
p, err := ioutil.ReadAll(resp.Body) p, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatalf("ReadFull(resp.Body) returned error %v", err) t.Fatalf("ReadFull(resp.Body) returned error %v", err)
} }
@ -564,7 +559,6 @@ type testLogWriter struct {
} }
func (w testLogWriter) Write(p []byte) (int, error) { func (w testLogWriter) Write(p []byte) (int, error) {
w.t.Logf("%s", p)
return len(p), nil return len(p), nil
} }
@ -781,7 +775,10 @@ func TestSocksProxyDial(t *testing.T) {
} }
defer c1.Close() defer c1.Close()
c1.SetDeadline(time.Now().Add(30 * time.Second)) if err := c1.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
t.Errorf("set deadline failed: %v", err)
return
}
buf := make([]byte, 32) buf := make([]byte, 32)
if _, err := io.ReadFull(c1, buf[:3]); err != nil { if _, err := io.ReadFull(c1, buf[:3]); err != nil {
@ -820,10 +817,15 @@ func TestSocksProxyDial(t *testing.T) {
defer c2.Close() defer c2.Close()
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
io.Copy(c1, c2) if _, err := io.Copy(c1, c2); err != nil {
t.Errorf("copy failed: %v", err)
}
close(done) close(done)
}() }()
io.Copy(c2, c1) if _, err := io.Copy(c2, c1); err != nil {
t.Errorf("copy failed: %v", err)
return
}
<-done <-done
}() }()

View File

@ -8,6 +8,7 @@ import (
"compress/flate" "compress/flate"
"errors" "errors"
"io" "io"
"log"
"strings" "strings"
"sync" "sync"
) )
@ -33,7 +34,9 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
"\x01\x00\x00\xff\xff" "\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser) fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) if err := fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil); err != nil {
panic(err)
}
return &flateReadWrapper{fr} return &flateReadWrapper{fr}
} }
@ -132,7 +135,9 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) {
// Preemptively place the reader back in the pool. This helps with // Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after // scenarios where the application does not call NextReader() soon after
// this final read. // this final read.
r.Close() if err := r.Close(); err != nil {
log.Printf("websocket: flateReadWrapper.Close() returned error: %v", err)
}
} }
return n, err return n, err
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"testing" "testing"
) )
@ -23,7 +22,9 @@ func TestTruncWriter(t *testing.T) {
if m > n { if m > n {
m = n m = n
} }
w.Write(p[:m]) if _, err := w.Write(p[:m]); err != nil {
t.Fatal(err)
}
p = p[m:] p = p[m:]
} }
if b.String() != data[:len(data)-len(w.p)] { if b.String() != data[:len(data)-len(w.p)] {
@ -42,25 +43,29 @@ func textMessages(num int) [][]byte {
} }
func BenchmarkWriteNoCompression(b *testing.B) { func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)]) if err := c.WriteMessage(TextMessage, messages[i%len(messages)]); err != nil {
b.Fatal(err)
}
} }
b.ReportAllocs() b.ReportAllocs()
} }
func BenchmarkWriteWithCompression(b *testing.B) { func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
c.enableWriteCompression = true c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover c.newCompressionWriter = compressNoContextTakeover
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)]) if err := c.WriteMessage(TextMessage, messages[i%len(messages)]); err != nil {
b.Fatal(err)
}
} }
b.ReportAllocs() b.ReportAllocs()
} }

75
conn.go
View File

@ -6,11 +6,11 @@ package websocket
import ( import (
"bufio" "bufio"
"crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"io/ioutil" "log"
"math/rand"
"net" "net"
"strconv" "strconv"
"strings" "strings"
@ -181,13 +181,20 @@ var (
errInvalidControlFrame = errors.New("websocket: invalid control frame") errInvalidControlFrame = errors.New("websocket: invalid control frame")
) )
// maskRand is an io.Reader for generating mask bytes. The reader is initialized
// to crypto/rand Reader. Tests swap the reader to a math/rand reader for
// reproducible results.
var maskRand = rand.Reader
// newMaskKey returns a new 32 bit value for masking client frames.
func newMaskKey() [4]byte { func newMaskKey() [4]byte {
n := rand.Uint32() var k [4]byte
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} _, _ = io.ReadFull(maskRand, k[:])
return k
} }
func hideTempErr(err error) error { func hideTempErr(err error) error {
if e, ok := err.(net.Error); ok && e.Temporary() { if e, ok := err.(net.Error); ok {
err = &netError{msg: e.Error(), timeout: e.Timeout()} err = &netError{msg: e.Error(), timeout: e.Timeout()}
} }
return err return err
@ -372,7 +379,9 @@ func (c *Conn) read(n int) ([]byte, error) {
if err == io.EOF { if err == io.EOF {
err = errUnexpectedEOF err = errUnexpectedEOF
} }
c.br.Discard(len(p)) if _, err := c.br.Discard(len(p)); err != nil {
return p, err
}
return p, err return p, err
} }
@ -387,7 +396,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return err return err
} }
c.conn.SetWriteDeadline(deadline) if err := c.conn.SetWriteDeadline(deadline); err != nil {
return c.writeFatal(err)
}
if len(buf1) == 0 { if len(buf1) == 0 {
_, err = c.conn.Write(buf0) _, err = c.conn.Write(buf0)
} else { } else {
@ -397,7 +408,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return c.writeFatal(err) return c.writeFatal(err)
} }
if frameType == CloseMessage { if frameType == CloseMessage {
c.writeFatal(ErrCloseSent) _ = c.writeFatal(ErrCloseSent)
} }
return nil return nil
} }
@ -438,7 +449,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
d := 1000 * time.Hour d := 1000 * time.Hour
if !deadline.IsZero() { if !deadline.IsZero() {
d = deadline.Sub(time.Now()) d = time.Until(deadline)
if d < 0 { if d < 0 {
return errWriteTimeout return errWriteTimeout
} }
@ -460,13 +471,15 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return err return err
} }
c.conn.SetWriteDeadline(deadline) if err := c.conn.SetWriteDeadline(deadline); err != nil {
return c.writeFatal(err)
}
_, err = c.conn.Write(buf) _, err = c.conn.Write(buf)
if err != nil { if err != nil {
return c.writeFatal(err) return c.writeFatal(err)
} }
if messageType == CloseMessage { if messageType == CloseMessage {
c.writeFatal(ErrCloseSent) _ = c.writeFatal(ErrCloseSent)
} }
return err return err
} }
@ -477,7 +490,9 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// probably better to return an error in this situation, but we cannot // probably better to return an error in this situation, but we cannot
// change this without breaking existing applications. // change this without breaking existing applications.
if c.writer != nil { if c.writer != nil {
c.writer.Close() if err := c.writer.Close(); err != nil {
log.Printf("websocket: discarding writer close error: %v", err)
}
c.writer = nil c.writer = nil
} }
@ -630,7 +645,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
} }
if final { if final {
w.endMessage(errWriteClosed) _ = w.endMessage(errWriteClosed)
return nil return nil
} }
@ -795,7 +810,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
if c.readRemaining > 0 { if c.readRemaining > 0 {
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err return noFrame, err
} }
} }
@ -817,7 +832,9 @@ func (c *Conn) advanceFrame() (int, error) {
rsv2 := p[0]&rsv2Bit != 0 rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0 rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0 mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f)) if err := c.setReadRemaining(int64(p[1] & 0x7f)); err != nil {
return noFrame, err
}
c.readDecompress = false c.readDecompress = false
if rsv1 { if rsv1 {
@ -922,7 +939,9 @@ func (c *Conn) advanceFrame() (int, error) {
} }
if c.readLimit > 0 && c.readLength > c.readLimit { if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) if err := c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)); err != nil {
return noFrame, err
}
return noFrame, ErrReadLimit return noFrame, ErrReadLimit
} }
@ -934,7 +953,9 @@ func (c *Conn) advanceFrame() (int, error) {
var payload []byte var payload []byte
if c.readRemaining > 0 { if c.readRemaining > 0 {
payload, err = c.read(int(c.readRemaining)) payload, err = c.read(int(c.readRemaining))
c.setReadRemaining(0) if err := c.setReadRemaining(0); err != nil {
return noFrame, err
}
if err != nil { if err != nil {
return noFrame, err return noFrame, err
} }
@ -981,7 +1002,9 @@ func (c *Conn) handleProtocolError(message string) error {
if len(data) > maxControlFramePayloadSize { if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize] data = data[:maxControlFramePayloadSize]
} }
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) if err := c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)); err != nil {
return err
}
return errors.New("websocket: " + message) return errors.New("websocket: " + message)
} }
@ -998,7 +1021,9 @@ func (c *Conn) handleProtocolError(message string) error {
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
// Close previous reader, only relevant for decompression. // Close previous reader, only relevant for decompression.
if c.reader != nil { if c.reader != nil {
c.reader.Close() if err := c.reader.Close(); err != nil {
log.Printf("websocket: discarding reader close error: %v", err)
}
c.reader = nil c.reader = nil
} }
@ -1054,7 +1079,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
} }
rem := c.readRemaining rem := c.readRemaining
rem -= int64(n) rem -= int64(n)
c.setReadRemaining(rem) if err := c.setReadRemaining(rem); err != nil {
return 0, err
}
if c.readRemaining > 0 && c.readErr == io.EOF { if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF c.readErr = errUnexpectedEOF
} }
@ -1094,7 +1121,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
if err != nil { if err != nil {
return messageType, nil, err return messageType, nil, err
} }
p, err = ioutil.ReadAll(r) p, err = io.ReadAll(r)
return messageType, p, err return messageType, p, err
} }
@ -1136,7 +1163,9 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
if h == nil { if h == nil {
h = func(code int, text string) error { h = func(code int, text string) error {
message := FormatCloseMessage(code, "") message := FormatCloseMessage(code, "")
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) if err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)); err != nil {
return err
}
return nil return nil
} }
} }
@ -1161,7 +1190,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
if err == ErrCloseSent { if err == ErrCloseSent {
return nil return nil
} else if e, ok := err.(net.Error); ok && e.Temporary() { } else if _, ok := err.(net.Error); ok {
return nil return nil
} }
return err return err

View File

@ -6,7 +6,6 @@ package websocket
import ( import (
"io" "io"
"io/ioutil"
"sync/atomic" "sync/atomic"
"testing" "testing"
) )
@ -45,7 +44,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
func newBroadcastBench(usePrepared, compression bool) *broadcastBench { func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
bench := &broadcastBench{ bench := &broadcastBench{
w: ioutil.Discard, w: io.Discard,
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
usePrepared: usePrepared, usePrepared: usePrepared,
@ -70,9 +69,13 @@ func (b *broadcastBench) makeConns(numConns int) {
select { select {
case msg := <-c.msgCh: case msg := <-c.msgCh:
if msg.prepared != nil { if msg.prepared != nil {
c.conn.WritePreparedMessage(msg.prepared) if err := c.conn.WritePreparedMessage(msg.prepared); err != nil {
panic(err)
}
} else { } else {
c.conn.WriteMessage(TextMessage, msg.payload) if err := c.conn.WriteMessage(TextMessage, msg.payload); err != nil {
panic(err)
}
} }
val := atomic.AddInt32(&b.count, 1) val := atomic.AddInt32(&b.count, 1)
if val%int32(numConns) == 0 { if val%int32(numConns) == 0 {

View File

@ -10,7 +10,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"reflect" "reflect"
"sync" "sync"
@ -124,8 +123,7 @@ func TestFraming(t *testing.T) {
continue continue
} }
t.Logf("frame size: %d", n) rbuf, err := io.ReadAll(r)
rbuf, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue continue
@ -158,7 +156,10 @@ func TestControl(t *testing.T) {
wc := newTestConn(nil, &connBuf, isServer) wc := newTestConn(nil, &connBuf, isServer)
rc := newTestConn(&connBuf, nil, !isServer) rc := newTestConn(&connBuf, nil, !isServer)
if isWriteControl { if isWriteControl {
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) if err := wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)); err != nil {
t.Errorf("%s: wc.WriteControl() returned %v", name, err)
continue
}
} else { } else {
w, err := wc.NextWriter(PongMessage) w, err := wc.NextWriter(PongMessage)
if err != nil { if err != nil {
@ -175,7 +176,9 @@ func TestControl(t *testing.T) {
} }
var actualMessage string var actualMessage string
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
rc.NextReader() if _, _, err := rc.NextReader(); err != nil {
continue
}
if actualMessage != message { if actualMessage != message {
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
continue continue
@ -359,15 +362,19 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
rc := newTestConn(&b1, &b2, true) rc := newTestConn(&b1, &b2, true)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2)) if _, err := w.Write(make([]byte, bufSize+bufSize/2)); err != nil {
wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) t.Fatalf("w.Write() returned %v", err)
}
if err := wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)); err != nil {
t.Fatalf("wc.WriteControl() returned %v", err)
}
w.Close() w.Close()
op, r, err := rc.NextReader() op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if !reflect.DeepEqual(err, expectedErr) { if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
} }
@ -386,7 +393,9 @@ func TestEOFWithinFrame(t *testing.T) {
rc := newTestConn(&b, nil, true) rc := newTestConn(&b, nil, true)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize)) if _, err := w.Write(make([]byte, bufSize)); err != nil {
t.Fatalf("%d: w.Write() returned %v", n, err)
}
w.Close() w.Close()
if n >= b.Len() { if n >= b.Len() {
@ -401,7 +410,7 @@ func TestEOFWithinFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
} }
@ -420,13 +429,15 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
rc := newTestConn(&b1, &b2, true) rc := newTestConn(&b1, &b2, true)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2)) if _, err := w.Write(make([]byte, bufSize+bufSize/2)); err != nil {
t.Fatalf("w.Write() returned %v", err)
}
op, r, err := rc.NextReader() op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
} }
@ -439,7 +450,9 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
func TestWriteAfterMessageWriterClose(t *testing.T) { func TestWriteAfterMessageWriterClose(t *testing.T) {
wc := newTestConn(nil, &bytes.Buffer{}, false) wc := newTestConn(nil, &bytes.Buffer{}, false)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello") if _, err := io.WriteString(w, "hello"); err != nil {
t.Fatalf("unexpected error writing, %v", err)
}
if err := w.Close(); err != nil { if err := w.Close(); err != nil {
t.Fatalf("unxpected error closing message writer, %v", err) t.Fatalf("unxpected error closing message writer, %v", err)
} }
@ -449,7 +462,9 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
} }
w, _ = wc.NextWriter(BinaryMessage) w, _ = wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello") if _, err := io.WriteString(w, "hello"); err != nil {
t.Fatalf("unexpected error writing after getting new writer, %v", err)
}
// close w by getting next writer // close w by getting next writer
_, err := wc.NextWriter(BinaryMessage) _, err := wc.NextWriter(BinaryMessage)
@ -474,13 +489,21 @@ func TestReadLimit(t *testing.T) {
// Send message at the limit with interleaved pong. // Send message at the limit with interleaved pong.
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(message[:readLimit-1]) if _, err := w.Write(message[:readLimit-1]); err != nil {
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) t.Fatalf("w.WriteMessage() returned %v", err)
w.Write(message[:1]) }
if err := wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)); err != nil {
t.Fatalf("wc.WriteControl() returned %v", err)
}
if _, err := w.Write(message[:1]); err != nil {
t.Fatalf("w.Write() returned %v", err)
}
w.Close() w.Close()
// Send message larger than the limit. // Send message larger than the limit.
wc.WriteMessage(BinaryMessage, message[:readLimit+1]) if err := wc.WriteMessage(BinaryMessage, message[:readLimit+1]); err != nil {
t.Fatalf("wc.WriteMessage() returned %v", err)
}
op, _, err := rc.NextReader() op, _, err := rc.NextReader()
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
@ -490,7 +513,7 @@ func TestReadLimit(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err) t.Fatalf("2: NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != ErrReadLimit { if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err) t.Fatalf("io.Copy() returned %v", err)
} }
@ -593,7 +616,9 @@ func TestBufioReadBytes(t *testing.T) {
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(m) if _, err := w.Write(m); err != nil {
t.Fatalf("w.Write() returned %v", err)
}
w.Close() w.Close()
op, r, err := rc.NextReader() op, r, err := rc.NextReader()
@ -667,7 +692,9 @@ func TestConcurrentWritePanic(t *testing.T) {
w := blockingWriter{make(chan struct{}), make(chan struct{})} w := blockingWriter{make(chan struct{}), make(chan struct{})}
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
go func() { go func() {
c.WriteMessage(TextMessage, []byte{}) if err := c.WriteMessage(TextMessage, []byte{}); err != nil {
t.Error(err)
}
}() }()
// wait for goroutine to block in write. // wait for goroutine to block in write.
@ -680,7 +707,9 @@ func TestConcurrentWritePanic(t *testing.T) {
} }
}() }()
c.WriteMessage(TextMessage, []byte{}) if err := c.WriteMessage(TextMessage, []byte{}); err != nil {
t.Error(err)
}
t.Fatal("should not get here") t.Fatal("should not get here")
} }
@ -700,7 +729,7 @@ func TestFailedConnectionReadPanic(t *testing.T) {
}() }()
for i := 0; i < 20000; i++ { for i := 0; i < 20000; i++ {
c.ReadMessage() _, _, _ = c.ReadMessage()
} }
t.Fatal("should not get here") t.Fatal("should not get here")
} }

View File

@ -178,7 +178,11 @@ func main() {
http.HandleFunc("/r", echoReadAllWriter) http.HandleFunc("/r", echoReadAllWriter)
http.HandleFunc("/m", echoReadAllWriteMessage) http.HandleFunc("/m", echoReadAllWriteMessage)
http.HandleFunc("/p", echoReadAllWritePreparedMessage) http.HandleFunc("/p", echoReadAllWritePreparedMessage)
err := http.ListenAndServe(*addr, nil) server := &http.Server{
Addr: *addr,
ReadHeaderTimeout: 3 * time.Second,
}
err := server.ListenAndServe()
if err != nil { if err != nil {
log.Fatal("ListenAndServe: ", err) log.Fatal("ListenAndServe: ", err)
} }

View File

@ -8,6 +8,7 @@ import (
"flag" "flag"
"log" "log"
"net/http" "net/http"
"time"
) )
var addr = flag.String("addr", ":8080", "http service address") var addr = flag.String("addr", ":8080", "http service address")
@ -33,7 +34,11 @@ func main() {
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
serveWs(hub, w, r) serveWs(hub, w, r)
}) })
err := http.ListenAndServe(*addr, nil) server := &http.Server{
Addr: *addr,
ReadHeaderTimeout: 3 * time.Second,
}
err := server.ListenAndServe()
if err != nil { if err != nil {
log.Fatal("ListenAndServe: ", err) log.Fatal("ListenAndServe: ", err)
} }

View File

@ -189,5 +189,9 @@ func main() {
} }
http.HandleFunc("/", serveHome) http.HandleFunc("/", serveHome)
http.HandleFunc("/ws", serveWs) http.HandleFunc("/ws", serveWs)
log.Fatal(http.ListenAndServe(*addr, nil)) server := &http.Server{
Addr: *addr,
ReadHeaderTimeout: 3 * time.Second,
}
log.Fatal(server.ListenAndServe())
} }

View File

@ -11,6 +11,7 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"path/filepath"
"strconv" "strconv"
"time" "time"
@ -49,7 +50,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
if !fi.ModTime().After(lastMod) { if !fi.ModTime().After(lastMod) {
return nil, lastMod, nil return nil, lastMod, nil
} }
p, err := ioutil.ReadFile(filename) p, err := ioutil.ReadFile(filepath.Clean(filename))
if err != nil { if err != nil {
return nil, fi.ModTime(), err return nil, fi.ModTime(), err
} }
@ -163,7 +164,11 @@ func main() {
filename = flag.Args()[0] filename = flag.Args()[0]
http.HandleFunc("/", serveHome) http.HandleFunc("/", serveHome)
http.HandleFunc("/ws", serveWs) http.HandleFunc("/ws", serveWs)
if err := http.ListenAndServe(*addr, nil); err != nil { server := &http.Server{
Addr: *addr,
ReadHeaderTimeout: 3 * time.Second,
}
if err := server.ListenAndServe(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }

4
go.mod
View File

@ -1,3 +1,5 @@
module github.com/gorilla/websocket module github.com/gorilla/websocket
go 1.12 go 1.19
require golang.org/x/net v0.14.0

2
go.sum
View File

@ -0,0 +1,2 @@
golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14=
golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=

View File

@ -19,7 +19,9 @@ func TestJoinMessages(t *testing.T) {
wc := newTestConn(nil, &connBuf, true) wc := newTestConn(nil, &connBuf, true)
rc := newTestConn(&connBuf, nil, false) rc := newTestConn(&connBuf, nil, false)
for _, m := range messages { for _, m := range messages {
wc.WriteMessage(BinaryMessage, []byte(m)) if err := wc.WriteMessage(BinaryMessage, []byte(m)); err != nil {
t.Fatalf("write %q: %v", m, err)
}
} }
var result bytes.Buffer var result bytes.Buffer

View File

@ -9,6 +9,7 @@ package websocket
import "unsafe" import "unsafe"
// #nosec G103 -- (CWE-242) Has been audited
const wordSize = int(unsafe.Sizeof(uintptr(0))) const wordSize = int(unsafe.Sizeof(uintptr(0)))
func maskBytes(key [4]byte, pos int, b []byte) int { func maskBytes(key [4]byte, pos int, b []byte) int {
@ -22,6 +23,7 @@ func maskBytes(key [4]byte, pos int, b []byte) int {
} }
// Mask one byte at a time to word boundary. // Mask one byte at a time to word boundary.
//#nosec G103 -- (CWE-242) Has been audited
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n n = wordSize - n
for i := range b[:n] { for i := range b[:n] {
@ -36,11 +38,13 @@ func maskBytes(key [4]byte, pos int, b []byte) int {
for i := range k { for i := range k {
k[i] = key[(pos+i)&3] k[i] = key[(pos+i)&3]
} }
//#nosec G103 -- (CWE-242) Has been audited
kw := *(*uintptr)(unsafe.Pointer(&k)) kw := *(*uintptr)(unsafe.Pointer(&k))
// Mask one word at a time. // Mask one word at a time.
n := (len(b) / wordSize) * wordSize n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize { for i := 0; i < n; i += wordSize {
//#nosec G103 -- (CWE-242) Has been audited
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
} }

View File

@ -33,6 +33,11 @@ var preparedMessageTests = []struct {
} }
func TestPreparedMessage(t *testing.T) { func TestPreparedMessage(t *testing.T) {
testRand := rand.New(rand.NewSource(99))
prevMaskRand := maskRand
maskRand = testRand
defer func() { maskRand = prevMaskRand }()
for _, tt := range preparedMessageTests { for _, tt := range preparedMessageTests {
var data = []byte("this is a test") var data = []byte("this is a test")
var buf bytes.Buffer var buf bytes.Buffer
@ -40,10 +45,13 @@ func TestPreparedMessage(t *testing.T) {
if tt.enableWriteCompression { if tt.enableWriteCompression {
c.newCompressionWriter = compressNoContextTakeover c.newCompressionWriter = compressNoContextTakeover
} }
c.SetCompressionLevel(tt.compressionLevel)
if err := c.SetCompressionLevel(tt.compressionLevel); err != nil {
t.Fatal(err)
}
// Seed random number generator for consistent frame mask. // Seed random number generator for consistent frame mask.
rand.Seed(1234) testRand.Seed(1234)
if err := c.WriteMessage(tt.messageType, data); err != nil { if err := c.WriteMessage(tt.messageType, data); err != nil {
t.Fatal(err) t.Fatal(err)
@ -59,7 +67,7 @@ func TestPreparedMessage(t *testing.T) {
copy(data, "hello world") copy(data, "hello world")
// Seed random number generator for consistent frame mask. // Seed random number generator for consistent frame mask.
rand.Seed(1234) testRand.Seed(1234)
buf.Reset() buf.Reset()
if err := c.WritePreparedMessage(pm); err != nil { if err := c.WritePreparedMessage(pm); err != nil {
@ -68,7 +76,7 @@ func TestPreparedMessage(t *testing.T) {
got := buf.String() got := buf.String()
if got != want { if got != want {
t.Errorf("write message != prepared message for %+v", tt) t.Errorf("write message != prepared message, got %#v, want %#v", got, want)
} }
} }
} }

View File

@ -8,10 +8,13 @@ import (
"bufio" "bufio"
"encoding/base64" "encoding/base64"
"errors" "errors"
"log"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"golang.org/x/net/proxy"
) )
type netDialerFunc func(network, addr string) (net.Conn, error) type netDialerFunc func(network, addr string) (net.Conn, error)
@ -21,7 +24,7 @@ func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
} }
func init() { func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { proxy.RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy.Dialer) (proxy.Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
}) })
} }
@ -55,7 +58,9 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
} }
if err := connectReq.Write(conn); err != nil { if err := connectReq.Write(conn); err != nil {
conn.Close() if err := conn.Close(); err != nil {
log.Printf("httpProxyDialer: failed to close connection: %v", err)
}
return nil, err return nil, err
} }
@ -64,12 +69,16 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, connectReq) resp, err := http.ReadResponse(br, connectReq)
if err != nil { if err != nil {
conn.Close() if err := conn.Close(); err != nil {
log.Printf("httpProxyDialer: failed to close connection: %v", err)
}
return nil, err return nil, err
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
conn.Close() if err := conn.Close(); err != nil {
log.Printf("httpProxyDialer: failed to close connection: %v", err)
}
f := strings.SplitN(resp.Status, " ", 2) f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1]) return nil, errors.New(f[1])
} }

View File

@ -8,6 +8,7 @@ import (
"bufio" "bufio"
"errors" "errors"
"io" "io"
"log"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -183,7 +184,9 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
if brw.Reader.Buffered() > 0 { if brw.Reader.Buffered() > 0 {
netConn.Close() if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
return nil, errors.New("websocket: client sent data before handshake is complete") return nil, errors.New("websocket: client sent data before handshake is complete")
} }
@ -248,17 +251,34 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server. // Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{}) if err := netConn.SetDeadline(time.Time{}); err != nil {
if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
return nil, err
}
if u.HandshakeTimeout > 0 { if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) if err := netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)); err != nil {
if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
return nil, err
}
} }
if _, err = netConn.Write(p); err != nil { if _, err = netConn.Write(p); err != nil {
netConn.Close() if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
return nil, err return nil, err
} }
if u.HandshakeTimeout > 0 { if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{}) if err := netConn.SetWriteDeadline(time.Time{}); err != nil {
if err := netConn.Close(); err != nil {
log.Printf("websocket: failed to close network connection: %v", err)
}
return nil, err
}
} }
return c, nil return c, nil
@ -356,8 +376,12 @@ func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// bufio.Writer's underlying writer. // bufio.Writer's underlying writer.
var wh writeHook var wh writeHook
bw.Reset(&wh) bw.Reset(&wh)
bw.WriteByte(0) if err := bw.WriteByte(0); err != nil {
bw.Flush() panic(err)
}
if err := bw.Flush(); err != nil {
log.Printf("websocket: bufioWriterBuffer: Flush: %v", err)
}
bw.Reset(originalWriter) bw.Reset(originalWriter)

View File

@ -1,6 +1,3 @@
//go:build go1.17
// +build go1.17
package websocket package websocket
import ( import (

View File

@ -1,21 +0,0 @@
//go:build !go1.17
// +build !go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@ -6,7 +6,7 @@ package websocket
import ( import (
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1" //#nosec G505 -- (CWE-327) https://datatracker.ietf.org/doc/html/rfc6455#page-54
"encoding/base64" "encoding/base64"
"io" "io"
"net/http" "net/http"
@ -17,7 +17,7 @@ import (
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string { func computeAcceptKey(challengeKey string) string {
h := sha1.New() h := sha1.New() //#nosec G401 -- (CWE-326) https://datatracker.ietf.org/doc/html/rfc6455#page-54
h.Write([]byte(challengeKey)) h.Write([]byte(challengeKey))
h.Write(keyGUID) h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil)) return base64.StdEncoding.EncodeToString(h.Sum(nil))

27
vendor/golang.org/x/net/LICENSE generated vendored Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
vendor/golang.org/x/net/PATENTS generated vendored Normal file
View File

@ -0,0 +1,22 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

168
vendor/golang.org/x/net/internal/socks/client.go generated vendored Normal file
View File

@ -0,0 +1,168 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package socks
import (
"context"
"errors"
"io"
"net"
"strconv"
"time"
)
var (
noDeadline = time.Time{}
aLongTimeAgo = time.Unix(1, 0)
)
func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
host, port, err := splitHostPort(address)
if err != nil {
return nil, err
}
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
c.SetDeadline(deadline)
defer c.SetDeadline(noDeadline)
}
if ctx != context.Background() {
errCh := make(chan error, 1)
done := make(chan struct{})
defer func() {
close(done)
if ctxErr == nil {
ctxErr = <-errCh
}
}()
go func() {
select {
case <-ctx.Done():
c.SetDeadline(aLongTimeAgo)
errCh <- ctx.Err()
case <-done:
errCh <- nil
}
}()
}
b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
b = append(b, Version5)
if len(d.AuthMethods) == 0 || d.Authenticate == nil {
b = append(b, 1, byte(AuthMethodNotRequired))
} else {
ams := d.AuthMethods
if len(ams) > 255 {
return nil, errors.New("too many authentication methods")
}
b = append(b, byte(len(ams)))
for _, am := range ams {
b = append(b, byte(am))
}
}
if _, ctxErr = c.Write(b); ctxErr != nil {
return
}
if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
return
}
if b[0] != Version5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
am := AuthMethod(b[1])
if am == AuthMethodNoAcceptableMethods {
return nil, errors.New("no acceptable authentication methods")
}
if d.Authenticate != nil {
if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
return
}
}
b = b[:0]
b = append(b, Version5, byte(d.cmd), 0)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
b = append(b, AddrTypeIPv4)
b = append(b, ip4...)
} else if ip6 := ip.To16(); ip6 != nil {
b = append(b, AddrTypeIPv6)
b = append(b, ip6...)
} else {
return nil, errors.New("unknown address type")
}
} else {
if len(host) > 255 {
return nil, errors.New("FQDN too long")
}
b = append(b, AddrTypeFQDN)
b = append(b, byte(len(host)))
b = append(b, host...)
}
b = append(b, byte(port>>8), byte(port))
if _, ctxErr = c.Write(b); ctxErr != nil {
return
}
if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
return
}
if b[0] != Version5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
return nil, errors.New("unknown error " + cmdErr.String())
}
if b[2] != 0 {
return nil, errors.New("non-zero reserved field")
}
l := 2
var a Addr
switch b[3] {
case AddrTypeIPv4:
l += net.IPv4len
a.IP = make(net.IP, net.IPv4len)
case AddrTypeIPv6:
l += net.IPv6len
a.IP = make(net.IP, net.IPv6len)
case AddrTypeFQDN:
if _, err := io.ReadFull(c, b[:1]); err != nil {
return nil, err
}
l += int(b[0])
default:
return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
}
if cap(b) < l {
b = make([]byte, l)
} else {
b = b[:l]
}
if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
return
}
if a.IP != nil {
copy(a.IP, b)
} else {
a.Name = string(b[:len(b)-2])
}
a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
return &a, nil
}
func splitHostPort(address string) (string, int, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return "", 0, err
}
portnum, err := strconv.Atoi(port)
if err != nil {
return "", 0, err
}
if 1 > portnum || portnum > 0xffff {
return "", 0, errors.New("port number out of range " + port)
}
return host, portnum, nil
}

317
vendor/golang.org/x/net/internal/socks/socks.go generated vendored Normal file
View File

@ -0,0 +1,317 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package socks provides a SOCKS version 5 client implementation.
//
// SOCKS protocol version 5 is defined in RFC 1928.
// Username/Password authentication for SOCKS version 5 is defined in
// RFC 1929.
package socks
import (
"context"
"errors"
"io"
"net"
"strconv"
)
// A Command represents a SOCKS command.
type Command int
func (cmd Command) String() string {
switch cmd {
case CmdConnect:
return "socks connect"
case cmdBind:
return "socks bind"
default:
return "socks " + strconv.Itoa(int(cmd))
}
}
// An AuthMethod represents a SOCKS authentication method.
type AuthMethod int
// A Reply represents a SOCKS command reply code.
type Reply int
func (code Reply) String() string {
switch code {
case StatusSucceeded:
return "succeeded"
case 0x01:
return "general SOCKS server failure"
case 0x02:
return "connection not allowed by ruleset"
case 0x03:
return "network unreachable"
case 0x04:
return "host unreachable"
case 0x05:
return "connection refused"
case 0x06:
return "TTL expired"
case 0x07:
return "command not supported"
case 0x08:
return "address type not supported"
default:
return "unknown code: " + strconv.Itoa(int(code))
}
}
// Wire protocol constants.
const (
Version5 = 0x05
AddrTypeIPv4 = 0x01
AddrTypeFQDN = 0x03
AddrTypeIPv6 = 0x04
CmdConnect Command = 0x01 // establishes an active-open forward proxy connection
cmdBind Command = 0x02 // establishes a passive-open forward proxy connection
AuthMethodNotRequired AuthMethod = 0x00 // no authentication required
AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password
AuthMethodNoAcceptableMethods AuthMethod = 0xff // no acceptable authentication methods
StatusSucceeded Reply = 0x00
)
// An Addr represents a SOCKS-specific address.
// Either Name or IP is used exclusively.
type Addr struct {
Name string // fully-qualified domain name
IP net.IP
Port int
}
func (a *Addr) Network() string { return "socks" }
func (a *Addr) String() string {
if a == nil {
return "<nil>"
}
port := strconv.Itoa(a.Port)
if a.IP == nil {
return net.JoinHostPort(a.Name, port)
}
return net.JoinHostPort(a.IP.String(), port)
}
// A Conn represents a forward proxy connection.
type Conn struct {
net.Conn
boundAddr net.Addr
}
// BoundAddr returns the address assigned by the proxy server for
// connecting to the command target address from the proxy server.
func (c *Conn) BoundAddr() net.Addr {
if c == nil {
return nil
}
return c.boundAddr
}
// A Dialer holds SOCKS-specific options.
type Dialer struct {
cmd Command // either CmdConnect or cmdBind
proxyNetwork string // network between a proxy server and a client
proxyAddress string // proxy server address
// ProxyDial specifies the optional dial function for
// establishing the transport connection.
ProxyDial func(context.Context, string, string) (net.Conn, error)
// AuthMethods specifies the list of request authentication
// methods.
// If empty, SOCKS client requests only AuthMethodNotRequired.
AuthMethods []AuthMethod
// Authenticate specifies the optional authentication
// function. It must be non-nil when AuthMethods is not empty.
// It must return an error when the authentication is failed.
Authenticate func(context.Context, io.ReadWriter, AuthMethod) error
}
// DialContext connects to the provided address on the provided
// network.
//
// The returned error value may be a net.OpError. When the Op field of
// net.OpError contains "socks", the Source field contains a proxy
// server address and the Addr field contains a command target
// address.
//
// See func Dial of the net package of standard library for a
// description of the network and address parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if ctx == nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
}
var err error
var c net.Conn
if d.ProxyDial != nil {
c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
} else {
var dd net.Dialer
c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
}
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
a, err := d.connect(ctx, c, address)
if err != nil {
c.Close()
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
return &Conn{Conn: c, boundAddr: a}, nil
}
// DialWithConn initiates a connection from SOCKS server to the target
// network and address using the connection c that is already
// connected to the SOCKS server.
//
// It returns the connection's local address assigned by the SOCKS
// server.
func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if ctx == nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
}
a, err := d.connect(ctx, c, address)
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
return a, nil
}
// Dial connects to the provided address on the provided network.
//
// Unlike DialContext, it returns a raw transport connection instead
// of a forward proxy connection.
//
// Deprecated: Use DialContext or DialWithConn instead.
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
if err := d.validateTarget(network, address); err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
var err error
var c net.Conn
if d.ProxyDial != nil {
c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
} else {
c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
}
if err != nil {
proxy, dst, _ := d.pathAddrs(address)
return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
}
if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
c.Close()
return nil, err
}
return c, nil
}
func (d *Dialer) validateTarget(network, address string) error {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return errors.New("network not implemented")
}
switch d.cmd {
case CmdConnect, cmdBind:
default:
return errors.New("command not implemented")
}
return nil
}
func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
for i, s := range []string{d.proxyAddress, address} {
host, port, err := splitHostPort(s)
if err != nil {
return nil, nil, err
}
a := &Addr{Port: port}
a.IP = net.ParseIP(host)
if a.IP == nil {
a.Name = host
}
if i == 0 {
proxy = a
} else {
dst = a
}
}
return
}
// NewDialer returns a new Dialer that dials through the provided
// proxy server's network and address.
func NewDialer(network, address string) *Dialer {
return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect}
}
const (
authUsernamePasswordVersion = 0x01
authStatusSucceeded = 0x00
)
// UsernamePassword are the credentials for the username/password
// authentication method.
type UsernamePassword struct {
Username string
Password string
}
// Authenticate authenticates a pair of username and password with the
// proxy server.
func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth AuthMethod) error {
switch auth {
case AuthMethodNotRequired:
return nil
case AuthMethodUsernamePassword:
if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) > 255 {
return errors.New("invalid username/password")
}
b := []byte{authUsernamePasswordVersion}
b = append(b, byte(len(up.Username)))
b = append(b, up.Username...)
b = append(b, byte(len(up.Password)))
b = append(b, up.Password...)
// TODO(mikio): handle IO deadlines and cancelation if
// necessary
if _, err := rw.Write(b); err != nil {
return err
}
if _, err := io.ReadFull(rw, b[:2]); err != nil {
return err
}
if b[0] != authUsernamePasswordVersion {
return errors.New("invalid username/password version")
}
if b[1] != authStatusSucceeded {
return errors.New("username/password authentication failed")
}
return nil
}
return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
}

54
vendor/golang.org/x/net/proxy/dial.go generated vendored Normal file
View File

@ -0,0 +1,54 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"context"
"net"
)
// A ContextDialer dials using a context.
type ContextDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment.
//
// The passed ctx is only used for returning the Conn, not the lifetime of the Conn.
//
// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
//
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
func Dial(ctx context.Context, network, address string) (net.Conn, error) {
d := FromEnvironment()
if xd, ok := d.(ContextDialer); ok {
return xd.DialContext(ctx, network, address)
}
return dialContext(ctx, d, network, address)
}
// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) {
var (
conn net.Conn
done = make(chan struct{}, 1)
err error
)
go func() {
conn, err = d.Dial(network, address)
close(done)
if conn != nil && ctx.Err() != nil {
conn.Close()
}
}()
select {
case <-ctx.Done():
err = ctx.Err()
case <-done:
}
return conn, err
}

31
vendor/golang.org/x/net/proxy/direct.go generated vendored Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"context"
"net"
)
type direct struct{}
// Direct implements Dialer by making network connections directly using net.Dial or net.DialContext.
var Direct = direct{}
var (
_ Dialer = Direct
_ ContextDialer = Direct
)
// Dial directly invokes net.Dial with the supplied parameters.
func (direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters.
func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, addr)
}

155
vendor/golang.org/x/net/proxy/per_host.go generated vendored Normal file
View File

@ -0,0 +1,155 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"context"
"net"
"strings"
)
// A PerHost directs connections to a default Dialer unless the host name
// requested matches one of a number of exceptions.
type PerHost struct {
def, bypass Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
return &PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the given network through either
// defaultDialer or bypass.
func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
// DialContext connects to the address addr on the given network through either
// defaultDialer or bypass.
func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
d := p.dialerForRequest(host)
if x, ok := d.(ContextDialer); ok {
return x.DialContext(ctx, network, addr)
}
return dialContext(ctx, d, network, addr)
}
func (p *PerHost) dialerForRequest(host string) Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone ".example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifying hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a host name
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match.
func (p *PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a host name that will use the bypass proxy.
func (p *PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}

149
vendor/golang.org/x/net/proxy/proxy.go generated vendored Normal file
View File

@ -0,0 +1,149 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package proxy provides support for a variety of protocols to proxy network
// data.
package proxy // import "golang.org/x/net/proxy"
import (
"errors"
"net"
"net/url"
"os"
"sync"
)
// A Dialer is a means to establish a connection.
// Custom dialers should also implement ContextDialer.
type Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type Auth struct {
User, Password string
}
// FromEnvironment returns the dialer specified by the proxy-related
// variables in the environment and makes underlying connections
// directly.
func FromEnvironment() Dialer {
return FromEnvironmentUsing(Direct)
}
// FromEnvironmentUsing returns the dialer specify by the proxy-related
// variables in the environment and makes underlying connections
// using the provided forwarding Dialer (for instance, a *net.Dialer
// with desired configuration).
func FromEnvironmentUsing(forward Dialer) Dialer {
allProxy := allProxyEnv.Get()
if len(allProxy) == 0 {
return forward
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return forward
}
proxy, err := FromURL(proxyURL, forward)
if err != nil {
return forward
}
noProxy := noProxyEnv.Get()
if len(noProxy) == 0 {
return proxy
}
perHost := NewPerHost(proxy, forward)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxySchemes map[string]func(*url.URL, Dialer) (Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func RegisterDialerType(scheme string, f func(*url.URL, Dialer) (Dialer, error)) {
if proxySchemes == nil {
proxySchemes = make(map[string]func(*url.URL, Dialer) (Dialer, error))
}
proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func FromURL(u *url.URL, forward Dialer) (Dialer, error) {
var auth *Auth
if u.User != nil {
auth = new(Auth)
auth.User = u.User.Username()
if p, ok := u.User.Password(); ok {
auth.Password = p
}
}
switch u.Scheme {
case "socks5", "socks5h":
addr := u.Hostname()
port := u.Port()
if port == "" {
port = "1080"
}
return SOCKS5("tcp", net.JoinHostPort(addr, port), auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxySchemes != nil {
if f, ok := proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
var (
allProxyEnv = &envOnce{
names: []string{"ALL_PROXY", "all_proxy"},
}
noProxyEnv = &envOnce{
names: []string{"NO_PROXY", "no_proxy"},
}
)
// envOnce looks up an environment variable (optionally by multiple
// names) once. It mitigates expensive lookups on some platforms
// (e.g. Windows).
// (Borrowed from net/http/transport.go)
type envOnce struct {
names []string
once sync.Once
val string
}
func (e *envOnce) Get() string {
e.once.Do(e.init)
return e.val
}
func (e *envOnce) init() {
for _, n := range e.names {
e.val = os.Getenv(n)
if e.val != "" {
return
}
}
}
// reset is used by tests
func (e *envOnce) reset() {
e.once = sync.Once{}
e.val = ""
}

42
vendor/golang.org/x/net/proxy/socks5.go generated vendored Normal file
View File

@ -0,0 +1,42 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package proxy
import (
"context"
"net"
"golang.org/x/net/internal/socks"
)
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given
// address with an optional username and password.
// See RFC 1928 and RFC 1929.
func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) {
d := socks.NewDialer(network, address)
if forward != nil {
if f, ok := forward.(ContextDialer); ok {
d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
return f.DialContext(ctx, network, address)
}
} else {
d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
return dialContext(ctx, forward, network, address)
}
}
}
if auth != nil {
up := socks.UsernamePassword{
Username: auth.User,
Password: auth.Password,
}
d.AuthMethods = []socks.AuthMethod{
socks.AuthMethodNotRequired,
socks.AuthMethodUsernamePassword,
}
d.Authenticate = up.Authenticate
}
return d, nil
}

4
vendor/modules.txt vendored Normal file
View File

@ -0,0 +1,4 @@
# golang.org/x/net v0.14.0
## explicit; go 1.17
golang.org/x/net/internal/socks
golang.org/x/net/proxy

View File

@ -1,473 +0,0 @@
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy
// Package proxy provides support for a variety of protocols to proxy network
// data.
//
package websocket
import (
"errors"
"io"
"net"
"net/url"
"os"
"strconv"
"strings"
"sync"
)
type proxy_direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var proxy_Direct = proxy_direct{}
func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
// A PerHost directs connections to a default Dialer unless the host name
// requested matches one of a number of exceptions.
type proxy_PerHost struct {
def, bypass proxy_Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost {
return &proxy_PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the given network through either
// defaultDialer or bypass.
func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone ".example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifying hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a host name
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *proxy_PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *proxy_PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match.
func (p *proxy_PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *proxy_PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a host name that will use the bypass proxy.
func (p *proxy_PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}
// A Dialer is a means to establish a connection.
type proxy_Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type proxy_Auth struct {
User, Password string
}
// FromEnvironment returns the dialer specified by the proxy related variables in
// the environment.
func proxy_FromEnvironment() proxy_Dialer {
allProxy := proxy_allProxyEnv.Get()
if len(allProxy) == 0 {
return proxy_Direct
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return proxy_Direct
}
proxy, err := proxy_FromURL(proxyURL, proxy_Direct)
if err != nil {
return proxy_Direct
}
noProxy := proxy_noProxyEnv.Get()
if len(noProxy) == 0 {
return proxy
}
perHost := proxy_NewPerHost(proxy, proxy_Direct)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) {
if proxy_proxySchemes == nil {
proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error))
}
proxy_proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) {
var auth *proxy_Auth
if u.User != nil {
auth = new(proxy_Auth)
auth.User = u.User.Username()
if p, ok := u.User.Password(); ok {
auth.Password = p
}
}
switch u.Scheme {
case "socks5":
return proxy_SOCKS5("tcp", u.Host, auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxy_proxySchemes != nil {
if f, ok := proxy_proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
var (
proxy_allProxyEnv = &proxy_envOnce{
names: []string{"ALL_PROXY", "all_proxy"},
}
proxy_noProxyEnv = &proxy_envOnce{
names: []string{"NO_PROXY", "no_proxy"},
}
)
// envOnce looks up an environment variable (optionally by multiple
// names) once. It mitigates expensive lookups on some platforms
// (e.g. Windows).
// (Borrowed from net/http/transport.go)
type proxy_envOnce struct {
names []string
once sync.Once
val string
}
func (e *proxy_envOnce) Get() string {
e.once.Do(e.init)
return e.val
}
func (e *proxy_envOnce) init() {
for _, n := range e.names {
e.val = os.Getenv(n)
if e.val != "" {
return
}
}
}
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
// with an optional username and password. See RFC 1928 and RFC 1929.
func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) {
s := &proxy_socks5{
network: network,
addr: addr,
forward: forward,
}
if auth != nil {
s.user = auth.User
s.password = auth.Password
}
return s, nil
}
type proxy_socks5 struct {
user, password string
network, addr string
forward proxy_Dialer
}
const proxy_socks5Version = 5
const (
proxy_socks5AuthNone = 0
proxy_socks5AuthPassword = 2
)
const proxy_socks5Connect = 1
const (
proxy_socks5IP4 = 1
proxy_socks5Domain = 3
proxy_socks5IP6 = 4
)
var proxy_socks5Errors = []string{
"",
"general failure",
"connection forbidden",
"network unreachable",
"host unreachable",
"connection refused",
"TTL expired",
"command not supported",
"address type not supported",
}
// Dial connects to the address addr on the given network via the SOCKS5 proxy.
func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
}
conn, err := s.forward.Dial(s.network, s.addr)
if err != nil {
return nil, err
}
if err := s.connect(conn, addr); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
// connect takes an existing connection to a socks5 proxy server,
// and commands the server to extend that connection to target,
// which must be a canonical address with a host and port.
func (s *proxy_socks5) connect(conn net.Conn, target string) error {
host, portStr, err := net.SplitHostPort(target)
if err != nil {
return err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf = append(buf, proxy_socks5Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword)
} else {
buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone)
}
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[0] != 5 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
}
if buf[1] == 0xff {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
}
// See RFC 1929
if buf[1] == proxy_socks5AuthPassword {
buf = buf[:0]
buf = append(buf, 1 /* password protocol version */)
buf = append(buf, uint8(len(s.user)))
buf = append(buf, s.user...)
buf = append(buf, uint8(len(s.password)))
buf = append(buf, s.password...)
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[1] != 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
}
}
buf = buf[:0]
buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, proxy_socks5IP4)
ip = ip4
} else {
buf = append(buf, proxy_socks5IP6)
}
buf = append(buf, ip...)
} else {
if len(host) > 255 {
return errors.New("proxy: destination host name too long: " + host)
}
buf = append(buf, proxy_socks5Domain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
failure := "unknown error"
if int(buf[1]) < len(proxy_socks5Errors) {
failure = proxy_socks5Errors[buf[1]]
}
if len(failure) > 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
}
bytesToDiscard := 0
switch buf[3] {
case proxy_socks5IP4:
bytesToDiscard = net.IPv4len
case proxy_socks5IP6:
bytesToDiscard = net.IPv6len
case proxy_socks5Domain:
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
bytesToDiscard = int(buf[0])
default:
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
}
if cap(buf) < bytesToDiscard {
buf = make([]byte, bytesToDiscard)
} else {
buf = buf[:bytesToDiscard]
}
if _, err := io.ReadFull(conn, buf); err != nil {
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
// Also need to discard the port number
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
return nil
}