From bc7ce893c36595e095de550a211feb5993e6ef92 Mon Sep 17 00:00:00 2001 From: Chan Kang Date: Tue, 21 Jun 2022 13:54:14 -0400 Subject: [PATCH] Check for and report bad protocol in TLSClientConfig.NextProtos (#788) * return an error when Dialer.TLSClientConfig.NextProtos contains a protocol that is not http/1.1 * include the likely cause of the error in the error message * check for nil-ness of Dialer.TLSClientConfig before attempting to run the check * addressing the review * move the NextProtos test into a separate file so that it can be run conditionally on go versions >= 1.14 * moving the new error check into existing http response error block to reduce the possibility of false positives * wrapping the error in %w * using %v instead of %w for compatibility with older versions of go * Revert "using %v instead of %w for compatibility with older versions of go" This reverts commit d34dd940eeb29b6f4d21d3ab9148893b4019afd1. * move the unit test back into the existing test code since golang build constraint is no longer necessary Co-authored-by: Chan Kang --- client.go | 12 ++++++++++++ client_server_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/client.go b/client.go index 2efd835..f79ac98 100644 --- a/client.go +++ b/client.go @@ -9,6 +9,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "io/ioutil" "net" @@ -370,6 +371,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h resp, err := http.ReadResponse(conn.br, req) if err != nil { + if d.TLSClientConfig != nil { + for _, proto := range d.TLSClientConfig.NextProtos { + if proto != "http/1.1" { + return nil, nil, fmt.Errorf( + "websocket: protocol %q was given but is not supported;"+ + "sharing tls.Config with net/http Transport can cause this error: %w", + proto, err, + ) + } + } + } return nil, nil, err } diff --git a/client_server_test.go b/client_server_test.go index e975e51..a47df48 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -1098,3 +1098,38 @@ func TestNetDialConnect(t *testing.T) { } } } +func TestNextProtos(t *testing.T) { + ts := httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + ) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + d := Dialer{ + TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig, + } + + r, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + r.Body.Close() + + // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2" + // after the Client.Get call from net/http above. + var containsHTTP2 bool = false + for _, proto := range d.TLSClientConfig.NextProtos { + if proto == "h2" { + containsHTTP2 = true + } + } + if !containsHTTP2 { + t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"") + } + + _, _, err = d.Dial(makeWsProto(ts.URL), nil) + if err == nil { + t.Fatalf("Dial succeeded, expect fail ") + } +}