diff --git a/client.go b/client.go index c25d24f..ebca8ed 100644 --- a/client.go +++ b/client.go @@ -5,8 +5,11 @@ package websocket import ( + "bytes" "crypto/tls" "errors" + "io" + "io/ioutil" "net" "net/http" "net/url" @@ -155,7 +158,8 @@ var DefaultDialer *Dialer // // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, -// etc. +// etcetera. The response body may not contain the entire response and does not +// need to be closed by the application. func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { u, err := parseURL(urlStr) if err != nil { @@ -225,7 +229,16 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize) + if err != nil { + if err == ErrBadHandshake { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + } return nil, resp, err } diff --git a/client_server_test.go b/client_server_test.go index 8c608f6..38a1afc 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -8,11 +8,13 @@ import ( "crypto/tls" "crypto/x509" "io" + "io/ioutil" "net" "net/http" "net/http/httptest" "net/url" "reflect" + "strings" "testing" "time" ) @@ -34,22 +36,22 @@ var cstDialer = Dialer{ type cstHandler struct{ *testing.T } -type Server struct { +type cstServer struct { *httptest.Server URL string } -func newServer(t *testing.T) *Server { - var s Server +func newServer(t *testing.T) *cstServer { + var s cstServer s.Server = httptest.NewServer(cstHandler{t}) - s.URL = "ws" + s.Server.URL[len("http"):] + s.URL = makeWsProto(s.Server.URL) return &s } -func newTLSServer(t *testing.T) *Server { - var s Server +func newTLSServer(t *testing.T) *cstServer { + var s cstServer s.Server = httptest.NewTLSServer(cstHandler{t}) - s.URL = "ws" + s.Server.URL[len("http"):] + s.URL = makeWsProto(s.Server.URL) return &s } @@ -97,6 +99,10 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func makeWsProto(s string) string { + return "ws" + strings.TrimPrefix(s, "http") +} + func sendRecv(t *testing.T, ws *Conn) { const message = "Hello World!" if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { @@ -157,6 +163,7 @@ func TestDialTLS(t *testing.T) { } func xTestDialTLSBadCert(t *testing.T) { + // This test is deactivated because of noisy logging from the net/http package. s := newTLSServer(t) defer s.Close() @@ -247,3 +254,37 @@ func TestHandshake(t *testing.T) { } sendRecv(t, ws) } + +func TestRespOnBadHandshake(t *testing.T) { + const expectedStatus = http.StatusGone + const expectedBody = "This is the response body." + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(expectedStatus) + io.WriteString(w, expectedBody) + })) + defer s.Close() + + ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } + + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + + if resp.StatusCode != expectedStatus { + t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) + } + + p, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadFull(resp.Body) returned error %v", err) + } + + if string(p) != expectedBody { + t.Errorf("resp.Body=%s, want %s", p, expectedBody) + } +}