Improve client host header handling.

- Set request host header to substring of the URL. Do not add default
  port to string.
- Do not include port when verifying TLS host name.
This commit is contained in:
Gary Burd 2014-05-08 11:21:56 -07:00
parent 1e6e1281b0
commit db7a2a1679
2 changed files with 102 additions and 31 deletions

View File

@ -96,7 +96,9 @@ type Dialer struct {
var errMalformedURL = errors.New("malformed ws or wss URL")
func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
// parseURL parses the URL. The url.Parse function is not used here because
// url.Parse mangles the path.
func parseURL(s string) (*url.URL, error) {
// From the RFC:
//
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
@ -106,33 +108,41 @@ func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
// not provide a way for applications to work around percent deocding in
// the net/url parser.
var u url.URL
switch {
case strings.HasPrefix(u, "ws://"):
u = u[len("ws://"):]
case strings.HasPrefix(u, "wss://"):
u = u[len("wss://"):]
useTLS = true
case strings.HasPrefix(s, "ws://"):
u.Scheme = "ws"
s = s[len("ws://"):]
case strings.HasPrefix(s, "wss://"):
u.Scheme = "wss"
s = s[len("wss://"):]
default:
return false, "", "", "", errMalformedURL
return nil, errMalformedURL
}
hostPort := u
opaque = "/"
if i := strings.Index(u, "/"); i >= 0 {
hostPort = u[:i]
opaque = u[i:]
u.Host = s
u.Opaque = "/"
if i := strings.Index(s, "/"); i >= 0 {
u.Host = s[:i]
u.Opaque = s[i:]
}
host = hostPort
port = ":80"
if i := strings.LastIndex(hostPort, ":"); i > strings.LastIndex(hostPort, "]") {
host = hostPort[:i]
port = hostPort[i:]
} else if useTLS {
port = ":443"
return &u, nil
}
return useTLS, host, port, opaque, nil
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
hostNoPort = hostNoPort[:i]
} else {
if u.Scheme == "wss" {
hostPort += ":443"
} else {
hostPort += ":80"
}
}
return hostPort, hostNoPort
}
// DefaultDialer is a dialer with all fields set to the default zero values.
@ -147,12 +157,13 @@ var DefaultDialer *Dialer
// non-nil *http.Response so that callers can handle redirects, authentication,
// etc.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
useTLS, host, port, opaque, err := parseURL(urlStr)
u, err := parseURL(urlStr)
if err != nil {
return nil, nil, err
}
hostPort, hostNoPort := hostPortNoPort(u)
if d == nil {
d = &Dialer{}
}
@ -168,7 +179,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
netDial = netDialer.Dial
}
netConn, err := netDial("tcp", host+port)
netConn, err := netDial("tcp", hostPort)
if err != nil {
return nil, nil, err
}
@ -183,14 +194,14 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
return nil, nil, err
}
if useTLS {
if u.Scheme == "wss" {
cfg := d.TLSClientConfig
if cfg == nil {
cfg = &tls.Config{ServerName: host}
cfg = &tls.Config{ServerName: hostNoPort}
} else if cfg.ServerName == "" {
shallowCopy := *cfg
cfg = &shallowCopy
cfg.ServerName = host
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
@ -223,10 +234,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
requestHeader = h
}
conn, resp, err := NewClient(
netConn,
&url.URL{Host: host + port, Opaque: opaque},
requestHeader, readBufferSize, writeBufferSize)
conn, resp, err := NewClient(netConn, u, requestHeader, readBufferSize, writeBufferSize)
if err != nil {
return nil, resp, err
}

63
client_test.go Normal file
View File

@ -0,0 +1,63 @@
// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"net/url"
"reflect"
"testing"
)
var parseURLTests = []struct {
s string
u *url.URL
}{
{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}},
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}},
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}},
{"ss://example.com/a/b", nil},
}
func TestParseURL(t *testing.T) {
for _, tt := range parseURLTests {
u, err := parseURL(tt.s)
if tt.u != nil && err != nil {
t.Errorf("parseURL(%q) returned error %v", tt.s, err)
continue
}
if tt.u == nil && err == nil {
t.Errorf("parseURL(%q) did not return error", tt.s)
continue
}
if !reflect.DeepEqual(u, tt.u) {
t.Errorf("parseURL(%q) returned %v, want %v", tt.s, u, tt.u)
continue
}
}
}
var hostPortNoPortTests = []struct {
u *url.URL
hostPort, hostNoPort string
}{
{&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"},
{&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"},
{&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"},
{&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"},
}
func TestHostPortNoPort(t *testing.T) {
for _, tt := range hostPortNoPortTests {
hostPort, hostNoPort := hostPortNoPort(tt.u)
if hostPort != tt.hostPort {
t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort)
}
if hostNoPort != tt.hostNoPort {
t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort)
}
}
}