diff --git a/client.go b/client.go index 24bd7ff..b1ff322 100644 --- a/client.go +++ b/client.go @@ -317,6 +317,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } }() + var tlsState *tls.ConnectionState if u.Scheme == "https" && d.NetDialTLSContext == nil { // If NetDialTLSContext is set, assume that the TLS handshake has already been done @@ -330,13 +331,18 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - err := doHandshake(ctx, tlsConn, cfg) - if trace != nil && trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) - } - if err != nil { + if err := doHandshake(ctx, tlsConn, cfg); err != nil { + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } return nil, nil, err + } else { + cs := tlsConn.ConnectionState() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(cs, nil) + } + tlsState = &cs } } @@ -374,6 +380,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } } + if resp.TLS == nil && tlsState != nil { + resp.TLS = tlsState + } + if resp.StatusCode != 101 || !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || !tokenListContainsValue(resp.Header, "Connection", "upgrade") || diff --git a/client_server_test.go b/client_server_test.go index e4546ae..0c2fbc5 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -330,6 +330,25 @@ func TestDialTLS(t *testing.T) { sendRecv(t, ws) } +func TestDialTLSConnState(t *testing.T) { + s := newTLSServer(t) + defer s.Close() + + d := cstDialer + d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)} + ws, resp, err := d.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + if resp.TLS == nil { + t.Errorf("http response tls is nil") + } else if len(resp.TLS.PeerCertificates) == 0 { + t.Errorf("http response PeerCertificates count is 0") + } + defer ws.Close() + sendRecv(t, ws) +} + func TestDialTimeout(t *testing.T) { s := newServer(t) defer s.Close()