From 5e557d257ee50dce975600b89d927ba049df3c84 Mon Sep 17 00:00:00 2001 From: Cooper Oh Date: Thu, 18 Jul 2024 23:43:04 +0900 Subject: [PATCH] add tests --- client.go | 9 ++- client_server_test.go | 146 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 149 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 0dc5119..043e456 100644 --- a/client.go +++ b/client.go @@ -258,10 +258,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext) if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil { tlsClientConfig := cloneTLSConfig(d.TLSClientConfig) - if d.TLSClientConfig == nil { - tlsClientConfig = &tls.Config{ - ServerName: proxyURL.Hostname(), - } + if tlsClientConfig.ServerName == "" { + _, hostNoPort := hostPortNoPort(proxyURL) + tlsClientConfig.ServerName = hostNoPort } netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig) } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { @@ -369,7 +368,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h if proto != "http/1.1" { return nil, nil, fmt.Errorf( "websocket: protocol %q was given but is not supported;"+ - "sharing tls.Config with net/http Transport can cause this error: %w", + "sharing tlsServerName.Config with net/http Transport can cause this error: %w", proto, err, ) } diff --git a/client_server_test.go b/client_server_test.go index e4546ae..5a2113f 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -85,6 +85,51 @@ func newTLSServer(t *testing.T) *cstServer { return &s } +type cstProxyServer struct{} + +func (s *cstProxyServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodConnect { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer conn.Close() + + upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host) + if err != nil { + _, _ = fmt.Fprintf(conn, "HTTP/1.1 502 Bad Gateway\r\n\r\n") + return + } + defer upstream.Close() + + _, _ = fmt.Fprintf(conn, "HTTP/1.1 200 Connection established\r\n\r\n") + + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + _, _ = io.Copy(upstream, conn) + }() + go func() { + defer wg.Done() + _, _ = io.Copy(conn, upstream) + }() + wg.Wait() +} + +func newProxyServer() *httptest.Server { + return httptest.NewServer(&cstProxyServer{}) +} + +func newTLSProxyServer() *httptest.Server { + return httptest.NewTLSServer(&cstProxyServer{}) +} + func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Because tests wait for a response from a server, we are guaranteed that // the wait group count is incremented before the test waits on the group @@ -165,7 +210,6 @@ func sendRecv(t *testing.T, ws *Conn) { } func TestProxyDial(t *testing.T) { - s := newServer(t) defer s.Close() @@ -202,6 +246,106 @@ func TestProxyDial(t *testing.T) { sendRecv(t, ws) } +func TestProxyDialer(t *testing.T) { + testcases := []struct { + name string + isTLS bool + tlsServerName string + insecureSkipVerify bool + netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + }{{ + name: "http", + isTLS: false, + }, { + name: "https", + isTLS: true, + }, { + name: "https with ServerName", + isTLS: true, + tlsServerName: "example.com", + }, { + name: "https with insecureSkipVerify", + isTLS: true, + insecureSkipVerify: true, + }, { + name: "https with netDialTLSContext", + isTLS: true, + netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &tls.Dialer{ + Config: &tls.Config{ + InsecureSkipVerify: true, + }, + } + return dialer.DialContext(ctx, network, addr) + }, + }} + + for _, tc := range testcases { + t.Run(tc.name, func(tt *testing.T) { + s := newServer(tt) + defer s.Close() + + var ps *httptest.Server + if tc.isTLS { + ps = newTLSProxyServer() + } else { + ps = newProxyServer() + } + + psurl, _ := url.Parse(ps.URL) + + netDialCalled := false + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(psurl) + if tc.isTLS { + cstDialer.TLSClientConfig = &tls.Config{ + RootCAs: rootCAs(tt, ps), + ServerName: tc.tlsServerName, + InsecureSkipVerify: tc.insecureSkipVerify, + } + if tc.netDialTLSContext != nil { + cstDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + netDialCalled = true + return tc.netDialTLSContext(ctx, network, addr) + } + } else { + netDialCalled = true + } + } else { + netDialCalled = true + } + + connect := false + origHandler := ps.Config.Handler + + // Capture the request Host header. + ps.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + connect = true + } + + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + tt.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(tt, ws) + + if !connect { + tt.Error("connect not received") + } + if !netDialCalled { + tt.Error("netDialTLSContext not called") + } + }) + } +} + func TestProxyAuthorizationDial(t *testing.T) { s := newServer(t) defer s.Close()