mirror of https://github.com/gorilla/websocket.git
add tests
This commit is contained in:
parent
bad5b0af7f
commit
5e557d257e
|
@ -258,10 +258,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)
|
||||
if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
|
||||
tlsClientConfig := cloneTLSConfig(d.TLSClientConfig)
|
||||
if d.TLSClientConfig == nil {
|
||||
tlsClientConfig = &tls.Config{
|
||||
ServerName: proxyURL.Hostname(),
|
||||
}
|
||||
if tlsClientConfig.ServerName == "" {
|
||||
_, hostNoPort := hostPortNoPort(proxyURL)
|
||||
tlsClientConfig.ServerName = hostNoPort
|
||||
}
|
||||
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig)
|
||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
|
@ -369,7 +368,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
if proto != "http/1.1" {
|
||||
return nil, nil, fmt.Errorf(
|
||||
"websocket: protocol %q was given but is not supported;"+
|
||||
"sharing tls.Config with net/http Transport can cause this error: %w",
|
||||
"sharing tlsServerName.Config with net/http Transport can cause this error: %w",
|
||||
proto, err,
|
||||
)
|
||||
}
|
||||
|
|
|
@ -85,6 +85,51 @@ func newTLSServer(t *testing.T) *cstServer {
|
|||
return &s
|
||||
}
|
||||
|
||||
type cstProxyServer struct{}
|
||||
|
||||
func (s *cstProxyServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method != http.MethodConnect {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
conn, _, err := w.(http.Hijacker).Hijack()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(conn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
|
||||
return
|
||||
}
|
||||
defer upstream.Close()
|
||||
|
||||
_, _ = fmt.Fprintf(conn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(upstream, conn)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = io.Copy(conn, upstream)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func newProxyServer() *httptest.Server {
|
||||
return httptest.NewServer(&cstProxyServer{})
|
||||
}
|
||||
|
||||
func newTLSProxyServer() *httptest.Server {
|
||||
return httptest.NewTLSServer(&cstProxyServer{})
|
||||
}
|
||||
|
||||
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Because tests wait for a response from a server, we are guaranteed that
|
||||
// the wait group count is incremented before the test waits on the group
|
||||
|
@ -165,7 +210,6 @@ func sendRecv(t *testing.T, ws *Conn) {
|
|||
}
|
||||
|
||||
func TestProxyDial(t *testing.T) {
|
||||
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
|
@ -202,6 +246,106 @@ func TestProxyDial(t *testing.T) {
|
|||
sendRecv(t, ws)
|
||||
}
|
||||
|
||||
func TestProxyDialer(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
isTLS bool
|
||||
tlsServerName string
|
||||
insecureSkipVerify bool
|
||||
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}{{
|
||||
name: "http",
|
||||
isTLS: false,
|
||||
}, {
|
||||
name: "https",
|
||||
isTLS: true,
|
||||
}, {
|
||||
name: "https with ServerName",
|
||||
isTLS: true,
|
||||
tlsServerName: "example.com",
|
||||
}, {
|
||||
name: "https with insecureSkipVerify",
|
||||
isTLS: true,
|
||||
insecureSkipVerify: true,
|
||||
}, {
|
||||
name: "https with netDialTLSContext",
|
||||
isTLS: true,
|
||||
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &tls.Dialer{
|
||||
Config: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
}}
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(tt *testing.T) {
|
||||
s := newServer(tt)
|
||||
defer s.Close()
|
||||
|
||||
var ps *httptest.Server
|
||||
if tc.isTLS {
|
||||
ps = newTLSProxyServer()
|
||||
} else {
|
||||
ps = newProxyServer()
|
||||
}
|
||||
|
||||
psurl, _ := url.Parse(ps.URL)
|
||||
|
||||
netDialCalled := false
|
||||
|
||||
cstDialer := cstDialer // make local copy for modification on next line.
|
||||
cstDialer.Proxy = http.ProxyURL(psurl)
|
||||
if tc.isTLS {
|
||||
cstDialer.TLSClientConfig = &tls.Config{
|
||||
RootCAs: rootCAs(tt, ps),
|
||||
ServerName: tc.tlsServerName,
|
||||
InsecureSkipVerify: tc.insecureSkipVerify,
|
||||
}
|
||||
if tc.netDialTLSContext != nil {
|
||||
cstDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
netDialCalled = true
|
||||
return tc.netDialTLSContext(ctx, network, addr)
|
||||
}
|
||||
} else {
|
||||
netDialCalled = true
|
||||
}
|
||||
} else {
|
||||
netDialCalled = true
|
||||
}
|
||||
|
||||
connect := false
|
||||
origHandler := ps.Config.Handler
|
||||
|
||||
// Capture the request Host header.
|
||||
ps.Config.Handler = http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodConnect {
|
||||
connect = true
|
||||
}
|
||||
|
||||
origHandler.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
ws, _, err := cstDialer.Dial(s.URL, nil)
|
||||
if err != nil {
|
||||
tt.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
sendRecv(tt, ws)
|
||||
|
||||
if !connect {
|
||||
tt.Error("connect not received")
|
||||
}
|
||||
if !netDialCalled {
|
||||
tt.Error("netDialTLSContext not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyAuthorizationDial(t *testing.T) {
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
|
Loading…
Reference in New Issue