mirror of https://github.com/gorilla/websocket.git
Expose tls connection state in http response
This commit is contained in:
parent
5e00238113
commit
75d81364b7
20
client.go
20
client.go
|
@ -317,6 +317,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var tlsState *tls.ConnectionState
|
||||||
if u.Scheme == "https" && d.NetDialTLSContext == nil {
|
if u.Scheme == "https" && d.NetDialTLSContext == nil {
|
||||||
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
|
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
|
||||||
|
|
||||||
|
@ -330,13 +331,18 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
if trace != nil && trace.TLSHandshakeStart != nil {
|
if trace != nil && trace.TLSHandshakeStart != nil {
|
||||||
trace.TLSHandshakeStart()
|
trace.TLSHandshakeStart()
|
||||||
}
|
}
|
||||||
err := doHandshake(ctx, tlsConn, cfg)
|
|
||||||
if trace != nil && trace.TLSHandshakeDone != nil {
|
|
||||||
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err := doHandshake(ctx, tlsConn, cfg); err != nil {
|
||||||
|
if trace != nil && trace.TLSHandshakeDone != nil {
|
||||||
|
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
|
||||||
|
}
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
} else {
|
||||||
|
cs := tlsConn.ConnectionState()
|
||||||
|
if trace != nil && trace.TLSHandshakeDone != nil {
|
||||||
|
trace.TLSHandshakeDone(cs, nil)
|
||||||
|
}
|
||||||
|
tlsState = &cs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -374,6 +380,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resp.TLS == nil && tlsState != nil {
|
||||||
|
resp.TLS = tlsState
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 101 ||
|
if resp.StatusCode != 101 ||
|
||||||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
|
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
|
||||||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
|
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
|
||||||
|
|
|
@ -330,6 +330,25 @@ func TestDialTLS(t *testing.T) {
|
||||||
sendRecv(t, ws)
|
sendRecv(t, ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialTLSConnState(t *testing.T) {
|
||||||
|
s := newTLSServer(t)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
d := cstDialer
|
||||||
|
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
||||||
|
ws, resp, err := d.Dial(s.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial: %v", err)
|
||||||
|
}
|
||||||
|
if resp.TLS == nil {
|
||||||
|
t.Errorf("http response tls is nil")
|
||||||
|
} else if len(resp.TLS.PeerCertificates) == 0 {
|
||||||
|
t.Errorf("http response PeerCertificates count is 0")
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
sendRecv(t, ws)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDialTimeout(t *testing.T) {
|
func TestDialTimeout(t *testing.T) {
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
Loading…
Reference in New Issue