mirror of https://github.com/gorilla/websocket.git
Merge pull request #166 from MaximeHeckel/add-cookiejar
add cookie jar to dialer
This commit is contained in:
commit
5df680c89f
20
client.go
20
client.go
|
@ -78,6 +78,11 @@ type Dialer struct {
|
||||||
// guarantee that compression will be supported. Currently only "no context
|
// guarantee that compression will be supported. Currently only "no context
|
||||||
// takeover" modes are supported.
|
// takeover" modes are supported.
|
||||||
EnableCompression bool
|
EnableCompression bool
|
||||||
|
|
||||||
|
// Jar specifies the cookie jar.
|
||||||
|
// If Jar is nil, cookies are not sent in requests and ignored
|
||||||
|
// in responses.
|
||||||
|
Jar http.CookieJar
|
||||||
}
|
}
|
||||||
|
|
||||||
var errMalformedURL = errors.New("malformed ws or wss URL")
|
var errMalformedURL = errors.New("malformed ws or wss URL")
|
||||||
|
@ -91,7 +96,6 @@ func parseURL(s string) (*url.URL, error) {
|
||||||
//
|
//
|
||||||
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
|
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
|
||||||
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
|
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
|
||||||
|
|
||||||
var u url.URL
|
var u url.URL
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(s, "ws://"):
|
case strings.HasPrefix(s, "ws://"):
|
||||||
|
@ -201,6 +205,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
Host: u.Host,
|
Host: u.Host,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the cookies present in the cookie jar of the dialer
|
||||||
|
if d.Jar != nil {
|
||||||
|
for _, cookie := range d.Jar.Cookies(u) {
|
||||||
|
req.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Set the request headers using the capitalization for names and values in
|
// Set the request headers using the capitalization for names and values in
|
||||||
// RFC examples. Although the capitalization shouldn't matter, there are
|
// RFC examples. Although the capitalization shouldn't matter, there are
|
||||||
// servers that depend on it. The Header.Set method is not used because the
|
// servers that depend on it. The Header.Set method is not used because the
|
||||||
|
@ -337,6 +348,13 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if d.Jar != nil {
|
||||||
|
if rc := resp.Cookies(); len(rc) > 0 {
|
||||||
|
d.Jar.SetCookies(u, rc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 101 ||
|
if resp.StatusCode != 101 ||
|
||||||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
||||||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -228,6 +229,54 @@ func TestDial(t *testing.T) {
|
||||||
sendRecv(t, ws)
|
sendRecv(t, ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialCookieJar(t *testing.T) {
|
||||||
|
s := newServer(t)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
jar, _ := cookiejar.New(nil)
|
||||||
|
d := cstDialer
|
||||||
|
d.Jar = jar
|
||||||
|
|
||||||
|
u, _ := parseURL(s.URL)
|
||||||
|
|
||||||
|
switch u.Scheme {
|
||||||
|
case "ws":
|
||||||
|
u.Scheme = "http"
|
||||||
|
case "wss":
|
||||||
|
u.Scheme = "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
cookies := []*http.Cookie{&http.Cookie{Name: "gorilla", Value: "ws", Path: "/"}}
|
||||||
|
d.Jar.SetCookies(u, cookies)
|
||||||
|
|
||||||
|
ws, _, err := d.Dial(s.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Dial: %v", err)
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
var gorilla string
|
||||||
|
var sessionID string
|
||||||
|
for _, c := range d.Jar.Cookies(u) {
|
||||||
|
if c.Name == "gorilla" {
|
||||||
|
gorilla = c.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Name == "sessionID" {
|
||||||
|
sessionID = c.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if gorilla != "ws" {
|
||||||
|
t.Error("Cookie not present in jar.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sessionID != "1234" {
|
||||||
|
t.Error("Set-Cookie not received from the server.")
|
||||||
|
}
|
||||||
|
|
||||||
|
sendRecv(t, ws)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDialTLS(t *testing.T) {
|
func TestDialTLS(t *testing.T) {
|
||||||
s := newTLSServer(t)
|
s := newTLSServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
Loading…
Reference in New Issue