mirror of https://github.com/gorilla/websocket.git
Add support for custom TLS implementations
Add the new optional field DialTLS to the Dialer. If it is set, it will be used for TLS connection setup instead of net.Dial and tls.Client.
This commit is contained in:
parent
80c2d40e9b
commit
cde7c32826
41
client.go
41
client.go
|
@ -57,6 +57,11 @@ type Dialer struct {
|
|||
// NetDialContext is nil, net.DialContext is used.
|
||||
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// DialTLS specifies the dial function for creating TLS connections
|
||||
// for non-proxied HTTPS requests. If DialTLS is nil, net.Dial and
|
||||
// net.TLSClientConfig are used.
|
||||
DialTLS func(network, addr string) (net.Conn, error)
|
||||
|
||||
// Proxy specifies a function to return a proxy for a given
|
||||
// Request. If the function returns a non-nil error, the
|
||||
// request is aborted with the provided error.
|
||||
|
@ -234,6 +239,26 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
defer cancel()
|
||||
}
|
||||
|
||||
hostPort, hostNoPort := hostPortNoPort(u)
|
||||
trace := httptrace.ContextClientTrace(ctx)
|
||||
if trace != nil && trace.GetConn != nil {
|
||||
trace.GetConn(hostPort)
|
||||
}
|
||||
|
||||
var netConn net.Conn
|
||||
defer func() {
|
||||
if netConn != nil {
|
||||
netConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if u.Scheme == "https" && d.DialTLS != nil {
|
||||
tlsConn, err := d.DialTLS("tcp", hostPort)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
netConn = tlsConn
|
||||
} else {
|
||||
// Get network dial function.
|
||||
var netDial func(network, add string) (net.Conn, error)
|
||||
|
||||
|
@ -282,13 +307,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
}
|
||||
}
|
||||
|
||||
hostPort, hostNoPort := hostPortNoPort(u)
|
||||
trace := httptrace.ContextClientTrace(ctx)
|
||||
if trace != nil && trace.GetConn != nil {
|
||||
trace.GetConn(hostPort)
|
||||
}
|
||||
|
||||
netConn, err := netDial("tcp", hostPort)
|
||||
var err error
|
||||
netConn, err = netDial("tcp", hostPort)
|
||||
if trace != nil && trace.GotConn != nil {
|
||||
trace.GotConn(httptrace.GotConnInfo{
|
||||
Conn: netConn,
|
||||
|
@ -298,12 +318,6 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if netConn != nil {
|
||||
netConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if u.Scheme == "https" {
|
||||
cfg := cloneTLSConfig(d.TLSClientConfig)
|
||||
if cfg.ServerName == "" {
|
||||
|
@ -323,6 +337,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
|
||||
|
||||
|
|
Loading…
Reference in New Issue