mirror of https://github.com/gorilla/websocket.git
use http.ResposnseController
This commit is contained in:
parent
a70cea529a
commit
c7502098b0
10
server.go
10
server.go
|
@ -172,13 +172,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
h, ok := w.(http.Hijacker)
|
netConn, brw, err := http.NewResponseController(w).Hijack()
|
||||||
if !ok {
|
switch {
|
||||||
|
case errors.Is(err, errors.ErrUnsupported):
|
||||||
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
|
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
|
||||||
}
|
case err != nil:
|
||||||
var brw *bufio.ReadWriter
|
|
||||||
netConn, brw, err := h.Hijack()
|
|
||||||
if err != nil {
|
|
||||||
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
|
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,10 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -117,3 +119,23 @@ func TestBufioReuse(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHijack_NotSupported(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
req.Header.Set("Connection", "upgrade")
|
||||||
|
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||||
|
req.Header.Set("Sec-Websocket-Version", "13")
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
upgrader := Upgrader{}
|
||||||
|
_, err := upgrader.Upgrade(recorder, req, nil)
|
||||||
|
|
||||||
|
if want := (HandshakeError{}); !errors.As(err, &want) || recorder.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("want %T and status_code=%d", want, http.StatusInternalServerError)
|
||||||
|
t.Fatalf("got err=%T and status_code=%d", err, recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue