mirror of https://github.com/gorilla/websocket.git
Refactor client handshake
- To take advantage of the Host header cleanup in the net/http Request.Write method, use a net/http Request to write the handshake to the wire. - Move code from the deprecated NewClientConn function to Dialer.Dial. This change makes it easier to add proxy support to Dialer.Dial. Add comment noting that NewClientConn is deprecated. - Update the code so that parseURL can be replaced with net/url Parse. We need to wait until we can require 1.5 before making the swap.
This commit is contained in:
parent
423912737d
commit
5ed2f4547d
169
client.go
169
client.go
|
@ -30,50 +30,17 @@ var ErrBadHandshake = errors.New("websocket: bad handshake")
|
||||||
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
|
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
|
||||||
// non-nil *http.Response so that callers can handle redirects, authentication,
|
// non-nil *http.Response so that callers can handle redirects, authentication,
|
||||||
// etc.
|
// etc.
|
||||||
|
//
|
||||||
|
// Deprecated: Use Dialer instead.
|
||||||
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
|
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
|
||||||
challengeKey, err := generateChallengeKey()
|
d := Dialer{
|
||||||
if err != nil {
|
ReadBufferSize: readBufSize,
|
||||||
return nil, nil, err
|
WriteBufferSize: writeBufSize,
|
||||||
|
NetDial: func(net, addr string) (net.Conn, error) {
|
||||||
|
return netConn, nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
acceptKey := computeAcceptKey(challengeKey)
|
return d.Dial(u.String(), requestHeader)
|
||||||
|
|
||||||
c = newConn(netConn, false, readBufSize, writeBufSize)
|
|
||||||
p := c.writeBuf[:0]
|
|
||||||
p = append(p, "GET "...)
|
|
||||||
p = append(p, u.RequestURI()...)
|
|
||||||
p = append(p, " HTTP/1.1\r\nHost: "...)
|
|
||||||
p = append(p, u.Host...)
|
|
||||||
// "Upgrade" is capitalized for servers that do not use case insensitive
|
|
||||||
// comparisons on header tokens.
|
|
||||||
p = append(p, "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...)
|
|
||||||
p = append(p, challengeKey...)
|
|
||||||
p = append(p, "\r\n"...)
|
|
||||||
for k, vs := range requestHeader {
|
|
||||||
for _, v := range vs {
|
|
||||||
p = append(p, k...)
|
|
||||||
p = append(p, ": "...)
|
|
||||||
p = append(p, v...)
|
|
||||||
p = append(p, "\r\n"...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
p = append(p, "\r\n"...)
|
|
||||||
|
|
||||||
if _, err := netConn.Write(p); err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u})
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
if resp.StatusCode != 101 ||
|
|
||||||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
|
||||||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
|
||||||
resp.Header.Get("Sec-Websocket-Accept") != acceptKey {
|
|
||||||
return nil, resp, ErrBadHandshake
|
|
||||||
}
|
|
||||||
c.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
|
||||||
return c, resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// A Dialer contains options for connecting to WebSocket server.
|
// A Dialer contains options for connecting to WebSocket server.
|
||||||
|
@ -99,17 +66,15 @@ type Dialer struct {
|
||||||
|
|
||||||
var errMalformedURL = errors.New("malformed ws or wss URL")
|
var errMalformedURL = errors.New("malformed ws or wss URL")
|
||||||
|
|
||||||
// parseURL parses the URL. The url.Parse function is not used here because
|
// parseURL parses the URL.
|
||||||
// url.Parse mangles the path.
|
//
|
||||||
|
// This function is a replacement for the standard library url.Parse function.
|
||||||
|
// In Go 1.4 and earlier, url.Parse loses information from the path.
|
||||||
func parseURL(s string) (*url.URL, error) {
|
func parseURL(s string) (*url.URL, error) {
|
||||||
// From the RFC:
|
// From the RFC:
|
||||||
//
|
//
|
||||||
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
|
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
|
||||||
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
|
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
|
||||||
//
|
|
||||||
// We don't use the net/url parser here because the dialer interface does
|
|
||||||
// not provide a way for applications to work around percent deocding in
|
|
||||||
// the net/url parser.
|
|
||||||
|
|
||||||
var u url.URL
|
var u url.URL
|
||||||
switch {
|
switch {
|
||||||
|
@ -131,7 +96,8 @@ func parseURL(s string) (*url.URL, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(u.Host, "@") {
|
if strings.Contains(u.Host, "@") {
|
||||||
// WebSocket URIs do not contain user information.
|
// Don't bother parsing user information because user information is
|
||||||
|
// not allowed in websocket URIs.
|
||||||
return nil, errMalformedURL
|
return nil, errMalformedURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,17 +132,68 @@ var DefaultDialer = &Dialer{}
|
||||||
// etcetera. The response body may not contain the entire response and does not
|
// etcetera. The response body may not contain the entire response and does not
|
||||||
// need to be closed by the application.
|
// need to be closed by the application.
|
||||||
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
|
||||||
|
|
||||||
|
if d == nil {
|
||||||
|
d = &Dialer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
challengeKey, err := generateChallengeKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
u, err := parseURL(urlStr)
|
u, err := parseURL(urlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
hostPort, hostNoPort := hostPortNoPort(u)
|
switch u.Scheme {
|
||||||
|
case "ws":
|
||||||
if d == nil {
|
u.Scheme = "http"
|
||||||
d = &Dialer{}
|
case "wss":
|
||||||
|
u.Scheme = "https"
|
||||||
|
default:
|
||||||
|
return nil, nil, errMalformedURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if u.User != nil {
|
||||||
|
// User name and password are not allowed in websocket URIs.
|
||||||
|
return nil, nil, errMalformedURL
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &http.Request{
|
||||||
|
Method: "GET",
|
||||||
|
URL: u,
|
||||||
|
Proto: "HTTP/1.1",
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Host: u.Host,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the request headers using the capitalization for names and values in
|
||||||
|
// RFC examples. Although the capitalization shouldn't matter, there are
|
||||||
|
// servers that depend on it. The Header.Set method is not used because the
|
||||||
|
// method canonicalizes the header names.
|
||||||
|
req.Header["Upgrade"] = []string{"websocket"}
|
||||||
|
req.Header["Connection"] = []string{"Upgrade"}
|
||||||
|
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
|
||||||
|
req.Header["Sec-WebSocket-Version"] = []string{"13"}
|
||||||
|
if len(d.Subprotocols) > 0 {
|
||||||
|
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
|
||||||
|
}
|
||||||
|
for k, vs := range requestHeader {
|
||||||
|
if k == "Host" {
|
||||||
|
if len(vs) > 0 {
|
||||||
|
req.Host = vs[0]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
req.Header[k] = vs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hostPort, hostNoPort := hostPortNoPort(u)
|
||||||
|
|
||||||
var deadline time.Time
|
var deadline time.Time
|
||||||
if d.HandshakeTimeout != 0 {
|
if d.HandshakeTimeout != 0 {
|
||||||
deadline = time.Now().Add(d.HandshakeTimeout)
|
deadline = time.Now().Add(d.HandshakeTimeout)
|
||||||
|
@ -203,7 +220,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.Scheme == "wss" {
|
if u.Scheme == "https" {
|
||||||
cfg := d.TLSClientConfig
|
cfg := d.TLSClientConfig
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
cfg = &tls.Config{ServerName: hostNoPort}
|
cfg = &tls.Config{ServerName: hostNoPort}
|
||||||
|
@ -224,44 +241,32 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(d.Subprotocols) > 0 {
|
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
|
||||||
h := http.Header{}
|
|
||||||
for k, v := range requestHeader {
|
if err := req.Write(netConn); err != nil {
|
||||||
h[k] = v
|
return nil, nil, err
|
||||||
}
|
|
||||||
h.Set("Sec-Websocket-Protocol", strings.Join(d.Subprotocols, ", "))
|
|
||||||
requestHeader = h
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(requestHeader["Host"]) > 0 {
|
resp, err := http.ReadResponse(conn.br, req)
|
||||||
// This can be used to supply a Host: header which is different from
|
|
||||||
// the dial address.
|
|
||||||
u.Host = requestHeader.Get("Host")
|
|
||||||
|
|
||||||
// Drop "Host" header
|
|
||||||
h := http.Header{}
|
|
||||||
for k, v := range requestHeader {
|
|
||||||
if k == "Host" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
h[k] = v
|
|
||||||
}
|
|
||||||
requestHeader = h
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == ErrBadHandshake {
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 101 ||
|
||||||
|
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
|
||||||
|
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
|
||||||
|
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
|
||||||
// Before closing the network connection on return from this
|
// Before closing the network connection on return from this
|
||||||
// function, slurp up some of the response to aid application
|
// function, slurp up some of the response to aid application
|
||||||
// debugging.
|
// debugging.
|
||||||
buf := make([]byte, 1024)
|
buf := make([]byte, 1024)
|
||||||
n, _ := io.ReadFull(resp.Body, buf)
|
n, _ := io.ReadFull(resp.Body, buf)
|
||||||
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
|
||||||
|
return nil, resp, ErrBadHandshake
|
||||||
|
} else {
|
||||||
|
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
||||||
}
|
}
|
||||||
return nil, resp, err
|
|
||||||
}
|
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
|
||||||
|
|
||||||
netConn.SetDeadline(time.Time{})
|
netConn.SetDeadline(time.Time{})
|
||||||
netConn = nil // to avoid close in defer.
|
netConn = nil // to avoid close in defer.
|
||||||
|
|
|
@ -289,8 +289,8 @@ func TestRespOnBadHandshake(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the Host header is specified in `Dial()`, the server must receive it as
|
// TestHostHeader confirms that the host header provided in the call to Dial is
|
||||||
// the `Host:` header.
|
// sent to the server.
|
||||||
func TestHostHeader(t *testing.T) {
|
func TestHostHeader(t *testing.T) {
|
||||||
s := newServer(t)
|
s := newServer(t)
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
@ -305,16 +305,12 @@ func TestHostHeader(t *testing.T) {
|
||||||
origHandler.ServeHTTP(w, r)
|
origHandler.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
|
|
||||||
ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
|
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Dial: %v", err)
|
t.Fatalf("Dial: %v", err)
|
||||||
}
|
}
|
||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
|
||||||
t.Fatalf("resp.StatusCode = %v, want http.StatusSwitchingProtocols", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
if gotHost := <-specifiedHost; gotHost != "testhost" {
|
if gotHost := <-specifiedHost; gotHost != "testhost" {
|
||||||
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
|
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue