Improve client host header handling.

- Set request host header to substring of the URL. Do not add default
  port to string.
- Do not include port when verifying TLS host name.
This commit is contained in:
Gary Burd 2014-05-08 11:21:56 -07:00
parent 1e6e1281b0
commit db7a2a1679
2 changed files with 102 additions and 31 deletions

View File

@ -96,7 +96,9 @@ type Dialer struct {
var errMalformedURL = errors.New("malformed ws or wss URL") var errMalformedURL = errors.New("malformed ws or wss URL")
func parseURL(u string) (useTLS bool, host, port, opaque string, err error) { // parseURL parses the URL. The url.Parse function is not used here because
// url.Parse mangles the path.
func parseURL(s string) (*url.URL, error) {
// From the RFC: // From the RFC:
// //
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
@ -106,33 +108,41 @@ func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
// not provide a way for applications to work around percent deocding in // not provide a way for applications to work around percent deocding in
// the net/url parser. // the net/url parser.
var u url.URL
switch { switch {
case strings.HasPrefix(u, "ws://"): case strings.HasPrefix(s, "ws://"):
u = u[len("ws://"):] u.Scheme = "ws"
case strings.HasPrefix(u, "wss://"): s = s[len("ws://"):]
u = u[len("wss://"):] case strings.HasPrefix(s, "wss://"):
useTLS = true u.Scheme = "wss"
s = s[len("wss://"):]
default: default:
return false, "", "", "", errMalformedURL return nil, errMalformedURL
} }
hostPort := u u.Host = s
opaque = "/" u.Opaque = "/"
if i := strings.Index(u, "/"); i >= 0 { if i := strings.Index(s, "/"); i >= 0 {
hostPort = u[:i] u.Host = s[:i]
opaque = u[i:] u.Opaque = s[i:]
} }
host = hostPort return &u, nil
port = ":80" }
if i := strings.LastIndex(hostPort, ":"); i > strings.LastIndex(hostPort, "]") {
host = hostPort[:i]
port = hostPort[i:]
} else if useTLS {
port = ":443"
}
return useTLS, host, port, opaque, nil func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
hostNoPort = hostNoPort[:i]
} else {
if u.Scheme == "wss" {
hostPort += ":443"
} else {
hostPort += ":80"
}
}
return hostPort, hostNoPort
} }
// DefaultDialer is a dialer with all fields set to the default zero values. // DefaultDialer is a dialer with all fields set to the default zero values.
@ -147,12 +157,13 @@ var DefaultDialer *Dialer
// non-nil *http.Response so that callers can handle redirects, authentication, // non-nil *http.Response so that callers can handle redirects, authentication,
// etc. // etc.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
u, err := parseURL(urlStr)
useTLS, host, port, opaque, err := parseURL(urlStr)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
hostPort, hostNoPort := hostPortNoPort(u)
if d == nil { if d == nil {
d = &Dialer{} d = &Dialer{}
} }
@ -168,7 +179,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
netDial = netDialer.Dial netDial = netDialer.Dial
} }
netConn, err := netDial("tcp", host+port) netConn, err := netDial("tcp", hostPort)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -183,14 +194,14 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
return nil, nil, err return nil, nil, err
} }
if useTLS { if u.Scheme == "wss" {
cfg := d.TLSClientConfig cfg := d.TLSClientConfig
if cfg == nil { if cfg == nil {
cfg = &tls.Config{ServerName: host} cfg = &tls.Config{ServerName: hostNoPort}
} else if cfg.ServerName == "" { } else if cfg.ServerName == "" {
shallowCopy := *cfg shallowCopy := *cfg
cfg = &shallowCopy cfg = &shallowCopy
cfg.ServerName = host cfg.ServerName = hostNoPort
} }
tlsConn := tls.Client(netConn, cfg) tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn netConn = tlsConn
@ -223,10 +234,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
requestHeader = h requestHeader = h
} }
conn, resp, err := NewClient( conn, resp, err := NewClient(netConn, u, requestHeader, readBufferSize, writeBufferSize)
netConn,
&url.URL{Host: host + port, Opaque: opaque},
requestHeader, readBufferSize, writeBufferSize)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }

63
client_test.go Normal file
View File

@ -0,0 +1,63 @@
// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"net/url"
"reflect"
"testing"
)
var parseURLTests = []struct {
s string
u *url.URL
}{
{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}},
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}},
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}},
{"ss://example.com/a/b", nil},
}
func TestParseURL(t *testing.T) {
for _, tt := range parseURLTests {
u, err := parseURL(tt.s)
if tt.u != nil && err != nil {
t.Errorf("parseURL(%q) returned error %v", tt.s, err)
continue
}
if tt.u == nil && err == nil {
t.Errorf("parseURL(%q) did not return error", tt.s)
continue
}
if !reflect.DeepEqual(u, tt.u) {
t.Errorf("parseURL(%q) returned %v, want %v", tt.s, u, tt.u)
continue
}
}
}
var hostPortNoPortTests = []struct {
u *url.URL
hostPort, hostNoPort string
}{
{&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"},
{&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"},
{&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"},
{&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"},
}
func TestHostPortNoPort(t *testing.T) {
for _, tt := range hostPortNoPortTests {
hostPort, hostNoPort := hostPortNoPort(tt.u)
if hostPort != tt.hostPort {
t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort)
}
if hostNoPort != tt.hostNoPort {
t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort)
}
}
}