diff --git a/server.go b/server.go index d8a205e..da96088 100644 --- a/server.go +++ b/server.go @@ -174,13 +174,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } } - h, ok := w.(http.Hijacker) - if !ok { + netConn, brw, err := http.NewResponseController(w).Hijack() + switch { + case errors.Is(err, errors.ErrUnsupported): return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") - } - var brw *bufio.ReadWriter - netConn, brw, err := h.Hijack() - if err != nil { + case err != nil: return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } diff --git a/server_test.go b/server_test.go index b0dc625..d7eb880 100644 --- a/server_test.go +++ b/server_test.go @@ -7,8 +7,10 @@ package websocket import ( "bufio" "bytes" + "errors" "net" "net/http" + "net/http/httptest" "reflect" "strings" "testing" @@ -152,3 +154,23 @@ func TestBufioReuse(t *testing.T) { } } } + +func TestHijack_NotSupported(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "upgrade") + req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + req.Header.Set("Sec-Websocket-Version", "13") + + recorder := httptest.NewRecorder() + + upgrader := Upgrader{} + _, err := upgrader.Upgrade(recorder, req, nil) + + if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError { + t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError) + t.Fatalf("got err=%T and status_code=%d", err, recorder.Code) + } +}