mirror of https://github.com/gorilla/websocket.git
Expose tls connection state in http response
This commit is contained in:
parent
5e00238113
commit
75d81364b7
20
client.go
20
client.go
|
@ -317,6 +317,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
}
|
||||
}()
|
||||
|
||||
var tlsState *tls.ConnectionState
|
||||
if u.Scheme == "https" && d.NetDialTLSContext == nil {
|
||||
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
|
||||
|
||||
|
@ -330,13 +331,18 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
if trace != nil && trace.TLSHandshakeStart != nil {
|
||||
trace.TLSHandshakeStart()
|
||||
}
|
||||
err := doHandshake(ctx, tlsConn, cfg)
|
||||
if trace != nil && trace.TLSHandshakeDone != nil {
|
||||
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err := doHandshake(ctx, tlsConn, cfg); err != nil {
|
||||
if trace != nil && trace.TLSHandshakeDone != nil {
|
||||
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
|
||||
}
|
||||
return nil, nil, err
|
||||
} else {
|
||||
cs := tlsConn.ConnectionState()
|
||||
if trace != nil && trace.TLSHandshakeDone != nil {
|
||||
trace.TLSHandshakeDone(cs, nil)
|
||||
}
|
||||
tlsState = &cs
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -374,6 +380,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
}
|
||||
}
|
||||
|
||||
if resp.TLS == nil && tlsState != nil {
|
||||
resp.TLS = tlsState
|
||||
}
|
||||
|
||||
if resp.StatusCode != 101 ||
|
||||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
|
||||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
|
||||
|
|
|
@ -330,6 +330,25 @@ func TestDialTLS(t *testing.T) {
|
|||
sendRecv(t, ws)
|
||||
}
|
||||
|
||||
func TestDialTLSConnState(t *testing.T) {
|
||||
s := newTLSServer(t)
|
||||
defer s.Close()
|
||||
|
||||
d := cstDialer
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
||||
ws, resp, err := d.Dial(s.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
if resp.TLS == nil {
|
||||
t.Errorf("http response tls is nil")
|
||||
} else if len(resp.TLS.PeerCertificates) == 0 {
|
||||
t.Errorf("http response PeerCertificates count is 0")
|
||||
}
|
||||
defer ws.Close()
|
||||
sendRecv(t, ws)
|
||||
}
|
||||
|
||||
func TestDialTimeout(t *testing.T) {
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
|
Loading…
Reference in New Issue