From c7502098b0f8461511d811d27c26075165735b05 Mon Sep 17 00:00:00 2001 From: merlin Date: Wed, 22 Nov 2023 21:47:04 +0200 Subject: [PATCH] use http.ResposnseController --- server.go | 10 ++++------ server_test.go | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/server.go b/server.go index bb33597..69e6f83 100644 --- a/server.go +++ b/server.go @@ -172,13 +172,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } } - h, ok := w.(http.Hijacker) - if !ok { + netConn, brw, err := http.NewResponseController(w).Hijack() + switch { + case errors.Is(err, errors.ErrUnsupported): return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") - } - var brw *bufio.ReadWriter - netConn, brw, err := h.Hijack() - if err != nil { + case err != nil: return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } diff --git a/server_test.go b/server_test.go index 5804be1..2ce6f7f 100644 --- a/server_test.go +++ b/server_test.go @@ -7,8 +7,10 @@ package websocket import ( "bufio" "bytes" + "errors" "net" "net/http" + "net/http/httptest" "reflect" "strings" "testing" @@ -117,3 +119,23 @@ func TestBufioReuse(t *testing.T) { } } } + +func TestHijack_NotSupported(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "upgrade") + req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Sec-Websocket-Version", "13") + + recorder := httptest.NewRecorder() + + upgrader := Upgrader{} + _, err := upgrader.Upgrade(recorder, req, nil) + + if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError { + t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError) + t.Fatalf("got err=%T and status_code=%d", err, recorder.Code) + } +}