diff --git a/client.go b/client.go index efcdc5d..ce07329 100644 --- a/client.go +++ b/client.go @@ -275,15 +275,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } if proxyURL != nil { proxyDialer := &netDialerFunc{fn: netDial} - if proxyURL.Scheme == "https" { - proxyDialer.usesTLS = true - proxyDialer.fn = func(network, addr string) (net.Conn, error) { - t := tls.Dialer{} - t.Config = d.TLSClientConfig - t.NetDialer = &net.Dialer{} - return t.DialContext(ctx, network, addr) - } - } + modifyProxyDialer(ctx, d, proxyURL, proxyDialer) dialer, err := proxy_FromURL(proxyURL, proxyDialer) if err != nil { return nil, nil, err diff --git a/client_server_httpsproxy_test.go b/client_server_httpsproxy_test.go new file mode 100644 index 0000000..214ce33 --- /dev/null +++ b/client_server_httpsproxy_test.go @@ -0,0 +1,51 @@ +// +build go1.15 + +package websocket + +import ( + "crypto/tls" + "net/http" + "net/url" + "testing" +) + +func TestHttpsProxy(t *testing.T) { + + sTLS := newTLSServer(t) + defer sTLS.Close() + s := newServer(t) + defer s.Close() + + surlTLS, _ := url.Parse(sTLS.Server.URL) + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(surlTLS) + + connect := false + origHandler := sTLS.Server.Config.Handler + + // Capture the request Host header. + sTLS.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + connect = true + w.WriteHeader(http.StatusOK) + return + } + + if !connect { + t.Log("connect not received") + http.Error(w, "connect not received", http.StatusMethodNotAllowed) + return + } + origHandler.ServeHTTP(w, r) + }) + + cstDialer.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, sTLS.Server)} + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} diff --git a/client_server_test.go b/client_server_test.go index 116b043..5fd2c85 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -185,47 +185,6 @@ func TestProxyDial(t *testing.T) { sendRecv(t, ws) } -func TestHttpsProxy(t *testing.T) { - - sTLS := newTLSServer(t) - defer sTLS.Close() - s := newServer(t) - defer s.Close() - - surlTLS, _ := url.Parse(sTLS.Server.URL) - - cstDialer := cstDialer // make local copy for modification on next line. - cstDialer.Proxy = http.ProxyURL(surlTLS) - - connect := false - origHandler := sTLS.Server.Config.Handler - - // Capture the request Host header. - sTLS.Server.Config.Handler = http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.Method == "CONNECT" { - connect = true - w.WriteHeader(http.StatusOK) - return - } - - if !connect { - t.Log("connect not received") - http.Error(w, "connect not received", http.StatusMethodNotAllowed) - return - } - origHandler.ServeHTTP(w, r) - }) - - cstDialer.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, sTLS.Server)} - ws, _, err := cstDialer.Dial(s.URL, nil) - if err != nil { - t.Fatalf("Dial: %v", err) - } - defer ws.Close() - sendRecv(t, ws) -} - func TestProxyAuthorizationDial(t *testing.T) { s := newServer(t) defer s.Close() diff --git a/proxy.go b/proxy.go index c66b46c..8abf01d 100644 --- a/proxy.go +++ b/proxy.go @@ -6,7 +6,6 @@ package websocket import ( "bufio" - "crypto/tls" "encoding/base64" "errors" "net" @@ -39,17 +38,7 @@ func init() { proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, usesTLS: false}, nil }) - proxy_RegisterDialerType("https", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { - fwd := forwardDialer.Dial - if dialerEx, ok := forwardDialer.(proxyDialerEx); !ok || !dialerEx.UsesTLS() { - tlsDialer := &tls.Dialer{ - Config: &tls.Config{}, - NetDialer: &net.Dialer{}, - } - fwd = tlsDialer.Dial - } - return &httpProxyDialer{proxyURL: proxyURL, forwardDial: fwd, usesTLS: true}, nil - }) + registerDialerHttps() } type httpProxyDialer struct { diff --git a/proxy_https.go b/proxy_https.go new file mode 100644 index 0000000..3cc2baa --- /dev/null +++ b/proxy_https.go @@ -0,0 +1,36 @@ +// +build go1.15 + +package websocket + +import ( + "context" + "crypto/tls" + "net" + "net/url" +) + +func registerDialerHttps() { + proxy_RegisterDialerType("https", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { + fwd := forwardDialer.Dial + if dialerEx, ok := forwardDialer.(proxyDialerEx); !ok || !dialerEx.UsesTLS() { + tlsDialer := &tls.Dialer{ + Config: &tls.Config{}, + NetDialer: &net.Dialer{}, + } + fwd = tlsDialer.Dial + } + return &httpProxyDialer{proxyURL: proxyURL, forwardDial: fwd, usesTLS: true}, nil + }) +} + +func modifyProxyDialer(ctx context.Context, d *Dialer, proxyURL *url.URL, proxyDialer *netDialerFunc) { + if proxyURL.Scheme == "https" { + proxyDialer.usesTLS = true + proxyDialer.fn = func(network, addr string) (net.Conn, error) { + t := tls.Dialer{} + t.Config = d.TLSClientConfig + t.NetDialer = &net.Dialer{} + return t.DialContext(ctx, network, addr) + } + } +} diff --git a/proxy_https_legacy.go b/proxy_https_legacy.go new file mode 100644 index 0000000..2c8595c --- /dev/null +++ b/proxy_https_legacy.go @@ -0,0 +1,15 @@ +// +build !go1.15 + +package websocket + +import ( + "context" + "net/url" +) + +func registerDialerHttps() { +} + +func modifyProxyDialer(ctx context.Context, d *Dialer, proxyURL *url.URL, proxyDialer *netDialerFunc) { + return nil +}