forked from mirror/websocket
Add comprehensive host test (#429)
Add table driven test for handling of host in request URL, request header and TLS server name. In addition to testing various uses of host names, this test also confirms that host names are handled the same as the net/http client. The new table driven test replaces TestDialTLS, TestDialTLSNoverify, TestDialTLSBadCert and TestHostHeader. Eliminate duplicated code for constructing root CA.
This commit is contained in:
parent
66b9c49e59
commit
cdd40f587d
|
@ -11,8 +11,10 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
|
@ -42,17 +44,12 @@ var cstDialer = Dialer{
|
|||
HandshakeTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
var cstDialerWithoutHandshakeTimeout = Dialer{
|
||||
Subprotocols: []string{"p1", "p2"},
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
type cstHandler struct{ *testing.T }
|
||||
|
||||
type cstServer struct {
|
||||
*httptest.Server
|
||||
URL string
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -288,10 +285,7 @@ func TestDialCookieJar(t *testing.T) {
|
|||
sendRecv(t, ws)
|
||||
}
|
||||
|
||||
func TestDialTLS(t *testing.T) {
|
||||
s := newTLSServer(t)
|
||||
defer s.Close()
|
||||
|
||||
func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
|
||||
certs := x509.NewCertPool()
|
||||
for _, c := range s.TLS.Certificates {
|
||||
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
|
||||
|
@ -302,35 +296,15 @@ func TestDialTLS(t *testing.T) {
|
|||
certs.AddCert(root)
|
||||
}
|
||||
}
|
||||
|
||||
d := cstDialer
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
||||
ws, _, err := d.Dial(s.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
sendRecv(t, ws)
|
||||
return certs
|
||||
}
|
||||
|
||||
func xTestDialTLSBadCert(t *testing.T) {
|
||||
// This test is deactivated because of noisy logging from the net/http package.
|
||||
s := newTLSServer(t)
|
||||
defer s.Close()
|
||||
|
||||
ws, _, err := cstDialer.Dial(s.URL, nil)
|
||||
if err == nil {
|
||||
ws.Close()
|
||||
t.Fatalf("Dial: nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialTLSNoVerify(t *testing.T) {
|
||||
func TestDialTLS(t *testing.T) {
|
||||
s := newTLSServer(t)
|
||||
defer s.Close()
|
||||
|
||||
d := cstDialer
|
||||
d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
||||
ws, _, err := d.Dial(s.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
|
@ -415,7 +389,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
|
|||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
d := cstDialerWithoutHandshakeTimeout
|
||||
d := cstDialer
|
||||
d.HandshakeTimeout = 0
|
||||
d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
|
||||
netDialer := &net.Dialer{}
|
||||
c, err := netDialer.DialContext(ctx, n, a)
|
||||
|
@ -566,33 +541,195 @@ func TestRespOnBadHandshake(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestHostHeader confirms that the host header provided in the call to Dial is
|
||||
// sent to the server.
|
||||
func TestHostHeader(t *testing.T) {
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
type testLogWriter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
specifiedHost := make(chan string, 1)
|
||||
origHandler := s.Server.Config.Handler
|
||||
func (w testLogWriter) Write(p []byte) (int, error) {
|
||||
w.t.Logf("%s", p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Capture the request Host header.
|
||||
s.Server.Config.Handler = http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
specifiedHost <- r.Host
|
||||
origHandler.ServeHTTP(w, r)
|
||||
})
|
||||
// TestHost tests handling of host names and confirms that it matches net/http.
|
||||
func TestHost(t *testing.T) {
|
||||
|
||||
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
upgrader := Upgrader{}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if IsWebSocketUpgrade(r) {
|
||||
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
} else {
|
||||
w.Header().Set("X-Test-Host", r.Host)
|
||||
}
|
||||
})
|
||||
|
||||
if gotHost := <-specifiedHost; gotHost != "testhost" {
|
||||
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
|
||||
tlsServer := httptest.NewTLSServer(handler)
|
||||
defer tlsServer.Close()
|
||||
|
||||
addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
|
||||
wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
|
||||
httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
|
||||
|
||||
// Avoid log noise from net/http server by logging to testing.T
|
||||
server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
|
||||
tlsServer.Config.ErrorLog = server.Config.ErrorLog
|
||||
|
||||
cas := rootCAs(t, tlsServer)
|
||||
|
||||
tests := []struct {
|
||||
fail bool // true if dial / get should fail
|
||||
server *httptest.Server // server to use
|
||||
url string // host for request URI
|
||||
header string // optional request host header
|
||||
tls string // optiona host for tls ServerName
|
||||
wantAddr string // expected host for dial
|
||||
wantHeader string // expected request header on server
|
||||
insecureSkipVerify bool
|
||||
}{
|
||||
{
|
||||
server: server,
|
||||
url: addrs[server],
|
||||
wantAddr: addrs[server],
|
||||
wantHeader: addrs[server],
|
||||
},
|
||||
{
|
||||
server: tlsServer,
|
||||
url: addrs[tlsServer],
|
||||
wantAddr: addrs[tlsServer],
|
||||
wantHeader: addrs[tlsServer],
|
||||
},
|
||||
|
||||
{
|
||||
server: server,
|
||||
url: addrs[server],
|
||||
header: "badhost.com",
|
||||
wantAddr: addrs[server],
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
{
|
||||
server: tlsServer,
|
||||
url: addrs[tlsServer],
|
||||
header: "badhost.com",
|
||||
wantAddr: addrs[tlsServer],
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
|
||||
{
|
||||
server: server,
|
||||
url: "example.com",
|
||||
header: "badhost.com",
|
||||
wantAddr: "example.com:80",
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
{
|
||||
server: tlsServer,
|
||||
url: "example.com",
|
||||
header: "badhost.com",
|
||||
wantAddr: "example.com:443",
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
|
||||
{
|
||||
server: server,
|
||||
url: "badhost.com",
|
||||
header: "example.com",
|
||||
wantAddr: "badhost.com:80",
|
||||
wantHeader: "example.com",
|
||||
},
|
||||
{
|
||||
fail: true,
|
||||
server: tlsServer,
|
||||
url: "badhost.com",
|
||||
header: "example.com",
|
||||
wantAddr: "badhost.com:443",
|
||||
},
|
||||
{
|
||||
server: tlsServer,
|
||||
url: "badhost.com",
|
||||
insecureSkipVerify: true,
|
||||
wantAddr: "badhost.com:443",
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
{
|
||||
server: tlsServer,
|
||||
url: "badhost.com",
|
||||
tls: "example.com",
|
||||
wantAddr: "badhost.com:443",
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
}
|
||||
|
||||
sendRecv(t, ws)
|
||||
for i, tt := range tests {
|
||||
|
||||
tls := &tls.Config{
|
||||
RootCAs: cas,
|
||||
ServerName: tt.tls,
|
||||
InsecureSkipVerify: tt.insecureSkipVerify,
|
||||
}
|
||||
|
||||
var gotAddr string
|
||||
dialer := Dialer{
|
||||
NetDial: func(network, addr string) (net.Conn, error) {
|
||||
gotAddr = addr
|
||||
return net.Dial(network, addrs[tt.server])
|
||||
},
|
||||
TLSClientConfig: tls,
|
||||
}
|
||||
|
||||
// Test websocket dial
|
||||
|
||||
h := http.Header{}
|
||||
if tt.header != "" {
|
||||
h.Set("Host", tt.header)
|
||||
}
|
||||
c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
|
||||
if err == nil {
|
||||
c.Close()
|
||||
}
|
||||
|
||||
check := func(protos map[*httptest.Server]string) {
|
||||
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
|
||||
if gotAddr != tt.wantAddr {
|
||||
t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
|
||||
}
|
||||
switch {
|
||||
case tt.fail && err == nil:
|
||||
t.Errorf("%s: unexpected success", name)
|
||||
case !tt.fail && err != nil:
|
||||
t.Errorf("%s: unexpected error %v", name, err)
|
||||
case !tt.fail && err == nil:
|
||||
if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
|
||||
t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
check(wsProtos)
|
||||
|
||||
// Confirm that net/http has same result
|
||||
|
||||
transport := &http.Transport{
|
||||
Dial: dialer.NetDial,
|
||||
TLSClientConfig: dialer.TLSClientConfig,
|
||||
}
|
||||
req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
|
||||
if tt.header != "" {
|
||||
req.Host = tt.header
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
resp, err = client.Do(req)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
transport.CloseIdleConnections()
|
||||
check(httpProtos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialCompression(t *testing.T) {
|
||||
|
@ -716,19 +853,8 @@ func TestTracingDialWithContext(t *testing.T) {
|
|||
s := newTLSServer(t)
|
||||
defer s.Close()
|
||||
|
||||
certs := x509.NewCertPool()
|
||||
for _, c := range s.TLS.Certificates {
|
||||
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing server's root cert: %v", err)
|
||||
}
|
||||
for _, root := range roots {
|
||||
certs.AddCert(root)
|
||||
}
|
||||
}
|
||||
|
||||
d := cstDialer
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
||||
|
||||
ws, _, err := d.DialContext(ctx, s.URL, nil)
|
||||
if err != nil {
|
||||
|
@ -766,19 +892,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
|
|||
s := newTLSServer(t)
|
||||
defer s.Close()
|
||||
|
||||
certs := x509.NewCertPool()
|
||||
for _, c := range s.TLS.Certificates {
|
||||
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing server's root cert: %v", err)
|
||||
}
|
||||
for _, root := range roots {
|
||||
certs.AddCert(root)
|
||||
}
|
||||
}
|
||||
|
||||
d := cstDialer
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
||||
|
||||
ws, _, err := d.DialContext(ctx, s.URL, nil)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue