diff --git a/client.go b/client.go index 170301d..9b6e18a 100644 --- a/client.go +++ b/client.go @@ -303,7 +303,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h return nil, nil, err } if proxyURL != nil { - dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) + proxyDialer := &netDialerFunc{fn: netDial} + modifyProxyDialer(ctx, d, proxyURL, proxyDialer) + dialer, err := proxy_FromURL(proxyURL, proxyDialer) if err != nil { return nil, nil, err } diff --git a/client_server_httpsproxy_test.go b/client_server_httpsproxy_test.go new file mode 100644 index 0000000..71f0d55 --- /dev/null +++ b/client_server_httpsproxy_test.go @@ -0,0 +1,52 @@ +//go:build go1.15 +// +build go1.15 + +package websocket + +import ( + "crypto/tls" + "net/http" + "net/url" + "testing" +) + +func TestHttpsProxy(t *testing.T) { + + sTLS := newTLSServer(t) + defer sTLS.Close() + s := newServer(t) + defer s.Close() + + surlTLS, _ := url.Parse(sTLS.Server.URL) + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(surlTLS) + + connect := false + origHandler := sTLS.Server.Config.Handler + + // Capture the request Host header. + sTLS.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + connect = true + w.WriteHeader(http.StatusOK) + return + } + + if !connect { + t.Log("connect not received") + http.Error(w, "connect not received", http.StatusMethodNotAllowed) + return + } + origHandler.ServeHTTP(w, r) + }) + + cstDialer.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, sTLS.Server)} + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} diff --git a/proxy.go b/proxy.go index 18abf6e..a370b31 100644 --- a/proxy.go +++ b/proxy.go @@ -15,21 +15,37 @@ import ( "strings" ) -type netDialerFunc func(network, addr string) (net.Conn, error) +// proxyDialerEx extends the generated proxy_Dialer +type proxyDialerEx interface { + proxy_Dialer + // UsesTLS indicates whether we expect to dial to a TLS proxy + UsesTLS() bool +} -func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { - return fn(network, addr) +type netDialerFunc struct { + fn func(network, addr string) (net.Conn, error) + usesTLS bool +} + +func (ndf *netDialerFunc) Dial(network, addr string) (net.Conn, error) { + return ndf.fn(network, addr) +} + +func (ndf *netDialerFunc) UsesTLS() bool { + return ndf.usesTLS } func init() { proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { - return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil + return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, usesTLS: false}, nil }) + registerDialerHttps() } type httpProxyDialer struct { proxyURL *url.URL forwardDial func(network, addr string) (net.Conn, error) + usesTLS bool } func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { @@ -86,3 +102,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) } return conn, nil } + +func (hpd *httpProxyDialer) UsesTLS() bool { + return hpd.usesTLS +} diff --git a/proxy_https.go b/proxy_https.go new file mode 100644 index 0000000..f74258a --- /dev/null +++ b/proxy_https.go @@ -0,0 +1,37 @@ +//go:build go1.15 +// +build go1.15 + +package websocket + +import ( + "context" + "crypto/tls" + "net" + "net/url" +) + +func registerDialerHttps() { + proxy_RegisterDialerType("https", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { + fwd := forwardDialer.Dial + if dialerEx, ok := forwardDialer.(proxyDialerEx); !ok || !dialerEx.UsesTLS() { + tlsDialer := &tls.Dialer{ + Config: &tls.Config{}, + NetDialer: &net.Dialer{}, + } + fwd = tlsDialer.Dial + } + return &httpProxyDialer{proxyURL: proxyURL, forwardDial: fwd, usesTLS: true}, nil + }) +} + +func modifyProxyDialer(ctx context.Context, d *Dialer, proxyURL *url.URL, proxyDialer *netDialerFunc) { + if proxyURL.Scheme == "https" { + proxyDialer.usesTLS = true + proxyDialer.fn = func(network, addr string) (net.Conn, error) { + t := tls.Dialer{} + t.Config = d.TLSClientConfig + t.NetDialer = &net.Dialer{} + return t.DialContext(ctx, network, addr) + } + } +} diff --git a/proxy_https_legacy.go b/proxy_https_legacy.go new file mode 100644 index 0000000..40bc5e2 --- /dev/null +++ b/proxy_https_legacy.go @@ -0,0 +1,15 @@ +//go:build !go1.15 +// +build !go1.15 + +package websocket + +import ( + "context" + "net/url" +) + +func registerDialerHttps() { +} + +func modifyProxyDialer(ctx context.Context, d *Dialer, proxyURL *url.URL, proxyDialer *netDialerFunc) { +}