diff --git a/client.go b/client.go index 23678e3..efcdc5d 100644 --- a/client.go +++ b/client.go @@ -274,15 +274,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h return nil, nil, err } if proxyURL != nil { + proxyDialer := &netDialerFunc{fn: netDial} if proxyURL.Scheme == "https" { - netDial = func(network, addr string) (net.Conn, error) { + proxyDialer.usesTLS = true + proxyDialer.fn = func(network, addr string) (net.Conn, error) { t := tls.Dialer{} t.Config = d.TLSClientConfig t.NetDialer = &net.Dialer{} return t.DialContext(ctx, network, addr) } } - dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) + dialer, err := proxy_FromURL(proxyURL, proxyDialer) if err != nil { return nil, nil, err } diff --git a/proxy.go b/proxy.go index f84d881..c332633 100644 --- a/proxy.go +++ b/proxy.go @@ -6,6 +6,7 @@ package websocket import ( "bufio" + "crypto/tls" "encoding/base64" "errors" "net" @@ -14,24 +15,40 @@ import ( "strings" ) -type netDialerFunc func(network, addr string) (net.Conn, error) +type netDialerFunc struct { + fn func(network, addr string) (net.Conn, error) + usesTLS bool +} -func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { - return fn(network, addr) +func (ndf *netDialerFunc) Dial(network, addr string) (net.Conn, error) { + return ndf.fn(network, addr) +} + +func (ndf *netDialerFunc) UsesTLS() bool { + return ndf.usesTLS } func init() { proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { - return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil + return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, usesTLS: forwardDialer.UsesTLS()}, nil }) proxy_RegisterDialerType("https", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { - return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil + fwd := forwardDialer.Dial + if !forwardDialer.UsesTLS() { + tlsDialer := &tls.Dialer{ + Config: &tls.Config{}, + NetDialer: &net.Dialer{}, + } + fwd = tlsDialer.Dial + } + return &httpProxyDialer{proxyURL: proxyURL, forwardDial: fwd, usesTLS: true}, nil }) } type httpProxyDialer struct { proxyURL *url.URL forwardDial func(network, addr string) (net.Conn, error) + usesTLS bool } func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { @@ -78,3 +95,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) } return conn, nil } + +func (hpd *httpProxyDialer) UsesTLS() bool { + return hpd.usesTLS +} diff --git a/x_net_proxy.go b/x_net_proxy.go index 2e668f6..f05a113 100644 --- a/x_net_proxy.go +++ b/x_net_proxy.go @@ -27,6 +27,10 @@ func (proxy_direct) Dial(network, addr string) (net.Conn, error) { return net.Dial(network, addr) } +func (proxy_direct) UsesTLS() bool { + return false +} + // A PerHost directs connections to a default Dialer unless the host name // requested matches one of a number of exceptions. type proxy_PerHost struct { @@ -59,6 +63,10 @@ func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) { return p.dialerForRequest(host).Dial(network, addr) } +func (p *proxy_PerHost) UsesTLS() bool { + return p.def.UsesTLS() || p.bypass.UsesTLS() +} + func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer { if ip := net.ParseIP(host); ip != nil { for _, net := range p.bypassNetworks { @@ -161,6 +169,8 @@ func (p *proxy_PerHost) AddHost(host string) { type proxy_Dialer interface { // Dial connects to the given address via the proxy. Dial(network, addr string) (c net.Conn, err error) + // UsesTLS indicates whether we expect to dial to a TLS proxy + UsesTLS() bool } // Auth contains authentication parameters that specific Dialers may require. @@ -338,6 +348,10 @@ func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) { return conn, nil } +func (s *proxy_socks5) UsesTLS() bool { + return s.forward.UsesTLS() +} + // connect takes an existing connection to a socks5 proxy server, // and commands the server to extend that connection to target, // which must be a canonical address with a host and port.