Return response body on bad handshake.

The Dialer.Dial method returns an *http.Response on a bad handshake.
This CL updates the Dial method to include up to 1024 bytes of the
response body in the returned *http.Response. Applications may find the
response body helpful when debugging bad handshakes.

Fixes issue #62.
This commit is contained in:
Gary Burd 2015-05-08 14:59:31 -07:00
parent ecff5aabe4
commit b2fa8f6d58
2 changed files with 62 additions and 8 deletions

View File

@ -5,8 +5,11 @@
package websocket package websocket
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"errors" "errors"
"io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -155,7 +158,8 @@ var DefaultDialer *Dialer
// //
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication, // non-nil *http.Response so that callers can handle redirects, authentication,
// etc. // etcetera. The response body may not contain the entire response and does not
// need to be closed by the application.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
u, err := parseURL(urlStr) u, err := parseURL(urlStr)
if err != nil { if err != nil {
@ -225,7 +229,16 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
} }
conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize) conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
if err != nil { if err != nil {
if err == ErrBadHandshake {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
// debugging.
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
}
return nil, resp, err return nil, resp, err
} }

View File

@ -8,11 +8,13 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
) )
@ -34,22 +36,22 @@ var cstDialer = Dialer{
type cstHandler struct{ *testing.T } type cstHandler struct{ *testing.T }
type Server struct { type cstServer struct {
*httptest.Server *httptest.Server
URL string URL string
} }
func newServer(t *testing.T) *Server { func newServer(t *testing.T) *cstServer {
var s Server var s cstServer
s.Server = httptest.NewServer(cstHandler{t}) s.Server = httptest.NewServer(cstHandler{t})
s.URL = "ws" + s.Server.URL[len("http"):] s.URL = makeWsProto(s.Server.URL)
return &s return &s
} }
func newTLSServer(t *testing.T) *Server { func newTLSServer(t *testing.T) *cstServer {
var s Server var s cstServer
s.Server = httptest.NewTLSServer(cstHandler{t}) s.Server = httptest.NewTLSServer(cstHandler{t})
s.URL = "ws" + s.Server.URL[len("http"):] s.URL = makeWsProto(s.Server.URL)
return &s return &s
} }
@ -97,6 +99,10 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
func makeWsProto(s string) string {
return "ws" + strings.TrimPrefix(s, "http")
}
func sendRecv(t *testing.T, ws *Conn) { func sendRecv(t *testing.T, ws *Conn) {
const message = "Hello World!" const message = "Hello World!"
if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
@ -157,6 +163,7 @@ func TestDialTLS(t *testing.T) {
} }
func xTestDialTLSBadCert(t *testing.T) { func xTestDialTLSBadCert(t *testing.T) {
// This test is deactivated because of noisy logging from the net/http package.
s := newTLSServer(t) s := newTLSServer(t)
defer s.Close() defer s.Close()
@ -247,3 +254,37 @@ func TestHandshake(t *testing.T) {
} }
sendRecv(t, ws) sendRecv(t, ws)
} }
func TestRespOnBadHandshake(t *testing.T) {
const expectedStatus = http.StatusGone
const expectedBody = "This is the response body."
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(expectedStatus)
io.WriteString(w, expectedBody)
}))
defer s.Close()
ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
if resp == nil {
t.Fatalf("resp=nil, err=%v", err)
}
if resp.StatusCode != expectedStatus {
t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
}
p, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ReadFull(resp.Body) returned error %v", err)
}
if string(p) != expectedBody {
t.Errorf("resp.Body=%s, want %s", p, expectedBody)
}
}