Compare commits

...

17 Commits

Author SHA1 Message Date
Sleeyax 883b0a2814
Merge ea5afd45e1 into 70bf50955e 2024-06-19 17:37:57 +10:00
Canelo Hill 70bf50955e Silence false positive lint warning in proxy code 2024-06-19 17:31:46 +10:00
Konstantin Burkalev f78ed9f987 Added tests for subprotocol selection 2024-06-19 17:13:42 +10:00
Konstantin Burkalev 17f407278f Fixes subprotocol selection (aling with rfc6455) 2024-06-19 17:13:42 +10:00
mstmdev efaec3cbd1 Update README.md, replace master to main 2024-06-19 17:13:16 +10:00
Canelo Hill 688592ebe6 Improve client/server tests
Tests must not call *testing.T methods after the test function returns.
Use a sync.WaitGroup to ensure that server handler functions complete
before tests return.
2024-06-19 17:11:11 +10:00
tebuka 7e5e9b5a25 Improve hijack failure error text
Include "hijack" in text to indicate where in this package the error
occurred.
2024-06-19 17:10:25 +10:00
merlin 8890e3e578 fix: don't use errors.ErrUnsupported, it's available only since go1.21 2024-06-19 17:10:25 +10:00
merlin c7502098b0 use http.ResposnseController 2024-06-19 17:10:25 +10:00
Canelo Hill a70cea529a
Update for deprecated ioutil package (#931) 2024-06-19 14:44:41 +10:00
Canelo Hill ac1b326ac0
Set min Go version to 1.20 (#930)
Update go.mod and CI to Go version 1.20.
2024-06-19 14:40:57 +10:00
Daniel Holmes 227456c3cc chore: Retract v1.5.2 from go.mod
Maintainers accidentally changed the reference commit
for v1.5.2. This change retracts v1.5.2 which also
includes a number of avoidable issues.

Fixes #927
2024-06-19 04:30:55 +00:00
Corey Daley ea5afd45e1
Merge branch 'main' into http-proxy-tls-fix 2023-08-17 15:10:23 -04:00
Corey Daley fe5bdd86f3
Merge branch 'main' into http-proxy-tls-fix 2023-07-30 14:23:03 -04:00
Sleeyax 36867e9319 Include hostPort in function signature 2022-10-26 19:04:40 +02:00
Sleeyax adb13885c2 Check error 2022-05-14 15:47:30 +02:00
Sleeyax 502bd65db8 Add optional method ProxyTLSConnection (closes #779)
Removed the call to NetDialTLSContext from the HTTP proxy CONNECT step and replaced it with a regular net.Dial in order to prevent connection issues. Custom TLS connections can now be made via the new optional ProxyTLSConnection method, after the proxy connection has been successfully established.
2022-05-12 18:06:13 +02:00
15 changed files with 153 additions and 68 deletions

View File

@ -67,4 +67,4 @@ workflows:
- test: - test:
matrix: matrix:
parameters: parameters:
version: ["1.18", "1.17", "1.16"] version: ["1.22", "1.21", "1.20"]

View File

@ -10,10 +10,10 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
### Documentation ### Documentation
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) * [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) * [Command example](https://github.com/gorilla/websocket/tree/main/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) * [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) * [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch)
### Status ### Status
@ -29,5 +29,4 @@ package API is stable.
The Gorilla WebSocket package passes the server tests in the [Autobahn Test The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn).

View File

@ -11,7 +11,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
@ -66,6 +65,12 @@ type Dialer struct {
// TLSClientConfig is ignored. // TLSClientConfig is ignored.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// ProxyTLSConnection specifies the dial function for creating TLS connections through a Proxy. If
// ProxyTLSConnection is nil, NetDialTLSContext is used.
// If ProxyTLSConnection is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
ProxyTLSConnection func(ctx context.Context, hostPort string, proxyConn net.Conn) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given // Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the // Request. If the function returns a non-nil error, the
// request is aborted with the provided error. // request is aborted with the provided error.
@ -334,7 +339,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
}() }()
if u.Scheme == "https" && d.NetDialTLSContext == nil { if u.Scheme == "https" {
if d.ProxyTLSConnection != nil && d.Proxy != nil {
// If we are connected to a proxy, perform the TLS handshake through the existing tunnel
netConn, err = d.ProxyTLSConnection(ctx, hostPort, netConn)
if err != nil {
return nil, nil, err
}
} else if d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done // If NetDialTLSContext is set, assume that the TLS handshake has already been done
cfg := cloneTLSConfig(d.TLSClientConfig) cfg := cloneTLSConfig(d.TLSClientConfig)
@ -356,6 +368,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err return nil, nil, err
} }
} }
}
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
@ -400,7 +413,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// debugging. // debugging.
buf := make([]byte, 1024) buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf) n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake return nil, resp, ErrBadHandshake
} }
@ -418,7 +431,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
break break
} }
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{}) netConn.SetDeadline(time.Time{})

View File

@ -14,7 +14,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -24,6 +23,7 @@ import (
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
) )
@ -45,12 +45,15 @@ var cstDialer = Dialer{
HandshakeTimeout: 30 * time.Second, HandshakeTimeout: 30 * time.Second,
} }
type cstHandler struct{ *testing.T } type cstHandler struct {
*testing.T
s *cstServer
}
type cstServer struct { type cstServer struct {
*httptest.Server
URL string URL string
t *testing.T Server *httptest.Server
wg sync.WaitGroup
} }
const ( const (
@ -59,9 +62,15 @@ const (
cstRequestURI = cstPath + "?" + cstRawQuery cstRequestURI = cstPath + "?" + cstRawQuery
) )
func (s *cstServer) Close() {
s.Server.Close()
// Wait for handler functions to complete.
s.wg.Wait()
}
func newServer(t *testing.T) *cstServer { func newServer(t *testing.T) *cstServer {
var s cstServer var s cstServer
s.Server = httptest.NewServer(cstHandler{t}) s.Server = httptest.NewServer(cstHandler{T: t, s: &s})
s.Server.URL += cstRequestURI s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL) s.URL = makeWsProto(s.Server.URL)
return &s return &s
@ -69,13 +78,19 @@ func newServer(t *testing.T) *cstServer {
func newTLSServer(t *testing.T) *cstServer { func newTLSServer(t *testing.T) *cstServer {
var s cstServer var s cstServer
s.Server = httptest.NewTLSServer(cstHandler{t}) s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s})
s.Server.URL += cstRequestURI s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL) s.URL = makeWsProto(s.Server.URL)
return &s return &s
} }
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Because tests wait for a response from a server, we are guaranteed that
// the wait group count is incremented before the test waits on the group
// in the call to (*cstServer).Close().
t.s.wg.Add(1)
defer t.s.wg.Done()
if r.URL.Path != cstPath { if r.URL.Path != cstPath {
t.Logf("path=%v, want %v", r.URL.Path, cstPath) t.Logf("path=%v, want %v", r.URL.Path, cstPath)
http.Error(w, "bad path", http.StatusBadRequest) http.Error(w, "bad path", http.StatusBadRequest)
@ -549,7 +564,7 @@ func TestRespOnBadHandshake(t *testing.T) {
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
} }
p, err := ioutil.ReadAll(resp.Body) p, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatalf("ReadFull(resp.Body) returned error %v", err) t.Fatalf("ReadFull(resp.Body) returned error %v", err)
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"testing" "testing"
) )
@ -42,7 +41,7 @@ func textMessages(num int) [][]byte {
} }
func BenchmarkWriteNoCompression(b *testing.B) { func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
b.ResetTimer() b.ResetTimer()
@ -53,7 +52,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
} }
func BenchmarkWriteWithCompression(b *testing.B) { func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
c.enableWriteCompression = true c.enableWriteCompression = true

View File

@ -9,7 +9,6 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
@ -795,7 +794,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
if c.readRemaining > 0 { if c.readRemaining > 0 {
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err return noFrame, err
} }
} }
@ -1094,7 +1093,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
if err != nil { if err != nil {
return messageType, nil, err return messageType, nil, err
} }
p, err = ioutil.ReadAll(r) p, err = io.ReadAll(r)
return messageType, p, err return messageType, p, err
} }

View File

@ -6,7 +6,6 @@ package websocket
import ( import (
"io" "io"
"io/ioutil"
"sync/atomic" "sync/atomic"
"testing" "testing"
) )
@ -45,7 +44,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
func newBroadcastBench(usePrepared, compression bool) *broadcastBench { func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
bench := &broadcastBench{ bench := &broadcastBench{
w: ioutil.Discard, w: io.Discard,
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
usePrepared: usePrepared, usePrepared: usePrepared,

View File

@ -10,7 +10,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"reflect" "reflect"
"sync" "sync"
@ -125,7 +124,7 @@ func TestFraming(t *testing.T) {
} }
t.Logf("frame size: %d", n) t.Logf("frame size: %d", n)
rbuf, err := ioutil.ReadAll(r) rbuf, err := io.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue continue
@ -367,7 +366,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if !reflect.DeepEqual(err, expectedErr) { if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
} }
@ -401,7 +400,7 @@ func TestEOFWithinFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
} }
@ -426,7 +425,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
} }
@ -490,7 +489,7 @@ func TestReadLimit(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err) t.Fatalf("2: NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != ErrReadLimit { if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err) t.Fatalf("io.Copy() returned %v", err)
} }

View File

@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
} }
// echoReadAll echoes messages from the client by reading the entire message // echoReadAll echoes messages from the client by reading the entire message
// with ioutil.ReadAll. // with io.ReadAll.
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {

View File

@ -38,7 +38,7 @@ sends them to the hub.
### Hub ### Hub
The code for the `Hub` type is in The code for the `Hub` type is in
[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go). [hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go).
The application's `main` function starts the hub's `run` method as a goroutine. The application's `main` function starts the hub's `run` method as a goroutine.
Clients send requests to the hub using the `register`, `unregister` and Clients send requests to the hub using the `register`, `unregister` and
`broadcast` channels. `broadcast` channels.
@ -57,7 +57,7 @@ unregisters the client and closes the websocket.
### Client ### Client
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go). The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go).
The `serveWs` function is registered by the application's `main` function as The `serveWs` function is registered by the application's `main` function as
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
@ -85,7 +85,7 @@ network.
## Frontend ## Frontend
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html). The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html).
On document load, the script checks for websocket functionality in the browser. On document load, the script checks for websocket functionality in the browser.
If websocket functionality is available, then the script opens a connection to If websocket functionality is available, then the script opens a connection to

View File

@ -7,7 +7,6 @@ package main
import ( import (
"flag" "flag"
"html/template" "html/template"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -49,7 +48,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
if !fi.ModTime().After(lastMod) { if !fi.ModTime().After(lastMod) {
return nil, lastMod, nil return nil, lastMod, nil
} }
p, err := ioutil.ReadFile(filename) p, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, fi.ModTime(), err return nil, fi.ModTime(), err
} }

6
go.mod
View File

@ -1,3 +1,7 @@
module github.com/gorilla/websocket module github.com/gorilla/websocket
go 1.12 go 1.20
retract (
v1.5.2 // tag accidentally overwritten
)

View File

@ -6,6 +6,7 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes"
"encoding/base64" "encoding/base64"
"errors" "errors"
"net" "net"
@ -33,7 +34,7 @@ type httpProxyDialer struct {
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL) hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.forwardDial(network, hostPort) conn, err := net.Dial(network, hostPort)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -68,8 +69,18 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
return nil, err return nil, err
} }
if resp.StatusCode != 200 { // Close the response body to silence false positives from linters. Reset
conn.Close() // the buffered reader first to ensure that Close() does not read from
// conn.
// Note: Applications must call resp.Body.Close() on a response returned
// http.ReadResponse to inspect trailers or read another response from the
// buffered reader. The call to resp.Body.Close() does not release
// resources.
br.Reset(bytes.NewReader(nil))
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
_ = conn.Close()
f := strings.SplitN(resp.Status, " ", 2) f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1]) return nil, errors.New(f[1])
} }

View File

@ -101,8 +101,8 @@ func checkSameOrigin(r *http.Request) bool {
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil { if u.Subprotocols != nil {
clientProtocols := Subprotocols(r) clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols { for _, clientProtocol := range clientProtocols {
for _, serverProtocol := range u.Subprotocols {
if clientProtocol == serverProtocol { if clientProtocol == serverProtocol {
return clientProtocol return clientProtocol
} }
@ -172,14 +172,10 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
} }
h, ok := w.(http.Hijacker) netConn, brw, err := http.NewResponseController(w).Hijack()
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error()) return u.returnError(w, r, http.StatusInternalServerError,
"websocket: hijack: "+err.Error())
} }
if brw.Reader.Buffered() > 0 { if brw.Reader.Buffered() > 0 {

View File

@ -7,8 +7,10 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -54,6 +56,36 @@ func TestIsWebSocketUpgrade(t *testing.T) {
} }
} }
func TestSubProtocolSelection(t *testing.T) {
upgrader := Upgrader{
Subprotocols: []string{"foo", "bar", "baz"},
}
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
s := upgrader.selectSubprotocol(&r, nil)
if s != "foo" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "bar" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "baz" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
}
}
var checkSameOriginTests = []struct { var checkSameOriginTests = []struct {
ok bool ok bool
r *http.Request r *http.Request
@ -117,3 +149,23 @@ func TestBufioReuse(t *testing.T) {
} }
} }
} }
func TestHijack_NotSupported(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "upgrade")
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Sec-Websocket-Version", "13")
recorder := httptest.NewRecorder()
upgrader := Upgrader{}
_, err := upgrader.Upgrade(recorder, req, nil)
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
}
}