From db7a2a1679028d026a418babdf072d8765f12898 Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Thu, 8 May 2014 11:21:56 -0700 Subject: [PATCH] Improve client host header handling. - Set request host header to substring of the URL. Do not add default port to string. - Do not include port when verifying TLS host name. --- client.go | 70 ++++++++++++++++++++++++++++---------------------- client_test.go | 63 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 31 deletions(-) create mode 100644 client_test.go diff --git a/client.go b/client.go index f20fe83..3b5cac4 100644 --- a/client.go +++ b/client.go @@ -96,7 +96,9 @@ type Dialer struct { var errMalformedURL = errors.New("malformed ws or wss URL") -func parseURL(u string) (useTLS bool, host, port, opaque string, err error) { +// parseURL parses the URL. The url.Parse function is not used here because +// url.Parse mangles the path. +func parseURL(s string) (*url.URL, error) { // From the RFC: // // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] @@ -106,33 +108,41 @@ func parseURL(u string) (useTLS bool, host, port, opaque string, err error) { // not provide a way for applications to work around percent deocding in // the net/url parser. + var u url.URL switch { - case strings.HasPrefix(u, "ws://"): - u = u[len("ws://"):] - case strings.HasPrefix(u, "wss://"): - u = u[len("wss://"):] - useTLS = true + case strings.HasPrefix(s, "ws://"): + u.Scheme = "ws" + s = s[len("ws://"):] + case strings.HasPrefix(s, "wss://"): + u.Scheme = "wss" + s = s[len("wss://"):] default: - return false, "", "", "", errMalformedURL + return nil, errMalformedURL } - hostPort := u - opaque = "/" - if i := strings.Index(u, "/"); i >= 0 { - hostPort = u[:i] - opaque = u[i:] + u.Host = s + u.Opaque = "/" + if i := strings.Index(s, "/"); i >= 0 { + u.Host = s[:i] + u.Opaque = s[i:] } - host = hostPort - port = ":80" - if i := strings.LastIndex(hostPort, ":"); i > strings.LastIndex(hostPort, "]") { - host = hostPort[:i] - port = hostPort[i:] - } else if useTLS { - port = ":443" - } + return &u, nil +} - return useTLS, host, port, opaque, nil +func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { + hostPort = u.Host + hostNoPort = u.Host + if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { + hostNoPort = hostNoPort[:i] + } else { + if u.Scheme == "wss" { + hostPort += ":443" + } else { + hostPort += ":80" + } + } + return hostPort, hostNoPort } // DefaultDialer is a dialer with all fields set to the default zero values. @@ -147,12 +157,13 @@ var DefaultDialer *Dialer // non-nil *http.Response so that callers can handle redirects, authentication, // etc. func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { - - useTLS, host, port, opaque, err := parseURL(urlStr) + u, err := parseURL(urlStr) if err != nil { return nil, nil, err } + hostPort, hostNoPort := hostPortNoPort(u) + if d == nil { d = &Dialer{} } @@ -168,7 +179,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re netDial = netDialer.Dial } - netConn, err := netDial("tcp", host+port) + netConn, err := netDial("tcp", hostPort) if err != nil { return nil, nil, err } @@ -183,14 +194,14 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, nil, err } - if useTLS { + if u.Scheme == "wss" { cfg := d.TLSClientConfig if cfg == nil { - cfg = &tls.Config{ServerName: host} + cfg = &tls.Config{ServerName: hostNoPort} } else if cfg.ServerName == "" { shallowCopy := *cfg cfg = &shallowCopy - cfg.ServerName = host + cfg.ServerName = hostNoPort } tlsConn := tls.Client(netConn, cfg) netConn = tlsConn @@ -223,10 +234,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re requestHeader = h } - conn, resp, err := NewClient( - netConn, - &url.URL{Host: host + port, Opaque: opaque}, - requestHeader, readBufferSize, writeBufferSize) + conn, resp, err := NewClient(netConn, u, requestHeader, readBufferSize, writeBufferSize) if err != nil { return nil, resp, err } diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..d2f2ebd --- /dev/null +++ b/client_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "net/url" + "reflect" + "testing" +) + +var parseURLTests = []struct { + s string + u *url.URL +}{ + {"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}}, + {"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}}, + {"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}}, + {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}}, + {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}}, + {"ss://example.com/a/b", nil}, +} + +func TestParseURL(t *testing.T) { + for _, tt := range parseURLTests { + u, err := parseURL(tt.s) + if tt.u != nil && err != nil { + t.Errorf("parseURL(%q) returned error %v", tt.s, err) + continue + } + if tt.u == nil && err == nil { + t.Errorf("parseURL(%q) did not return error", tt.s) + continue + } + if !reflect.DeepEqual(u, tt.u) { + t.Errorf("parseURL(%q) returned %v, want %v", tt.s, u, tt.u) + continue + } + } +} + +var hostPortNoPortTests = []struct { + u *url.URL + hostPort, hostNoPort string +}{ + {&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"}, + {&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"}, + {&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"}, + {&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"}, +} + +func TestHostPortNoPort(t *testing.T) { + for _, tt := range hostPortNoPortTests { + hostPort, hostNoPort := hostPortNoPort(tt.u) + if hostPort != tt.hostPort { + t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort) + } + if hostNoPort != tt.hostNoPort { + t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort) + } + } +}