diff --git a/client.go b/client.go index a5c6659..443e815 100644 --- a/client.go +++ b/client.go @@ -30,50 +30,17 @@ var ErrBadHandshake = errors.New("websocket: bad handshake") // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, // etc. +// +// Deprecated: Use Dialer instead. func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { - challengeKey, err := generateChallengeKey() - if err != nil { - return nil, nil, err + d := Dialer{ + ReadBufferSize: readBufSize, + WriteBufferSize: writeBufSize, + NetDial: func(net, addr string) (net.Conn, error) { + return netConn, nil + }, } - acceptKey := computeAcceptKey(challengeKey) - - c = newConn(netConn, false, readBufSize, writeBufSize) - p := c.writeBuf[:0] - p = append(p, "GET "...) - p = append(p, u.RequestURI()...) - p = append(p, " HTTP/1.1\r\nHost: "...) - p = append(p, u.Host...) - // "Upgrade" is capitalized for servers that do not use case insensitive - // comparisons on header tokens. - p = append(p, "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...) - p = append(p, challengeKey...) - p = append(p, "\r\n"...) - for k, vs := range requestHeader { - for _, v := range vs { - p = append(p, k...) - p = append(p, ": "...) - p = append(p, v...) - p = append(p, "\r\n"...) - } - } - p = append(p, "\r\n"...) - - if _, err := netConn.Write(p); err != nil { - return nil, nil, err - } - - resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u}) - if err != nil { - return nil, nil, err - } - if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || - resp.Header.Get("Sec-Websocket-Accept") != acceptKey { - return nil, resp, ErrBadHandshake - } - c.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") - return c, resp, nil + return d.Dial(u.String(), requestHeader) } // A Dialer contains options for connecting to WebSocket server. @@ -99,17 +66,15 @@ type Dialer struct { var errMalformedURL = errors.New("malformed ws or wss URL") -// parseURL parses the URL. The url.Parse function is not used here because -// url.Parse mangles the path. +// parseURL parses the URL. +// +// This function is a replacement for the standard library url.Parse function. +// In Go 1.4 and earlier, url.Parse loses information from the path. func parseURL(s string) (*url.URL, error) { // From the RFC: // // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] - // - // We don't use the net/url parser here because the dialer interface does - // not provide a way for applications to work around percent deocding in - // the net/url parser. var u url.URL switch { @@ -131,7 +96,8 @@ func parseURL(s string) (*url.URL, error) { } if strings.Contains(u.Host, "@") { - // WebSocket URIs do not contain user information. + // Don't bother parsing user information because user information is + // not allowed in websocket URIs. return nil, errMalformedURL } @@ -166,17 +132,68 @@ var DefaultDialer = &Dialer{} // etcetera. The response body may not contain the entire response and does not // need to be closed by the application. func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + + if d == nil { + d = &Dialer{} + } + + challengeKey, err := generateChallengeKey() + if err != nil { + return nil, nil, err + } + u, err := parseURL(urlStr) if err != nil { return nil, nil, err } - hostPort, hostNoPort := hostPortNoPort(u) - - if d == nil { - d = &Dialer{} + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, nil, errMalformedURL } + if u.User != nil { + // User name and password are not allowed in websocket URIs. + return nil, nil, errMalformedURL + } + + req := &http.Request{ + Method: "GET", + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if len(d.Subprotocols) > 0 { + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} + } + for k, vs := range requestHeader { + if k == "Host" { + if len(vs) > 0 { + req.Host = vs[0] + } + } else { + req.Header[k] = vs + } + } + + hostPort, hostNoPort := hostPortNoPort(u) + var deadline time.Time if d.HandshakeTimeout != 0 { deadline = time.Now().Add(d.HandshakeTimeout) @@ -203,7 +220,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, nil, err } - if u.Scheme == "wss" { + if u.Scheme == "https" { cfg := d.TLSClientConfig if cfg == nil { cfg = &tls.Config{ServerName: hostNoPort} @@ -224,44 +241,32 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } } - if len(d.Subprotocols) > 0 { - h := http.Header{} - for k, v := range requestHeader { - h[k] = v - } - h.Set("Sec-Websocket-Protocol", strings.Join(d.Subprotocols, ", ")) - requestHeader = h + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) + + if err := req.Write(netConn); err != nil { + return nil, nil, err } - if len(requestHeader["Host"]) > 0 { - // This can be used to supply a Host: header which is different from - // the dial address. - u.Host = requestHeader.Get("Host") - - // Drop "Host" header - h := http.Header{} - for k, v := range requestHeader { - if k == "Host" { - continue - } - h[k] = v - } - requestHeader = h - } - - conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize) - + resp, err := http.ReadResponse(conn.br, req) if err != nil { - if err == ErrBadHandshake { - // Before closing the network connection on return from this - // function, slurp up some of the response to aid application - // debugging. - buf := make([]byte, 1024) - n, _ := io.ReadFull(resp.Body, buf) - resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) - } - return nil, resp, err + return nil, nil, err } + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, ErrBadHandshake + } else { + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + } + + conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") netConn.SetDeadline(time.Time{}) netConn = nil // to avoid close in defer. diff --git a/client_server_test.go b/client_server_test.go index 749ef20..b6cb89a 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -289,8 +289,8 @@ func TestRespOnBadHandshake(t *testing.T) { } } -// If the Host header is specified in `Dial()`, the server must receive it as -// the `Host:` header. +// TestHostHeader confirms that the host header provided in the call to Dial is +// sent to the server. func TestHostHeader(t *testing.T) { s := newServer(t) defer s.Close() @@ -305,16 +305,12 @@ func TestHostHeader(t *testing.T) { origHandler.ServeHTTP(w, r) }) - ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) + ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) if err != nil { t.Fatalf("Dial: %v", err) } defer ws.Close() - if resp.StatusCode != http.StatusSwitchingProtocols { - t.Fatalf("resp.StatusCode = %v, want http.StatusSwitchingProtocols", resp.StatusCode) - } - if gotHost := <-specifiedHost; gotHost != "testhost" { t.Fatalf("gotHost = %q, want \"testhost\"", gotHost) }