diff --git a/client.go b/client.go index 24bd7ff..8666562 100644 --- a/client.go +++ b/client.go @@ -17,6 +17,8 @@ import ( "net/url" "strings" "time" + + "golang.org/x/net/proxy" ) // ErrBadHandshake is returned when the server response to opening handshake is @@ -282,7 +284,21 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h return nil, nil, err } if proxyURL != nil { - netDial, err = proxyFromURL(proxyURL, netDial) + netDial, err = func(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) { + if proxyURL.Scheme == "http" { + return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil + } + dialer, err := proxy.FromURL(proxyURL, forwardDial) + if err != nil { + return nil, err + } + if d, ok := dialer.(proxy.ContextDialer); ok { + return d.DialContext, nil + } + return func(ctx context.Context, net, addr string) (net.Conn, error) { + return dialer.Dial(net, addr) + }, nil + }(proxyURL, netDial) if err != nil { return nil, nil, err } diff --git a/proxy.go b/proxy.go index b4683b9..8c8a27e 100644 --- a/proxy.go +++ b/proxy.go @@ -14,8 +14,6 @@ import ( "net/http" "net/url" "strings" - - "golang.org/x/net/proxy" ) type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error) @@ -28,22 +26,6 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) ( return fn(ctx, network, addr) } -func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) { - if proxyURL.Scheme == "http" { - return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil - } - dialer, err := proxy.FromURL(proxyURL, forwardDial) - if err != nil { - return nil, err - } - if d, ok := dialer.(proxy.ContextDialer); ok { - return d.DialContext, nil - } - return func(ctx context.Context, net, addr string) (net.Conn, error) { - return dialer.Dial(net, addr) - }, nil -} - type httpProxyDialer struct { proxyURL *url.URL forwardDial netDialerFunc