Add support for custom TLS implementations

Add the new optional field DialTLS to the Dialer. If it is set, it
will be used for TLS connection setup instead of net.Dial and
tls.Client.
This commit is contained in:
Renato Aguiar 2020-04-14 16:20:17 -07:00
parent 80c2d40e9b
commit cde7c32826
1 changed files with 88 additions and 73 deletions

View File

@ -57,6 +57,11 @@ type Dialer struct {
// NetDialContext is nil, net.DialContext is used. // NetDialContext is nil, net.DialContext is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// DialTLS specifies the dial function for creating TLS connections
// for non-proxied HTTPS requests. If DialTLS is nil, net.Dial and
// net.TLSClientConfig are used.
DialTLS func(network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given // Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the // Request. If the function returns a non-nil error, the
// request is aborted with the provided error. // request is aborted with the provided error.
@ -234,6 +239,26 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
defer cancel() defer cancel()
} }
hostPort, hostNoPort := hostPortNoPort(u)
trace := httptrace.ContextClientTrace(ctx)
if trace != nil && trace.GetConn != nil {
trace.GetConn(hostPort)
}
var netConn net.Conn
defer func() {
if netConn != nil {
netConn.Close()
}
}()
if u.Scheme == "https" && d.DialTLS != nil {
tlsConn, err := d.DialTLS("tcp", hostPort)
if err != nil {
return nil, nil, err
}
netConn = tlsConn
} else {
// Get network dial function. // Get network dial function.
var netDial func(network, add string) (net.Conn, error) var netDial func(network, add string) (net.Conn, error)
@ -282,13 +307,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
} }
hostPort, hostNoPort := hostPortNoPort(u) var err error
trace := httptrace.ContextClientTrace(ctx) netConn, err = netDial("tcp", hostPort)
if trace != nil && trace.GetConn != nil {
trace.GetConn(hostPort)
}
netConn, err := netDial("tcp", hostPort)
if trace != nil && trace.GotConn != nil { if trace != nil && trace.GotConn != nil {
trace.GotConn(httptrace.GotConnInfo{ trace.GotConn(httptrace.GotConnInfo{
Conn: netConn, Conn: netConn,
@ -298,12 +318,6 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err return nil, nil, err
} }
defer func() {
if netConn != nil {
netConn.Close()
}
}()
if u.Scheme == "https" { if u.Scheme == "https" {
cfg := cloneTLSConfig(d.TLSClientConfig) cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" { if cfg.ServerName == "" {
@ -323,6 +337,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err return nil, nil, err
} }
} }
}
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)