mirror of https://github.com/gorilla/websocket.git
Merge pull request #166 from MaximeHeckel/add-cookiejar
add cookie jar to dialer
This commit is contained in:
commit
5df680c89f
20
client.go
20
client.go
|
@ -78,6 +78,11 @@ type Dialer struct {
|
|||
// guarantee that compression will be supported. Currently only "no context
|
||||
// takeover" modes are supported.
|
||||
EnableCompression bool
|
||||
|
||||
// Jar specifies the cookie jar.
|
||||
// If Jar is nil, cookies are not sent in requests and ignored
|
||||
// in responses.
|
||||
Jar http.CookieJar
|
||||
}
|
||||
|
||||
var errMalformedURL = errors.New("malformed ws or wss URL")
|
||||
|
@ -91,7 +96,6 @@ func parseURL(s string) (*url.URL, error) {
|
|||
//
|
||||
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
|
||||
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
|
||||
|
||||
var u url.URL
|
||||
switch {
|
||||
case strings.HasPrefix(s, "ws://"):
|
||||
|
@ -201,6 +205,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
Host: u.Host,
|
||||
}
|
||||
|
||||
// Set the cookies present in the cookie jar of the dialer
|
||||
if d.Jar != nil {
|
||||
for _, cookie := range d.Jar.Cookies(u) {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the request headers using the capitalization for names and values in
|
||||
// RFC examples. Although the capitalization shouldn't matter, there are
|
||||
// servers that depend on it. The Header.Set method is not used because the
|
||||
|
@ -337,6 +348,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if d.Jar != nil {
|
||||
if rc := resp.Cookies(); len(rc) > 0 {
|
||||
d.Jar.SetCookies(u, rc)
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != 101 ||
|
||||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
||||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
|
@ -228,6 +229,54 @@ func TestDial(t *testing.T) {
|
|||
sendRecv(t, ws)
|
||||
}
|
||||
|
||||
func TestDialCookieJar(t *testing.T) {
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
||||
jar, _ := cookiejar.New(nil)
|
||||
d := cstDialer
|
||||
d.Jar = jar
|
||||
|
||||
u, _ := parseURL(s.URL)
|
||||
|
||||
switch u.Scheme {
|
||||
case "ws":
|
||||
u.Scheme = "http"
|
||||
case "wss":
|
||||
u.Scheme = "https"
|
||||
}
|
||||
|
||||
cookies := []*http.Cookie{&http.Cookie{Name: "gorilla", Value: "ws", Path: "/"}}
|
||||
d.Jar.SetCookies(u, cookies)
|
||||
|
||||
ws, _, err := d.Dial(s.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
var gorilla string
|
||||
var sessionID string
|
||||
for _, c := range d.Jar.Cookies(u) {
|
||||
if c.Name == "gorilla" {
|
||||
gorilla = c.Value
|
||||
}
|
||||
|
||||
if c.Name == "sessionID" {
|
||||
sessionID = c.Value
|
||||
}
|
||||
}
|
||||
if gorilla != "ws" {
|
||||
t.Error("Cookie not present in jar.")
|
||||
}
|
||||
|
||||
if sessionID != "1234" {
|
||||
t.Error("Set-Cookie not received from the server.")
|
||||
}
|
||||
|
||||
sendRecv(t, ws)
|
||||
}
|
||||
|
||||
func TestDialTLS(t *testing.T) {
|
||||
s := newTLSServer(t)
|
||||
defer s.Close()
|
||||
|
|
Loading…
Reference in New Issue