From 70eca1b8e76b4efbc03dd0711114e6d040668302 Mon Sep 17 00:00:00 2001 From: Mark Wolfe Date: Tue, 20 Oct 2015 20:29:17 +1100 Subject: [PATCH] Add Proxy support for websocket clients. - Uses `http.ProxyFromEnvironment` for configuration in line with the golang standard library. --- client.go | 64 +++++++++++++++++++++++++++++++++++++++---- client_server_test.go | 39 ++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 443e815..51acf24 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ package websocket import ( + "bufio" "bytes" "crypto/tls" "errors" @@ -49,6 +50,12 @@ type Dialer struct { // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. TLSClientConfig *tls.Config @@ -110,9 +117,12 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { hostNoPort = hostNoPort[:i] } else { - if u.Scheme == "wss" { + switch u.Scheme { + case "wss": hostPort += ":443" - } else { + case "https": + hostPort += ":443" + default: hostPort += ":80" } } @@ -120,7 +130,9 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { } // DefaultDialer is a dialer with all fields set to the default zero values. -var DefaultDialer = &Dialer{} +var DefaultDialer = &Dialer{ + Proxy: http.ProxyFromEnvironment, +} // Dial creates a new client connection. Use requestHeader to specify the // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). @@ -134,7 +146,9 @@ var DefaultDialer = &Dialer{} func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { if d == nil { - d = &Dialer{} + d = &Dialer{ + Proxy: http.ProxyFromEnvironment, + } } challengeKey, err := generateChallengeKey() @@ -194,6 +208,22 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re hostPort, hostNoPort := hostPortNoPort(u) + var proxyURL *url.URL + // Check wether the proxy method has been configured + if d.Proxy != nil { + proxyURL, err = d.Proxy(req) + } + if err != nil { + return nil, nil, err + } + + var targetHostPort string + if proxyURL != nil { + targetHostPort, _ = hostPortNoPort(proxyURL) + } else { + targetHostPort = hostPort + } + var deadline time.Time if d.HandshakeTimeout != 0 { deadline = time.Now().Add(d.HandshakeTimeout) @@ -205,7 +235,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re netDial = netDialer.Dial } - netConn, err := netDial("tcp", hostPort) + netConn, err := netDial("tcp", targetHostPort) if err != nil { return nil, nil, err } @@ -220,6 +250,30 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, nil, err } + if proxyURL != nil { + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: hostPort}, + Host: hostPort, + Header: make(http.Header), + } + + connectReq.Write(netConn) + + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(netConn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != 200 { + f := strings.SplitN(resp.Status, " ", 2) + return nil, nil, errors.New(f[1]) + } + } + if u.Scheme == "https" { cfg := d.TLSClientConfig if cfg == nil { diff --git a/client_server_test.go b/client_server_test.go index b6cb89a..ebcba9f 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -123,6 +123,45 @@ func sendRecv(t *testing.T, ws *Conn) { } } +func TestProxyDial(t *testing.T) { + + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.URL) + + cstDialer.Proxy = http.ProxyURL(surl) + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + connect = true + w.WriteHeader(200) + return + } + + if !connect { + t.Log("connect not recieved") + http.Error(w, "connect not recieved", 405) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) + + cstDialer.Proxy = http.ProxyFromEnvironment +} + func TestDial(t *testing.T) { s := newServer(t) defer s.Close()