diff --git a/client.go b/client.go index ebca8ed..5bc27e1 100644 --- a/client.go +++ b/client.go @@ -228,6 +228,22 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re requestHeader = h } + if len(requestHeader["Host"]) > 0 { + // This can be used to supply a Host: header which is different from + // the dial address. + u.Host = requestHeader.Get("Host") + + // Drop "Host" header + h := http.Header{} + for k, v := range requestHeader { + if k == "Host" { + continue + } + h[k] = v + } + requestHeader = h + } + conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize) if err != nil { diff --git a/client_server_test.go b/client_server_test.go index 38a1afc..749ef20 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -288,3 +288,36 @@ func TestRespOnBadHandshake(t *testing.T) { t.Errorf("resp.Body=%s, want %s", p, expectedBody) } } + +// If the Host header is specified in `Dial()`, the server must receive it as +// the `Host:` header. +func TestHostHeader(t *testing.T) { + s := newServer(t) + defer s.Close() + + specifiedHost := make(chan string, 1) + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + specifiedHost <- r.Host + origHandler.ServeHTTP(w, r) + }) + + ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + if resp.StatusCode != http.StatusSwitchingProtocols { + t.Fatalf("resp.StatusCode = %v, want http.StatusSwitchingProtocols", resp.StatusCode) + } + + if gotHost := <-specifiedHost; gotHost != "testhost" { + t.Fatalf("gotHost = %q, want \"testhost\"", gotHost) + } + + sendRecv(t, ws) +}