mirror of https://github.com/gorilla/websocket.git
Compare commits
1 Commits
a7a56527e5
...
523c8697a4
Author | SHA1 | Date |
---|---|---|
Allen | 523c8697a4 |
|
@ -67,4 +67,4 @@ workflows:
|
|||
- test:
|
||||
matrix:
|
||||
parameters:
|
||||
version: ["1.22", "1.21", "1.20"]
|
||||
version: ["1.18", "1.17", "1.16"]
|
||||
|
|
12
README.md
12
README.md
|
@ -10,10 +10,11 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
|
|||
### Documentation
|
||||
|
||||
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
|
||||
* [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat)
|
||||
* [Command example](https://github.com/gorilla/websocket/tree/main/examples/command)
|
||||
* [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo)
|
||||
* [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch)
|
||||
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
|
||||
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
|
||||
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
|
||||
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
|
||||
* [Write buffer pool example](https://github.com/gorilla/websocket/tree/master/examples/bufferpool)
|
||||
|
||||
### Status
|
||||
|
||||
|
@ -29,4 +30,5 @@ package API is stable.
|
|||
|
||||
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
||||
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
||||
subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn).
|
||||
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
|
@ -402,7 +403,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
// debugging.
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := io.ReadFull(resp.Body, buf)
|
||||
resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
|
||||
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
||||
return nil, resp, ErrBadHandshake
|
||||
}
|
||||
|
||||
|
@ -420,7 +421,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
break
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
|
||||
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||
|
||||
netConn.SetDeadline(time.Time{})
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -23,7 +24,6 @@ import (
|
|||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -45,15 +45,12 @@ var cstDialer = Dialer{
|
|||
HandshakeTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
type cstHandler struct {
|
||||
*testing.T
|
||||
s *cstServer
|
||||
}
|
||||
type cstHandler struct{ *testing.T }
|
||||
|
||||
type cstServer struct {
|
||||
URL string
|
||||
Server *httptest.Server
|
||||
wg sync.WaitGroup
|
||||
*httptest.Server
|
||||
URL string
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -62,15 +59,9 @@ const (
|
|||
cstRequestURI = cstPath + "?" + cstRawQuery
|
||||
)
|
||||
|
||||
func (s *cstServer) Close() {
|
||||
s.Server.Close()
|
||||
// Wait for handler functions to complete.
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func newServer(t *testing.T) *cstServer {
|
||||
var s cstServer
|
||||
s.Server = httptest.NewServer(cstHandler{T: t, s: &s})
|
||||
s.Server = httptest.NewServer(cstHandler{t})
|
||||
s.Server.URL += cstRequestURI
|
||||
s.URL = makeWsProto(s.Server.URL)
|
||||
return &s
|
||||
|
@ -78,19 +69,13 @@ func newServer(t *testing.T) *cstServer {
|
|||
|
||||
func newTLSServer(t *testing.T) *cstServer {
|
||||
var s cstServer
|
||||
s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s})
|
||||
s.Server = httptest.NewTLSServer(cstHandler{t})
|
||||
s.Server.URL += cstRequestURI
|
||||
s.URL = makeWsProto(s.Server.URL)
|
||||
return &s
|
||||
}
|
||||
|
||||
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Because tests wait for a response from a server, we are guaranteed that
|
||||
// the wait group count is incremented before the test waits on the group
|
||||
// in the call to (*cstServer).Close().
|
||||
t.s.wg.Add(1)
|
||||
defer t.s.wg.Done()
|
||||
|
||||
if r.URL.Path != cstPath {
|
||||
t.Logf("path=%v, want %v", r.URL.Path, cstPath)
|
||||
http.Error(w, "bad path", http.StatusBadRequest)
|
||||
|
@ -571,7 +556,7 @@ func TestRespOnBadHandshake(t *testing.T) {
|
|||
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
||||
}
|
||||
|
||||
p, err := io.ReadAll(resp.Body)
|
||||
p, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -41,7 +42,7 @@ func textMessages(num int) [][]byte {
|
|||
}
|
||||
|
||||
func BenchmarkWriteNoCompression(b *testing.B) {
|
||||
w := io.Discard
|
||||
w := ioutil.Discard
|
||||
c := newTestConn(nil, w, false)
|
||||
messages := textMessages(100)
|
||||
b.ResetTimer()
|
||||
|
@ -52,7 +53,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkWriteWithCompression(b *testing.B) {
|
||||
w := io.Discard
|
||||
w := ioutil.Discard
|
||||
c := newTestConn(nil, w, false)
|
||||
messages := textMessages(100)
|
||||
c.enableWriteCompression = true
|
||||
|
|
5
conn.go
5
conn.go
|
@ -9,6 +9,7 @@ import (
|
|||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
|
@ -794,7 +795,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
// 1. Skip remainder of previous frame.
|
||||
|
||||
if c.readRemaining > 0 {
|
||||
if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
|
||||
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
|
||||
return noFrame, err
|
||||
}
|
||||
}
|
||||
|
@ -1093,7 +1094,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
|
|||
if err != nil {
|
||||
return messageType, nil, err
|
||||
}
|
||||
p, err = io.ReadAll(r)
|
||||
p, err = ioutil.ReadAll(r)
|
||||
return messageType, p, err
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ package websocket
|
|||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
@ -44,7 +45,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
|
|||
|
||||
func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
|
||||
bench := &broadcastBench{
|
||||
w: io.Discard,
|
||||
w: ioutil.Discard,
|
||||
doneCh: make(chan struct{}),
|
||||
closeCh: make(chan struct{}),
|
||||
usePrepared: usePrepared,
|
||||
|
|
11
conn_test.go
11
conn_test.go
|
@ -10,6 +10,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
@ -124,7 +125,7 @@ func TestFraming(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Logf("frame size: %d", n)
|
||||
rbuf, err := io.ReadAll(r)
|
||||
rbuf, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
||||
continue
|
||||
|
@ -366,7 +367,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
|||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||
}
|
||||
_, err = io.Copy(io.Discard, r)
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
if !reflect.DeepEqual(err, expectedErr) {
|
||||
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
|
||||
}
|
||||
|
@ -400,7 +401,7 @@ func TestEOFWithinFrame(t *testing.T) {
|
|||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
||||
}
|
||||
_, err = io.Copy(io.Discard, r)
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
if err != errUnexpectedEOF {
|
||||
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||
}
|
||||
|
@ -425,7 +426,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
|||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||
}
|
||||
_, err = io.Copy(io.Discard, r)
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
if err != errUnexpectedEOF {
|
||||
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
||||
}
|
||||
|
@ -489,7 +490,7 @@ func TestReadLimit(t *testing.T) {
|
|||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
||||
}
|
||||
_, err = io.Copy(io.Discard, r)
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
if err != ErrReadLimit {
|
||||
t.Fatalf("io.Copy() returned %v", err)
|
||||
}
|
||||
|
|
|
@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// echoReadAll echoes messages from the client by reading the entire message
|
||||
// with io.ReadAll.
|
||||
// with ioutil.ReadAll.
|
||||
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
//go:build ignore
|
||||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var addr = flag.String("addr", "localhost:8080", "http service address")
|
||||
|
||||
func runNewConn(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
interrupt := make(chan os.Signal, 1)
|
||||
signal.Notify(interrupt, os.Interrupt)
|
||||
|
||||
u := url.URL{Scheme: "ws", Host: *addr, Path: "/ws"}
|
||||
log.Printf("connecting to %s", u.String())
|
||||
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
log.Fatal("dial:", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
_, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("read:", err)
|
||||
return
|
||||
}
|
||||
log.Printf("recv: %s", message)
|
||||
}
|
||||
}()
|
||||
|
||||
ticker := time.NewTicker(time.Minute * 5)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case t := <-ticker.C:
|
||||
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
|
||||
if err != nil {
|
||||
log.Println("write:", err)
|
||||
return
|
||||
}
|
||||
case <-interrupt:
|
||||
log.Println("interrupt")
|
||||
|
||||
// Cleanly close the connection by sending a close message and then
|
||||
// waiting (with timeout) for the server to close the connection.
|
||||
err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
if err != nil {
|
||||
log.Println("write close:", err)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
log.SetFlags(0)
|
||||
wg := &sync.WaitGroup{}
|
||||
for i := 0; i < 1000; i++ {
|
||||
wg.Add(1)
|
||||
go runNewConn(wg)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
//go:build ignore
|
||||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
_ "net/http/pprof"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var addr = flag.String("addr", "localhost:8080", "http service address")
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 256,
|
||||
WriteBufferSize: 256,
|
||||
WriteBufferPool: &sync.Pool{},
|
||||
}
|
||||
|
||||
func process(c *websocket.Conn) {
|
||||
defer c.Close()
|
||||
for {
|
||||
_, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("read:", err)
|
||||
break
|
||||
}
|
||||
log.Printf("recv: %s", message)
|
||||
}
|
||||
}
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
log.Print("upgrade:", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Process connection in a new goroutine
|
||||
go process(c)
|
||||
|
||||
// Let the http handler return, the 8k buffer created by it will be garbage collected
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
log.SetFlags(0)
|
||||
http.HandleFunc("/ws", handler)
|
||||
log.Fatal(http.ListenAndServe(*addr, nil))
|
||||
}
|
|
@ -38,7 +38,7 @@ sends them to the hub.
|
|||
### Hub
|
||||
|
||||
The code for the `Hub` type is in
|
||||
[hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go).
|
||||
[hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go).
|
||||
The application's `main` function starts the hub's `run` method as a goroutine.
|
||||
Clients send requests to the hub using the `register`, `unregister` and
|
||||
`broadcast` channels.
|
||||
|
@ -57,7 +57,7 @@ unregisters the client and closes the websocket.
|
|||
|
||||
### Client
|
||||
|
||||
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go).
|
||||
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go).
|
||||
|
||||
The `serveWs` function is registered by the application's `main` function as
|
||||
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
|
||||
|
@ -85,7 +85,7 @@ network.
|
|||
|
||||
## Frontend
|
||||
|
||||
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html).
|
||||
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html).
|
||||
|
||||
On document load, the script checks for websocket functionality in the browser.
|
||||
If websocket functionality is available, then the script opens a connection to
|
||||
|
|
|
@ -7,6 +7,7 @@ package main
|
|||
import (
|
||||
"flag"
|
||||
"html/template"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -48,7 +49,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
|
|||
if !fi.ModTime().After(lastMod) {
|
||||
return nil, lastMod, nil
|
||||
}
|
||||
p, err := os.ReadFile(filename)
|
||||
p, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, fi.ModTime(), err
|
||||
}
|
||||
|
|
6
go.mod
6
go.mod
|
@ -1,7 +1,3 @@
|
|||
module github.com/gorilla/websocket
|
||||
|
||||
go 1.20
|
||||
|
||||
retract (
|
||||
v1.5.2 // tag accidentally overwritten
|
||||
)
|
||||
go 1.12
|
||||
|
|
15
proxy.go
15
proxy.go
|
@ -6,7 +6,6 @@ package websocket
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
|
@ -81,18 +80,8 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Close the response body to silence false positives from linters. Reset
|
||||
// the buffered reader first to ensure that Close() does not read from
|
||||
// conn.
|
||||
// Note: Applications must call resp.Body.Close() on a response returned
|
||||
// http.ReadResponse to inspect trailers or read another response from the
|
||||
// buffered reader. The call to resp.Body.Close() does not release
|
||||
// resources.
|
||||
br.Reset(bytes.NewReader(nil))
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_ = conn.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
conn.Close()
|
||||
f := strings.SplitN(resp.Status, " ", 2)
|
||||
return nil, errors.New(f[1])
|
||||
}
|
||||
|
|
14
server.go
14
server.go
|
@ -101,8 +101,8 @@ func checkSameOrigin(r *http.Request) bool {
|
|||
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
||||
if u.Subprotocols != nil {
|
||||
clientProtocols := Subprotocols(r)
|
||||
for _, clientProtocol := range clientProtocols {
|
||||
for _, serverProtocol := range u.Subprotocols {
|
||||
for _, serverProtocol := range u.Subprotocols {
|
||||
for _, clientProtocol := range clientProtocols {
|
||||
if clientProtocol == serverProtocol {
|
||||
return clientProtocol
|
||||
}
|
||||
|
@ -172,10 +172,14 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|||
}
|
||||
}
|
||||
|
||||
netConn, brw, err := http.NewResponseController(w).Hijack()
|
||||
h, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
|
||||
}
|
||||
var brw *bufio.ReadWriter
|
||||
netConn, brw, err := h.Hijack()
|
||||
if err != nil {
|
||||
return u.returnError(w, r, http.StatusInternalServerError,
|
||||
"websocket: hijack: "+err.Error())
|
||||
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
if brw.Reader.Buffered() > 0 {
|
||||
|
|
|
@ -7,10 +7,8 @@ package websocket
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -56,36 +54,6 @@ func TestIsWebSocketUpgrade(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSubProtocolSelection(t *testing.T) {
|
||||
upgrader := Upgrader{
|
||||
Subprotocols: []string{"foo", "bar", "baz"},
|
||||
}
|
||||
|
||||
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
|
||||
s := upgrader.selectSubprotocol(&r, nil)
|
||||
if s != "foo" {
|
||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
|
||||
}
|
||||
|
||||
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
|
||||
s = upgrader.selectSubprotocol(&r, nil)
|
||||
if s != "bar" {
|
||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
|
||||
}
|
||||
|
||||
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
|
||||
s = upgrader.selectSubprotocol(&r, nil)
|
||||
if s != "baz" {
|
||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
|
||||
}
|
||||
|
||||
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
|
||||
s = upgrader.selectSubprotocol(&r, nil)
|
||||
if s != "" {
|
||||
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
|
||||
}
|
||||
}
|
||||
|
||||
var checkSameOriginTests = []struct {
|
||||
ok bool
|
||||
r *http.Request
|
||||
|
@ -149,23 +117,3 @@ func TestBufioReuse(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHijack_NotSupported(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "upgrade")
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
req.Header.Set("Sec-Websocket-Version", "13")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
upgrader := Upgrader{}
|
||||
_, err := upgrader.Upgrade(recorder, req, nil)
|
||||
|
||||
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
|
||||
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
|
||||
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue