Merge branch 'main' into prerrcheck

Signed-off-by: Canelo Hill <172609632+canelohill@users.noreply.github.com>
This commit is contained in:
Canelo Hill 2024-07-01 10:33:09 -07:00 committed by GitHub
commit c9d30b6eb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 312 additions and 617 deletions

View File

@ -67,4 +67,4 @@ workflows:
- test: - test:
matrix: matrix:
parameters: parameters:
version: ["1.18", "1.17", "1.16"] version: ["1.22", "1.21", "1.20"]

View File

@ -10,10 +10,10 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
### Documentation ### Documentation
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) * [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) * [Command example](https://github.com/gorilla/websocket/tree/main/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) * [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) * [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch)
### Status ### Status
@ -29,5 +29,4 @@ package API is stable.
The Gorilla WebSocket package passes the server tests in the [Autobahn Test The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn).

View File

@ -11,7 +11,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
@ -53,7 +52,7 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
// It is safe to call Dialer's methods concurrently. // It is safe to call Dialer's methods concurrently.
type Dialer struct { type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If // NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used. // NetDial is nil, net.Dialer DialContext is used.
NetDial func(network, addr string) (net.Conn, error) NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If // NetDialContext specifies the dial function for creating TCP connections. If
@ -245,46 +244,25 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
defer cancel() defer cancel()
} }
// Get network dial function. var netDial netDialerFunc
var netDial func(network, add string) (net.Conn, error) switch {
case u.Scheme == "https" && d.NetDialTLSContext != nil:
switch u.Scheme { netDial = d.NetDialTLSContext
case "http": case d.NetDialContext != nil:
if d.NetDialContext != nil { netDial = d.NetDialContext
netDial = func(network, addr string) (net.Conn, error) { case d.NetDial != nil:
return d.NetDialContext(ctx, network, addr) netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
} return d.NetDial(net, addr)
} else if d.NetDial != nil {
netDial = d.NetDial
}
case "https":
if d.NetDialTLSContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialTLSContext(ctx, network, addr)
}
} else if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
} }
default: default:
return nil, nil, errMalformedURL netDial = (&net.Dialer{}).DialContext
}
if netDial == nil {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
}
} }
// If needed, wrap the dial function to set the connection deadline. // If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok { if deadline, ok := ctx.Deadline(); ok {
forwardDial := netDial forwardDial := netDial
netDial = func(network, addr string) (net.Conn, error) { netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := forwardDial(network, addr) c, err := forwardDial(ctx, network, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -304,11 +282,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err return nil, nil, err
} }
if proxyURL != nil { if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) netDial, err = proxyFromURL(proxyURL, netDial)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
netDial = dialer.Dial
} }
} }
@ -318,7 +295,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
trace.GetConn(hostPort) trace.GetConn(hostPort)
} }
netConn, err := netDial("tcp", hostPort) netConn, err := netDial(ctx, "tcp", hostPort)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -406,7 +383,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// debugging. // debugging.
buf := make([]byte, 1024) buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf) n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake return nil, resp, ErrBadHandshake
} }
@ -424,7 +401,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
break break
} }
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
if err := netConn.SetDeadline(time.Time{}); err != nil { if err := netConn.SetDeadline(time.Time{}); err != nil {

View File

@ -5,6 +5,7 @@
package websocket package websocket
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"crypto/tls" "crypto/tls"
@ -14,7 +15,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -24,6 +24,7 @@ import (
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
) )
@ -45,12 +46,15 @@ var cstDialer = Dialer{
HandshakeTimeout: 30 * time.Second, HandshakeTimeout: 30 * time.Second,
} }
type cstHandler struct{ *testing.T } type cstHandler struct {
*testing.T
s *cstServer
}
type cstServer struct { type cstServer struct {
*httptest.Server URL string
URL string Server *httptest.Server
t *testing.T wg sync.WaitGroup
} }
const ( const (
@ -59,9 +63,15 @@ const (
cstRequestURI = cstPath + "?" + cstRawQuery cstRequestURI = cstPath + "?" + cstRawQuery
) )
func (s *cstServer) Close() {
s.Server.Close()
// Wait for handler functions to complete.
s.wg.Wait()
}
func newServer(t *testing.T) *cstServer { func newServer(t *testing.T) *cstServer {
var s cstServer var s cstServer
s.Server = httptest.NewServer(cstHandler{t}) s.Server = httptest.NewServer(cstHandler{T: t, s: &s})
s.Server.URL += cstRequestURI s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL) s.URL = makeWsProto(s.Server.URL)
return &s return &s
@ -69,13 +79,19 @@ func newServer(t *testing.T) *cstServer {
func newTLSServer(t *testing.T) *cstServer { func newTLSServer(t *testing.T) *cstServer {
var s cstServer var s cstServer
s.Server = httptest.NewTLSServer(cstHandler{t}) s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s})
s.Server.URL += cstRequestURI s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL) s.URL = makeWsProto(s.Server.URL)
return &s return &s
} }
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Because tests wait for a response from a server, we are guaranteed that
// the wait group count is incremented before the test waits on the group
// in the call to (*cstServer).Close().
t.s.wg.Add(1)
defer t.s.wg.Done()
if r.URL.Path != cstPath { if r.URL.Path != cstPath {
t.Logf("path=%v, want %v", r.URL.Path, cstPath) t.Logf("path=%v, want %v", r.URL.Path, cstPath)
http.Error(w, "bad path", http.StatusBadRequest) http.Error(w, "bad path", http.StatusBadRequest)
@ -482,6 +498,37 @@ func TestBadMethod(t *testing.T) {
} }
} }
func TestNoUpgrade(t *testing.T) {
t.Parallel()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ws, err := cstUpgrader.Upgrade(w, r, nil)
if err == nil {
t.Errorf("handshake succeeded, expect fail")
ws.Close()
}
}))
defer s.Close()
req, err := http.NewRequest(http.MethodGet, s.URL, strings.NewReader(""))
if err != nil {
t.Fatalf("NewRequest returned error %v", err)
}
req.Header.Set("Connection", "upgrade")
req.Header.Set("Sec-Websocket-Version", "13")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Do returned error %v", err)
}
resp.Body.Close()
if u := resp.Header.Get("Upgrade"); u != "websocket" {
t.Errorf("Uprade response header is %q, want %q", u, "websocket")
}
if resp.StatusCode != http.StatusUpgradeRequired {
t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired)
}
}
func TestDialExtraTokensInRespHeaders(t *testing.T) { func TestDialExtraTokensInRespHeaders(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
challengeKey := r.Header.Get("Sec-Websocket-Key") challengeKey := r.Header.Get("Sec-Websocket-Key")
@ -549,7 +596,7 @@ func TestRespOnBadHandshake(t *testing.T) {
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
} }
p, err := ioutil.ReadAll(resp.Body) p, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatalf("ReadFull(resp.Body) returned error %v", err) t.Fatalf("ReadFull(resp.Body) returned error %v", err)
} }
@ -1133,3 +1180,66 @@ func TestNextProtos(t *testing.T) {
t.Fatalf("Dial succeeded, expect fail ") t.Fatalf("Dial succeeded, expect fail ")
} }
} }
type dataBeforeHandshakeResponseWriter struct {
http.ResponseWriter
}
type dataBeforeHandshakeConnection struct {
net.Conn
io.Reader
}
func (c *dataBeforeHandshakeConnection) Read(p []byte) (int, error) {
return c.Reader.Read(p)
}
func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
// Example single-frame masked text message from section 5.7 of the RFC.
message := []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58}
n := len(message) / 2
c, rw, err := http.NewResponseController(w.ResponseWriter).Hijack()
if rw != nil {
// Load first part of message into bufio.Reader. If the websocket
// connection reads more than n bytes from the bufio.Reader, then the
// test will fail with an unexpected EOF error.
rw.Reader.Reset(bytes.NewReader(message[:n]))
rw.Reader.Peek(n)
}
if c != nil {
// Inject second part of message before data read from the network connection.
c = &dataBeforeHandshakeConnection{
Conn: c,
Reader: io.MultiReader(bytes.NewReader(message[n:]), c),
}
}
return c, rw, err
}
func TestDataReceivedBeforeHandshake(t *testing.T) {
s := newServer(t)
defer s.Close()
origHandler := s.Server.Config.Handler
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
origHandler.ServeHTTP(dataBeforeHandshakeResponseWriter{w}, r)
})
for _, readBufferSize := range []int{0, 1024} {
t.Run(fmt.Sprintf("ReadBufferSize=%d", readBufferSize), func(t *testing.T) {
dialer := cstDialer
dialer.ReadBufferSize = readBufferSize
ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
_, m, err := ws.ReadMessage()
if err != nil || string(m) != "Hello" {
t.Fatalf("ReadMessage() = %q, %v, want \"Hello\", nil", m, err)
}
})
}
}

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"testing" "testing"
) )
@ -42,7 +41,7 @@ func textMessages(num int) [][]byte {
} }
func BenchmarkWriteNoCompression(b *testing.B) { func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
b.ResetTimer() b.ResetTimer()
@ -53,7 +52,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
} }
func BenchmarkWriteWithCompression(b *testing.B) { func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
c.enableWriteCompression = true c.enableWriteCompression = true

18
conn.go
View File

@ -6,11 +6,10 @@ package websocket
import ( import (
"bufio" "bufio"
"crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"io/ioutil"
"math/rand"
"net" "net"
"strconv" "strconv"
"strings" "strings"
@ -181,9 +180,16 @@ var (
errInvalidControlFrame = errors.New("websocket: invalid control frame") errInvalidControlFrame = errors.New("websocket: invalid control frame")
) )
// maskRand is an io.Reader for generating mask bytes. The reader is initialized
// to crypto/rand Reader. Tests swap the reader to a math/rand reader for
// reproducible results.
var maskRand = rand.Reader
// newMaskKey returns a new 32 bit value for masking client frames.
func newMaskKey() [4]byte { func newMaskKey() [4]byte {
n := rand.Uint32() var k [4]byte
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} _, _ = io.ReadFull(maskRand, k[:])
return k
} }
func hideTempErr(err error) error { func hideTempErr(err error) error {
@ -800,7 +806,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
if c.readRemaining > 0 { if c.readRemaining > 0 {
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err return noFrame, err
} }
} }
@ -1101,7 +1107,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
if err != nil { if err != nil {
return messageType, nil, err return messageType, nil, err
} }
p, err = ioutil.ReadAll(r) p, err = io.ReadAll(r)
return messageType, p, err return messageType, p, err
} }

View File

@ -6,7 +6,6 @@ package websocket
import ( import (
"io" "io"
"io/ioutil"
"sync/atomic" "sync/atomic"
"testing" "testing"
) )
@ -45,7 +44,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
func newBroadcastBench(usePrepared, compression bool) *broadcastBench { func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
bench := &broadcastBench{ bench := &broadcastBench{
w: ioutil.Discard, w: io.Discard,
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
usePrepared: usePrepared, usePrepared: usePrepared,

View File

@ -10,7 +10,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"reflect" "reflect"
"sync" "sync"
@ -125,7 +124,7 @@ func TestFraming(t *testing.T) {
} }
t.Logf("frame size: %d", n) t.Logf("frame size: %d", n)
rbuf, err := ioutil.ReadAll(r) rbuf, err := io.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue continue
@ -367,7 +366,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if !reflect.DeepEqual(err, expectedErr) { if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
} }
@ -401,7 +400,7 @@ func TestEOFWithinFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
} }
@ -426,7 +425,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
} }
@ -490,7 +489,7 @@ func TestReadLimit(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err) t.Fatalf("2: NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != ErrReadLimit { if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err) t.Fatalf("io.Copy() returned %v", err)
} }

View File

@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
} }
// echoReadAll echoes messages from the client by reading the entire message // echoReadAll echoes messages from the client by reading the entire message
// with ioutil.ReadAll. // with io.ReadAll.
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {

View File

@ -38,7 +38,7 @@ sends them to the hub.
### Hub ### Hub
The code for the `Hub` type is in The code for the `Hub` type is in
[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go). [hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go).
The application's `main` function starts the hub's `run` method as a goroutine. The application's `main` function starts the hub's `run` method as a goroutine.
Clients send requests to the hub using the `register`, `unregister` and Clients send requests to the hub using the `register`, `unregister` and
`broadcast` channels. `broadcast` channels.
@ -57,7 +57,7 @@ unregisters the client and closes the websocket.
### Client ### Client
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go). The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go).
The `serveWs` function is registered by the application's `main` function as The `serveWs` function is registered by the application's `main` function as
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
@ -85,7 +85,7 @@ network.
## Frontend ## Frontend
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html). The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html).
On document load, the script checks for websocket functionality in the browser. On document load, the script checks for websocket functionality in the browser.
If websocket functionality is available, then the script opens a connection to If websocket functionality is available, then the script opens a connection to

View File

@ -7,7 +7,6 @@ package main
import ( import (
"flag" "flag"
"html/template" "html/template"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -49,7 +48,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
if !fi.ModTime().After(lastMod) { if !fi.ModTime().After(lastMod) {
return nil, lastMod, nil return nil, lastMod, nil
} }
p, err := ioutil.ReadFile(filename) p, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, fi.ModTime(), err return nil, fi.ModTime(), err
} }

8
go.mod
View File

@ -1,3 +1,9 @@
module github.com/gorilla/websocket module github.com/gorilla/websocket
go 1.12 go 1.20
retract (
v1.5.2 // tag accidentally overwritten
)
require golang.org/x/net v0.26.0

2
go.sum
View File

@ -0,0 +1,2 @@
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=

View File

@ -33,6 +33,11 @@ var preparedMessageTests = []struct {
} }
func TestPreparedMessage(t *testing.T) { func TestPreparedMessage(t *testing.T) {
testRand := rand.New(rand.NewSource(99))
prevMaskRand := maskRand
maskRand = testRand
defer func() { maskRand = prevMaskRand }()
for _, tt := range preparedMessageTests { for _, tt := range preparedMessageTests {
var data = []byte("this is a test") var data = []byte("this is a test")
var buf bytes.Buffer var buf bytes.Buffer
@ -45,7 +50,7 @@ func TestPreparedMessage(t *testing.T) {
} }
// Seed random number generator for consistent frame mask. // Seed random number generator for consistent frame mask.
rand.Seed(1234) testRand.Seed(1234)
if err := c.WriteMessage(tt.messageType, data); err != nil { if err := c.WriteMessage(tt.messageType, data); err != nil {
t.Fatal(err) t.Fatal(err)
@ -61,7 +66,7 @@ func TestPreparedMessage(t *testing.T) {
copy(data, "hello world") copy(data, "hello world")
// Seed random number generator for consistent frame mask. // Seed random number generator for consistent frame mask.
rand.Seed(1234) testRand.Seed(1234)
buf.Reset() buf.Reset()
if err := c.WritePreparedMessage(pm); err != nil { if err := c.WritePreparedMessage(pm); err != nil {

View File

@ -6,34 +6,52 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes"
"context"
"encoding/base64" "encoding/base64"
"errors" "errors"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"golang.org/x/net/proxy"
) )
type netDialerFunc func(network, addr string) (net.Conn, error) type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error)
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr) return fn(context.Background(), network, addr)
} }
func init() { func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { return fn(ctx, network, addr)
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil }
})
func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) {
if proxyURL.Scheme == "http" {
return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil
}
dialer, err := proxy.FromURL(proxyURL, forwardDial)
if err != nil {
return nil, err
}
if d, ok := dialer.(proxy.ContextDialer); ok {
return d.DialContext, nil
}
return func(ctx context.Context, net, addr string) (net.Conn, error) {
return dialer.Dial(net, addr)
}, nil
} }
type httpProxyDialer struct { type httpProxyDialer struct {
proxyURL *url.URL proxyURL *url.URL
forwardDial func(network, addr string) (net.Conn, error) forwardDial netDialerFunc
} }
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL) hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.forwardDial(network, hostPort) conn, err := hpd.forwardDial(ctx, network, hostPort)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -68,8 +86,18 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
return nil, err return nil, err
} }
if resp.StatusCode != 200 { // Close the response body to silence false positives from linters. Reset
conn.Close() // the buffered reader first to ensure that Close() does not read from
// conn.
// Note: Applications must call resp.Body.Close() on a response returned
// http.ReadResponse to inspect trailers or read another response from the
// buffered reader. The call to resp.Body.Close() does not release
// resources.
br.Reset(bytes.NewReader(nil))
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
_ = conn.Close()
f := strings.SplitN(resp.Status, " ", 2) f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1]) return nil, errors.New(f[1])
} }

View File

@ -6,8 +6,7 @@ package websocket
import ( import (
"bufio" "bufio"
"errors" "net"
"io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -101,8 +100,8 @@ func checkSameOrigin(r *http.Request) bool {
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil { if u.Subprotocols != nil {
clientProtocols := Subprotocols(r) clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols { for _, clientProtocol := range clientProtocols {
for _, clientProtocol := range clientProtocols { for _, serverProtocol := range u.Subprotocols {
if clientProtocol == serverProtocol { if clientProtocol == serverProtocol {
return clientProtocol return clientProtocol
} }
@ -130,7 +129,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") w.Header().Set("Upgrade", "websocket")
return u.returnError(w, r, http.StatusUpgradeRequired, badHandshake+"'websocket' token not found in 'Upgrade' header")
} }
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
@ -172,18 +172,10 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
} }
h, ok := w.(http.Hijacker) netConn, brw, err := http.NewResponseController(w).Hijack()
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error()) return u.returnError(w, r, http.StatusInternalServerError,
} "websocket: hijack: "+err.Error())
if brw.Reader.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
} }
// Close the network connection when returning an error. The variable // Close the network connection when returning an error. The variable
@ -199,12 +191,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}() }()
var br *bufio.Reader var br *bufio.Reader
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 { if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 {
// Reuse hijacked buffered reader as connection reader. // Use hijacked buffered reader as the connection reader.
br = brw.Reader br = brw.Reader
} else if brw.Reader.Buffered() > 0 {
// Wrap the network connection to read buffered data in brw.Reader
// before reading from the network connection. This should be rare
// because a client must not send message data before receiving the
// handshake response.
netConn = &brNetConn{br: brw.Reader, Conn: netConn}
} }
buf := bufioWriterBuffer(netConn, brw.Writer) buf := brw.Writer.AvailableBuffer()
var writeBuf []byte var writeBuf []byte
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
@ -348,39 +346,28 @@ func IsWebSocketUpgrade(r *http.Request) bool {
tokenListContainsValue(r.Header, "Upgrade", "websocket") tokenListContainsValue(r.Header, "Upgrade", "websocket")
} }
// bufioReaderSize size returns the size of a bufio.Reader. type brNetConn struct {
func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int { br *bufio.Reader
// This code assumes that peek on a reset reader returns net.Conn
// bufio.Reader.buf[:0]. }
// TODO: Use bufio.Reader.Size() after Go 1.10
br.Reset(originalReader) func (b *brNetConn) Read(p []byte) (n int, err error) {
if p, err := br.Peek(0); err == nil { if b.br != nil {
return cap(p) // Limit read to buferred data.
if n := b.br.Buffered(); len(p) > n {
p = p[:n]
}
n, err = b.br.Read(p)
if b.br.Buffered() == 0 {
b.br = nil
}
return n, err
} }
return 0 return b.Conn.Read(p)
} }
// writeHook is an io.Writer that records the last slice passed to it vio // NetConn returns the underlying connection that is wrapped by b.
// io.Writer.Write. func (b *brNetConn) NetConn() net.Conn {
type writeHook struct { return b.Conn
p []byte
} }
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// This code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
bw.Reset(&wh)
bw.WriteByte(0)
bw.Flush()
bw.Reset(originalWriter)
return wh.p[:cap(wh.p)]
}

View File

@ -7,8 +7,10 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -54,6 +56,36 @@ func TestIsWebSocketUpgrade(t *testing.T) {
} }
} }
func TestSubProtocolSelection(t *testing.T) {
upgrader := Upgrader{
Subprotocols: []string{"foo", "bar", "baz"},
}
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
s := upgrader.selectSubprotocol(&r, nil)
if s != "foo" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "bar" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "baz" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
}
}
var checkSameOriginTests = []struct { var checkSameOriginTests = []struct {
ok bool ok bool
r *http.Request r *http.Request
@ -89,7 +121,7 @@ var bufioReuseTests = []struct {
{128, false}, {128, false},
} }
func TestBufioReuse(t *testing.T) { func xTestBufioReuse(t *testing.T) {
for i, tt := range bufioReuseTests { for i, tt := range bufioReuseTests {
br := bufio.NewReaderSize(strings.NewReader(""), tt.n) br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n) bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
@ -111,9 +143,29 @@ func TestBufioReuse(t *testing.T) {
if reuse := c.br == br; reuse != tt.reuse { if reuse := c.br == br; reuse != tt.reuse {
t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
} }
writeBuf := bufioWriterBuffer(c.NetConn(), bw) writeBuf := bw.AvailableBuffer()
if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
} }
} }
} }
func TestHijack_NotSupported(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "upgrade")
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Sec-Websocket-Version", "13")
recorder := httptest.NewRecorder()
upgrader := Upgrader{}
_, err := upgrader.Upgrade(recorder, req, nil)
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
}
}

View File

@ -1,473 +0,0 @@
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy
// Package proxy provides support for a variety of protocols to proxy network
// data.
//
package websocket
import (
"errors"
"io"
"net"
"net/url"
"os"
"strconv"
"strings"
"sync"
)
type proxy_direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var proxy_Direct = proxy_direct{}
func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
// A PerHost directs connections to a default Dialer unless the host name
// requested matches one of a number of exceptions.
type proxy_PerHost struct {
def, bypass proxy_Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost {
return &proxy_PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the given network through either
// defaultDialer or bypass.
func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone ".example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifying hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a host name
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *proxy_PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *proxy_PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match.
func (p *proxy_PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *proxy_PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a host name that will use the bypass proxy.
func (p *proxy_PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}
// A Dialer is a means to establish a connection.
type proxy_Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type proxy_Auth struct {
User, Password string
}
// FromEnvironment returns the dialer specified by the proxy related variables in
// the environment.
func proxy_FromEnvironment() proxy_Dialer {
allProxy := proxy_allProxyEnv.Get()
if len(allProxy) == 0 {
return proxy_Direct
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return proxy_Direct
}
proxy, err := proxy_FromURL(proxyURL, proxy_Direct)
if err != nil {
return proxy_Direct
}
noProxy := proxy_noProxyEnv.Get()
if len(noProxy) == 0 {
return proxy
}
perHost := proxy_NewPerHost(proxy, proxy_Direct)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) {
if proxy_proxySchemes == nil {
proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error))
}
proxy_proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) {
var auth *proxy_Auth
if u.User != nil {
auth = new(proxy_Auth)
auth.User = u.User.Username()
if p, ok := u.User.Password(); ok {
auth.Password = p
}
}
switch u.Scheme {
case "socks5":
return proxy_SOCKS5("tcp", u.Host, auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxy_proxySchemes != nil {
if f, ok := proxy_proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
var (
proxy_allProxyEnv = &proxy_envOnce{
names: []string{"ALL_PROXY", "all_proxy"},
}
proxy_noProxyEnv = &proxy_envOnce{
names: []string{"NO_PROXY", "no_proxy"},
}
)
// envOnce looks up an environment variable (optionally by multiple
// names) once. It mitigates expensive lookups on some platforms
// (e.g. Windows).
// (Borrowed from net/http/transport.go)
type proxy_envOnce struct {
names []string
once sync.Once
val string
}
func (e *proxy_envOnce) Get() string {
e.once.Do(e.init)
return e.val
}
func (e *proxy_envOnce) init() {
for _, n := range e.names {
e.val = os.Getenv(n)
if e.val != "" {
return
}
}
}
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
// with an optional username and password. See RFC 1928 and RFC 1929.
func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) {
s := &proxy_socks5{
network: network,
addr: addr,
forward: forward,
}
if auth != nil {
s.user = auth.User
s.password = auth.Password
}
return s, nil
}
type proxy_socks5 struct {
user, password string
network, addr string
forward proxy_Dialer
}
const proxy_socks5Version = 5
const (
proxy_socks5AuthNone = 0
proxy_socks5AuthPassword = 2
)
const proxy_socks5Connect = 1
const (
proxy_socks5IP4 = 1
proxy_socks5Domain = 3
proxy_socks5IP6 = 4
)
var proxy_socks5Errors = []string{
"",
"general failure",
"connection forbidden",
"network unreachable",
"host unreachable",
"connection refused",
"TTL expired",
"command not supported",
"address type not supported",
}
// Dial connects to the address addr on the given network via the SOCKS5 proxy.
func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
}
conn, err := s.forward.Dial(s.network, s.addr)
if err != nil {
return nil, err
}
if err := s.connect(conn, addr); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
// connect takes an existing connection to a socks5 proxy server,
// and commands the server to extend that connection to target,
// which must be a canonical address with a host and port.
func (s *proxy_socks5) connect(conn net.Conn, target string) error {
host, portStr, err := net.SplitHostPort(target)
if err != nil {
return err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf = append(buf, proxy_socks5Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword)
} else {
buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone)
}
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[0] != 5 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
}
if buf[1] == 0xff {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
}
// See RFC 1929
if buf[1] == proxy_socks5AuthPassword {
buf = buf[:0]
buf = append(buf, 1 /* password protocol version */)
buf = append(buf, uint8(len(s.user)))
buf = append(buf, s.user...)
buf = append(buf, uint8(len(s.password)))
buf = append(buf, s.password...)
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[1] != 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
}
}
buf = buf[:0]
buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, proxy_socks5IP4)
ip = ip4
} else {
buf = append(buf, proxy_socks5IP6)
}
buf = append(buf, ip...)
} else {
if len(host) > 255 {
return errors.New("proxy: destination host name too long: " + host)
}
buf = append(buf, proxy_socks5Domain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
failure := "unknown error"
if int(buf[1]) < len(proxy_socks5Errors) {
failure = proxy_socks5Errors[buf[1]]
}
if len(failure) > 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
}
bytesToDiscard := 0
switch buf[3] {
case proxy_socks5IP4:
bytesToDiscard = net.IPv4len
case proxy_socks5IP6:
bytesToDiscard = net.IPv6len
case proxy_socks5Domain:
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
bytesToDiscard = int(buf[0])
default:
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
}
if cap(buf) < bytesToDiscard {
buf = make([]byte, bytesToDiscard)
} else {
buf = buf[:bytesToDiscard]
}
if _, err := io.ReadFull(conn, buf); err != nil {
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
// Also need to discard the port number
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
return nil
}