Check for and report bad protocol in TLSClientConfig.NextProtos (#788)

* return an error when Dialer.TLSClientConfig.NextProtos contains a protocol that is not http/1.1

* include the likely cause of the error in the error message

* check for nil-ness of Dialer.TLSClientConfig before attempting to run the check

* addressing the review

* move the NextProtos test into a separate file so that it can be run conditionally on go versions >= 1.14

* moving the new error check into existing http response error block to reduce the possibility of false positives

* wrapping the error in %w

* using %v instead of %w for compatibility with older versions of go

* Revert "using %v instead of %w for compatibility with older versions of go"

This reverts commit d34dd940ee.

* move the unit test back into the existing test code since golang build constraint is no longer necessary

Co-authored-by: Chan Kang <chankang@chankang17@gmail.com>
This commit is contained in:
Chan Kang 2022-06-21 13:54:14 -04:00 committed by GitHub
parent 27d91a9be5
commit bc7ce893c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 0 deletions

View File

@ -9,6 +9,7 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
@ -370,6 +371,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
resp, err := http.ReadResponse(conn.br, req)
if err != nil {
if d.TLSClientConfig != nil {
for _, proto := range d.TLSClientConfig.NextProtos {
if proto != "http/1.1" {
return nil, nil, fmt.Errorf(
"websocket: protocol %q was given but is not supported;"+
"sharing tls.Config with net/http Transport can cause this error: %w",
proto, err,
)
}
}
}
return nil, nil, err
}

View File

@ -1098,3 +1098,38 @@ func TestNetDialConnect(t *testing.T) {
}
}
}
func TestNextProtos(t *testing.T) {
ts := httptest.NewUnstartedServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
)
ts.EnableHTTP2 = true
ts.StartTLS()
defer ts.Close()
d := Dialer{
TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
}
r, err := ts.Client().Get(ts.URL)
if err != nil {
t.Fatalf("Get: %v", err)
}
r.Body.Close()
// Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
// after the Client.Get call from net/http above.
var containsHTTP2 bool = false
for _, proto := range d.TLSClientConfig.NextProtos {
if proto == "h2" {
containsHTTP2 = true
}
}
if !containsHTTP2 {
t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
}
_, _, err = d.Dial(makeWsProto(ts.URL), nil)
if err == nil {
t.Fatalf("Dial succeeded, expect fail ")
}
}