diff --git a/client.go b/client.go index c0c3073..a910356 100644 --- a/client.go +++ b/client.go @@ -190,8 +190,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h for _, proto := range d.TLSClientConfig.NextProtos { if proto != "http/1.1" { return nil, nil, fmt.Errorf( - `protocol %q was given but is not supported; - sharing tls.Config with net/http Transport can cause this error`, + "websocket: protocol %q was given but is not supported; "+ + "sharing tls.Config with net/http Transport can cause this error", proto, ) } diff --git a/client_server_test.go b/client_server_test.go index e975e51..f6a0664 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -1098,3 +1098,39 @@ func TestNetDialConnect(t *testing.T) { } } } + +func TestNextProtos(t *testing.T) { + ts := httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + ) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + d := Dialer{ + TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig, + } + + r, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + r.Body.Close() + + // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2" + // after the Client.Get call from net/http above. + var containsHTTP2 bool = false + for _, proto := range d.TLSClientConfig.NextProtos { + if proto == "h2" { + containsHTTP2 = true + } + } + if !containsHTTP2 { + t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"") + } + + _, _, err = d.Dial(makeWsProto(ts.URL), nil) + if err == nil { + t.Fatalf("Dial succeeded, expect fail ") + } +}