mirror of https://github.com/gorilla/websocket.git
Improve client host header handling.
- Set request host header to substring of the URL. Do not add default port to string. - Do not include port when verifying TLS host name.
This commit is contained in:
parent
1e6e1281b0
commit
db7a2a1679
68
client.go
68
client.go
|
@ -96,7 +96,9 @@ type Dialer struct {
|
|||
|
||||
var errMalformedURL = errors.New("malformed ws or wss URL")
|
||||
|
||||
func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
|
||||
// parseURL parses the URL. The url.Parse function is not used here because
|
||||
// url.Parse mangles the path.
|
||||
func parseURL(s string) (*url.URL, error) {
|
||||
// From the RFC:
|
||||
//
|
||||
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
|
||||
|
@ -106,33 +108,41 @@ func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
|
|||
// not provide a way for applications to work around percent deocding in
|
||||
// the net/url parser.
|
||||
|
||||
var u url.URL
|
||||
switch {
|
||||
case strings.HasPrefix(u, "ws://"):
|
||||
u = u[len("ws://"):]
|
||||
case strings.HasPrefix(u, "wss://"):
|
||||
u = u[len("wss://"):]
|
||||
useTLS = true
|
||||
case strings.HasPrefix(s, "ws://"):
|
||||
u.Scheme = "ws"
|
||||
s = s[len("ws://"):]
|
||||
case strings.HasPrefix(s, "wss://"):
|
||||
u.Scheme = "wss"
|
||||
s = s[len("wss://"):]
|
||||
default:
|
||||
return false, "", "", "", errMalformedURL
|
||||
return nil, errMalformedURL
|
||||
}
|
||||
|
||||
hostPort := u
|
||||
opaque = "/"
|
||||
if i := strings.Index(u, "/"); i >= 0 {
|
||||
hostPort = u[:i]
|
||||
opaque = u[i:]
|
||||
u.Host = s
|
||||
u.Opaque = "/"
|
||||
if i := strings.Index(s, "/"); i >= 0 {
|
||||
u.Host = s[:i]
|
||||
u.Opaque = s[i:]
|
||||
}
|
||||
|
||||
host = hostPort
|
||||
port = ":80"
|
||||
if i := strings.LastIndex(hostPort, ":"); i > strings.LastIndex(hostPort, "]") {
|
||||
host = hostPort[:i]
|
||||
port = hostPort[i:]
|
||||
} else if useTLS {
|
||||
port = ":443"
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
return useTLS, host, port, opaque, nil
|
||||
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
|
||||
hostPort = u.Host
|
||||
hostNoPort = u.Host
|
||||
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
|
||||
hostNoPort = hostNoPort[:i]
|
||||
} else {
|
||||
if u.Scheme == "wss" {
|
||||
hostPort += ":443"
|
||||
} else {
|
||||
hostPort += ":80"
|
||||
}
|
||||
}
|
||||
return hostPort, hostNoPort
|
||||
}
|
||||
|
||||
// DefaultDialer is a dialer with all fields set to the default zero values.
|
||||
|
@ -147,12 +157,13 @@ var DefaultDialer *Dialer
|
|||
// non-nil *http.Response so that callers can handle redirects, authentication,
|
||||
// etc.
|
||||
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
||||
|
||||
useTLS, host, port, opaque, err := parseURL(urlStr)
|
||||
u, err := parseURL(urlStr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
hostPort, hostNoPort := hostPortNoPort(u)
|
||||
|
||||
if d == nil {
|
||||
d = &Dialer{}
|
||||
}
|
||||
|
@ -168,7 +179,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
netDial = netDialer.Dial
|
||||
}
|
||||
|
||||
netConn, err := netDial("tcp", host+port)
|
||||
netConn, err := netDial("tcp", hostPort)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -183,14 +194,14 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
if useTLS {
|
||||
if u.Scheme == "wss" {
|
||||
cfg := d.TLSClientConfig
|
||||
if cfg == nil {
|
||||
cfg = &tls.Config{ServerName: host}
|
||||
cfg = &tls.Config{ServerName: hostNoPort}
|
||||
} else if cfg.ServerName == "" {
|
||||
shallowCopy := *cfg
|
||||
cfg = &shallowCopy
|
||||
cfg.ServerName = host
|
||||
cfg.ServerName = hostNoPort
|
||||
}
|
||||
tlsConn := tls.Client(netConn, cfg)
|
||||
netConn = tlsConn
|
||||
|
@ -223,10 +234,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
requestHeader = h
|
||||
}
|
||||
|
||||
conn, resp, err := NewClient(
|
||||
netConn,
|
||||
&url.URL{Host: host + port, Opaque: opaque},
|
||||
requestHeader, readBufferSize, writeBufferSize)
|
||||
conn, resp, err := NewClient(netConn, u, requestHeader, readBufferSize, writeBufferSize)
|
||||
if err != nil {
|
||||
return nil, resp, err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var parseURLTests = []struct {
|
||||
s string
|
||||
u *url.URL
|
||||
}{
|
||||
{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
|
||||
{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
|
||||
{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}},
|
||||
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}},
|
||||
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}},
|
||||
{"ss://example.com/a/b", nil},
|
||||
}
|
||||
|
||||
func TestParseURL(t *testing.T) {
|
||||
for _, tt := range parseURLTests {
|
||||
u, err := parseURL(tt.s)
|
||||
if tt.u != nil && err != nil {
|
||||
t.Errorf("parseURL(%q) returned error %v", tt.s, err)
|
||||
continue
|
||||
}
|
||||
if tt.u == nil && err == nil {
|
||||
t.Errorf("parseURL(%q) did not return error", tt.s)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(u, tt.u) {
|
||||
t.Errorf("parseURL(%q) returned %v, want %v", tt.s, u, tt.u)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var hostPortNoPortTests = []struct {
|
||||
u *url.URL
|
||||
hostPort, hostNoPort string
|
||||
}{
|
||||
{&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"},
|
||||
{&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"},
|
||||
{&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"},
|
||||
{&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"},
|
||||
}
|
||||
|
||||
func TestHostPortNoPort(t *testing.T) {
|
||||
for _, tt := range hostPortNoPortTests {
|
||||
hostPort, hostNoPort := hostPortNoPort(tt.u)
|
||||
if hostPort != tt.hostPort {
|
||||
t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort)
|
||||
}
|
||||
if hostNoPort != tt.hostNoPort {
|
||||
t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue