From 76d8f8c50354545de1b982b558b50a447bfa36e2 Mon Sep 17 00:00:00 2001 From: Cooper Oh Date: Thu, 20 Jun 2024 16:17:03 +0900 Subject: [PATCH] support http proxy correctly --- client.go | 25 +++++++++++++++++++++++-- proxy.go | 40 ++++++++++++++-------------------------- 2 files changed, 37 insertions(+), 28 deletions(-) diff --git a/client.go b/client.go index bef9434..e5c7df8 100644 --- a/client.go +++ b/client.go @@ -17,6 +17,8 @@ import ( "net/url" "strings" "time" + + "golang.org/x/net/proxy" ) // ErrBadHandshake is returned when the server response to opening handshake is @@ -281,11 +283,30 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h if err != nil { return nil, nil, err } - if proxyURL != nil { - netDial, err = proxyFromURL(proxyURL, netDial) + switch { + case proxyURL == nil: + // Do nothing. Not using a proxy. + case u.Scheme == "http": + if pa := proxyAuth(proxyURL.User); pa != "" { + req.Header.Set("Proxy-Authorization", pa) + } + netDial = func(ctx context.Context, network, addr string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, network, proxyURL.Host) + } + case u.Scheme == "https": + netDial = (&httpsProxyDialer{proxyURL: proxyURL, forwardDial: netDial}).DialContext + default: + dialer, err := proxy.FromURL(proxyURL, netDial) if err != nil { return nil, nil, err } + if d, ok := dialer.(proxy.ContextDialer); ok { + netDial = d.DialContext + } else { + netDial = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + } + } } } diff --git a/proxy.go b/proxy.go index f113710..cce25da 100644 --- a/proxy.go +++ b/proxy.go @@ -14,8 +14,6 @@ import ( "net/http" "net/url" "strings" - - "golang.org/x/net/proxy" ) type netDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error) @@ -28,28 +26,12 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) ( return fn(ctx, network, addr) } -func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) { - if proxyURL.Scheme == "http" { - return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil - } - dialer, err := proxy.FromURL(proxyURL, forwardDial) - if err != nil { - return nil, err - } - if d, ok := dialer.(proxy.ContextDialer); ok { - return d.DialContext, nil - } - return func(ctx context.Context, net, addr string) (net.Conn, error) { - return dialer.Dial(net, addr) - }, nil -} - -type httpProxyDialer struct { +type httpsProxyDialer struct { proxyURL *url.URL forwardDial netDialerFunc } -func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (hpd *httpsProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { hostPort, _ := hostPortNoPort(hpd.proxyURL) conn, err := hpd.forwardDial(ctx, network, hostPort) if err != nil { @@ -57,12 +39,8 @@ func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, add } connectHeader := make(http.Header) - if user := hpd.proxyURL.User; user != nil { - proxyUser := user.Username() - if proxyPassword, passwordSet := user.Password(); passwordSet { - credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) - connectHeader.Set("Proxy-Authorization", "Basic "+credential) - } + if pa := proxyAuth(hpd.proxyURL.User); pa != "" { + connectHeader.Set("Proxy-Authorization", pa) } connectReq := &http.Request{ @@ -103,3 +81,13 @@ func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, add } return conn, nil } + +func proxyAuth(user *url.Userinfo) string { + if user != nil { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + return "Basic " + base64.StdEncoding.EncodeToString([]byte(proxyUser+":"+proxyPassword)) + } + } + return "" +}