mirror of https://github.com/gorilla/websocket.git
Return response body on bad handshake.
The Dialer.Dial method returns an *http.Response on a bad handshake. This CL updates the Dial method to include up to 1024 bytes of the response body in the returned *http.Response. Applications may find the response body helpful when debugging bad handshakes. Fixes issue #62.
This commit is contained in:
parent
ecff5aabe4
commit
b2fa8f6d58
15
client.go
15
client.go
|
@ -5,8 +5,11 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -155,7 +158,8 @@ var DefaultDialer *Dialer
|
||||||
//
|
//
|
||||||
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
|
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
|
||||||
// non-nil *http.Response so that callers can handle redirects, authentication,
|
// non-nil *http.Response so that callers can handle redirects, authentication,
|
||||||
// etc.
|
// etcetera. The response body may not contain the entire response and does not
|
||||||
|
// need to be closed by the application.
|
||||||
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
||||||
u, err := parseURL(urlStr)
|
u, err := parseURL(urlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -225,7 +229,16 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
|
conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == ErrBadHandshake {
|
||||||
|
// Before closing the network connection on return from this
|
||||||
|
// function, slurp up some of the response to aid application
|
||||||
|
// debugging.
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, _ := io.ReadFull(resp.Body, buf)
|
||||||
|
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
||||||
|
}
|
||||||
return nil, resp, err
|
return nil, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,11 +8,13 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -34,22 +36,22 @@ var cstDialer = Dialer{
|
||||||
|
|
||||||
type cstHandler struct{ *testing.T }
|
type cstHandler struct{ *testing.T }
|
||||||
|
|
||||||
type Server struct {
|
type cstServer struct {
|
||||||
*httptest.Server
|
*httptest.Server
|
||||||
URL string
|
URL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServer(t *testing.T) *Server {
|
func newServer(t *testing.T) *cstServer {
|
||||||
var s Server
|
var s cstServer
|
||||||
s.Server = httptest.NewServer(cstHandler{t})
|
s.Server = httptest.NewServer(cstHandler{t})
|
||||||
s.URL = "ws" + s.Server.URL[len("http"):]
|
s.URL = makeWsProto(s.Server.URL)
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTLSServer(t *testing.T) *Server {
|
func newTLSServer(t *testing.T) *cstServer {
|
||||||
var s Server
|
var s cstServer
|
||||||
s.Server = httptest.NewTLSServer(cstHandler{t})
|
s.Server = httptest.NewTLSServer(cstHandler{t})
|
||||||
s.URL = "ws" + s.Server.URL[len("http"):]
|
s.URL = makeWsProto(s.Server.URL)
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,6 +99,10 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeWsProto(s string) string {
|
||||||
|
return "ws" + strings.TrimPrefix(s, "http")
|
||||||
|
}
|
||||||
|
|
||||||
func sendRecv(t *testing.T, ws *Conn) {
|
func sendRecv(t *testing.T, ws *Conn) {
|
||||||
const message = "Hello World!"
|
const message = "Hello World!"
|
||||||
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
|
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
|
||||||
|
@ -157,6 +163,7 @@ func TestDialTLS(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func xTestDialTLSBadCert(t *testing.T) {
|
func xTestDialTLSBadCert(t *testing.T) {
|
||||||
|
// This test is deactivated because of noisy logging from the net/http package.
|
||||||
s := newTLSServer(t)
|
s := newTLSServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -247,3 +254,37 @@ func TestHandshake(t *testing.T) {
|
||||||
}
|
}
|
||||||
sendRecv(t, ws)
|
sendRecv(t, ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRespOnBadHandshake(t *testing.T) {
|
||||||
|
const expectedStatus = http.StatusGone
|
||||||
|
const expectedBody = "This is the response body."
|
||||||
|
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(expectedStatus)
|
||||||
|
io.WriteString(w, expectedBody)
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
|
||||||
|
if err == nil {
|
||||||
|
ws.Close()
|
||||||
|
t.Fatalf("Dial: nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp == nil {
|
||||||
|
t.Fatalf("resp=nil, err=%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != expectedStatus {
|
||||||
|
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(p) != expectedBody {
|
||||||
|
t.Errorf("resp.Body=%s, want %s", p, expectedBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue