diff --git a/client.go b/client.go index 901a24b..2e32fd5 100644 --- a/client.go +++ b/client.go @@ -103,7 +103,7 @@ type Dialer struct { // Dial creates a new client connection by calling DialContext with a background context. func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { - return d.DialContext(urlStr, requestHeader, context.Background()) + return d.DialContext(context.Background(), urlStr, requestHeader) } var errMalformedURL = errors.New("malformed ws or wss URL") @@ -146,7 +146,7 @@ var nilDialer = *DefaultDialer // non-nil *http.Response so that callers can handle redirects, authentication, // etcetera. The response body may not contain the entire response and does not // need to be closed by the application. -func (d *Dialer) DialContext(urlStr string, requestHeader http.Header, ctx context.Context) (*Conn, *http.Response, error) { +func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { if d == nil { d = &nilDialer } diff --git a/client_server_test.go b/client_server_test.go index f829566..564fbbe 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -424,7 +424,7 @@ func TestHandshakeTimeoutInContext(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) defer cancel() - ws, _, err := d.DialContext(s.URL, nil, ctx) + ws, _, err := d.DialContext(ctx, s.URL, nil) if err != nil { t.Fatal("Dial:", err) } @@ -730,7 +730,7 @@ func TestTracingDialWithContext(t *testing.T) { d := cstDialer d.TLSClientConfig = &tls.Config{RootCAs: certs} - ws, _, err := d.DialContext(s.URL, nil, ctx) + ws, _, err := d.DialContext(ctx, s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) } @@ -780,7 +780,7 @@ func TestEmptyTracingDialWithContext(t *testing.T) { d := cstDialer d.TLSClientConfig = &tls.Config{RootCAs: certs} - ws, _, err := d.DialContext(s.URL, nil, ctx) + ws, _, err := d.DialContext(ctx, s.URL, nil) if err != nil { t.Fatalf("Dial: %v", err) }