mirror of https://github.com/gorilla/websocket.git
make it more intuitive for tls proxy
This commit is contained in:
parent
2a082eee69
commit
7f3a5bcae0
|
@ -274,15 +274,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if proxyURL != nil {
|
if proxyURL != nil {
|
||||||
|
proxyDialer := &netDialerFunc{fn: netDial}
|
||||||
if proxyURL.Scheme == "https" {
|
if proxyURL.Scheme == "https" {
|
||||||
netDial = func(network, addr string) (net.Conn, error) {
|
proxyDialer.usesTLS = true
|
||||||
|
proxyDialer.fn = func(network, addr string) (net.Conn, error) {
|
||||||
t := tls.Dialer{}
|
t := tls.Dialer{}
|
||||||
t.Config = d.TLSClientConfig
|
t.Config = d.TLSClientConfig
|
||||||
t.NetDialer = &net.Dialer{}
|
t.NetDialer = &net.Dialer{}
|
||||||
return t.DialContext(ctx, network, addr)
|
return t.DialContext(ctx, network, addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
|
dialer, err := proxy_FromURL(proxyURL, proxyDialer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
31
proxy.go
31
proxy.go
|
@ -6,6 +6,7 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"crypto/tls"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
@ -14,24 +15,40 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type netDialerFunc func(network, addr string) (net.Conn, error)
|
type netDialerFunc struct {
|
||||||
|
fn func(network, addr string) (net.Conn, error)
|
||||||
|
usesTLS bool
|
||||||
|
}
|
||||||
|
|
||||||
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
func (ndf *netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
||||||
return fn(network, addr)
|
return ndf.fn(network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ndf *netDialerFunc) UsesTLS() bool {
|
||||||
|
return ndf.usesTLS
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
||||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, usesTLS: forwardDialer.UsesTLS()}, nil
|
||||||
})
|
})
|
||||||
proxy_RegisterDialerType("https", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
proxy_RegisterDialerType("https", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
||||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
fwd := forwardDialer.Dial
|
||||||
|
if !forwardDialer.UsesTLS() {
|
||||||
|
tlsDialer := &tls.Dialer{
|
||||||
|
Config: &tls.Config{},
|
||||||
|
NetDialer: &net.Dialer{},
|
||||||
|
}
|
||||||
|
fwd = tlsDialer.Dial
|
||||||
|
}
|
||||||
|
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: fwd, usesTLS: true}, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpProxyDialer struct {
|
type httpProxyDialer struct {
|
||||||
proxyURL *url.URL
|
proxyURL *url.URL
|
||||||
forwardDial func(network, addr string) (net.Conn, error)
|
forwardDial func(network, addr string) (net.Conn, error)
|
||||||
|
usesTLS bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
|
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
|
||||||
|
@ -78,3 +95,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hpd *httpProxyDialer) UsesTLS() bool {
|
||||||
|
return hpd.usesTLS
|
||||||
|
}
|
||||||
|
|
|
@ -27,6 +27,10 @@ func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
|
||||||
return net.Dial(network, addr)
|
return net.Dial(network, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (proxy_direct) UsesTLS() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// A PerHost directs connections to a default Dialer unless the host name
|
// A PerHost directs connections to a default Dialer unless the host name
|
||||||
// requested matches one of a number of exceptions.
|
// requested matches one of a number of exceptions.
|
||||||
type proxy_PerHost struct {
|
type proxy_PerHost struct {
|
||||||
|
@ -59,6 +63,10 @@ func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
|
||||||
return p.dialerForRequest(host).Dial(network, addr)
|
return p.dialerForRequest(host).Dial(network, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *proxy_PerHost) UsesTLS() bool {
|
||||||
|
return p.def.UsesTLS() || p.bypass.UsesTLS()
|
||||||
|
}
|
||||||
|
|
||||||
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
|
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
for _, net := range p.bypassNetworks {
|
for _, net := range p.bypassNetworks {
|
||||||
|
@ -161,6 +169,8 @@ func (p *proxy_PerHost) AddHost(host string) {
|
||||||
type proxy_Dialer interface {
|
type proxy_Dialer interface {
|
||||||
// Dial connects to the given address via the proxy.
|
// Dial connects to the given address via the proxy.
|
||||||
Dial(network, addr string) (c net.Conn, err error)
|
Dial(network, addr string) (c net.Conn, err error)
|
||||||
|
// UsesTLS indicates whether we expect to dial to a TLS proxy
|
||||||
|
UsesTLS() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth contains authentication parameters that specific Dialers may require.
|
// Auth contains authentication parameters that specific Dialers may require.
|
||||||
|
@ -338,6 +348,10 @@ func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *proxy_socks5) UsesTLS() bool {
|
||||||
|
return s.forward.UsesTLS()
|
||||||
|
}
|
||||||
|
|
||||||
// connect takes an existing connection to a socks5 proxy server,
|
// connect takes an existing connection to a socks5 proxy server,
|
||||||
// and commands the server to extend that connection to target,
|
// and commands the server to extend that connection to target,
|
||||||
// which must be a canonical address with a host and port.
|
// which must be a canonical address with a host and port.
|
||||||
|
|
Loading…
Reference in New Issue