add cookie jar to dialer

This commit is contained in:
Maxime Heckel 2016-10-17 11:19:52 -07:00
parent 8003df83ee
commit 56d95f2940
2 changed files with 68 additions and 1 deletions

View File

@ -78,6 +78,11 @@ type Dialer struct {
// guarantee that compression will be supported. Currently only "no context // guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported. // takeover" modes are supported.
EnableCompression bool EnableCompression bool
// Jar specifies the cookie jar.
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar http.CookieJar
} }
var errMalformedURL = errors.New("malformed ws or wss URL") var errMalformedURL = errors.New("malformed ws or wss URL")
@ -91,7 +96,6 @@ func parseURL(s string) (*url.URL, error) {
// //
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
var u url.URL var u url.URL
switch { switch {
case strings.HasPrefix(s, "ws://"): case strings.HasPrefix(s, "ws://"):
@ -201,6 +205,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
Host: u.Host, Host: u.Host,
} }
// Set the cookies present in the cookie jar of the dialer
if d.Jar != nil {
for _, cookie := range d.Jar.Cookies(u) {
req.AddCookie(cookie)
}
}
// Set the request headers using the capitalization for names and values in // Set the request headers using the capitalization for names and values in
// RFC examples. Although the capitalization shouldn't matter, there are // RFC examples. Although the capitalization shouldn't matter, there are
// servers that depend on it. The Header.Set method is not used because the // servers that depend on it. The Header.Set method is not used because the
@ -337,6 +348,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if d.Jar != nil {
if rc := resp.Cookies(); len(rc) > 0 {
d.Jar.SetCookies(u, rc)
}
}
if resp.StatusCode != 101 || if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||

View File

@ -11,6 +11,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/cookiejar"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
@ -228,6 +229,54 @@ func TestDial(t *testing.T) {
sendRecv(t, ws) sendRecv(t, ws)
} }
func TestDialCookieJar(t *testing.T) {
s := newServer(t)
defer s.Close()
jar, _ := cookiejar.New(nil)
d := cstDialer
d.Jar = jar
u, _ := parseURL(s.URL)
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
}
cookies := []*http.Cookie{&http.Cookie{Name: "gorilla", Value: "ws", Path: "/"}}
d.Jar.SetCookies(u, cookies)
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
var gorilla string
var sessionID string
for _, c := range d.Jar.Cookies(u) {
if c.Name == "gorilla" {
gorilla = c.Value
}
if c.Name == "sessionID" {
sessionID = c.Value
}
}
if gorilla != "ws" {
t.Error("Cookie not present in jar.")
}
if sessionID != "1234" {
t.Error("Set-Cookie not received from the server.")
}
sendRecv(t, ws)
}
func TestDialTLS(t *testing.T) { func TestDialTLS(t *testing.T) {
s := newTLSServer(t) s := newTLSServer(t)
defer s.Close() defer s.Close()