Compare commits

..

8 Commits

Author SHA1 Message Date
Daniel Holmes 81ab9ae479
Merge branch 'main' into prsubproto 2024-06-19 17:12:02 +10:00
Canelo Hill 688592ebe6 Improve client/server tests
Tests must not call *testing.T methods after the test function returns.
Use a sync.WaitGroup to ensure that server handler functions complete
before tests return.
2024-06-19 17:11:11 +10:00
tebuka 7e5e9b5a25 Improve hijack failure error text
Include "hijack" in text to indicate where in this package the error
occurred.
2024-06-19 17:10:25 +10:00
merlin 8890e3e578 fix: don't use errors.ErrUnsupported, it's available only since go1.21 2024-06-19 17:10:25 +10:00
merlin c7502098b0 use http.ResposnseController 2024-06-19 17:10:25 +10:00
Canelo Hill a70cea529a
Update for deprecated ioutil package (#931) 2024-06-19 14:44:41 +10:00
Canelo Hill ac1b326ac0
Set min Go version to 1.20 (#930)
Update go.mod and CI to Go version 1.20.
2024-06-19 14:40:57 +10:00
Daniel Holmes 227456c3cc chore: Retract v1.5.2 from go.mod
Maintainers accidentally changed the reference commit
for v1.5.2. This change retracts v1.5.2 which also
includes a number of avoidable issues.

Fixes #927
2024-06-19 04:30:55 +00:00
12 changed files with 68 additions and 37 deletions

View File

@ -67,4 +67,4 @@ workflows:
- test: - test:
matrix: matrix:
parameters: parameters:
version: ["1.18", "1.17", "1.16"] version: ["1.22", "1.21", "1.20"]

View File

@ -11,7 +11,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
@ -400,7 +399,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// debugging. // debugging.
buf := make([]byte, 1024) buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf) n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake return nil, resp, ErrBadHandshake
} }
@ -418,7 +417,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
break break
} }
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{}) netConn.SetDeadline(time.Time{})

View File

@ -14,7 +14,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -24,6 +23,7 @@ import (
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
) )
@ -45,12 +45,15 @@ var cstDialer = Dialer{
HandshakeTimeout: 30 * time.Second, HandshakeTimeout: 30 * time.Second,
} }
type cstHandler struct{ *testing.T } type cstHandler struct {
*testing.T
s *cstServer
}
type cstServer struct { type cstServer struct {
*httptest.Server
URL string URL string
t *testing.T Server *httptest.Server
wg sync.WaitGroup
} }
const ( const (
@ -59,9 +62,15 @@ const (
cstRequestURI = cstPath + "?" + cstRawQuery cstRequestURI = cstPath + "?" + cstRawQuery
) )
func (s *cstServer) Close() {
s.Server.Close()
// Wait for handler functions to complete.
s.wg.Wait()
}
func newServer(t *testing.T) *cstServer { func newServer(t *testing.T) *cstServer {
var s cstServer var s cstServer
s.Server = httptest.NewServer(cstHandler{t}) s.Server = httptest.NewServer(cstHandler{T: t, s: &s})
s.Server.URL += cstRequestURI s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL) s.URL = makeWsProto(s.Server.URL)
return &s return &s
@ -69,13 +78,19 @@ func newServer(t *testing.T) *cstServer {
func newTLSServer(t *testing.T) *cstServer { func newTLSServer(t *testing.T) *cstServer {
var s cstServer var s cstServer
s.Server = httptest.NewTLSServer(cstHandler{t}) s.Server = httptest.NewTLSServer(cstHandler{T: t, s: &s})
s.Server.URL += cstRequestURI s.Server.URL += cstRequestURI
s.URL = makeWsProto(s.Server.URL) s.URL = makeWsProto(s.Server.URL)
return &s return &s
} }
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Because tests wait for a response from a server, we are guaranteed that
// the wait group count is incremented before the test waits on the group
// in the call to (*cstServer).Close().
t.s.wg.Add(1)
defer t.s.wg.Done()
if r.URL.Path != cstPath { if r.URL.Path != cstPath {
t.Logf("path=%v, want %v", r.URL.Path, cstPath) t.Logf("path=%v, want %v", r.URL.Path, cstPath)
http.Error(w, "bad path", http.StatusBadRequest) http.Error(w, "bad path", http.StatusBadRequest)
@ -549,7 +564,7 @@ func TestRespOnBadHandshake(t *testing.T) {
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
} }
p, err := ioutil.ReadAll(resp.Body) p, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatalf("ReadFull(resp.Body) returned error %v", err) t.Fatalf("ReadFull(resp.Body) returned error %v", err)
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"testing" "testing"
) )
@ -42,7 +41,7 @@ func textMessages(num int) [][]byte {
} }
func BenchmarkWriteNoCompression(b *testing.B) { func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
b.ResetTimer() b.ResetTimer()
@ -53,7 +52,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
} }
func BenchmarkWriteWithCompression(b *testing.B) { func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard w := io.Discard
c := newTestConn(nil, w, false) c := newTestConn(nil, w, false)
messages := textMessages(100) messages := textMessages(100)
c.enableWriteCompression = true c.enableWriteCompression = true

View File

@ -9,7 +9,6 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
@ -795,7 +794,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
if c.readRemaining > 0 { if c.readRemaining > 0 {
if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { if _, err := io.CopyN(io.Discard, c.br, c.readRemaining); err != nil {
return noFrame, err return noFrame, err
} }
} }
@ -1094,7 +1093,7 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
if err != nil { if err != nil {
return messageType, nil, err return messageType, nil, err
} }
p, err = ioutil.ReadAll(r) p, err = io.ReadAll(r)
return messageType, p, err return messageType, p, err
} }

View File

@ -6,7 +6,6 @@ package websocket
import ( import (
"io" "io"
"io/ioutil"
"sync/atomic" "sync/atomic"
"testing" "testing"
) )
@ -45,7 +44,7 @@ func newBroadcastConn(c *Conn) *broadcastConn {
func newBroadcastBench(usePrepared, compression bool) *broadcastBench { func newBroadcastBench(usePrepared, compression bool) *broadcastBench {
bench := &broadcastBench{ bench := &broadcastBench{
w: ioutil.Discard, w: io.Discard,
doneCh: make(chan struct{}), doneCh: make(chan struct{}),
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
usePrepared: usePrepared, usePrepared: usePrepared,

View File

@ -10,7 +10,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"reflect" "reflect"
"sync" "sync"
@ -125,7 +124,7 @@ func TestFraming(t *testing.T) {
} }
t.Logf("frame size: %d", n) t.Logf("frame size: %d", n)
rbuf, err := ioutil.ReadAll(r) rbuf, err := io.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue continue
@ -367,7 +366,7 @@ func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if !reflect.DeepEqual(err, expectedErr) { if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
} }
@ -401,7 +400,7 @@ func TestEOFWithinFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
} }
@ -426,7 +425,7 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if err != errUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
} }
@ -490,7 +489,7 @@ func TestReadLimit(t *testing.T) {
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err) t.Fatalf("2: NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(io.Discard, r)
if err != ErrReadLimit { if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err) t.Fatalf("io.Copy() returned %v", err)
} }

View File

@ -84,7 +84,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) {
} }
// echoReadAll echoes messages from the client by reading the entire message // echoReadAll echoes messages from the client by reading the entire message
// with ioutil.ReadAll. // with io.ReadAll.
func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) {
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {

View File

@ -7,7 +7,6 @@ package main
import ( import (
"flag" "flag"
"html/template" "html/template"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -49,7 +48,7 @@ func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) {
if !fi.ModTime().After(lastMod) { if !fi.ModTime().After(lastMod) {
return nil, lastMod, nil return nil, lastMod, nil
} }
p, err := ioutil.ReadFile(filename) p, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, fi.ModTime(), err return nil, fi.ModTime(), err
} }

6
go.mod
View File

@ -1,3 +1,7 @@
module github.com/gorilla/websocket module github.com/gorilla/websocket
go 1.12 go 1.20
retract (
v1.5.2 // tag accidentally overwritten
)

View File

@ -172,14 +172,10 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
} }
h, ok := w.(http.Hijacker) netConn, brw, err := http.NewResponseController(w).Hijack()
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error()) return u.returnError(w, r, http.StatusInternalServerError,
"websocket: hijack: "+err.Error())
} }
if brw.Reader.Buffered() > 0 { if brw.Reader.Buffered() > 0 {

View File

@ -7,8 +7,10 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -147,3 +149,23 @@ func TestBufioReuse(t *testing.T) {
} }
} }
} }
func TestHijack_NotSupported(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "upgrade")
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Sec-Websocket-Version", "13")
recorder := httptest.NewRecorder()
upgrader := Upgrader{}
_, err := upgrader.Upgrade(recorder, req, nil)
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
}
}