diff --git a/.circleci/config.yml b/.circleci/config.yml index a0eb0ed..ecb33f6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -16,8 +16,8 @@ jobs: type: string default: "" docker: - - image: "circleci/golang:<< parameters.version >>" - working_directory: /go/src/github.com/gorilla/websocket + - image: "cimg/go:<< parameters.version >>" + working_directory: /home/circleci/project/go/src/github.com/gorilla/websocket environment: GO111MODULE: "on" GOPROXY: "<< parameters.goproxy >>" @@ -67,4 +67,4 @@ workflows: - test: matrix: parameters: - version: ["latest", "1.17", "1.16", "1.15", "1.14", "1.13", "1.12", "1.11"] + version: ["1.18", "1.17", "1.16"] diff --git a/README.md b/README.md index 2517a28..d33ed7f 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,6 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. ---- - -⚠️ **[The Gorilla WebSocket Package is looking for a new maintainer](https://github.com/gorilla/websocket/issues/370)** - ---- - ### Documentation * [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc) diff --git a/client.go b/client.go index 274bf6f..8790b91 100644 --- a/client.go +++ b/client.go @@ -9,6 +9,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "io/ioutil" "net" @@ -320,14 +321,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } netConn, err := netDial("tcp", hostPort) + if err != nil { + return nil, nil, err + } if trace != nil && trace.GotConn != nil { trace.GotConn(httptrace.GotConnInfo{ Conn: netConn, }) } - if err != nil { - return nil, nil, err - } defer func() { if netConn != nil { @@ -372,6 +373,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h resp, err := http.ReadResponse(conn.br, req) if err != nil { + if d.TLSClientConfig != nil { + for _, proto := range d.TLSClientConfig.NextProtos { + if proto != "http/1.1" { + return nil, nil, fmt.Errorf( + "websocket: protocol %q was given but is not supported;"+ + "sharing tls.Config with net/http Transport can cause this error: %w", + proto, err, + ) + } + } + } return nil, nil, err } diff --git a/client_server_test.go b/client_server_test.go index e975e51..a47df48 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -1098,3 +1098,38 @@ func TestNetDialConnect(t *testing.T) { } } } +func TestNextProtos(t *testing.T) { + ts := httptest.NewUnstartedServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), + ) + ts.EnableHTTP2 = true + ts.StartTLS() + defer ts.Close() + + d := Dialer{ + TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig, + } + + r, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + r.Body.Close() + + // Asserts that Dialer.TLSClientConfig.NextProtos contains "h2" + // after the Client.Get call from net/http above. + var containsHTTP2 bool = false + for _, proto := range d.TLSClientConfig.NextProtos { + if proto == "h2" { + containsHTTP2 = true + } + } + if !containsHTTP2 { + t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"") + } + + _, _, err = d.Dial(makeWsProto(ts.URL), nil) + if err == nil { + t.Fatalf("Dial succeeded, expect fail ") + } +} diff --git a/conn.go b/conn.go index 331eebc..5161ef8 100644 --- a/conn.go +++ b/conn.go @@ -1189,8 +1189,16 @@ func (c *Conn) SetPongHandler(h func(appData string) error) { c.handlePong = h } +// NetConn returns the underlying connection that is wrapped by c. +// Note that writing to or reading from this connection directly will corrupt the +// WebSocket connection. +func (c *Conn) NetConn() net.Conn { + return c.conn +} + // UnderlyingConn returns the internal net.Conn. This can be used to further // modifications to connection specific flags. +// Deprecated: Use the NetConn method. func (c *Conn) UnderlyingConn() net.Conn { return c.conn } diff --git a/conn_test.go b/conn_test.go index bd96e0a..06e5184 100644 --- a/conn_test.go +++ b/conn_test.go @@ -562,7 +562,7 @@ func TestAddrs(t *testing.T) { } } -func TestUnderlyingConn(t *testing.T) { +func TestDeprecatedUnderlyingConn(t *testing.T) { var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2} c := newConn(fc, true, 1024, 1024, nil, nil, nil) @@ -572,6 +572,16 @@ func TestUnderlyingConn(t *testing.T) { } } +func TestNetConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024, nil, nil, nil) + ul := c.NetConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} + func TestBufioReadBytes(t *testing.T) { // Test calling bufio.ReadBytes for value longer than read buffer size. diff --git a/server_test.go b/server_test.go index 3029aa8..5804be1 100644 --- a/server_test.go +++ b/server_test.go @@ -111,7 +111,7 @@ func TestBufioReuse(t *testing.T) { if reuse := c.br == br; reuse != tt.reuse { t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) } - writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw) + writeBuf := bufioWriterBuffer(c.NetConn(), bw) if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) }