diff --git a/client.go b/client.go index 0df5c31..0dc5119 100644 --- a/client.go +++ b/client.go @@ -256,8 +256,16 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } if proxyURL != nil { forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext) - if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial) + if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil { + tlsClientConfig := cloneTLSConfig(d.TLSClientConfig) + if d.TLSClientConfig == nil { + tlsClientConfig = &tls.Config{ + ServerName: proxyURL.Hostname(), + } + } + netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig) + } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { + netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, nil) } else { dialer, err := proxy.FromURL(proxyURL, forwardDial) if err != nil { diff --git a/proxy.go b/proxy.go index 8638ab0..c3c0e45 100644 --- a/proxy.go +++ b/proxy.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "context" + "crypto/tls" "encoding/base64" "errors" "net" @@ -46,7 +47,11 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) ( return fn(ctx, network, addr) } -func newHTTPProxyDialerFunc(proxyURL *url.URL, forwardDial netDialerFunc) netDialerFunc { +// newHTTPProxyDialerFunc returns a netDialerFunc that dials using the provided +// proxyURL. The forwardDial function is used to establish the connection to the +// proxy server. If tlsClientConfig is not nil, the connection to the proxy is +// upgraded to a TLS connection with tls.Client. +func newHTTPProxyDialerFunc(proxyURL *url.URL, forwardDial netDialerFunc, tlsClientConfig *tls.Config) netDialerFunc { return func(ctx context.Context, network, addr string) (net.Conn, error) { hostPort, _ := hostPortNoPort(proxyURL) conn, err := forwardDial(ctx, network, hostPort) @@ -54,6 +59,14 @@ func newHTTPProxyDialerFunc(proxyURL *url.URL, forwardDial netDialerFunc) netDia return nil, err } + if tlsClientConfig != nil { + tlsConn := tls.Client(conn, tlsClientConfig) + if err = tlsConn.HandshakeContext(ctx); err != nil { + return nil, err + } + conn = tlsConn + } + connectHeader := make(http.Header) if user := proxyURL.User; user != nil { proxyUser := user.Username()