mirror of https://github.com/gorilla/websocket.git
Merge branch 'main' into prerrcheck
Signed-off-by: Canelo Hill <172609632+canelohill@users.noreply.github.com>
This commit is contained in:
commit
c9d30b6eb0
|
@ -67,4 +67,4 @@ workflows:
|
||||||
- test:
|
- test:
|
||||||
matrix:
|
matrix:
|
||||||
parameters:
|
parameters:
|
||||||
version: ["1.18", "1.17", "1.16"]
|
version: ["1.22", "1.21", "1.20"]
|
||||||
|
|
11
README.md
11
README.md
|
@ -10,10 +10,10 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
|
||||||
### Documentation
|
### Documentation
|
||||||
|
|
||||||
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
|
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
|
||||||
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
|
* [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat)
|
||||||
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
|
* [Command example](https://github.com/gorilla/websocket/tree/main/examples/command)
|
||||||
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
|
* [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo)
|
||||||
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
|
* [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch)
|
||||||
|
|
||||||
### Status
|
### Status
|
||||||
|
|
||||||
|
@ -29,5 +29,4 @@ package API is stable.
|
||||||
|
|
||||||
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
||||||
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
||||||
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
|
subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn).
|
||||||
|
|
||||||
|
|
57
client.go
57
client.go
|
@ -11,7 +11,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
|
@ -53,7 +52,7 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
|
||||||
// It is safe to call Dialer's methods concurrently.
|
// It is safe to call Dialer's methods concurrently.
|
||||||
type Dialer struct {
|
type Dialer struct {
|
||||||
// NetDial specifies the dial function for creating TCP connections. If
|
// NetDial specifies the dial function for creating TCP connections. If
|
||||||
// NetDial is nil, net.Dial is used.
|
// NetDial is nil, net.Dialer DialContext is used.
|
||||||
NetDial func(network, addr string) (net.Conn, error)
|
NetDial func(network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
// NetDialContext specifies the dial function for creating TCP connections. If
|
// NetDialContext specifies the dial function for creating TCP connections. If
|
||||||
|
@ -245,46 +244,25 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get network dial function.
|
var netDial netDialerFunc
|
||||||
var netDial func(network, add string) (net.Conn, error)
|
switch {
|
||||||
|
case u.Scheme == "https" && d.NetDialTLSContext != nil:
|
||||||
switch u.Scheme {
|
netDial = d.NetDialTLSContext
|
||||||
case "http":
|
case d.NetDialContext != nil:
|
||||||
if d.NetDialContext != nil {
|
netDial = d.NetDialContext
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
case d.NetDial != nil:
|
||||||
return d.NetDialContext(ctx, network, addr)
|
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
|
||||||
}
|
return d.NetDial(net, addr)
|
||||||
} else if d.NetDial != nil {
|
|
||||||
netDial = d.NetDial
|
|
||||||
}
|
|
||||||
case "https":
|
|
||||||
if d.NetDialTLSContext != nil {
|
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
|
||||||
return d.NetDialTLSContext(ctx, network, addr)
|
|
||||||
}
|
|
||||||
} else if d.NetDialContext != nil {
|
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
|
||||||
return d.NetDialContext(ctx, network, addr)
|
|
||||||
}
|
|
||||||
} else if d.NetDial != nil {
|
|
||||||
netDial = d.NetDial
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, nil, errMalformedURL
|
netDial = (&net.Dialer{}).DialContext
|
||||||
}
|
|
||||||
|
|
||||||
if netDial == nil {
|
|
||||||
netDialer := &net.Dialer{}
|
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
|
||||||
return netDialer.DialContext(ctx, network, addr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If needed, wrap the dial function to set the connection deadline.
|
// If needed, wrap the dial function to set the connection deadline.
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
forwardDial := netDial
|
forwardDial := netDial
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
c, err := forwardDial(network, addr)
|
c, err := forwardDial(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -304,11 +282,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if proxyURL != nil {
|
if proxyURL != nil {
|
||||||
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
|
netDial, err = proxyFromURL(proxyURL, netDial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
netDial = dialer.Dial
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -318,7 +295,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
trace.GetConn(hostPort)
|
trace.GetConn(hostPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
netConn, err := netDial("tcp", hostPort)
|
netConn, err := netDial(ctx, "tcp", hostPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -406,7 +383,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
// debugging.
|
// debugging.
|
||||||
buf := make([]byte, 1024)
|
buf := make([]byte, 1024)
|
||||||
n, _ := io.ReadFull(resp.Body, buf)
|
n, _ := io.ReadFull(resp.Body, buf)
|
||||||
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
|
||||||
return nil, resp, ErrBadHandshake
|
return nil, resp, ErrBadHandshake
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -424,7 +401,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||||
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||||
|
|
||||||
if err := netConn.SetDeadline(time.Time{}); err != nil {
|
if err := netConn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
@ -14,7 +15,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -24,6 +24,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -45,12 +46,15 @@ var cstDialer = Dialer{
|
||||||
HandshakeTimeout: 30 * time.Second,
|
HandshakeTimeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
type cstHandler struct{ *testing.T }
|
type cstHandler struct {
|
||||||
|
*testing.T
|
||||||
|
s *cstServer
|
||||||
|
}
|
||||||
|
|
||||||
type cstServer struct {
|
type cstServer struct {
|
||||||
*httptest.Server
|
URL string
|
||||||
URL string
|
Server *httptest.Server
|
||||||
t *testing.T
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -59,9 +63,15 @@ const (
|
||||||
cstRequestURI = cstPath + "?" + cstRawQuery
|
cstRequestURI = cstPath + "?" + cstRawQuery
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (s *cstServer) Close() {
|
||||||
|
s.Server.Close()
|
||||||
|
// Wait for handler functions to complete.
|
||||||
|
s.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func newServer(t *testing.T) *cstServer {
|
func newServer(t *testing.T) *cstServer {
|
||||||
var s cstServer
|
var s cstServer
|
||||||
s.Server = httptest.NewServer(cstHandler{t})
|
s.Server = httptest.NewServer(cstHandler{T: t, s: &s})
|
||||||
s.Server.URL += cstRequestURI
|
s.Server.URL += cstRequestURI
|
||||||
s.URL = makeWsProto(s.Server.URL)
|
s.URL = makeWsProto(s.Server.URL)
|
||||||
return &s
|
return &s
|
||||||
|
@ -69,13 +79,19 @@ func newServer(t *testing.T) *cstServer {
|
||||||
|
|
||||||
func newTLSServer(t *testing.T) *cstServer {
|
func newTLSServer(t *testing.T) *cstServer {
|
||||||
var s cstServer
|
var s cstServer
|
||||||
s.Server = httptest.NewTLSServer(cstHandler{t})
|
s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s})
|
||||||
s.Server.URL += cstRequestURI
|
s.Server.URL += cstRequestURI
|
||||||
s.URL = makeWsProto(s.Server.URL)
|
s.URL = makeWsProto(s.Server.URL)
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Because tests wait for a response from a server, we are guaranteed that
|
||||||
|
// the wait group count is incremented before the test waits on the group
|
||||||
|
// in the call to (*cstServer).Close().
|
||||||
|
t.s.wg.Add(1)
|
||||||
|
defer t.s.wg.Done()
|
||||||
|
|
||||||
if r.URL.Path != cstPath {
|
if r.URL.Path != cstPath {
|
||||||
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
|
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
|
||||||
http.Error(w, "bad path", http.StatusBadRequest)
|
http.Error(w, "bad path", http.StatusBadRequest)
|
||||||
|
@ -482,6 +498,37 @@ func TestBadMethod(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNoUpgrade(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ws, err := cstUpgrader.Upgrade(w, r, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("handshake succeeded, expect fail")
|
||||||
|
ws.Close()
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, s.URL, strings.NewReader(""))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewRequest returned error %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Connection", "upgrade")
|
||||||
|
req.Header.Set("Sec-Websocket-Version", "13")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Do returned error %v", err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
if u := resp.Header.Get("Upgrade"); u != "websocket" {
|
||||||
|
t.Errorf("Uprade response header is %q, want %q", u, "websocket")
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusUpgradeRequired {
|
||||||
|
t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDialExtraTokensInRespHeaders(t *testing.T) {
|
func TestDialExtraTokensInRespHeaders(t *testing.T) {
|
||||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||||||
|
@ -549,7 +596,7 @@ func TestRespOnBadHandshake(t *testing.T) {
|
||||||
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := ioutil.ReadAll(resp.Body)
|
p, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1133,3 +1180,66 @@ func TestNextProtos(t *testing.T) {
|
||||||
t.Fatalf("Dial succeeded, expect fail ")
|
t.Fatalf("Dial succeeded, expect fail ")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type dataBeforeHandshakeResponseWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type dataBeforeHandshakeConnection struct {
|
||||||
|
net.Conn
|
||||||
|
io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dataBeforeHandshakeConnection) Read(p []byte) (int, error) {
|
||||||
|
return c.Reader.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w dataBeforeHandshakeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
// Example single-frame masked text message from section 5.7 of the RFC.
|
||||||
|
message := []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58}
|
||||||
|
n := len(message) / 2
|
||||||
|
|
||||||
|
c, rw, err := http.NewResponseController(w.ResponseWriter).Hijack()
|
||||||
|
if rw != nil {
|
||||||
|
// Load first part of message into bufio.Reader. If the websocket
|
||||||
|
// connection reads more than n bytes from the bufio.Reader, then the
|
||||||
|
// test will fail with an unexpected EOF error.
|
||||||
|
rw.Reader.Reset(bytes.NewReader(message[:n]))
|
||||||
|
rw.Reader.Peek(n)
|
||||||
|
}
|
||||||
|
if c != nil {
|
||||||
|
// Inject second part of message before data read from the network connection.
|
||||||
|
c = &dataBeforeHandshakeConnection{
|
||||||
|
Conn: c,
|
||||||
|
Reader: io.MultiReader(bytes.NewReader(message[n:]), c),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c, rw, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDataReceivedBeforeHandshake(t *testing.T) {
|
||||||
|
s := newServer(t)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
origHandler := s.Server.Config.Handler
|
||||||
|
s.Server.Config.Handler = http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
origHandler.ServeHTTP(dataBeforeHandshakeResponseWriter{w}, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, readBufferSize := range []int{0, 1024} {
|
||||||
|
t.Run(fmt.Sprintf("ReadBufferSize=%d", readBufferSize), func(t *testing.T) {
|
||||||
|
dialer := cstDialer
|
||||||
|
dialer.ReadBufferSize = readBufferSize
|
||||||
|
ws, _, err := cstDialer.Dial(s.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial: %v", err)
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
_, m, err := ws.ReadMessage()
|
||||||
|
if err != nil || string(m) != "Hello" {
|
||||||
|
t.Fatalf("ReadMessage() = %q, %v, want \"Hello\", nil", m, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,7 +41,7 @@ func textMessages(num int) [][]byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWriteNoCompression(b *testing.B) {
|
func BenchmarkWriteNoCompression(b *testing.B) {
|
||||||
w := ioutil.Discard
|
w := io.Discard
|
||||||
c := newTestConn(nil, w, false)
|
c := newTestConn(nil, w, false)
|
||||||
messages := textMessages(100)
|
messages := textMessages(100)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -53,7 +52,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWriteWithCompression(b *testing.B) {
|
func BenchmarkWriteWithCompression(b *testing.B) {
|
||||||
w := ioutil.Discard
|
w := io.Discard
|
||||||
c := newTestConn(nil, w, false)
|
c := newTestConn(nil, w, false)
|
||||||
messages := textMessages(100)
|
messages := textMessages(100)
|
||||||
c.enableWriteCompression = true
|
c.enableWriteCompression = true
|
||||||
|
|
18
conn.go
18
conn.go
|
@ -6,11 +6,10 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -181,9 +180,16 @@ var (
|
||||||
errInvalidControlFrame = errors.New("websocket: invalid control frame")
|
errInvalidControlFrame = errors.New("websocket: invalid control frame")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// maskRand is an io.Reader for generating mask bytes. The reader is initialized
|
||||||
|
// to crypto/rand Reader. Tests swap the reader to a math/rand reader for
|
||||||
|
// reproducible results.
|
||||||
|
var maskRand = rand.Reader
|
||||||
|
|
||||||
|
// newMaskKey returns a new 32 bit value for masking client frames.
|
||||||
func newMaskKey() [4]byte {
|
func newMaskKey() [4]byte {
|
||||||
n := rand.Uint32()
|
var k [4]byte
|
||||||
return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}
|
_, _ = io.ReadFull(maskRand, k[:])
|
||||||
|
return k
|
||||||
}
|
}
|
||||||
|
|
||||||
func hideTempErr(err error) error {
|
func hideTempErr(err error) error {
|
||||||
|
@ -800,7 +806,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
// 1. Skip remainder of previous frame.
|
// 1. Skip remainder of previous frame.
|
||||||
|
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
|
if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1101,7 +1107,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return messageType, nil, err
|
return messageType, nil, err
|
||||||
}
|
}
|
||||||
p, err = ioutil.ReadAll(r)
|
p, err = io.ReadAll(r)
|
||||||
return messageType, p, err
|
return messageType, p, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
@ -45,7 +44,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
|
||||||
|
|
||||||
func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
|
func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
|
||||||
bench := &broadcastBench{
|
bench := &broadcastBench{
|
||||||
w: ioutil.Discard,
|
w: io.Discard,
|
||||||
doneCh: make(chan struct{}),
|
doneCh: make(chan struct{}),
|
||||||
closeCh: make(chan struct{}),
|
closeCh: make(chan struct{}),
|
||||||
usePrepared: usePrepared,
|
usePrepared: usePrepared,
|
||||||
|
|
11
conn_test.go
11
conn_test.go
|
@ -10,7 +10,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -125,7 +124,7 @@ func TestFraming(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("frame size: %d", n)
|
t.Logf("frame size: %d", n)
|
||||||
rbuf, err := ioutil.ReadAll(r)
|
rbuf, err := io.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
||||||
continue
|
continue
|
||||||
|
@ -367,7 +366,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(ioutil.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if !reflect.DeepEqual(err, expectedErr) {
|
if !reflect.DeepEqual(err, expectedErr) {
|
||||||
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
|
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
|
||||||
}
|
}
|
||||||
|
@ -401,7 +400,7 @@ func TestEOFWithinFrame(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(ioutil.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if err != errUnexpectedEOF {
|
if err != errUnexpectedEOF {
|
||||||
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||||
}
|
}
|
||||||
|
@ -426,7 +425,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(ioutil.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if err != errUnexpectedEOF {
|
if err != errUnexpectedEOF {
|
||||||
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
||||||
}
|
}
|
||||||
|
@ -490,7 +489,7 @@ func TestReadLimit(t *testing.T) {
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(ioutil.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if err != ErrReadLimit {
|
if err != ErrReadLimit {
|
||||||
t.Fatalf("io.Copy() returned %v", err)
|
t.Fatalf("io.Copy() returned %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// echoReadAll echoes messages from the client by reading the entire message
|
// echoReadAll echoes messages from the client by reading the entire message
|
||||||
// with ioutil.ReadAll.
|
// with io.ReadAll.
|
||||||
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
|
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -38,7 +38,7 @@ sends them to the hub.
|
||||||
### Hub
|
### Hub
|
||||||
|
|
||||||
The code for the `Hub` type is in
|
The code for the `Hub` type is in
|
||||||
[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go).
|
[hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go).
|
||||||
The application's `main` function starts the hub's `run` method as a goroutine.
|
The application's `main` function starts the hub's `run` method as a goroutine.
|
||||||
Clients send requests to the hub using the `register`, `unregister` and
|
Clients send requests to the hub using the `register`, `unregister` and
|
||||||
`broadcast` channels.
|
`broadcast` channels.
|
||||||
|
@ -57,7 +57,7 @@ unregisters the client and closes the websocket.
|
||||||
|
|
||||||
### Client
|
### Client
|
||||||
|
|
||||||
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go).
|
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go).
|
||||||
|
|
||||||
The `serveWs` function is registered by the application's `main` function as
|
The `serveWs` function is registered by the application's `main` function as
|
||||||
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
|
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
|
||||||
|
@ -85,7 +85,7 @@ network.
|
||||||
|
|
||||||
## Frontend
|
## Frontend
|
||||||
|
|
||||||
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html).
|
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html).
|
||||||
|
|
||||||
On document load, the script checks for websocket functionality in the browser.
|
On document load, the script checks for websocket functionality in the browser.
|
||||||
If websocket functionality is available, then the script opens a connection to
|
If websocket functionality is available, then the script opens a connection to
|
||||||
|
|
|
@ -7,7 +7,6 @@ package main
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
@ -49,7 +48,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
|
||||||
if !fi.ModTime().After(lastMod) {
|
if !fi.ModTime().After(lastMod) {
|
||||||
return nil, lastMod, nil
|
return nil, lastMod, nil
|
||||||
}
|
}
|
||||||
p, err := ioutil.ReadFile(filename)
|
p, err := os.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fi.ModTime(), err
|
return nil, fi.ModTime(), err
|
||||||
}
|
}
|
||||||
|
|
8
go.mod
8
go.mod
|
@ -1,3 +1,9 @@
|
||||||
module github.com/gorilla/websocket
|
module github.com/gorilla/websocket
|
||||||
|
|
||||||
go 1.12
|
go 1.20
|
||||||
|
|
||||||
|
retract (
|
||||||
|
v1.5.2 // tag accidentally overwritten
|
||||||
|
)
|
||||||
|
|
||||||
|
require golang.org/x/net v0.26.0
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -0,0 +1,2 @@
|
||||||
|
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||||
|
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
|
@ -33,6 +33,11 @@ var preparedMessageTests = []struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPreparedMessage(t *testing.T) {
|
func TestPreparedMessage(t *testing.T) {
|
||||||
|
testRand := rand.New(rand.NewSource(99))
|
||||||
|
prevMaskRand := maskRand
|
||||||
|
maskRand = testRand
|
||||||
|
defer func() { maskRand = prevMaskRand }()
|
||||||
|
|
||||||
for _, tt := range preparedMessageTests {
|
for _, tt := range preparedMessageTests {
|
||||||
var data = []byte("this is a test")
|
var data = []byte("this is a test")
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
@ -45,7 +50,7 @@ func TestPreparedMessage(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Seed random number generator for consistent frame mask.
|
// Seed random number generator for consistent frame mask.
|
||||||
rand.Seed(1234)
|
testRand.Seed(1234)
|
||||||
|
|
||||||
if err := c.WriteMessage(tt.messageType, data); err != nil {
|
if err := c.WriteMessage(tt.messageType, data); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -61,7 +66,7 @@ func TestPreparedMessage(t *testing.T) {
|
||||||
copy(data, "hello world")
|
copy(data, "hello world")
|
||||||
|
|
||||||
// Seed random number generator for consistent frame mask.
|
// Seed random number generator for consistent frame mask.
|
||||||
rand.Seed(1234)
|
testRand.Seed(1234)
|
||||||
|
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
if err := c.WritePreparedMessage(pm); err != nil {
|
if err := c.WritePreparedMessage(pm); err != nil {
|
||||||
|
|
50
proxy.go
50
proxy.go
|
@ -6,34 +6,52 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type netDialerFunc func(network, addr string) (net.Conn, error)
|
type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
||||||
return fn(network, addr)
|
return fn(context.Background(), network, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
return fn(ctx, network, addr)
|
||||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
}
|
||||||
})
|
|
||||||
|
func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) {
|
||||||
|
if proxyURL.Scheme == "http" {
|
||||||
|
return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil
|
||||||
|
}
|
||||||
|
dialer, err := proxy.FromURL(proxyURL, forwardDial)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if d, ok := dialer.(proxy.ContextDialer); ok {
|
||||||
|
return d.DialContext, nil
|
||||||
|
}
|
||||||
|
return func(ctx context.Context, net, addr string) (net.Conn, error) {
|
||||||
|
return dialer.Dial(net, addr)
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpProxyDialer struct {
|
type httpProxyDialer struct {
|
||||||
proxyURL *url.URL
|
proxyURL *url.URL
|
||||||
forwardDial func(network, addr string) (net.Conn, error)
|
forwardDial netDialerFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
|
func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||||
hostPort, _ := hostPortNoPort(hpd.proxyURL)
|
hostPort, _ := hostPortNoPort(hpd.proxyURL)
|
||||||
conn, err := hpd.forwardDial(network, hostPort)
|
conn, err := hpd.forwardDial(ctx, network, hostPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -68,8 +86,18 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
// Close the response body to silence false positives from linters. Reset
|
||||||
conn.Close()
|
// the buffered reader first to ensure that Close() does not read from
|
||||||
|
// conn.
|
||||||
|
// Note: Applications must call resp.Body.Close() on a response returned
|
||||||
|
// http.ReadResponse to inspect trailers or read another response from the
|
||||||
|
// buffered reader. The call to resp.Body.Close() does not release
|
||||||
|
// resources.
|
||||||
|
br.Reset(bytes.NewReader(nil))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
_ = conn.Close()
|
||||||
f := strings.SplitN(resp.Status, " ", 2)
|
f := strings.SplitN(resp.Status, " ", 2)
|
||||||
return nil, errors.New(f[1])
|
return nil, errors.New(f[1])
|
||||||
}
|
}
|
||||||
|
|
87
server.go
87
server.go
|
@ -6,8 +6,7 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
"net"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -101,8 +100,8 @@ func checkSameOrigin(r *http.Request) bool {
|
||||||
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
||||||
if u.Subprotocols != nil {
|
if u.Subprotocols != nil {
|
||||||
clientProtocols := Subprotocols(r)
|
clientProtocols := Subprotocols(r)
|
||||||
for _, serverProtocol := range u.Subprotocols {
|
for _, clientProtocol := range clientProtocols {
|
||||||
for _, clientProtocol := range clientProtocols {
|
for _, serverProtocol := range u.Subprotocols {
|
||||||
if clientProtocol == serverProtocol {
|
if clientProtocol == serverProtocol {
|
||||||
return clientProtocol
|
return clientProtocol
|
||||||
}
|
}
|
||||||
|
@ -130,7 +129,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
|
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
|
||||||
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
|
w.Header().Set("Upgrade", "websocket")
|
||||||
|
return u.returnError(w, r, http.StatusUpgradeRequired, badHandshake+"'websocket' token not found in 'Upgrade' header")
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method != http.MethodGet {
|
if r.Method != http.MethodGet {
|
||||||
|
@ -172,18 +172,10 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h, ok := w.(http.Hijacker)
|
netConn, brw, err := http.NewResponseController(w).Hijack()
|
||||||
if !ok {
|
|
||||||
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
|
|
||||||
}
|
|
||||||
var brw *bufio.ReadWriter
|
|
||||||
netConn, brw, err := h.Hijack()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
|
return u.returnError(w, r, http.StatusInternalServerError,
|
||||||
}
|
"websocket: hijack: "+err.Error())
|
||||||
if brw.Reader.Buffered() > 0 {
|
|
||||||
netConn.Close()
|
|
||||||
return nil, errors.New("websocket: client sent data before handshake is complete")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close the network connection when returning an error. The variable
|
// Close the network connection when returning an error. The variable
|
||||||
|
@ -199,12 +191,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var br *bufio.Reader
|
var br *bufio.Reader
|
||||||
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
|
if u.ReadBufferSize == 0 && brw.Reader.Size() > 256 {
|
||||||
// Reuse hijacked buffered reader as connection reader.
|
// Use hijacked buffered reader as the connection reader.
|
||||||
br = brw.Reader
|
br = brw.Reader
|
||||||
|
} else if brw.Reader.Buffered() > 0 {
|
||||||
|
// Wrap the network connection to read buffered data in brw.Reader
|
||||||
|
// before reading from the network connection. This should be rare
|
||||||
|
// because a client must not send message data before receiving the
|
||||||
|
// handshake response.
|
||||||
|
netConn = &brNetConn{br: brw.Reader, Conn: netConn}
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := bufioWriterBuffer(netConn, brw.Writer)
|
buf := brw.Writer.AvailableBuffer()
|
||||||
|
|
||||||
var writeBuf []byte
|
var writeBuf []byte
|
||||||
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
|
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
|
||||||
|
@ -348,39 +346,28 @@ func IsWebSocketUpgrade(r *http.Request) bool {
|
||||||
tokenListContainsValue(r.Header, "Upgrade", "websocket")
|
tokenListContainsValue(r.Header, "Upgrade", "websocket")
|
||||||
}
|
}
|
||||||
|
|
||||||
// bufioReaderSize size returns the size of a bufio.Reader.
|
type brNetConn struct {
|
||||||
func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
|
br *bufio.Reader
|
||||||
// This code assumes that peek on a reset reader returns
|
net.Conn
|
||||||
// bufio.Reader.buf[:0].
|
}
|
||||||
// TODO: Use bufio.Reader.Size() after Go 1.10
|
|
||||||
br.Reset(originalReader)
|
func (b *brNetConn) Read(p []byte) (n int, err error) {
|
||||||
if p, err := br.Peek(0); err == nil {
|
if b.br != nil {
|
||||||
return cap(p)
|
// Limit read to buferred data.
|
||||||
|
if n := b.br.Buffered(); len(p) > n {
|
||||||
|
p = p[:n]
|
||||||
|
}
|
||||||
|
n, err = b.br.Read(p)
|
||||||
|
if b.br.Buffered() == 0 {
|
||||||
|
b.br = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
return 0
|
return b.Conn.Read(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeHook is an io.Writer that records the last slice passed to it vio
|
// NetConn returns the underlying connection that is wrapped by b.
|
||||||
// io.Writer.Write.
|
func (b *brNetConn) NetConn() net.Conn {
|
||||||
type writeHook struct {
|
return b.Conn
|
||||||
p []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wh *writeHook) Write(p []byte) (int, error) {
|
|
||||||
wh.p = p
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
|
|
||||||
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
|
|
||||||
// This code assumes that bufio.Writer.buf[:1] is passed to the
|
|
||||||
// bufio.Writer's underlying writer.
|
|
||||||
var wh writeHook
|
|
||||||
bw.Reset(&wh)
|
|
||||||
bw.WriteByte(0)
|
|
||||||
bw.Flush()
|
|
||||||
|
|
||||||
bw.Reset(originalWriter)
|
|
||||||
|
|
||||||
return wh.p[:cap(wh.p)]
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,8 +7,10 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -54,6 +56,36 @@ func TestIsWebSocketUpgrade(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSubProtocolSelection(t *testing.T) {
|
||||||
|
upgrader := Upgrader{
|
||||||
|
Subprotocols: []string{"foo", "bar", "baz"},
|
||||||
|
}
|
||||||
|
|
||||||
|
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
|
||||||
|
s := upgrader.selectSubprotocol(&r, nil)
|
||||||
|
if s != "foo" {
|
||||||
|
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
|
||||||
|
}
|
||||||
|
|
||||||
|
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
|
||||||
|
s = upgrader.selectSubprotocol(&r, nil)
|
||||||
|
if s != "bar" {
|
||||||
|
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
|
||||||
|
}
|
||||||
|
|
||||||
|
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
|
||||||
|
s = upgrader.selectSubprotocol(&r, nil)
|
||||||
|
if s != "baz" {
|
||||||
|
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
|
||||||
|
}
|
||||||
|
|
||||||
|
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
|
||||||
|
s = upgrader.selectSubprotocol(&r, nil)
|
||||||
|
if s != "" {
|
||||||
|
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var checkSameOriginTests = []struct {
|
var checkSameOriginTests = []struct {
|
||||||
ok bool
|
ok bool
|
||||||
r *http.Request
|
r *http.Request
|
||||||
|
@ -89,7 +121,7 @@ var bufioReuseTests = []struct {
|
||||||
{128, false},
|
{128, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBufioReuse(t *testing.T) {
|
func xTestBufioReuse(t *testing.T) {
|
||||||
for i, tt := range bufioReuseTests {
|
for i, tt := range bufioReuseTests {
|
||||||
br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
|
br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
|
||||||
bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
|
bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
|
||||||
|
@ -111,9 +143,29 @@ func TestBufioReuse(t *testing.T) {
|
||||||
if reuse := c.br == br; reuse != tt.reuse {
|
if reuse := c.br == br; reuse != tt.reuse {
|
||||||
t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
|
t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
|
||||||
}
|
}
|
||||||
writeBuf := bufioWriterBuffer(c.NetConn(), bw)
|
writeBuf := bw.AvailableBuffer()
|
||||||
if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
|
if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
|
||||||
t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
|
t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHijack_NotSupported(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
req.Header.Set("Connection", "upgrade")
|
||||||
|
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||||
|
req.Header.Set("Sec-Websocket-Version", "13")
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
upgrader := Upgrader{}
|
||||||
|
_, err := upgrader.Upgrade(recorder, req, nil)
|
||||||
|
|
||||||
|
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
|
||||||
|
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
473
x_net_proxy.go
473
x_net_proxy.go
|
@ -1,473 +0,0 @@
|
||||||
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
|
|
||||||
//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy
|
|
||||||
|
|
||||||
// Package proxy provides support for a variety of protocols to proxy network
|
|
||||||
// data.
|
|
||||||
//
|
|
||||||
|
|
||||||
package websocket
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type proxy_direct struct{}
|
|
||||||
|
|
||||||
// Direct is a direct proxy: one that makes network connections directly.
|
|
||||||
var proxy_Direct = proxy_direct{}
|
|
||||||
|
|
||||||
func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
|
|
||||||
return net.Dial(network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// A PerHost directs connections to a default Dialer unless the host name
|
|
||||||
// requested matches one of a number of exceptions.
|
|
||||||
type proxy_PerHost struct {
|
|
||||||
def, bypass proxy_Dialer
|
|
||||||
|
|
||||||
bypassNetworks []*net.IPNet
|
|
||||||
bypassIPs []net.IP
|
|
||||||
bypassZones []string
|
|
||||||
bypassHosts []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPerHost returns a PerHost Dialer that directs connections to either
|
|
||||||
// defaultDialer or bypass, depending on whether the connection matches one of
|
|
||||||
// the configured rules.
|
|
||||||
func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost {
|
|
||||||
return &proxy_PerHost{
|
|
||||||
def: defaultDialer,
|
|
||||||
bypass: bypass,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to the address addr on the given network through either
|
|
||||||
// defaultDialer or bypass.
|
|
||||||
func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
|
|
||||||
host, _, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.dialerForRequest(host).Dial(network, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
for _, net := range p.bypassNetworks {
|
|
||||||
if net.Contains(ip) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, bypassIP := range p.bypassIPs {
|
|
||||||
if bypassIP.Equal(ip) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.def
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, zone := range p.bypassZones {
|
|
||||||
if strings.HasSuffix(host, zone) {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
if host == zone[1:] {
|
|
||||||
// For a zone ".example.com", we match "example.com"
|
|
||||||
// too.
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, bypassHost := range p.bypassHosts {
|
|
||||||
if bypassHost == host {
|
|
||||||
return p.bypass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p.def
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFromString parses a string that contains comma-separated values
|
|
||||||
// specifying hosts that should use the bypass proxy. Each value is either an
|
|
||||||
// IP address, a CIDR range, a zone (*.example.com) or a host name
|
|
||||||
// (localhost). A best effort is made to parse the string and errors are
|
|
||||||
// ignored.
|
|
||||||
func (p *proxy_PerHost) AddFromString(s string) {
|
|
||||||
hosts := strings.Split(s, ",")
|
|
||||||
for _, host := range hosts {
|
|
||||||
host = strings.TrimSpace(host)
|
|
||||||
if len(host) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.Contains(host, "/") {
|
|
||||||
// We assume that it's a CIDR address like 127.0.0.0/8
|
|
||||||
if _, net, err := net.ParseCIDR(host); err == nil {
|
|
||||||
p.AddNetwork(net)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
p.AddIP(ip)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(host, "*.") {
|
|
||||||
p.AddZone(host[1:])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
p.AddHost(host)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddIP specifies an IP address that will use the bypass proxy. Note that
|
|
||||||
// this will only take effect if a literal IP address is dialed. A connection
|
|
||||||
// to a named host will never match an IP.
|
|
||||||
func (p *proxy_PerHost) AddIP(ip net.IP) {
|
|
||||||
p.bypassIPs = append(p.bypassIPs, ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
|
|
||||||
// this will only take effect if a literal IP address is dialed. A connection
|
|
||||||
// to a named host will never match.
|
|
||||||
func (p *proxy_PerHost) AddNetwork(net *net.IPNet) {
|
|
||||||
p.bypassNetworks = append(p.bypassNetworks, net)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
|
|
||||||
// "example.com" matches "example.com" and all of its subdomains.
|
|
||||||
func (p *proxy_PerHost) AddZone(zone string) {
|
|
||||||
if strings.HasSuffix(zone, ".") {
|
|
||||||
zone = zone[:len(zone)-1]
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(zone, ".") {
|
|
||||||
zone = "." + zone
|
|
||||||
}
|
|
||||||
p.bypassZones = append(p.bypassZones, zone)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddHost specifies a host name that will use the bypass proxy.
|
|
||||||
func (p *proxy_PerHost) AddHost(host string) {
|
|
||||||
if strings.HasSuffix(host, ".") {
|
|
||||||
host = host[:len(host)-1]
|
|
||||||
}
|
|
||||||
p.bypassHosts = append(p.bypassHosts, host)
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Dialer is a means to establish a connection.
|
|
||||||
type proxy_Dialer interface {
|
|
||||||
// Dial connects to the given address via the proxy.
|
|
||||||
Dial(network, addr string) (c net.Conn, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth contains authentication parameters that specific Dialers may require.
|
|
||||||
type proxy_Auth struct {
|
|
||||||
User, Password string
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromEnvironment returns the dialer specified by the proxy related variables in
|
|
||||||
// the environment.
|
|
||||||
func proxy_FromEnvironment() proxy_Dialer {
|
|
||||||
allProxy := proxy_allProxyEnv.Get()
|
|
||||||
if len(allProxy) == 0 {
|
|
||||||
return proxy_Direct
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyURL, err := url.Parse(allProxy)
|
|
||||||
if err != nil {
|
|
||||||
return proxy_Direct
|
|
||||||
}
|
|
||||||
proxy, err := proxy_FromURL(proxyURL, proxy_Direct)
|
|
||||||
if err != nil {
|
|
||||||
return proxy_Direct
|
|
||||||
}
|
|
||||||
|
|
||||||
noProxy := proxy_noProxyEnv.Get()
|
|
||||||
if len(noProxy) == 0 {
|
|
||||||
return proxy
|
|
||||||
}
|
|
||||||
|
|
||||||
perHost := proxy_NewPerHost(proxy, proxy_Direct)
|
|
||||||
perHost.AddFromString(noProxy)
|
|
||||||
return perHost
|
|
||||||
}
|
|
||||||
|
|
||||||
// proxySchemes is a map from URL schemes to a function that creates a Dialer
|
|
||||||
// from a URL with such a scheme.
|
|
||||||
var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)
|
|
||||||
|
|
||||||
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
|
|
||||||
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
|
|
||||||
// by FromURL.
|
|
||||||
func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) {
|
|
||||||
if proxy_proxySchemes == nil {
|
|
||||||
proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error))
|
|
||||||
}
|
|
||||||
proxy_proxySchemes[scheme] = f
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromURL returns a Dialer given a URL specification and an underlying
|
|
||||||
// Dialer for it to make network requests.
|
|
||||||
func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) {
|
|
||||||
var auth *proxy_Auth
|
|
||||||
if u.User != nil {
|
|
||||||
auth = new(proxy_Auth)
|
|
||||||
auth.User = u.User.Username()
|
|
||||||
if p, ok := u.User.Password(); ok {
|
|
||||||
auth.Password = p
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch u.Scheme {
|
|
||||||
case "socks5":
|
|
||||||
return proxy_SOCKS5("tcp", u.Host, auth, forward)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the scheme doesn't match any of the built-in schemes, see if it
|
|
||||||
// was registered by another package.
|
|
||||||
if proxy_proxySchemes != nil {
|
|
||||||
if f, ok := proxy_proxySchemes[u.Scheme]; ok {
|
|
||||||
return f(u, forward)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
proxy_allProxyEnv = &proxy_envOnce{
|
|
||||||
names: []string{"ALL_PROXY", "all_proxy"},
|
|
||||||
}
|
|
||||||
proxy_noProxyEnv = &proxy_envOnce{
|
|
||||||
names: []string{"NO_PROXY", "no_proxy"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// envOnce looks up an environment variable (optionally by multiple
|
|
||||||
// names) once. It mitigates expensive lookups on some platforms
|
|
||||||
// (e.g. Windows).
|
|
||||||
// (Borrowed from net/http/transport.go)
|
|
||||||
type proxy_envOnce struct {
|
|
||||||
names []string
|
|
||||||
once sync.Once
|
|
||||||
val string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *proxy_envOnce) Get() string {
|
|
||||||
e.once.Do(e.init)
|
|
||||||
return e.val
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *proxy_envOnce) init() {
|
|
||||||
for _, n := range e.names {
|
|
||||||
e.val = os.Getenv(n)
|
|
||||||
if e.val != "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
|
|
||||||
// with an optional username and password. See RFC 1928 and RFC 1929.
|
|
||||||
func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) {
|
|
||||||
s := &proxy_socks5{
|
|
||||||
network: network,
|
|
||||||
addr: addr,
|
|
||||||
forward: forward,
|
|
||||||
}
|
|
||||||
if auth != nil {
|
|
||||||
s.user = auth.User
|
|
||||||
s.password = auth.Password
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type proxy_socks5 struct {
|
|
||||||
user, password string
|
|
||||||
network, addr string
|
|
||||||
forward proxy_Dialer
|
|
||||||
}
|
|
||||||
|
|
||||||
const proxy_socks5Version = 5
|
|
||||||
|
|
||||||
const (
|
|
||||||
proxy_socks5AuthNone = 0
|
|
||||||
proxy_socks5AuthPassword = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
const proxy_socks5Connect = 1
|
|
||||||
|
|
||||||
const (
|
|
||||||
proxy_socks5IP4 = 1
|
|
||||||
proxy_socks5Domain = 3
|
|
||||||
proxy_socks5IP6 = 4
|
|
||||||
)
|
|
||||||
|
|
||||||
var proxy_socks5Errors = []string{
|
|
||||||
"",
|
|
||||||
"general failure",
|
|
||||||
"connection forbidden",
|
|
||||||
"network unreachable",
|
|
||||||
"host unreachable",
|
|
||||||
"connection refused",
|
|
||||||
"TTL expired",
|
|
||||||
"command not supported",
|
|
||||||
"address type not supported",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial connects to the address addr on the given network via the SOCKS5 proxy.
|
|
||||||
func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
|
|
||||||
switch network {
|
|
||||||
case "tcp", "tcp6", "tcp4":
|
|
||||||
default:
|
|
||||||
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := s.forward.Dial(s.network, s.addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := s.connect(conn, addr); err != nil {
|
|
||||||
conn.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect takes an existing connection to a socks5 proxy server,
|
|
||||||
// and commands the server to extend that connection to target,
|
|
||||||
// which must be a canonical address with a host and port.
|
|
||||||
func (s *proxy_socks5) connect(conn net.Conn, target string) error {
|
|
||||||
host, portStr, err := net.SplitHostPort(target)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
if err != nil {
|
|
||||||
return errors.New("proxy: failed to parse port number: " + portStr)
|
|
||||||
}
|
|
||||||
if port < 1 || port > 0xffff {
|
|
||||||
return errors.New("proxy: port number out of range: " + portStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// the size here is just an estimate
|
|
||||||
buf := make([]byte, 0, 6+len(host))
|
|
||||||
|
|
||||||
buf = append(buf, proxy_socks5Version)
|
|
||||||
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
|
|
||||||
buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword)
|
|
||||||
} else {
|
|
||||||
buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := conn.Write(buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
if buf[0] != 5 {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
|
|
||||||
}
|
|
||||||
if buf[1] == 0xff {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
|
|
||||||
}
|
|
||||||
|
|
||||||
// See RFC 1929
|
|
||||||
if buf[1] == proxy_socks5AuthPassword {
|
|
||||||
buf = buf[:0]
|
|
||||||
buf = append(buf, 1 /* password protocol version */)
|
|
||||||
buf = append(buf, uint8(len(s.user)))
|
|
||||||
buf = append(buf, s.user...)
|
|
||||||
buf = append(buf, uint8(len(s.password)))
|
|
||||||
buf = append(buf, s.password...)
|
|
||||||
|
|
||||||
if _, err := conn.Write(buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if buf[1] != 0 {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
buf = buf[:0]
|
|
||||||
buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */)
|
|
||||||
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
buf = append(buf, proxy_socks5IP4)
|
|
||||||
ip = ip4
|
|
||||||
} else {
|
|
||||||
buf = append(buf, proxy_socks5IP6)
|
|
||||||
}
|
|
||||||
buf = append(buf, ip...)
|
|
||||||
} else {
|
|
||||||
if len(host) > 255 {
|
|
||||||
return errors.New("proxy: destination host name too long: " + host)
|
|
||||||
}
|
|
||||||
buf = append(buf, proxy_socks5Domain)
|
|
||||||
buf = append(buf, byte(len(host)))
|
|
||||||
buf = append(buf, host...)
|
|
||||||
}
|
|
||||||
buf = append(buf, byte(port>>8), byte(port))
|
|
||||||
|
|
||||||
if _, err := conn.Write(buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
failure := "unknown error"
|
|
||||||
if int(buf[1]) < len(proxy_socks5Errors) {
|
|
||||||
failure = proxy_socks5Errors[buf[1]]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(failure) > 0 {
|
|
||||||
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
|
|
||||||
}
|
|
||||||
|
|
||||||
bytesToDiscard := 0
|
|
||||||
switch buf[3] {
|
|
||||||
case proxy_socks5IP4:
|
|
||||||
bytesToDiscard = net.IPv4len
|
|
||||||
case proxy_socks5IP6:
|
|
||||||
bytesToDiscard = net.IPv6len
|
|
||||||
case proxy_socks5Domain:
|
|
||||||
_, err := io.ReadFull(conn, buf[:1])
|
|
||||||
if err != nil {
|
|
||||||
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
bytesToDiscard = int(buf[0])
|
|
||||||
default:
|
|
||||||
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cap(buf) < bytesToDiscard {
|
|
||||||
buf = make([]byte, bytesToDiscard)
|
|
||||||
} else {
|
|
||||||
buf = buf[:bytesToDiscard]
|
|
||||||
}
|
|
||||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
||||||
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Also need to discard the port number
|
|
||||||
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
|
|
||||||
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
Loading…
Reference in New Issue