diff --git a/client.go b/client.go index 7023e11..41bd4ad 100644 --- a/client.go +++ b/client.go @@ -67,6 +67,12 @@ type Dialer struct { // TLSClientConfig is ignored. NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + // ProxyTLSConnection specifies the dial function for creating TLS connections through a Proxy. If + // ProxyTLSConnection is nil, NetDialTLSContext is used. + // If ProxyTLSConnection is set, Dial assumes the TLS handshake is done there and + // TLSClientConfig is ignored. + ProxyTLSConnection func(ctx context.Context, hostPort string, proxyConn net.Conn) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -335,26 +341,34 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } }() - if u.Scheme == "https" && d.NetDialTLSContext == nil { - // If NetDialTLSContext is set, assume that the TLS handshake has already been done + if u.Scheme == "https" { + if d.ProxyTLSConnection != nil && d.Proxy != nil { + // If we are connected to a proxy, perform the TLS handshake through the existing tunnel + netConn, err = d.ProxyTLSConnection(ctx, hostPort, netConn) + if err != nil { + return nil, nil, err + } + } else if d.NetDialTLSContext == nil { + // If NetDialTLSContext is set, assume that the TLS handshake has already been done - cfg := cloneTLSConfig(d.TLSClientConfig) - if cfg.ServerName == "" { - cfg.ServerName = hostNoPort - } - tlsConn := tls.Client(netConn, cfg) - netConn = tlsConn + cfg := cloneTLSConfig(d.TLSClientConfig) + if cfg.ServerName == "" { + cfg.ServerName = hostNoPort + } + tlsConn := tls.Client(netConn, cfg) + netConn = tlsConn - if trace != nil && trace.TLSHandshakeStart != nil { - trace.TLSHandshakeStart() - } - err := doHandshake(ctx, tlsConn, cfg) - if trace != nil && trace.TLSHandshakeDone != nil { - trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) - } + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) + } - if err != nil { - return nil, nil, err + if err != nil { + return nil, nil, err + } } } diff --git a/proxy.go b/proxy.go index 3c570c2..ebe0b4b 100644 --- a/proxy.go +++ b/proxy.go @@ -35,7 +35,7 @@ type httpProxyDialer struct { func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { hostPort, _ := hostPortNoPort(hpd.proxyURL) - conn, err := hpd.forwardDial(network, hostPort) + conn, err := net.Dial(network, hostPort) if err != nil { return nil, err }