mirror of https://github.com/gorilla/websocket.git
add tests
This commit is contained in:
parent
bad5b0af7f
commit
5e557d257e
|
@ -258,10 +258,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)
|
forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)
|
||||||
if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
|
if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
|
||||||
tlsClientConfig := cloneTLSConfig(d.TLSClientConfig)
|
tlsClientConfig := cloneTLSConfig(d.TLSClientConfig)
|
||||||
if d.TLSClientConfig == nil {
|
if tlsClientConfig.ServerName == "" {
|
||||||
tlsClientConfig = &tls.Config{
|
_, hostNoPort := hostPortNoPort(proxyURL)
|
||||||
ServerName: proxyURL.Hostname(),
|
tlsClientConfig.ServerName = hostNoPort
|
||||||
}
|
|
||||||
}
|
}
|
||||||
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig)
|
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig)
|
||||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||||
|
@ -369,7 +368,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
if proto != "http/1.1" {
|
if proto != "http/1.1" {
|
||||||
return nil, nil, fmt.Errorf(
|
return nil, nil, fmt.Errorf(
|
||||||
"websocket: protocol %q was given but is not supported;"+
|
"websocket: protocol %q was given but is not supported;"+
|
||||||
"sharing tls.Config with net/http Transport can cause this error: %w",
|
"sharing tlsServerName.Config with net/http Transport can cause this error: %w",
|
||||||
proto, err,
|
proto, err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,6 +85,51 @@ func newTLSServer(t *testing.T) *cstServer {
|
||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cstProxyServer struct{}
|
||||||
|
|
||||||
|
func (s *cstProxyServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.Method != http.MethodConnect {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, _, err := w.(http.Hijacker).Hijack()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(conn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer upstream.Close()
|
||||||
|
|
||||||
|
_, _ = fmt.Fprintf(conn, "HTTP/1.1 200 Connection established\r\n\r\n")
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(upstream, conn)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(conn, upstream)
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProxyServer() *httptest.Server {
|
||||||
|
return httptest.NewServer(&cstProxyServer{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTLSProxyServer() *httptest.Server {
|
||||||
|
return httptest.NewTLSServer(&cstProxyServer{})
|
||||||
|
}
|
||||||
|
|
||||||
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// Because tests wait for a response from a server, we are guaranteed that
|
// Because tests wait for a response from a server, we are guaranteed that
|
||||||
// the wait group count is incremented before the test waits on the group
|
// the wait group count is incremented before the test waits on the group
|
||||||
|
@ -165,7 +210,6 @@ func sendRecv(t *testing.T, ws *Conn) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyDial(t *testing.T) {
|
func TestProxyDial(t *testing.T) {
|
||||||
|
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
|
@ -202,6 +246,106 @@ func TestProxyDial(t *testing.T) {
|
||||||
sendRecv(t, ws)
|
sendRecv(t, ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyDialer(t *testing.T) {
|
||||||
|
testcases := []struct {
|
||||||
|
name string
|
||||||
|
isTLS bool
|
||||||
|
tlsServerName string
|
||||||
|
insecureSkipVerify bool
|
||||||
|
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
}{{
|
||||||
|
name: "http",
|
||||||
|
isTLS: false,
|
||||||
|
}, {
|
||||||
|
name: "https",
|
||||||
|
isTLS: true,
|
||||||
|
}, {
|
||||||
|
name: "https with ServerName",
|
||||||
|
isTLS: true,
|
||||||
|
tlsServerName: "example.com",
|
||||||
|
}, {
|
||||||
|
name: "https with insecureSkipVerify",
|
||||||
|
isTLS: true,
|
||||||
|
insecureSkipVerify: true,
|
||||||
|
}, {
|
||||||
|
name: "https with netDialTLSContext",
|
||||||
|
isTLS: true,
|
||||||
|
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
dialer := &tls.Dialer{
|
||||||
|
Config: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testcases {
|
||||||
|
t.Run(tc.name, func(tt *testing.T) {
|
||||||
|
s := newServer(tt)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
var ps *httptest.Server
|
||||||
|
if tc.isTLS {
|
||||||
|
ps = newTLSProxyServer()
|
||||||
|
} else {
|
||||||
|
ps = newProxyServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
psurl, _ := url.Parse(ps.URL)
|
||||||
|
|
||||||
|
netDialCalled := false
|
||||||
|
|
||||||
|
cstDialer := cstDialer // make local copy for modification on next line.
|
||||||
|
cstDialer.Proxy = http.ProxyURL(psurl)
|
||||||
|
if tc.isTLS {
|
||||||
|
cstDialer.TLSClientConfig = &tls.Config{
|
||||||
|
RootCAs: rootCAs(tt, ps),
|
||||||
|
ServerName: tc.tlsServerName,
|
||||||
|
InsecureSkipVerify: tc.insecureSkipVerify,
|
||||||
|
}
|
||||||
|
if tc.netDialTLSContext != nil {
|
||||||
|
cstDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
netDialCalled = true
|
||||||
|
return tc.netDialTLSContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
netDialCalled = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
netDialCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
connect := false
|
||||||
|
origHandler := ps.Config.Handler
|
||||||
|
|
||||||
|
// Capture the request Host header.
|
||||||
|
ps.Config.Handler = http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == http.MethodConnect {
|
||||||
|
connect = true
|
||||||
|
}
|
||||||
|
|
||||||
|
origHandler.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
ws, _, err := cstDialer.Dial(s.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
tt.Fatalf("Dial: %v", err)
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
sendRecv(tt, ws)
|
||||||
|
|
||||||
|
if !connect {
|
||||||
|
tt.Error("connect not received")
|
||||||
|
}
|
||||||
|
if !netDialCalled {
|
||||||
|
tt.Error("netDialTLSContext not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxyAuthorizationDial(t *testing.T) {
|
func TestProxyAuthorizationDial(t *testing.T) {
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
Loading…
Reference in New Issue