mirror of https://github.com/gorilla/websocket.git
Check for and report bad protocol in TLSClientConfig.NextProtos (#788)
* return an error when Dialer.TLSClientConfig.NextProtos contains a protocol that is not http/1.1
* include the likely cause of the error in the error message
* check for nil-ness of Dialer.TLSClientConfig before attempting to run the check
* addressing the review
* move the NextProtos test into a separate file so that it can be run conditionally on go versions >= 1.14
* moving the new error check into existing http response error block to reduce the possibility of false positives
* wrapping the error in %w
* using %v instead of %w for compatibility with older versions of go
* Revert "using %v instead of %w for compatibility with older versions of go"
This reverts commit d34dd940ee
.
* move the unit test back into the existing test code since golang build constraint is no longer necessary
Co-authored-by: Chan Kang <chankang@chankang17@gmail.com>
This commit is contained in:
parent
27d91a9be5
commit
bc7ce893c3
12
client.go
12
client.go
|
@ -9,6 +9,7 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
|
@ -370,6 +371,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
|
||||
resp, err := http.ReadResponse(conn.br, req)
|
||||
if err != nil {
|
||||
if d.TLSClientConfig != nil {
|
||||
for _, proto := range d.TLSClientConfig.NextProtos {
|
||||
if proto != "http/1.1" {
|
||||
return nil, nil, fmt.Errorf(
|
||||
"websocket: protocol %q was given but is not supported;"+
|
||||
"sharing tls.Config with net/http Transport can cause this error: %w",
|
||||
proto, err,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -1098,3 +1098,38 @@ func TestNetDialConnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
func TestNextProtos(t *testing.T) {
|
||||
ts := httptest.NewUnstartedServer(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
)
|
||||
ts.EnableHTTP2 = true
|
||||
ts.StartTLS()
|
||||
defer ts.Close()
|
||||
|
||||
d := Dialer{
|
||||
TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
|
||||
}
|
||||
|
||||
r, err := ts.Client().Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
r.Body.Close()
|
||||
|
||||
// Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
|
||||
// after the Client.Get call from net/http above.
|
||||
var containsHTTP2 bool = false
|
||||
for _, proto := range d.TLSClientConfig.NextProtos {
|
||||
if proto == "h2" {
|
||||
containsHTTP2 = true
|
||||
}
|
||||
}
|
||||
if !containsHTTP2 {
|
||||
t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
|
||||
}
|
||||
|
||||
_, _, err = d.Dial(makeWsProto(ts.URL), nil)
|
||||
if err == nil {
|
||||
t.Fatalf("Dial succeeded, expect fail ")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue