This commit is contained in:
mcqueen 2015-10-23 18:20:45 -07:00
commit 9cff039c6d
2 changed files with 98 additions and 5 deletions

View File

@ -5,6 +5,7 @@
package websocket package websocket
import ( import (
"bufio"
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"errors" "errors"
@ -49,6 +50,12 @@ type Dialer struct {
// NetDial is nil, net.Dial is used. // NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error) NetDial func(network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*http.Request) (*url.URL, error)
// TLSClientConfig specifies the TLS configuration to use with tls.Client. // TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used. // If nil, the default configuration is used.
TLSClientConfig *tls.Config TLSClientConfig *tls.Config
@ -110,9 +117,12 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
hostNoPort = hostNoPort[:i] hostNoPort = hostNoPort[:i]
} else { } else {
if u.Scheme == "wss" { switch u.Scheme {
case "wss":
hostPort += ":443" hostPort += ":443"
} else { case "https":
hostPort += ":443"
default:
hostPort += ":80" hostPort += ":80"
} }
} }
@ -120,7 +130,9 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
} }
// DefaultDialer is a dialer with all fields set to the default zero values. // DefaultDialer is a dialer with all fields set to the default zero values.
var DefaultDialer = &Dialer{} var DefaultDialer = &Dialer{
Proxy: http.ProxyFromEnvironment,
}
// Dial creates a new client connection. Use requestHeader to specify the // Dial creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
@ -134,7 +146,9 @@ var DefaultDialer = &Dialer{}
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil { if d == nil {
d = &Dialer{} d = &Dialer{
Proxy: http.ProxyFromEnvironment,
}
} }
challengeKey, err := generateChallengeKey() challengeKey, err := generateChallengeKey()
@ -194,6 +208,22 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
hostPort, hostNoPort := hostPortNoPort(u) hostPort, hostNoPort := hostPortNoPort(u)
var proxyURL *url.URL
// Check wether the proxy method has been configured
if d.Proxy != nil {
proxyURL, err = d.Proxy(req)
}
if err != nil {
return nil, nil, err
}
var targetHostPort string
if proxyURL != nil {
targetHostPort, _ = hostPortNoPort(proxyURL)
} else {
targetHostPort = hostPort
}
var deadline time.Time var deadline time.Time
if d.HandshakeTimeout != 0 { if d.HandshakeTimeout != 0 {
deadline = time.Now().Add(d.HandshakeTimeout) deadline = time.Now().Add(d.HandshakeTimeout)
@ -205,7 +235,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
netDial = netDialer.Dial netDial = netDialer.Dial
} }
netConn, err := netDial("tcp", hostPort) netConn, err := netDial("tcp", targetHostPort)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -220,6 +250,30 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
return nil, nil, err return nil, nil, err
} }
if proxyURL != nil {
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: hostPort},
Host: hostPort,
Header: make(http.Header),
}
connectReq.Write(netConn)
// Read response.
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(netConn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 200 {
f := strings.SplitN(resp.Status, " ", 2)
return nil, nil, errors.New(f[1])
}
}
if u.Scheme == "https" { if u.Scheme == "https" {
cfg := d.TLSClientConfig cfg := d.TLSClientConfig
if cfg == nil { if cfg == nil {

View File

@ -123,6 +123,45 @@ func sendRecv(t *testing.T, ws *Conn) {
} }
} }
func TestProxyDial(t *testing.T) {
s := newServer(t)
defer s.Close()
surl, _ := url.Parse(s.URL)
cstDialer.Proxy = http.ProxyURL(surl)
connect := false
origHandler := s.Server.Config.Handler
// Capture the request Host header.
s.Server.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
connect = true
w.WriteHeader(200)
return
}
if !connect {
t.Log("connect not recieved")
http.Error(w, "connect not recieved", 405)
return
}
origHandler.ServeHTTP(w, r)
})
ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
cstDialer.Proxy = http.ProxyFromEnvironment
}
func TestDial(t *testing.T) { func TestDial(t *testing.T) {
s := newServer(t) s := newServer(t)
defer s.Close() defer s.Close()