fix tls handshake on proxy

This commit is contained in:
Cooper Oh 2024-07-18 23:02:07 +09:00
parent 75fbe70bee
commit bad5b0af7f
2 changed files with 24 additions and 3 deletions

View File

@ -256,8 +256,16 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
} }
if proxyURL != nil { if proxyURL != nil {
forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext) forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial) tlsClientConfig := cloneTLSConfig(d.TLSClientConfig)
if d.TLSClientConfig == nil {
tlsClientConfig = &tls.Config{
ServerName: proxyURL.Hostname(),
}
}
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig)
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, nil)
} else { } else {
dialer, err := proxy.FromURL(proxyURL, forwardDial) dialer, err := proxy.FromURL(proxyURL, forwardDial)
if err != nil { if err != nil {

View File

@ -8,6 +8,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"net" "net"
@ -46,7 +47,11 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (
return fn(ctx, network, addr) return fn(ctx, network, addr)
} }
func newHTTPProxyDialerFunc(proxyURL *url.URL, forwardDial netDialerFunc) netDialerFunc { // newHTTPProxyDialerFunc returns a netDialerFunc that dials using the provided
// proxyURL. The forwardDial function is used to establish the connection to the
// proxy server. If tlsClientConfig is not nil, the connection to the proxy is
// upgraded to a TLS connection with tls.Client.
func newHTTPProxyDialerFunc(proxyURL *url.URL, forwardDial netDialerFunc, tlsClientConfig *tls.Config) netDialerFunc {
return func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(proxyURL) hostPort, _ := hostPortNoPort(proxyURL)
conn, err := forwardDial(ctx, network, hostPort) conn, err := forwardDial(ctx, network, hostPort)
@ -54,6 +59,14 @@ func newHTTPProxyDialerFunc(proxyURL *url.URL, forwardDial netDialerFunc) netDia
return nil, err return nil, err
} }
if tlsClientConfig != nil {
tlsConn := tls.Client(conn, tlsClientConfig)
if err = tlsConn.HandshakeContext(ctx); err != nil {
return nil, err
}
conn = tlsConn
}
connectHeader := make(http.Header) connectHeader := make(http.Header)
if user := proxyURL.User; user != nil { if user := proxyURL.User; user != nil {
proxyUser := user.Username() proxyUser := user.Username()