diff --git a/client.go b/client.go index 6138906..a353e18 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "crypto/tls" + "encoding/base64" "errors" "io" "io/ioutil" @@ -265,11 +266,19 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } if proxyURL != nil { + connectHeader := make(http.Header) + if user := proxyURL.User; user != nil { + proxyUser := user.Username() + if proxyPassword, passwordSet := user.Password(); passwordSet { + credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) + connectHeader.Set("Proxy-Authorization", "Basic "+credential) + } + } connectReq := &http.Request{ Method: "CONNECT", URL: &url.URL{Opaque: hostPort}, Host: hostPort, - Header: make(http.Header), + Header: connectHeader, } connectReq.Write(netConn) diff --git a/client_server_test.go b/client_server_test.go index c67550e..3f7345d 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -7,6 +7,7 @@ package websocket import ( "crypto/tls" "crypto/x509" + "encoding/base64" "io" "io/ioutil" "net" @@ -175,6 +176,46 @@ func TestProxyDial(t *testing.T) { cstDialer.Proxy = http.ProxyFromEnvironment } +func TestProxyAuthorizationDial(t *testing.T) { + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.URL) + surl.User = url.UserPassword("username", "password") + cstDialer.Proxy = http.ProxyURL(surl) + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + proxyAuth := r.Header.Get("Proxy-Authorization") + expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) + if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { + connect = true + w.WriteHeader(200) + return + } + + if !connect { + t.Log("connect with proxy authorization not recieved") + http.Error(w, "connect with proxy authorization not recieved", 405) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) + + cstDialer.Proxy = http.ProxyFromEnvironment +} + func TestDial(t *testing.T) { s := newServer(t) defer s.Close()