diff --git a/client.go b/client.go index e5c7df8..f49b8fd 100644 --- a/client.go +++ b/client.go @@ -283,7 +283,27 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h if err != nil { return nil, nil, err } + + getDefaultDialerFunc := func() (netDialerFunc, error) { + dialer, err := proxy.FromURL(proxyURL, netDial) + if err != nil { + return nil, err + } + if d, ok := dialer.(proxy.ContextDialer); ok { + return d.DialContext, nil + } else { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + }, nil + } + } + switch { + case proxyURL.Scheme == "socks5": + netDial, err = getDefaultDialerFunc() + if err != nil { + return nil, nil, err + } case proxyURL == nil: // Do nothing. Not using a proxy. case u.Scheme == "http": @@ -296,17 +316,10 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h case u.Scheme == "https": netDial = (&httpsProxyDialer{proxyURL: proxyURL, forwardDial: netDial}).DialContext default: - dialer, err := proxy.FromURL(proxyURL, netDial) + netDial, err = getDefaultDialerFunc() if err != nil { return nil, nil, err } - if d, ok := dialer.(proxy.ContextDialer); ok { - netDial = d.DialContext - } else { - netDial = func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - } - } } } diff --git a/client_server_test.go b/client_server_test.go index 610fbe2..678be10 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -173,24 +173,19 @@ func TestProxyDial(t *testing.T) { cstDialer := cstDialer // make local copy for modification on next line. cstDialer.Proxy = http.ProxyURL(surl) - connect := false origHandler := s.Server.Config.Handler // Capture the request Host header. s.Server.Config.Handler = http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodConnect { - connect = true + // HTTPS_PROXY comes here. w.WriteHeader(http.StatusOK) - return } - if !connect { - t.Log("connect not received") - http.Error(w, "connect not received", http.StatusMethodNotAllowed) - return - } + // HTTP_PROXY comes here. origHandler.ServeHTTP(w, r) + return }) ws, _, err := cstDialer.Dial(s.URL, nil) @@ -211,7 +206,6 @@ func TestProxyAuthorizationDial(t *testing.T) { cstDialer := cstDialer // make local copy for modification on next line. cstDialer.Proxy = http.ProxyURL(surl) - connect := false origHandler := s.Server.Config.Handler // Capture the request Host header. @@ -219,17 +213,22 @@ func TestProxyAuthorizationDial(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { proxyAuth := r.Header.Get("Proxy-Authorization") expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) - if r.Method == http.MethodConnect && proxyAuth == expectedProxyAuth { - connect = true - w.WriteHeader(http.StatusOK) + if proxyAuth != expectedProxyAuth { + msg := fmt.Sprintf("expected proxy authorization is %q, but %q is given", expectedProxyAuth, proxyAuth) + + t.Log(msg) + http.Error( + w, + msg, + http.StatusProxyAuthRequired, + ) return } - if !connect { - t.Log("connect with proxy authorization not received") - http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed) - return + if r.Method == http.MethodConnect { + w.WriteHeader(http.StatusOK) } + origHandler.ServeHTTP(w, r) })