use http.ResposnseController

This commit is contained in:
merlin 2023-11-22 21:47:04 +02:00 committed by Daniel Holmes
parent a70cea529a
commit c7502098b0
2 changed files with 26 additions and 6 deletions

View File

@ -172,13 +172,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
}
}
h, ok := w.(http.Hijacker)
if !ok {
netConn, brw, err := http.NewResponseController(w).Hijack()
switch {
case errors.Is(err, errors.ErrUnsupported):
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil {
case err != nil:
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
}

View File

@ -7,8 +7,10 @@ package websocket
import (
"bufio"
"bytes"
"errors"
"net"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
@ -117,3 +119,23 @@ func TestBufioReuse(t *testing.T) {
}
}
}
func TestHijack_NotSupported(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Connection", "upgrade")
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Sec-Websocket-Version", "13")
recorder := httptest.NewRecorder()
upgrader := Upgrader{}
_, err := upgrader.Upgrade(recorder, req, nil)
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
}
}