diff --git a/client.go b/client.go index c25d24f..93db8dd 100644 --- a/client.go +++ b/client.go @@ -5,8 +5,11 @@ package websocket import ( + "bytes" "crypto/tls" "errors" + "io" + "io/ioutil" "net" "net/http" "net/url" @@ -127,6 +130,11 @@ func parseURL(s string) (*url.URL, error) { u.Opaque = s[i:] } + if strings.Contains(u.Host, "@") { + // WebSocket URIs do not contain user information. + return nil, errMalformedURL + } + return &u, nil } @@ -155,7 +163,8 @@ var DefaultDialer *Dialer // // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, -// etc. +// etcetera. The response body may not contain the entire response and does not +// need to be closed by the application. func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { u, err := parseURL(urlStr) if err != nil { @@ -224,8 +233,33 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re requestHeader = h } + if len(requestHeader["Host"]) > 0 { + // This can be used to supply a Host: header which is different from + // the dial address. + u.Host = requestHeader.Get("Host") + + // Drop "Host" header + h := http.Header{} + for k, v := range requestHeader { + if k == "Host" { + continue + } + h[k] = v + } + requestHeader = h + } + conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize) + if err != nil { + if err == ErrBadHandshake { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + } return nil, resp, err } diff --git a/client_server_test.go b/client_server_test.go index 8c608f6..749ef20 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -8,11 +8,13 @@ import ( "crypto/tls" "crypto/x509" "io" + "io/ioutil" "net" "net/http" "net/http/httptest" "net/url" "reflect" + "strings" "testing" "time" ) @@ -34,22 +36,22 @@ var cstDialer = Dialer{ type cstHandler struct{ *testing.T } -type Server struct { +type cstServer struct { *httptest.Server URL string } -func newServer(t *testing.T) *Server { - var s Server +func newServer(t *testing.T) *cstServer { + var s cstServer s.Server = httptest.NewServer(cstHandler{t}) - s.URL = "ws" + s.Server.URL[len("http"):] + s.URL = makeWsProto(s.Server.URL) return &s } -func newTLSServer(t *testing.T) *Server { - var s Server +func newTLSServer(t *testing.T) *cstServer { + var s cstServer s.Server = httptest.NewTLSServer(cstHandler{t}) - s.URL = "ws" + s.Server.URL[len("http"):] + s.URL = makeWsProto(s.Server.URL) return &s } @@ -97,6 +99,10 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func makeWsProto(s string) string { + return "ws" + strings.TrimPrefix(s, "http") +} + func sendRecv(t *testing.T, ws *Conn) { const message = "Hello World!" if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { @@ -157,6 +163,7 @@ func TestDialTLS(t *testing.T) { } func xTestDialTLSBadCert(t *testing.T) { + // This test is deactivated because of noisy logging from the net/http package. s := newTLSServer(t) defer s.Close() @@ -247,3 +254,70 @@ func TestHandshake(t *testing.T) { } sendRecv(t, ws) } + +func TestRespOnBadHandshake(t *testing.T) { + const expectedStatus = http.StatusGone + const expectedBody = "This is the response body." + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(expectedStatus) + io.WriteString(w, expectedBody) + })) + defer s.Close() + + ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } + + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + + if resp.StatusCode != expectedStatus { + t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) + } + + p, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadFull(resp.Body) returned error %v", err) + } + + if string(p) != expectedBody { + t.Errorf("resp.Body=%s, want %s", p, expectedBody) + } +} + +// If the Host header is specified in `Dial()`, the server must receive it as +// the `Host:` header. +func TestHostHeader(t *testing.T) { + s := newServer(t) + defer s.Close() + + specifiedHost := make(chan string, 1) + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + specifiedHost <- r.Host + origHandler.ServeHTTP(w, r) + }) + + ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + if resp.StatusCode != http.StatusSwitchingProtocols { + t.Fatalf("resp.StatusCode = %v, want http.StatusSwitchingProtocols", resp.StatusCode) + } + + if gotHost := <-specifiedHost; gotHost != "testhost" { + t.Fatalf("gotHost = %q, want \"testhost\"", gotHost) + } + + sendRecv(t, ws) +} diff --git a/client_test.go b/client_test.go index d2f2ebd..07a9cb4 100644 --- a/client_test.go +++ b/client_test.go @@ -20,6 +20,7 @@ var parseURLTests = []struct { {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}}, {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}}, {"ss://example.com/a/b", nil}, + {"ws://webmaster@example.com/", nil}, } func TestParseURL(t *testing.T) { diff --git a/conn.go b/conn.go index 86c35e5..e719f1c 100644 --- a/conn.go +++ b/conn.go @@ -801,7 +801,7 @@ func (c *Conn) SetPingHandler(h func(string) error) { c.handlePing = h } -// SetPongHandler sets then handler for pong messages received from the peer. +// SetPongHandler sets the handler for pong messages received from the peer. // The default pong handler does nothing. func (c *Conn) SetPongHandler(h func(string) error) { if h == nil { diff --git a/doc.go b/doc.go index 0d2bd91..f52925d 100644 --- a/doc.go +++ b/doc.go @@ -24,7 +24,7 @@ // ... Use conn to send and receive messages. // } // -// Call the connection WriteMessage and ReadMessages methods to send and +// Call the connection's WriteMessage and ReadMessage methods to send and // receive messages as a slice of bytes. This snippet of code shows how to echo // messages using these methods: // diff --git a/json.go b/json.go index e0668f2..18e62f2 100644 --- a/json.go +++ b/json.go @@ -6,6 +6,7 @@ package websocket import ( "encoding/json" + "io" ) // WriteJSON is deprecated, use c.WriteJSON instead. @@ -45,5 +46,12 @@ func (c *Conn) ReadJSON(v interface{}) error { if err != nil { return err } - return json.NewDecoder(r).Decode(v) + err = json.NewDecoder(r).Decode(v) + if err == io.EOF { + // Decode returns io.EOF when the message is empty or all whitespace. + // Convert to io.ErrUnexpectedEOF so that application can distinguish + // between an error reading the JSON value and the connection closing. + err = io.ErrUnexpectedEOF + } + return err } diff --git a/json_test.go b/json_test.go index 2edb28d..1b7a5ec 100644 --- a/json_test.go +++ b/json_test.go @@ -6,6 +6,8 @@ package websocket import ( "bytes" + "encoding/json" + "io" "reflect" "testing" ) @@ -36,6 +38,60 @@ func TestJSON(t *testing.T) { } } +func TestPartialJsonRead(t *testing.T) { + var buf bytes.Buffer + c := fakeNetConn{&buf, &buf} + wc := newConn(c, true, 1024, 1024) + rc := newConn(c, false, 1024, 1024) + + var v struct { + A int + B string + } + v.A = 1 + v.B = "hello" + + messageCount := 0 + + // Partial JSON values. + + data, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + for i := len(data) - 1; i >= 0; i-- { + if err := wc.WriteMessage(TextMessage, data[:i]); err != nil { + t.Fatal(err) + } + messageCount++ + } + + // Whitespace. + + if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil { + t.Fatal(err) + } + messageCount++ + + // Close. + + if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil { + t.Fatal(err) + } + + for i := 0; i < messageCount; i++ { + err := rc.ReadJSON(&v) + if err != io.ErrUnexpectedEOF { + t.Error("read", i, err) + } + } + + err = rc.ReadJSON(&v) + if err != io.EOF { + t.Error("final", err) + } +} + func TestDeprecatedJSON(t *testing.T) { var buf bytes.Buffer c := fakeNetConn{&buf, &buf}