add tests

This commit is contained in:
Cooper Oh 2024-07-18 23:43:04 +09:00
parent bad5b0af7f
commit 5e557d257e
2 changed files with 149 additions and 6 deletions

View File

@ -258,10 +258,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)
if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
tlsClientConfig := cloneTLSConfig(d.TLSClientConfig)
if d.TLSClientConfig == nil {
tlsClientConfig = &tls.Config{
ServerName: proxyURL.Hostname(),
}
if tlsClientConfig.ServerName == "" {
_, hostNoPort := hostPortNoPort(proxyURL)
tlsClientConfig.ServerName = hostNoPort
}
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig)
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
@ -369,7 +368,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
if proto != "http/1.1" {
return nil, nil, fmt.Errorf(
"websocket: protocol %q was given but is not supported;"+
"sharing tls.Config with net/http Transport can cause this error: %w",
"sharing tlsServerName.Config with net/http Transport can cause this error: %w",
proto, err,
)
}

View File

@ -85,6 +85,51 @@ func newTLSServer(t *testing.T) *cstServer {
return &s
}
type cstProxyServer struct{}
func (s *cstProxyServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()
upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host)
if err != nil {
_, _ = fmt.Fprintf(conn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
return
}
defer upstream.Close()
_, _ = fmt.Fprintf(conn, "HTTP/1.1 200 Connection established\r\n\r\n")
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(upstream, conn)
}()
go func() {
defer wg.Done()
_, _ = io.Copy(conn, upstream)
}()
wg.Wait()
}
func newProxyServer() *httptest.Server {
return httptest.NewServer(&cstProxyServer{})
}
func newTLSProxyServer() *httptest.Server {
return httptest.NewTLSServer(&cstProxyServer{})
}
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Because tests wait for a response from a server, we are guaranteed that
// the wait group count is incremented before the test waits on the group
@ -165,7 +210,6 @@ func sendRecv(t *testing.T, ws *Conn) {
}
func TestProxyDial(t *testing.T) {
s := newServer(t)
defer s.Close()
@ -202,6 +246,106 @@ func TestProxyDial(t *testing.T) {
sendRecv(t, ws)
}
func TestProxyDialer(t *testing.T) {
testcases := []struct {
name string
isTLS bool
tlsServerName string
insecureSkipVerify bool
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
}{{
name: "http",
isTLS: false,
}, {
name: "https",
isTLS: true,
}, {
name: "https with ServerName",
isTLS: true,
tlsServerName: "example.com",
}, {
name: "https with insecureSkipVerify",
isTLS: true,
insecureSkipVerify: true,
}, {
name: "https with netDialTLSContext",
isTLS: true,
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &tls.Dialer{
Config: &tls.Config{
InsecureSkipVerify: true,
},
}
return dialer.DialContext(ctx, network, addr)
},
}}
for _, tc := range testcases {
t.Run(tc.name, func(tt *testing.T) {
s := newServer(tt)
defer s.Close()
var ps *httptest.Server
if tc.isTLS {
ps = newTLSProxyServer()
} else {
ps = newProxyServer()
}
psurl, _ := url.Parse(ps.URL)
netDialCalled := false
cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(psurl)
if tc.isTLS {
cstDialer.TLSClientConfig = &tls.Config{
RootCAs: rootCAs(tt, ps),
ServerName: tc.tlsServerName,
InsecureSkipVerify: tc.insecureSkipVerify,
}
if tc.netDialTLSContext != nil {
cstDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialCalled = true
return tc.netDialTLSContext(ctx, network, addr)
}
} else {
netDialCalled = true
}
} else {
netDialCalled = true
}
connect := false
origHandler := ps.Config.Handler
// Capture the request Host header.
ps.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
connect = true
}
origHandler.ServeHTTP(w, r)
})
ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
tt.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(tt, ws)
if !connect {
tt.Error("connect not received")
}
if !netDialCalled {
tt.Error("netDialTLSContext not called")
}
})
}
}
func TestProxyAuthorizationDial(t *testing.T) {
s := newServer(t)
defer s.Close()