mirror of https://github.com/gorilla/websocket.git
Merge 4d0a40247b
into ce903f6d1d
This commit is contained in:
commit
ba8e57da75
|
@ -304,7 +304,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
|||
return nil, nil, err
|
||||
}
|
||||
if proxyURL != nil {
|
||||
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
|
||||
proxyDialer := &netDialerFunc{fn: netDial}
|
||||
modifyProxyDialer(ctx, d, proxyURL, proxyDialer)
|
||||
dialer, err := proxy_FromURL(proxyURL, proxyDialer)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
//go:build go1.15
|
||||
// +build go1.15
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHttpsProxy(t *testing.T) {
|
||||
|
||||
sTLS := newTLSServer(t)
|
||||
defer sTLS.Close()
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
surlTLS, _ := url.Parse(sTLS.Server.URL)
|
||||
|
||||
cstDialer := cstDialer // make local copy for modification on next line.
|
||||
cstDialer.Proxy = http.ProxyURL(surlTLS)
|
||||
|
||||
connect := false
|
||||
origHandler := sTLS.Server.Config.Handler
|
||||
|
||||
// Capture the request Host header.
|
||||
sTLS.Server.Config.Handler = http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "CONNECT" {
|
||||
connect = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if !connect {
|
||||
t.Log("connect not received")
|
||||
http.Error(w, "connect not received", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
origHandler.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
cstDialer.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, sTLS.Server)}
|
||||
ws, _, err := cstDialer.Dial(s.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
sendRecv(t, ws)
|
||||
}
|
28
proxy.go
28
proxy.go
|
@ -14,21 +14,37 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
type netDialerFunc func(network, addr string) (net.Conn, error)
|
||||
// proxyDialerEx extends the generated proxy_Dialer
|
||||
type proxyDialerEx interface {
|
||||
proxy_Dialer
|
||||
// UsesTLS indicates whether we expect to dial to a TLS proxy
|
||||
UsesTLS() bool
|
||||
}
|
||||
|
||||
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
||||
return fn(network, addr)
|
||||
type netDialerFunc struct {
|
||||
fn func(network, addr string) (net.Conn, error)
|
||||
usesTLS bool
|
||||
}
|
||||
|
||||
func (ndf *netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
||||
return ndf.fn(network, addr)
|
||||
}
|
||||
|
||||
func (ndf *netDialerFunc) UsesTLS() bool {
|
||||
return ndf.usesTLS
|
||||
}
|
||||
|
||||
func init() {
|
||||
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, usesTLS: false}, nil
|
||||
})
|
||||
registerDialerHttps()
|
||||
}
|
||||
|
||||
type httpProxyDialer struct {
|
||||
proxyURL *url.URL
|
||||
forwardDial func(network, addr string) (net.Conn, error)
|
||||
usesTLS bool
|
||||
}
|
||||
|
||||
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
|
||||
|
@ -75,3 +91,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
|||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (hpd *httpProxyDialer) UsesTLS() bool {
|
||||
return hpd.usesTLS
|
||||
}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
//go:build go1.15
|
||||
// +build go1.15
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func registerDialerHttps() {
|
||||
proxy_RegisterDialerType("https", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
||||
fwd := forwardDialer.Dial
|
||||
if dialerEx, ok := forwardDialer.(proxyDialerEx); !ok || !dialerEx.UsesTLS() {
|
||||
tlsDialer := &tls.Dialer{
|
||||
Config: &tls.Config{},
|
||||
NetDialer: &net.Dialer{},
|
||||
}
|
||||
fwd = tlsDialer.Dial
|
||||
}
|
||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: fwd, usesTLS: true}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func modifyProxyDialer(ctx context.Context, d *Dialer, proxyURL *url.URL, proxyDialer *netDialerFunc) {
|
||||
if proxyURL.Scheme == "https" {
|
||||
proxyDialer.usesTLS = true
|
||||
proxyDialer.fn = func(network, addr string) (net.Conn, error) {
|
||||
t := tls.Dialer{}
|
||||
t.Config = d.TLSClientConfig
|
||||
t.NetDialer = &net.Dialer{}
|
||||
return t.DialContext(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
//go:build !go1.15
|
||||
// +build !go1.15
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func registerDialerHttps() {
|
||||
}
|
||||
|
||||
func modifyProxyDialer(ctx context.Context, d *Dialer, proxyURL *url.URL, proxyDialer *netDialerFunc) {
|
||||
}
|
Loading…
Reference in New Issue