Add comprehensive host test (#429)

Add table driven test for handling of host in request URL, request
header and TLS server name. In addition to testing various uses of host
names, this test also confirms that host names are handled the same as
the net/http client.

The new table driven test replaces TestDialTLS, TestDialTLSNoverify,
TestDialTLSBadCert and TestHostHeader.

Eliminate duplicated code for constructing root CA.
This commit is contained in:
Steven Scott 2018-09-24 16:10:46 -07:00 committed by Gary Burd
parent 66b9c49e59
commit cdd40f587d
1 changed files with 194 additions and 79 deletions

View File

@ -11,8 +11,10 @@ import (
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
@ -42,17 +44,12 @@ var cstDialer = Dialer{
HandshakeTimeout: 30 * time.Second, HandshakeTimeout: 30 * time.Second,
} }
var cstDialerWithoutHandshakeTimeout = Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
type cstHandler struct{ *testing.T } type cstHandler struct{ *testing.T }
type cstServer struct { type cstServer struct {
*httptest.Server *httptest.Server
URL string URL string
t *testing.T
} }
const ( const (
@ -288,10 +285,7 @@ func TestDialCookieJar(t *testing.T) {
sendRecv(t, ws) sendRecv(t, ws)
} }
func TestDialTLS(t *testing.T) { func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
s := newTLSServer(t)
defer s.Close()
certs := x509.NewCertPool() certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates { for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
@ -302,35 +296,15 @@ func TestDialTLS(t *testing.T) {
certs.AddCert(root) certs.AddCert(root)
} }
} }
return certs
d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
} }
func xTestDialTLSBadCert(t *testing.T) { func TestDialTLS(t *testing.T) {
// This test is deactivated because of noisy logging from the net/http package.
s := newTLSServer(t)
defer s.Close()
ws, _, err := cstDialer.Dial(s.URL, nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
}
func TestDialTLSNoVerify(t *testing.T) {
s := newTLSServer(t) s := newTLSServer(t)
defer s.Close() defer s.Close()
d := cstDialer d := cstDialer
d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.Dial(s.URL, nil) ws, _, err := d.Dial(s.URL, nil)
if err != nil { if err != nil {
t.Fatalf("Dial: %v", err) t.Fatalf("Dial: %v", err)
@ -415,7 +389,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()
d := cstDialerWithoutHandshakeTimeout d := cstDialer
d.HandshakeTimeout = 0
d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) { d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
netDialer := &net.Dialer{} netDialer := &net.Dialer{}
c, err := netDialer.DialContext(ctx, n, a) c, err := netDialer.DialContext(ctx, n, a)
@ -566,33 +541,195 @@ func TestRespOnBadHandshake(t *testing.T) {
} }
} }
// TestHostHeader confirms that the host header provided in the call to Dial is type testLogWriter struct {
// sent to the server. t *testing.T
func TestHostHeader(t *testing.T) { }
s := newServer(t)
defer s.Close()
specifiedHost := make(chan string, 1) func (w testLogWriter) Write(p []byte) (int, error) {
origHandler := s.Server.Config.Handler w.t.Logf("%s", p)
return len(p), nil
}
// Capture the request Host header. // TestHost tests handling of host names and confirms that it matches net/http.
s.Server.Config.Handler = http.HandlerFunc( func TestHost(t *testing.T) {
func(w http.ResponseWriter, r *http.Request) {
specifiedHost <- r.Host upgrader := Upgrader{}
origHandler.ServeHTTP(w, r) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if IsWebSocketUpgrade(r) {
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
if err != nil {
t.Fatal(err)
}
c.Close()
} else {
w.Header().Set("X-Test-Host", r.Host)
}
}) })
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) server := httptest.NewServer(handler)
if err != nil { defer server.Close()
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
if gotHost := <-specifiedHost; gotHost != "testhost" { tlsServer := httptest.NewTLSServer(handler)
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost) defer tlsServer.Close()
addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
// Avoid log noise from net/http server by logging to testing.T
server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
tlsServer.Config.ErrorLog = server.Config.ErrorLog
cas := rootCAs(t, tlsServer)
tests := []struct {
fail bool // true if dial / get should fail
server *httptest.Server // server to use
url string // host for request URI
header string // optional request host header
tls string // optiona host for tls ServerName
wantAddr string // expected host for dial
wantHeader string // expected request header on server
insecureSkipVerify bool
}{
{
server: server,
url: addrs[server],
wantAddr: addrs[server],
wantHeader: addrs[server],
},
{
server: tlsServer,
url: addrs[tlsServer],
wantAddr: addrs[tlsServer],
wantHeader: addrs[tlsServer],
},
{
server: server,
url: addrs[server],
header: "badhost.com",
wantAddr: addrs[server],
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: addrs[tlsServer],
header: "badhost.com",
wantAddr: addrs[tlsServer],
wantHeader: "badhost.com",
},
{
server: server,
url: "example.com",
header: "badhost.com",
wantAddr: "example.com:80",
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: "example.com",
header: "badhost.com",
wantAddr: "example.com:443",
wantHeader: "badhost.com",
},
{
server: server,
url: "badhost.com",
header: "example.com",
wantAddr: "badhost.com:80",
wantHeader: "example.com",
},
{
fail: true,
server: tlsServer,
url: "badhost.com",
header: "example.com",
wantAddr: "badhost.com:443",
},
{
server: tlsServer,
url: "badhost.com",
insecureSkipVerify: true,
wantAddr: "badhost.com:443",
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: "badhost.com",
tls: "example.com",
wantAddr: "badhost.com:443",
wantHeader: "badhost.com",
},
} }
sendRecv(t, ws) for i, tt := range tests {
tls := &tls.Config{
RootCAs: cas,
ServerName: tt.tls,
InsecureSkipVerify: tt.insecureSkipVerify,
}
var gotAddr string
dialer := Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
gotAddr = addr
return net.Dial(network, addrs[tt.server])
},
TLSClientConfig: tls,
}
// Test websocket dial
h := http.Header{}
if tt.header != "" {
h.Set("Host", tt.header)
}
c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
if err == nil {
c.Close()
}
check := func(protos map[*httptest.Server]string) {
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
if gotAddr != tt.wantAddr {
t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
}
switch {
case tt.fail && err == nil:
t.Errorf("%s: unexpected success", name)
case !tt.fail && err != nil:
t.Errorf("%s: unexpected error %v", name, err)
case !tt.fail && err == nil:
if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
}
}
}
check(wsProtos)
// Confirm that net/http has same result
transport := &http.Transport{
Dial: dialer.NetDial,
TLSClientConfig: dialer.TLSClientConfig,
}
req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
if tt.header != "" {
req.Host = tt.header
}
client := &http.Client{Transport: transport}
resp, err = client.Do(req)
if err == nil {
resp.Body.Close()
}
transport.CloseIdleConnections()
check(httpProtos)
}
} }
func TestDialCompression(t *testing.T) { func TestDialCompression(t *testing.T) {
@ -716,19 +853,8 @@ func TestTracingDialWithContext(t *testing.T) {
s := newTLSServer(t) s := newTLSServer(t)
defer s.Close() defer s.Close()
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}
d := cstDialer d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs} d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.DialContext(ctx, s.URL, nil) ws, _, err := d.DialContext(ctx, s.URL, nil)
if err != nil { if err != nil {
@ -766,19 +892,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
s := newTLSServer(t) s := newTLSServer(t)
defer s.Close() defer s.Close()
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}
d := cstDialer d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs} d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.DialContext(ctx, s.URL, nil) ws, _, err := d.DialContext(ctx, s.URL, nil)
if err != nil { if err != nil {