mirror of https://github.com/gorilla/websocket.git
Improve client host header handling.
- Set request host header to substring of the URL. Do not add default port to string. - Do not include port when verifying TLS host name.
This commit is contained in:
parent
1e6e1281b0
commit
db7a2a1679
70
client.go
70
client.go
|
@ -96,7 +96,9 @@ type Dialer struct {
|
||||||
|
|
||||||
var errMalformedURL = errors.New("malformed ws or wss URL")
|
var errMalformedURL = errors.New("malformed ws or wss URL")
|
||||||
|
|
||||||
func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
|
// parseURL parses the URL. The url.Parse function is not used here because
|
||||||
|
// url.Parse mangles the path.
|
||||||
|
func parseURL(s string) (*url.URL, error) {
|
||||||
// From the RFC:
|
// From the RFC:
|
||||||
//
|
//
|
||||||
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
|
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
|
||||||
|
@ -106,33 +108,41 @@ func parseURL(u string) (useTLS bool, host, port, opaque string, err error) {
|
||||||
// not provide a way for applications to work around percent deocding in
|
// not provide a way for applications to work around percent deocding in
|
||||||
// the net/url parser.
|
// the net/url parser.
|
||||||
|
|
||||||
|
var u url.URL
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(u, "ws://"):
|
case strings.HasPrefix(s, "ws://"):
|
||||||
u = u[len("ws://"):]
|
u.Scheme = "ws"
|
||||||
case strings.HasPrefix(u, "wss://"):
|
s = s[len("ws://"):]
|
||||||
u = u[len("wss://"):]
|
case strings.HasPrefix(s, "wss://"):
|
||||||
useTLS = true
|
u.Scheme = "wss"
|
||||||
|
s = s[len("wss://"):]
|
||||||
default:
|
default:
|
||||||
return false, "", "", "", errMalformedURL
|
return nil, errMalformedURL
|
||||||
}
|
}
|
||||||
|
|
||||||
hostPort := u
|
u.Host = s
|
||||||
opaque = "/"
|
u.Opaque = "/"
|
||||||
if i := strings.Index(u, "/"); i >= 0 {
|
if i := strings.Index(s, "/"); i >= 0 {
|
||||||
hostPort = u[:i]
|
u.Host = s[:i]
|
||||||
opaque = u[i:]
|
u.Opaque = s[i:]
|
||||||
}
|
}
|
||||||
|
|
||||||
host = hostPort
|
return &u, nil
|
||||||
port = ":80"
|
}
|
||||||
if i := strings.LastIndex(hostPort, ":"); i > strings.LastIndex(hostPort, "]") {
|
|
||||||
host = hostPort[:i]
|
|
||||||
port = hostPort[i:]
|
|
||||||
} else if useTLS {
|
|
||||||
port = ":443"
|
|
||||||
}
|
|
||||||
|
|
||||||
return useTLS, host, port, opaque, nil
|
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
|
||||||
|
hostPort = u.Host
|
||||||
|
hostNoPort = u.Host
|
||||||
|
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
|
||||||
|
hostNoPort = hostNoPort[:i]
|
||||||
|
} else {
|
||||||
|
if u.Scheme == "wss" {
|
||||||
|
hostPort += ":443"
|
||||||
|
} else {
|
||||||
|
hostPort += ":80"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return hostPort, hostNoPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultDialer is a dialer with all fields set to the default zero values.
|
// DefaultDialer is a dialer with all fields set to the default zero values.
|
||||||
|
@ -147,12 +157,13 @@ var DefaultDialer *Dialer
|
||||||
// non-nil *http.Response so that callers can handle redirects, authentication,
|
// non-nil *http.Response so that callers can handle redirects, authentication,
|
||||||
// etc.
|
// etc.
|
||||||
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
||||||
|
u, err := parseURL(urlStr)
|
||||||
useTLS, host, port, opaque, err := parseURL(urlStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hostPort, hostNoPort := hostPortNoPort(u)
|
||||||
|
|
||||||
if d == nil {
|
if d == nil {
|
||||||
d = &Dialer{}
|
d = &Dialer{}
|
||||||
}
|
}
|
||||||
|
@ -168,7 +179,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
netDial = netDialer.Dial
|
netDial = netDialer.Dial
|
||||||
}
|
}
|
||||||
|
|
||||||
netConn, err := netDial("tcp", host+port)
|
netConn, err := netDial("tcp", hostPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -183,14 +194,14 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if useTLS {
|
if u.Scheme == "wss" {
|
||||||
cfg := d.TLSClientConfig
|
cfg := d.TLSClientConfig
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
cfg = &tls.Config{ServerName: host}
|
cfg = &tls.Config{ServerName: hostNoPort}
|
||||||
} else if cfg.ServerName == "" {
|
} else if cfg.ServerName == "" {
|
||||||
shallowCopy := *cfg
|
shallowCopy := *cfg
|
||||||
cfg = &shallowCopy
|
cfg = &shallowCopy
|
||||||
cfg.ServerName = host
|
cfg.ServerName = hostNoPort
|
||||||
}
|
}
|
||||||
tlsConn := tls.Client(netConn, cfg)
|
tlsConn := tls.Client(netConn, cfg)
|
||||||
netConn = tlsConn
|
netConn = tlsConn
|
||||||
|
@ -223,10 +234,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
requestHeader = h
|
requestHeader = h
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, resp, err := NewClient(
|
conn, resp, err := NewClient(netConn, u, requestHeader, readBufferSize, writeBufferSize)
|
||||||
netConn,
|
|
||||||
&url.URL{Host: host + port, Opaque: opaque},
|
|
||||||
requestHeader, readBufferSize, writeBufferSize)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, resp, err
|
return nil, resp, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
// Copyright 2014 The Gorilla WebSocket Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var parseURLTests = []struct {
|
||||||
|
s string
|
||||||
|
u *url.URL
|
||||||
|
}{
|
||||||
|
{"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
|
||||||
|
{"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}},
|
||||||
|
{"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}},
|
||||||
|
{"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}},
|
||||||
|
{"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}},
|
||||||
|
{"ss://example.com/a/b", nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseURL(t *testing.T) {
|
||||||
|
for _, tt := range parseURLTests {
|
||||||
|
u, err := parseURL(tt.s)
|
||||||
|
if tt.u != nil && err != nil {
|
||||||
|
t.Errorf("parseURL(%q) returned error %v", tt.s, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if tt.u == nil && err == nil {
|
||||||
|
t.Errorf("parseURL(%q) did not return error", tt.s)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(u, tt.u) {
|
||||||
|
t.Errorf("parseURL(%q) returned %v, want %v", tt.s, u, tt.u)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var hostPortNoPortTests = []struct {
|
||||||
|
u *url.URL
|
||||||
|
hostPort, hostNoPort string
|
||||||
|
}{
|
||||||
|
{&url.URL{Scheme: "ws", Host: "example.com"}, "example.com:80", "example.com"},
|
||||||
|
{&url.URL{Scheme: "wss", Host: "example.com"}, "example.com:443", "example.com"},
|
||||||
|
{&url.URL{Scheme: "ws", Host: "example.com:7777"}, "example.com:7777", "example.com"},
|
||||||
|
{&url.URL{Scheme: "wss", Host: "example.com:7777"}, "example.com:7777", "example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHostPortNoPort(t *testing.T) {
|
||||||
|
for _, tt := range hostPortNoPortTests {
|
||||||
|
hostPort, hostNoPort := hostPortNoPort(tt.u)
|
||||||
|
if hostPort != tt.hostPort {
|
||||||
|
t.Errorf("hostPortNoPort(%v) returned hostPort %q, want %q", tt.u, hostPort, tt.hostPort)
|
||||||
|
}
|
||||||
|
if hostNoPort != tt.hostNoPort {
|
||||||
|
t.Errorf("hostPortNoPort(%v) returned hostNoPort %q, want %q", tt.u, hostNoPort, tt.hostNoPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue