Compare commits

..

1 Commits

Author SHA1 Message Date
Allen 523c8697a4
Merge a889672aa4 into 80393295c1 2023-08-23 10:11:45 +08:00
17 changed files with 195 additions and 121 deletions

View File

@ -67,4 +67,4 @@ workflows:
- test: - test:
matrix: matrix:
parameters: parameters:
version: ["1.22", "1.21", "1.20"] version: ["1.18", "1.17", "1.16"]

View File

@ -10,10 +10,11 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
### Documentation ### Documentation
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
* [Chat example](https://github.com/gorilla/websocket/tree/main/examples/chat) * [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/main/examples/command) * [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/main/examples/echo) * [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/main/examples/filewatch) * [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
* [Write buffer pool example](https://github.com/gorilla/websocket/tree/master/examples/bufferpool)
### Status ### Status
@ -29,4 +30,5 @@ package API is stable.
The Gorilla WebSocket package passes the server tests in the [Autobahn Test The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/main/examples/autobahn). subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).

View File

@ -11,6 +11,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
@ -402,7 +403,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// debugging. // debugging.
buf := make([]byte, 1024) buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf) n, _ := io.ReadFull(resp.Body, buf)
resp.Body = io.NopCloser(bytes.NewReader(buf[:n])) resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake return nil, resp, ErrBadHandshake
} }
@ -420,7 +421,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
break break
} }
resp.Body = io.NopCloser(bytes.NewReader([]byte{})) resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{}) netConn.SetDeadline(time.Time{})

View File

@ -14,6 +14,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -23,7 +24,6 @@ import (
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
) )
@ -45,15 +45,12 @@ var cstDialer = Dialer{
HandshakeTimeout: 30 * time.Second, HandshakeTimeout: 30 * time.Second,
} }
type cstHandler struct { type cstHandler struct{ *testing.T }
*testing.T
s *cstServer
}
type cstServer struct { type cstServer struct {
URL string *httptest.Server
Server *httptest.Server URL string
wg sync.WaitGroup t *testing.T
} }
const ( const (
@ -62,15 +59,9 @@ const (
cstRequestURI = cstPath + "?" + cstRawQuery cstRequestURI = cstPath + "?" + cstRawQuery
) )
func (s *cstServer) Close() {
s.Server.Close()
// Wait for handler functions to complete.
s.wg.Wait()
}
func newServer(t *testing.T) *cstServer { func newServer(t *testing.T) *cstServer {
var s cstServer var s cstServer
s.Server = httptest.NewServer(cstHandler{T: t, s: &s}) s.Server = httptest.NewServer(cstHandler{t})
s.Server.URL += cstRequestURI s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL) s.URL = makeWsProto(s.Server.URL)
return &s return &s
@ -78,19 +69,13 @@ func newServer(t *testing.T) *cstServer {
func newTLSServer(t *testing.T) *cstServer { func newTLSServer(t *testing.T) *cstServer {
var s cstServer var s cstServer
s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s}) s.Server = httptest.NewTLSServer(cstHandler{t})
s.Server.URL += cstRequestURI s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL) s.URL = makeWsProto(s.Server.URL)
return &s return &s
} }
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Because tests wait for a response from a server, we are guaranteed that
// the wait group count is incremented before the test waits on the group
// in the call to (*cstServer).Close().
t.s.wg.Add(1)
defer t.s.wg.Done()
if r.URL.Path != cstPath { if r.URL.Path != cstPath {
t.Logf("path=%v, want %v", r.URL.Path, cstPath) t.Logf("path=%v, want %v", r.URL.Path, cstPath)
http.Error(w, "bad path", http.StatusBadRequest) http.Error(w, "bad path", http.StatusBadRequest)
@ -571,7 +556,7 @@ func TestRespOnBadHandshake(t *testing.T) {
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
} }
p, err := io.ReadAll(resp.Body) p, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatalf("ReadFull(resp.Body) returned error %v", err) t.Fatalf("ReadFull(resp.Body) returned error %v", err)
} }

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"testing" "testing"
) )
@ -41,7 +42,7 @@ func textMessages(num int) [][]byte {
} }
func BenchmarkWriteNoCompression(b *testing.B) { func BenchmarkWriteNoCompression(b *testing.B) {
w := io.Discard w := ioutil.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
b.ResetTimer() b.ResetTimer()
@ -52,7 +53,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
} }
func BenchmarkWriteWithCompression(b *testing.B) { func BenchmarkWriteWithCompression(b *testing.B) {
w := io.Discard w := ioutil.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
c.enableWriteCompression = true c.enableWriteCompression = true

View File

@ -9,6 +9,7 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
@ -794,7 +795,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
if c.readRemaining > 0 { if c.readRemaining > 0 {
if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil { if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err return noFrame, err
} }
} }
@ -1093,7 +1094,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
if err != nil { if err != nil {
return messageType, nil, err return messageType, nil, err
} }
p, err = io.ReadAll(r) p, err = ioutil.ReadAll(r)
return messageType, p, err return messageType, p, err
} }

View File

@ -6,6 +6,7 @@ package websocket
import ( import (
"io" "io"
"io/ioutil"
"sync/atomic" "sync/atomic"
"testing" "testing"
) )
@ -44,7 +45,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
func newBroadcastBench(usePrepared, compression bool) *broadcastBench { func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
bench := &broadcastBench{ bench := &broadcastBench{
w: io.Discard, w: ioutil.Discard,
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
usePrepared: usePrepared, usePrepared: usePrepared,

View File

@ -10,6 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"reflect" "reflect"
"sync" "sync"
@ -124,7 +125,7 @@ func TestFraming(t *testing.T) {
} }
t.Logf("frame size: %d", n) t.Logf("frame size: %d", n)
rbuf, err := io.ReadAll(r) rbuf, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue continue
@ -366,7 +367,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(io.Discard, r) _, err = io.Copy(ioutil.Discard, r)
if !reflect.DeepEqual(err, expectedErr) { if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
} }
@ -400,7 +401,7 @@ func TestEOFWithinFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
} }
_, err = io.Copy(io.Discard, r) _, err = io.Copy(ioutil.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
} }
@ -425,7 +426,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(io.Discard, r) _, err = io.Copy(ioutil.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
} }
@ -489,7 +490,7 @@ func TestReadLimit(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err) t.Fatalf("2: NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(io.Discard, r) _, err = io.Copy(ioutil.Discard, r)
if err != ErrReadLimit { if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err) t.Fatalf("io.Copy() returned %v", err)
} }

View File

@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
} }
// echoReadAll echoes messages from the client by reading the entire message // echoReadAll echoes messages from the client by reading the entire message
// with io.ReadAll. // with ioutil.ReadAll.
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {

View File

@ -0,0 +1,89 @@
//go:build ignore
// +build ignore
package main
import (
"flag"
"log"
"net/url"
"os"
"os/signal"
"sync"
"time"
"github.com/gorilla/websocket"
)
var addr = flag.String("addr", "localhost:8080", "http service address")
func runNewConn(wg *sync.WaitGroup) {
defer wg.Done()
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
u := url.URL{Scheme: "ws", Host: *addr, Path: "/ws"}
log.Printf("connecting to %s", u.String())
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
log.Fatal("dial:", err)
}
defer c.Close()
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := c.ReadMessage()
if err != nil {
log.Println("read:", err)
return
}
log.Printf("recv: %s", message)
}
}()
ticker := time.NewTicker(time.Minute * 5)
defer ticker.Stop()
for {
select {
case <-done:
return
case t := <-ticker.C:
err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
if err != nil {
log.Println("write:", err)
return
}
case <-interrupt:
log.Println("interrupt")
// Cleanly close the connection by sending a close message and then
// waiting (with timeout) for the server to close the connection.
err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
log.Println("write close:", err)
return
}
select {
case <-done:
case <-time.After(time.Second):
}
return
}
}
}
func main() {
flag.Parse()
log.SetFlags(0)
wg := &sync.WaitGroup{}
for i := 0; i < 1000; i++ {
wg.Add(1)
go runNewConn(wg)
}
wg.Wait()
}

View File

@ -0,0 +1,55 @@
//go:build ignore
// +build ignore
package main
import (
"flag"
"log"
"net/http"
"sync"
_ "net/http/pprof"
"github.com/gorilla/websocket"
)
var addr = flag.String("addr", "localhost:8080", "http service address")
var upgrader = websocket.Upgrader{
ReadBufferSize: 256,
WriteBufferSize: 256,
WriteBufferPool: &sync.Pool{},
}
func process(c *websocket.Conn) {
defer c.Close()
for {
_, message, err := c.ReadMessage()
if err != nil {
log.Println("read:", err)
break
}
log.Printf("recv: %s", message)
}
}
func handler(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Print("upgrade:", err)
return
}
// Process connection in a new goroutine
go process(c)
// Let the http handler return, the 8k buffer created by it will be garbage collected
}
func main() {
flag.Parse()
log.SetFlags(0)
http.HandleFunc("/ws", handler)
log.Fatal(http.ListenAndServe(*addr, nil))
}

View File

@ -38,7 +38,7 @@ sends them to the hub.
### Hub ### Hub
The code for the `Hub` type is in The code for the `Hub` type is in
[hub.go](https://github.com/gorilla/websocket/blob/main/examples/chat/hub.go). [hub.go](https://github.com/gorilla/websocket/blob/master/examples/chat/hub.go).
The application's `main` function starts the hub's `run` method as a goroutine. The application's `main` function starts the hub's `run` method as a goroutine.
Clients send requests to the hub using the `register`, `unregister` and Clients send requests to the hub using the `register`, `unregister` and
`broadcast` channels. `broadcast` channels.
@ -57,7 +57,7 @@ unregisters the client and closes the websocket.
### Client ### Client
The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/main/examples/chat/client.go). The code for the `Client` type is in [client.go](https://github.com/gorilla/websocket/blob/master/examples/chat/client.go).
The `serveWs` function is registered by the application's `main` function as The `serveWs` function is registered by the application's `main` function as
an HTTP handler. The handler upgrades the HTTP connection to the WebSocket an HTTP handler. The handler upgrades the HTTP connection to the WebSocket
@ -85,7 +85,7 @@ network.
## Frontend ## Frontend
The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/main/examples/chat/home.html). The frontend code is in [home.html](https://github.com/gorilla/websocket/blob/master/examples/chat/home.html).
On document load, the script checks for websocket functionality in the browser. On document load, the script checks for websocket functionality in the browser.
If websocket functionality is available, then the script opens a connection to If websocket functionality is available, then the script opens a connection to

View File

@ -7,6 +7,7 @@ package main
import ( import (
"flag" "flag"
"html/template" "html/template"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -48,7 +49,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
if !fi.ModTime().After(lastMod) { if !fi.ModTime().After(lastMod) {
return nil, lastMod, nil return nil, lastMod, nil
} }
p, err := os.ReadFile(filename) p, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
return nil, fi.ModTime(), err return nil, fi.ModTime(), err
} }

6
go.mod
View File

@ -1,7 +1,3 @@
module github.com/gorilla/websocket module github.com/gorilla/websocket
go 1.20 go 1.12
retract (
v1.5.2 // tag accidentally overwritten
)

View File

@ -6,7 +6,6 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes"
"encoding/base64" "encoding/base64"
"errors" "errors"
"net" "net"
@ -81,18 +80,8 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
return nil, err return nil, err
} }
// Close the response body to silence false positives from linters. Reset if resp.StatusCode != 200 {
// the buffered reader first to ensure that Close() does not read from conn.Close()
// conn.
// Note: Applications must call resp.Body.Close() on a response returned
// http.ReadResponse to inspect trailers or read another response from the
// buffered reader. The call to resp.Body.Close() does not release
// resources.
br.Reset(bytes.NewReader(nil))
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
_ = conn.Close()
f := strings.SplitN(resp.Status, " ", 2) f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1]) return nil, errors.New(f[1])
} }

View File

@ -101,8 +101,8 @@ func checkSameOrigin(r *http.Request) bool {
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil { if u.Subprotocols != nil {
clientProtocols := Subprotocols(r) clientProtocols := Subprotocols(r)
for _, clientProtocol := range clientProtocols { for _, serverProtocol := range u.Subprotocols {
for _, serverProtocol := range u.Subprotocols { for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol { if clientProtocol == serverProtocol {
return clientProtocol return clientProtocol
} }
@ -172,10 +172,14 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
} }
netConn, brw, err := http.NewResponseController(w).Hijack() h, ok := w.(http.Hijacker)
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, return u.returnError(w, r, http.StatusInternalServerError, err.Error())
"websocket: hijack: "+err.Error())
} }
if brw.Reader.Buffered() > 0 { if brw.Reader.Buffered() > 0 {

View File

@ -7,10 +7,8 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -56,36 +54,6 @@ func TestIsWebSocketUpgrade(t *testing.T) {
} }
} }
func TestSubProtocolSelection(t *testing.T) {
upgrader := Upgrader{
Subprotocols: []string{"foo", "bar", "baz"},
}
r := http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"foo", "bar"}}}
s := upgrader.selectSubprotocol(&r, nil)
if s != "foo" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "foo")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"bar", "foo"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "bar" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "bar")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"baz"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "baz" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "baz")
}
r = http.Request{Header: http.Header{"Sec-Websocket-Protocol": {"quux"}}}
s = upgrader.selectSubprotocol(&r, nil)
if s != "" {
t.Errorf("Upgrader.selectSubprotocol returned %v, want %v", s, "empty string")
}
}
var checkSameOriginTests = []struct { var checkSameOriginTests = []struct {
ok bool ok bool
r *http.Request r *http.Request
@ -149,23 +117,3 @@ func TestBufioReuse(t *testing.T) {
} }
} }
} }
func TestHijack_NotSupported(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "upgrade")
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Sec-Websocket-Version", "13")
recorder := httptest.NewRecorder()
upgrader := Upgrader{}
_, err := upgrader.Upgrade(recorder, req, nil)
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
}
}