Compare commits

...

23 Commits

Author SHA1 Message Date
Philip Hamer 68a77f625e
Merge 4d0a40247b into 70bf50955e 2024-06-19 17:39:41 +10:00
Canelo Hill 70bf50955e Silence false positive lint warning in proxy code 2024-06-19 17:31:46 +10:00
Konstantin Burkalev f78ed9f987 Added tests for subprotocol selection 2024-06-19 17:13:42 +10:00
Konstantin Burkalev 17f407278f Fixes subprotocol selection (aling with rfc6455) 2024-06-19 17:13:42 +10:00
mstmdev efaec3cbd1 Update README.md, replace master to main 2024-06-19 17:13:16 +10:00
Canelo Hill 688592ebe6 Improve client/server tests
Tests must not call *testing.T methods after the test function returns.
Use a sync.WaitGroup to ensure that server handler functions complete
before tests return.
2024-06-19 17:11:11 +10:00
tebuka 7e5e9b5a25 Improve hijack failure error text
Include "hijack" in text to indicate where in this package the error
occurred.
2024-06-19 17:10:25 +10:00
merlin 8890e3e578 fix: don't use errors.ErrUnsupported, it's available only since go1.21 2024-06-19 17:10:25 +10:00
merlin c7502098b0 use http.ResposnseController 2024-06-19 17:10:25 +10:00
Canelo Hill a70cea529a
Update for deprecated ioutil package (#931) 2024-06-19 14:44:41 +10:00
Canelo Hill ac1b326ac0
Set min Go version to 1.20 (#930)
Update go.mod and CI to Go version 1.20.
2024-06-19 14:40:57 +10:00
Daniel Holmes 227456c3cc chore: Retract v1.5.2 from go.mod
Maintainers accidentally changed the reference commit
for v1.5.2. This change retracts v1.5.2 which also
includes a number of avoidable issues.

Fixes #927
2024-06-19 04:30:55 +00:00
Corey Daley 4d0a40247b
Merge branch 'main' into tls-proxy 2023-07-30 14:22:40 -04:00
Philip Hamer 444f3c080f
gofmt 2022-02-17 13:00:00 -05:00
Philip Hamer 9fb72e3db2
Merge branch 'master' into tls-proxy-build 2022-02-17 12:51:27 -05:00
Philip Hamer bb146cd3fd
fix build error 2021-12-06 18:52:57 -05:00
Philip Hamer b484a6e5a0
try compatibility with pre 1.15 as noop 2021-12-06 18:49:39 -05:00
Philip Hamer d16969baa1
add unit test for https proxy 2021-12-06 10:29:48 -05:00
Philip Hamer 2553869a29
clean up comment 2021-12-06 09:30:42 -05:00
Philip Hamer f724bd6a6c
do not edit the generated x_net_proxy.go 2021-12-03 15:59:00 -05:00
Philip Hamer 7f3a5bcae0
make it more intuitive for tls proxy 2021-11-29 14:49:47 -05:00
Philip Hamer 2a082eee69
simplify proxying with tls proxy 2021-11-29 14:19:00 -05:00
Philip Hamer d229c9f93d
try https proxy 2021-11-16 23:26:29 -05:00
18 changed files with 252 additions and 55 deletions

View File

@ -67,4 +67,4 @@ workflows:
- test: - test:
matrix: matrix:
parameters: parameters:
version: ["1.18", "1.17", "1.16"] version: ["1.22", "1.21", "1.20"]

View File

@ -10,10 +10,10 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
### Documentation ### Documentation
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) * [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) * [Command example](https://github.com/gorilla/websocket/tree/main/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) * [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) * [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch)
### Status ### Status
@ -29,5 +29,4 @@ package API is stable.
The Gorilla WebSocket package passes the server tests in the [Autobahn Test The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn).

View File

@ -11,7 +11,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
@ -304,7 +303,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err return nil, nil, err
} }
if proxyURL != nil { if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) proxyDialer := &netDialerFunc{fn: netDial}
modifyProxyDialer(ctx, d, proxyURL, proxyDialer)
dialer, err := proxy_FromURL(proxyURL, proxyDialer)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -400,7 +401,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// debugging. // debugging.
buf := make([]byte, 1024) buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf) n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake return nil, resp, ErrBadHandshake
} }
@ -418,7 +419,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
break break
} }
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{}) netConn.SetDeadline(time.Time{})

View File

@ -0,0 +1,52 @@
//go:build go1.15
// +build go1.15
package websocket
import (
"crypto/tls"
"net/http"
"net/url"
"testing"
)
func TestHttpsProxy(t *testing.T) {
sTLS := newTLSServer(t)
defer sTLS.Close()
s := newServer(t)
defer s.Close()
surlTLS, _ := url.Parse(sTLS.Server.URL)
cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(surlTLS)
connect := false
origHandler := sTLS.Server.Config.Handler
// Capture the request Host header.
sTLS.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
connect = true
w.WriteHeader(http.StatusOK)
return
}
if !connect {
t.Log("connect not received")
http.Error(w, "connect not received", http.StatusMethodNotAllowed)
return
}
origHandler.ServeHTTP(w, r)
})
cstDialer.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, sTLS.Server)}
ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
}

View File

@ -14,7 +14,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -24,6 +23,7 @@ import (
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
) )
@ -45,12 +45,15 @@ var cstDialer = Dialer{
HandshakeTimeout: 30 * time.Second, HandshakeTimeout: 30 * time.Second,
} }
type cstHandler struct{ *testing.T } type cstHandler struct {
*testing.T
s *cstServer
}
type cstServer struct { type cstServer struct {
*httptest.Server
URL string URL string
t *testing.T Server *httptest.Server
wg sync.WaitGroup
} }
const ( const (
@ -59,9 +62,15 @@ const (
cstRequestURI = cstPath + "?" + cstRawQuery cstRequestURI = cstPath + "?" + cstRawQuery
) )
func (s *cstServer) Close() {
s.Server.Close()
// Wait for handler functions to complete.
s.wg.Wait()
}
func newServer(t *testing.T) *cstServer { func newServer(t *testing.T) *cstServer {
var s cstServer var s cstServer
s.Server = httptest.NewServer(cstHandler{t}) s.Server = httptest.NewServer(cstHandler{T: t, s: &s})
s.Server.URL += cstRequestURI s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL) s.URL = makeWsProto(s.Server.URL)
return &s return &s
@ -69,13 +78,19 @@ func newServer(t *testing.T) *cstServer {
func newTLSServer(t *testing.T) *cstServer { func newTLSServer(t *testing.T) *cstServer {
var s cstServer var s cstServer
s.Server = httptest.NewTLSServer(cstHandler{t}) s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s})
s.Server.URL += cstRequestURI s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL) s.URL = makeWsProto(s.Server.URL)
return &s return &s
} }
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Because tests wait for a response from a server, we are guaranteed that
// the wait group count is incremented before the test waits on the group
// in the call to (*cstServer).Close().
t.s.wg.Add(1)
defer t.s.wg.Done()
if r.URL.Path != cstPath { if r.URL.Path != cstPath {
t.Logf("path=%v, want %v", r.URL.Path, cstPath) t.Logf("path=%v, want %v", r.URL.Path, cstPath)
http.Error(w, "bad path", http.StatusBadRequest) http.Error(w, "bad path", http.StatusBadRequest)
@ -549,7 +564,7 @@ func TestRespOnBadHandshake(t *testing.T) {
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
} }
p, err := ioutil.ReadAll(resp.Body) p, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatalf("ReadFull(resp.Body) returned error %v", err) t.Fatalf("ReadFull(resp.Body) returned error %v", err)
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"testing" "testing"
) )
@ -42,7 +41,7 @@ func textMessages(num int) [][]byte {
} }
func BenchmarkWriteNoCompression(b *testing.B) { func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
b.ResetTimer() b.ResetTimer()
@ -53,7 +52,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
} }
func BenchmarkWriteWithCompression(b *testing.B) { func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
c.enableWriteCompression = true c.enableWriteCompression = true

View File

@ -9,7 +9,6 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
@ -795,7 +794,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
if c.readRemaining > 0 { if c.readRemaining > 0 {
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err return noFrame, err
} }
} }
@ -1094,7 +1093,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
if err != nil { if err != nil {
return messageType, nil, err return messageType, nil, err
} }
p, err = ioutil.ReadAll(r) p, err = io.ReadAll(r)
return messageType, p, err return messageType, p, err
} }

View File

@ -6,7 +6,6 @@ package websocket
import ( import (
"io" "io"
"io/ioutil"
"sync/atomic" "sync/atomic"
"testing" "testing"
) )
@ -45,7 +44,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
func newBroadcastBench(usePrepared, compression bool) *broadcastBench { func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
bench := &broadcastBench{ bench := &broadcastBench{
w: ioutil.Discard, w: io.Discard,
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
usePrepared: usePrepared, usePrepared: usePrepared,

View File

@ -10,7 +10,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"reflect" "reflect"
"sync" "sync"
@ -125,7 +124,7 @@ func TestFraming(t *testing.T) {
} }
t.Logf("frame size: %d", n) t.Logf("frame size: %d", n)
rbuf, err := ioutil.ReadAll(r) rbuf, err := io.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue continue
@ -367,7 +366,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if !reflect.DeepEqual(err, expectedErr) { if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
} }
@ -401,7 +400,7 @@ func TestEOFWithinFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
} }
@ -426,7 +425,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
} }
@ -490,7 +489,7 @@ func TestReadLimit(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err) t.Fatalf("2: NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != ErrReadLimit { if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err) t.Fatalf("io.Copy() returned %v", err)
} }

View File

@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
} }
// echoReadAll echoes messages from the client by reading the entire message // echoReadAll echoes messages from the client by reading the entire message
// with ioutil.ReadAll. // with io.ReadAll.
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {

View File

@ -38,7 +38,7 @@ sends them to the hub.
### Hub ### Hub
The code for the `Hub` type is in The code for the `Hub` type is in
[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go). [hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go).
The application's `main` function starts the hub's `run` method as a goroutine. The application's `main` function starts the hub's `run` method as a goroutine.
Clients send requests to the hub using the `register`, `unregister` and Clients send requests to the hub using the `register`, `unregister` and
`broadcast` channels. `broadcast` channels.
@ -57,7 +57,7 @@ unregisters the client and closes the websocket.
### Client ### Client
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go). The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go).
The `serveWs` function is registered by the application's `main` function as The `serveWs` function is registered by the application's `main` function as
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
@ -85,7 +85,7 @@ network.
## Frontend ## Frontend
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html). The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html).
On document load, the script checks for websocket functionality in the browser. On document load, the script checks for websocket functionality in the browser.
If websocket functionality is available, then the script opens a connection to If websocket functionality is available, then the script opens a connection to

View File

@ -7,7 +7,6 @@ package main
import ( import (
"flag" "flag"
"html/template" "html/template"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -49,7 +48,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
if !fi.ModTime().After(lastMod) { if !fi.ModTime().After(lastMod) {
return nil, lastMod, nil return nil, lastMod, nil
} }
p, err := ioutil.ReadFile(filename) p, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, fi.ModTime(), err return nil, fi.ModTime(), err
} }

6
go.mod
View File

@ -1,3 +1,7 @@
module github.com/gorilla/websocket module github.com/gorilla/websocket
go 1.12 go 1.20
retract (
v1.5.2 // tag accidentally overwritten
)

View File

@ -6,6 +6,7 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes"
"encoding/base64" "encoding/base64"
"errors" "errors"
"net" "net"
@ -14,21 +15,37 @@ import (
"strings" "strings"
) )
type netDialerFunc func(network, addr string) (net.Conn, error) // proxyDialerEx extends the generated proxy_Dialer
type proxyDialerEx interface {
proxy_Dialer
// UsesTLS indicates whether we expect to dial to a TLS proxy
UsesTLS() bool
}
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { type netDialerFunc struct {
return fn(network, addr) fn func(network, addr string) (net.Conn, error)
usesTLS bool
}
func (ndf *netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return ndf.fn(network, addr)
}
func (ndf *netDialerFunc) UsesTLS() bool {
return ndf.usesTLS
} }
func init() { func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, usesTLS: false}, nil
}) })
registerDialerHttps()
} }
type httpProxyDialer struct { type httpProxyDialer struct {
proxyURL *url.URL proxyURL *url.URL
forwardDial func(network, addr string) (net.Conn, error) forwardDial func(network, addr string) (net.Conn, error)
usesTLS bool
} }
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
@ -68,10 +85,24 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
return nil, err return nil, err
} }
if resp.StatusCode != 200 { // Close the response body to silence false positives from linters. Reset
conn.Close() // the buffered reader first to ensure that Close() does not read from
// conn.
// Note: Applications must call resp.Body.Close() on a response returned
// http.ReadResponse to inspect trailers or read another response from the
// buffered reader. The call to resp.Body.Close() does not release
// resources.
br.Reset(bytes.NewReader(nil))
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
_ = conn.Close()
f := strings.SplitN(resp.Status, " ", 2) f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1]) return nil, errors.New(f[1])
} }
return conn, nil return conn, nil
} }
func (hpd *httpProxyDialer) UsesTLS() bool {
return hpd.usesTLS
}

37
proxy_https.go Normal file
View File

@ -0,0 +1,37 @@
//go:build go1.15
// +build go1.15
package websocket
import (
"context"
"crypto/tls"
"net"
"net/url"
)
func registerDialerHttps() {
proxy_RegisterDialerType("https", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
fwd := forwardDialer.Dial
if dialerEx, ok := forwardDialer.(proxyDialerEx); !ok || !dialerEx.UsesTLS() {
tlsDialer := &tls.Dialer{
Config: &tls.Config{},
NetDialer: &net.Dialer{},
}
fwd = tlsDialer.Dial
}
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: fwd, usesTLS: true}, nil
})
}
func modifyProxyDialer(ctx context.Context, d *Dialer, proxyURL *url.URL, proxyDialer *netDialerFunc) {
if proxyURL.Scheme == "https" {
proxyDialer.usesTLS = true
proxyDialer.fn = func(network, addr string) (net.Conn, error) {
t := tls.Dialer{}
t.Config = d.TLSClientConfig
t.NetDialer = &net.Dialer{}
return t.DialContext(ctx, network, addr)
}
}
}

15
proxy_https_legacy.go Normal file
View File

@ -0,0 +1,15 @@
//go:build !go1.15
// +build !go1.15
package websocket
import (
"context"
"net/url"
)
func registerDialerHttps() {
}
func modifyProxyDialer(ctx context.Context, d *Dialer, proxyURL *url.URL, proxyDialer *netDialerFunc) {
}

View File

@ -101,8 +101,8 @@ func checkSameOrigin(r *http.Request) bool {
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil { if u.Subprotocols != nil {
clientProtocols := Subprotocols(r) clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols { for _, clientProtocol := range clientProtocols {
for _, serverProtocol := range u.Subprotocols {
if clientProtocol == serverProtocol { if clientProtocol == serverProtocol {
return clientProtocol return clientProtocol
} }
@ -172,14 +172,10 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
} }
h, ok := w.(http.Hijacker) netConn, brw, err := http.NewResponseController(w).Hijack()
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error()) return u.returnError(w, r, http.StatusInternalServerError,
"websocket: hijack: "+err.Error())
} }
if brw.Reader.Buffered() > 0 { if brw.Reader.Buffered() > 0 {

View File

@ -7,8 +7,10 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -54,6 +56,36 @@ func TestIsWebSocketUpgrade(t *testing.T) {
} }
} }
func TestSubProtocolSelection(t *testing.T) {
upgrader := Upgrader{
Subprotocols: []string{"foo", "bar", "baz"},
}
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
s := upgrader.selectSubprotocol(&r, nil)
if s != "foo" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "bar" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "baz" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
}
}
var checkSameOriginTests = []struct { var checkSameOriginTests = []struct {
ok bool ok bool
r *http.Request r *http.Request
@ -117,3 +149,23 @@ func TestBufioReuse(t *testing.T) {
} }
} }
} }
func TestHijack_NotSupported(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "upgrade")
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Sec-Websocket-Version", "13")
recorder := httptest.NewRecorder()
upgrader := Upgrader{}
_, err := upgrader.Upgrade(recorder, req, nil)
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
}
}