mirror of https://github.com/gorilla/websocket.git
fix issue #479
This commit is contained in:
parent
b65e62901f
commit
aa46640059
|
@ -99,6 +99,9 @@ type Dialer struct {
|
||||||
// If Jar is nil, cookies are not sent in requests and ignored
|
// If Jar is nil, cookies are not sent in requests and ignored
|
||||||
// in responses.
|
// in responses.
|
||||||
Jar http.CookieJar
|
Jar http.CookieJar
|
||||||
|
|
||||||
|
// Custom proxy connect header
|
||||||
|
ProxyConnectHeader http.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial creates a new client connection by calling DialContext with a background context.
|
// Dial creates a new client connection by calling DialContext with a background context.
|
||||||
|
@ -274,7 +277,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if proxyURL != nil {
|
if proxyURL != nil {
|
||||||
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
|
dialer, err := proxy_FromURL(proxyURL, &netDialer{d.ProxyConnectHeader, netDial})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -156,6 +156,9 @@ func TestProxyDial(t *testing.T) {
|
||||||
|
|
||||||
cstDialer := cstDialer // make local copy for modification on next line.
|
cstDialer := cstDialer // make local copy for modification on next line.
|
||||||
cstDialer.Proxy = http.ProxyURL(surl)
|
cstDialer.Proxy = http.ProxyURL(surl)
|
||||||
|
cstDialer.ProxyConnectHeader = map[string][]string{
|
||||||
|
"User-Agents": {"xxx"},
|
||||||
|
}
|
||||||
|
|
||||||
connect := false
|
connect := false
|
||||||
origHandler := s.Server.Config.Handler
|
origHandler := s.Server.Config.Handler
|
||||||
|
@ -166,6 +169,10 @@ func TestProxyDial(t *testing.T) {
|
||||||
if r.Method == "CONNECT" {
|
if r.Method == "CONNECT" {
|
||||||
connect = true
|
connect = true
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if r.Header.Get("User-Agents") != "xxx" {
|
||||||
|
t.Log("xxx not found in the request header")
|
||||||
|
http.Error(w, "header xxx not found", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
20
proxy.go
20
proxy.go
|
@ -14,21 +14,29 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type netDialerFunc func(network, addr string) (net.Conn, error)
|
type netDialer struct {
|
||||||
|
proxyHeader http.Header
|
||||||
|
f func(network, addr string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
|
func (n netDialer) Dial(network, addr string) (net.Conn, error) {
|
||||||
return fn(network, addr)
|
return n.f(network, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
|
||||||
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
|
p, ok := forwardDialer.(*netDialer)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("type assertion failed when ini proxy info")
|
||||||
|
}
|
||||||
|
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial, proxyHeader: p.proxyHeader}, nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpProxyDialer struct {
|
type httpProxyDialer struct {
|
||||||
proxyURL *url.URL
|
proxyURL *url.URL
|
||||||
forwardDial func(network, addr string) (net.Conn, error)
|
forwardDial func(network, addr string) (net.Conn, error)
|
||||||
|
proxyHeader http.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
|
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
|
||||||
|
@ -47,6 +55,10 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for k, v := range hpd.proxyHeader {
|
||||||
|
connectHeader[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
connectReq := &http.Request{
|
connectReq := &http.Request{
|
||||||
Method: "CONNECT",
|
Method: "CONNECT",
|
||||||
URL: &url.URL{Opaque: addr},
|
URL: &url.URL{Opaque: addr},
|
||||||
|
|
Loading…
Reference in New Issue