diff --git a/client.go b/client.go index 5071a0d..3bf9b2e 100644 --- a/client.go +++ b/client.go @@ -197,11 +197,18 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} } for k, vs := range requestHeader { - if k == "Host" { + switch { + case k == "Host": if len(vs) > 0 { req.Host = vs[0] } - } else { + case k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): + return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + default: req.Header[k] = vs } } diff --git a/client_server_test.go b/client_server_test.go index ebcba9f..7d72fd9 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -268,6 +268,25 @@ func TestDialBadOrigin(t *testing.T) { } } +func TestDialBadHeader(t *testing.T) { + s := newServer(t) + defer s.Close() + + for _, k := range []string{"Upgrade", + "Connection", + "Sec-Websocket-Key", + "Sec-Websocket-Version", + "Sec-Websocket-Protocol"} { + h := http.Header{} + h.Set(k, "bad") + ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) + if err == nil { + ws.Close() + t.Errorf("Dial with header %s returned nil", k) + } + } +} + func TestHandshake(t *testing.T) { s := newServer(t) defer s.Close()