mirror of https://github.com/gorilla/websocket.git
Refactor client handshake
- To take advantage of the Host header cleanup in the net/http Request.Write method, use a net/http Request to write the handshake to the wire. - Move code from the deprecated NewClientConn function to Dialer.Dial. This change makes it easier to add proxy support to Dialer.Dial. Add comment noting that NewClientConn is deprecated. - Update the code so that parseURL can be replaced with net/url Parse. We need to wait until we can require 1.5 before making the swap.
This commit is contained in:
parent
423912737d
commit
5ed2f4547d
181
client.go
181
client.go
|
@ -30,50 +30,17 @@ var ErrBadHandshake = errors.New("websocket: bad handshake")
|
|||
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
|
||||
// non-nil *http.Response so that callers can handle redirects, authentication,
|
||||
// etc.
|
||||
//
|
||||
// Deprecated: Use Dialer instead.
|
||||
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
|
||||
challengeKey, err := generateChallengeKey()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
d := Dialer{
|
||||
ReadBufferSize: readBufSize,
|
||||
WriteBufferSize: writeBufSize,
|
||||
NetDial: func(net, addr string) (net.Conn, error) {
|
||||
return netConn, nil
|
||||
},
|
||||
}
|
||||
acceptKey := computeAcceptKey(challengeKey)
|
||||
|
||||
c = newConn(netConn, false, readBufSize, writeBufSize)
|
||||
p := c.writeBuf[:0]
|
||||
p = append(p, "GET "...)
|
||||
p = append(p, u.RequestURI()...)
|
||||
p = append(p, " HTTP/1.1\r\nHost: "...)
|
||||
p = append(p, u.Host...)
|
||||
// "Upgrade" is capitalized for servers that do not use case insensitive
|
||||
// comparisons on header tokens.
|
||||
p = append(p, "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...)
|
||||
p = append(p, challengeKey...)
|
||||
p = append(p, "\r\n"...)
|
||||
for k, vs := range requestHeader {
|
||||
for _, v := range vs {
|
||||
p = append(p, k...)
|
||||
p = append(p, ": "...)
|
||||
p = append(p, v...)
|
||||
p = append(p, "\r\n"...)
|
||||
}
|
||||
}
|
||||
p = append(p, "\r\n"...)
|
||||
|
||||
if _, err := netConn.Write(p); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if resp.StatusCode != 101 ||
|
||||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
||||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
||||
resp.Header.Get("Sec-Websocket-Accept") != acceptKey {
|
||||
return nil, resp, ErrBadHandshake
|
||||
}
|
||||
c.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||
return c, resp, nil
|
||||
return d.Dial(u.String(), requestHeader)
|
||||
}
|
||||
|
||||
// A Dialer contains options for connecting to WebSocket server.
|
||||
|
@ -99,17 +66,15 @@ type Dialer struct {
|
|||
|
||||
var errMalformedURL = errors.New("malformed ws or wss URL")
|
||||
|
||||
// parseURL parses the URL. The url.Parse function is not used here because
|
||||
// url.Parse mangles the path.
|
||||
// parseURL parses the URL.
|
||||
//
|
||||
// This function is a replacement for the standard library url.Parse function.
|
||||
// In Go 1.4 and earlier, url.Parse loses information from the path.
|
||||
func parseURL(s string) (*url.URL, error) {
|
||||
// From the RFC:
|
||||
//
|
||||
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
|
||||
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
|
||||
//
|
||||
// We don't use the net/url parser here because the dialer interface does
|
||||
// not provide a way for applications to work around percent deocding in
|
||||
// the net/url parser.
|
||||
|
||||
var u url.URL
|
||||
switch {
|
||||
|
@ -131,7 +96,8 @@ func parseURL(s string) (*url.URL, error) {
|
|||
}
|
||||
|
||||
if strings.Contains(u.Host, "@") {
|
||||
// WebSocket URIs do not contain user information.
|
||||
// Don't bother parsing user information because user information is
|
||||
// not allowed in websocket URIs.
|
||||
return nil, errMalformedURL
|
||||
}
|
||||
|
||||
|
@ -166,17 +132,68 @@ var DefaultDialer = &Dialer{}
|
|||
// etcetera. The response body may not contain the entire response and does not
|
||||
// need to be closed by the application.
|
||||
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
||||
|
||||
if d == nil {
|
||||
d = &Dialer{}
|
||||
}
|
||||
|
||||
challengeKey, err := generateChallengeKey()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
u, err := parseURL(urlStr)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
hostPort, hostNoPort := hostPortNoPort(u)
|
||||
|
||||
if d == nil {
|
||||
d = &Dialer{}
|
||||
switch u.Scheme {
|
||||
case "ws":
|
||||
u.Scheme = "http"
|
||||
case "wss":
|
||||
u.Scheme = "https"
|
||||
default:
|
||||
return nil, nil, errMalformedURL
|
||||
}
|
||||
|
||||
if u.User != nil {
|
||||
// User name and password are not allowed in websocket URIs.
|
||||
return nil, nil, errMalformedURL
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Method: "GET",
|
||||
URL: u,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
Host: u.Host,
|
||||
}
|
||||
|
||||
// Set the request headers using the capitalization for names and values in
|
||||
// RFC examples. Although the capitalization shouldn't matter, there are
|
||||
// servers that depend on it. The Header.Set method is not used because the
|
||||
// method canonicalizes the header names.
|
||||
req.Header["Upgrade"] = []string{"websocket"}
|
||||
req.Header["Connection"] = []string{"Upgrade"}
|
||||
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
|
||||
req.Header["Sec-WebSocket-Version"] = []string{"13"}
|
||||
if len(d.Subprotocols) > 0 {
|
||||
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
|
||||
}
|
||||
for k, vs := range requestHeader {
|
||||
if k == "Host" {
|
||||
if len(vs) > 0 {
|
||||
req.Host = vs[0]
|
||||
}
|
||||
} else {
|
||||
req.Header[k] = vs
|
||||
}
|
||||
}
|
||||
|
||||
hostPort, hostNoPort := hostPortNoPort(u)
|
||||
|
||||
var deadline time.Time
|
||||
if d.HandshakeTimeout != 0 {
|
||||
deadline = time.Now().Add(d.HandshakeTimeout)
|
||||
|
@ -203,7 +220,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
if u.Scheme == "wss" {
|
||||
if u.Scheme == "https" {
|
||||
cfg := d.TLSClientConfig
|
||||
if cfg == nil {
|
||||
cfg = &tls.Config{ServerName: hostNoPort}
|
||||
|
@ -224,44 +241,32 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
|||
}
|
||||
}
|
||||
|
||||
if len(d.Subprotocols) > 0 {
|
||||
h := http.Header{}
|
||||
for k, v := range requestHeader {
|
||||
h[k] = v
|
||||
}
|
||||
h.Set("Sec-Websocket-Protocol", strings.Join(d.Subprotocols, ", "))
|
||||
requestHeader = h
|
||||
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
|
||||
|
||||
if err := req.Write(netConn); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(requestHeader["Host"]) > 0 {
|
||||
// This can be used to supply a Host: header which is different from
|
||||
// the dial address.
|
||||
u.Host = requestHeader.Get("Host")
|
||||
|
||||
// Drop "Host" header
|
||||
h := http.Header{}
|
||||
for k, v := range requestHeader {
|
||||
if k == "Host" {
|
||||
continue
|
||||
}
|
||||
h[k] = v
|
||||
}
|
||||
requestHeader = h
|
||||
}
|
||||
|
||||
conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
|
||||
|
||||
resp, err := http.ReadResponse(conn.br, req)
|
||||
if err != nil {
|
||||
if err == ErrBadHandshake {
|
||||
// Before closing the network connection on return from this
|
||||
// function, slurp up some of the response to aid application
|
||||
// debugging.
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := io.ReadFull(resp.Body, buf)
|
||||
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
||||
}
|
||||
return nil, resp, err
|
||||
return nil, nil, err
|
||||
}
|
||||
if resp.StatusCode != 101 ||
|
||||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
||||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
||||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
|
||||
// Before closing the network connection on return from this
|
||||
// function, slurp up some of the response to aid application
|
||||
// debugging.
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := io.ReadFull(resp.Body, buf)
|
||||
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
||||
return nil, resp, ErrBadHandshake
|
||||
} else {
|
||||
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||
}
|
||||
|
||||
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||
|
||||
netConn.SetDeadline(time.Time{})
|
||||
netConn = nil // to avoid close in defer.
|
||||
|
|
|
@ -289,8 +289,8 @@ func TestRespOnBadHandshake(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// If the Host header is specified in `Dial()`, the server must receive it as
|
||||
// the `Host:` header.
|
||||
// TestHostHeader confirms that the host header provided in the call to Dial is
|
||||
// sent to the server.
|
||||
func TestHostHeader(t *testing.T) {
|
||||
s := newServer(t)
|
||||
defer s.Close()
|
||||
|
@ -305,16 +305,12 @@ func TestHostHeader(t *testing.T) {
|
|||
origHandler.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
|
||||
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("resp.StatusCode = %v, want http.StatusSwitchingProtocols", resp.StatusCode)
|
||||
}
|
||||
|
||||
if gotHost := <-specifiedHost; gotHost != "testhost" {
|
||||
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue