This commit is contained in:
Sleeyax 2024-06-19 17:37:57 +10:00 committed by GitHub
commit 883b0a2814
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 18 deletions

View File

@ -65,6 +65,12 @@ type Dialer struct {
// TLSClientConfig is ignored.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// ProxyTLSConnection specifies the dial function for creating TLS connections through a Proxy. If
// ProxyTLSConnection is nil, NetDialTLSContext is used.
// If ProxyTLSConnection is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
ProxyTLSConnection func(ctx context.Context, hostPort string, proxyConn net.Conn) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
@ -333,7 +339,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}()
if u.Scheme == "https" && d.NetDialTLSContext == nil {
if u.Scheme == "https" {
if d.ProxyTLSConnection != nil && d.Proxy != nil {
// If we are connected to a proxy, perform the TLS handshake through the existing tunnel
netConn, err = d.ProxyTLSConnection(ctx, hostPort, netConn)
if err != nil {
return nil, nil, err
}
} else if d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
cfg := cloneTLSConfig(d.TLSClientConfig)
@ -355,6 +368,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return nil, nil, err
}
}
}
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)

View File

@ -34,7 +34,7 @@ type httpProxyDialer struct {
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.forwardDial(network, hostPort)
conn, err := net.Dial(network, hostPort)
if err != nil {
return nil, err
}