From a943a8db9e995b8397cb2308d59f2814ade15343 Mon Sep 17 00:00:00 2001 From: Thomas Massie Date: Sat, 25 May 2024 17:16:00 -1000 Subject: [PATCH] Return 426 status on missing upgrade header --- client_server_test.go | 31 +++++++++++++++++++++++++++++++ server.go | 3 ++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/client_server_test.go b/client_server_test.go index 610fbe2..c5c08d7 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -497,6 +497,37 @@ func TestBadMethod(t *testing.T) { } } +func TestNoUpgrade(t *testing.T) { + t.Parallel() + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := cstUpgrader.Upgrade(w, r, nil) + if err == nil { + t.Errorf("handshake succeeded, expect fail") + ws.Close() + } + })) + defer s.Close() + + req, err := http.NewRequest(http.MethodGet, s.URL, strings.NewReader("")) + if err != nil { + t.Fatalf("NewRequest returned error %v", err) + } + req.Header.Set("Connection", "upgrade") + req.Header.Set("Sec-Websocket-Version", "13") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Do returned error %v", err) + } + resp.Body.Close() + if u := resp.Header.Get("Upgrade"); u != "websocket" { + t.Errorf("Uprade response header is %q, want %q", u, "websocket") + } + if resp.StatusCode != http.StatusUpgradeRequired { + t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired) + } +} + func TestDialExtraTokensInRespHeaders(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { challengeKey := r.Header.Get("Sec-Websocket-Key") diff --git a/server.go b/server.go index ff7d03a..28f2013 100644 --- a/server.go +++ b/server.go @@ -130,7 +130,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { - return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") + w.Header().Set("Upgrade", "websocket") + return u.returnError(w, r, http.StatusUpgradeRequired, badHandshake+"'websocket' token not found in 'Upgrade' header") } if r.Method != http.MethodGet {