// Copyright 2013 Gary Burd. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package websocket import ( "crypto/tls" "errors" "net" "net/http" "net/url" "strings" "time" ) // ErrBadHandshake is returned when the server response to opening handshake is // invalid. var ErrBadHandshake = errors.New("websocket: bad handshake") // NewClient creates a new client connection using the given net connection. // The URL u specifies the host and request URI. Use requestHeader to specify // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies // (Cookie). Use the response.Header to get the selected subprotocol // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). // // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, // etc. func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { challengeKey, err := generateChallengeKey() if err != nil { return nil, nil, err } acceptKey := computeAcceptKey(challengeKey) c = newConn(netConn, false, readBufSize, writeBufSize) p := c.writeBuf[:0] p = append(p, "GET "...) p = append(p, u.RequestURI()...) p = append(p, " HTTP/1.1\r\nHost: "...) p = append(p, u.Host...) // "Upgrade" is capitalized for servers that do not use case insensitive // comparisons on header tokens. p = append(p, "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...) p = append(p, challengeKey...) p = append(p, "\r\n"...) for k, vs := range requestHeader { for _, v := range vs { p = append(p, k...) p = append(p, ": "...) p = append(p, v...) p = append(p, "\r\n"...) } } p = append(p, "\r\n"...) if _, err := netConn.Write(p); err != nil { return nil, nil, err } resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u}) if err != nil { return nil, nil, err } if resp.StatusCode != 101 || !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != acceptKey { return nil, resp, ErrBadHandshake } c.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") return c, resp, nil } type Dialer struct { // NetDial specifies the dial function for creating TCP connections. If // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. TLSClientConfig *tls.Config // HandshakeTimeout specifies the duration for the handshake to complete. HandshakeTimeout time.Duration // Input and output buffer sizes. If the buffer size is zero, then a // default value of 4096 is used. ReadBufferSize, WriteBufferSize int // Subprotocols specifies the client's requested subprotocols. Subprotocols []string } var errMalformedURL = errors.New("malformed ws or wss URL") func parseURL(u string) (useTLS bool, host, port, opaque string, err error) { // From the RFC: // // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] // // We don't use the net/url parser here because the dialer interface does // not provide a way for applications to work around percent deocding in // the net/url parser. switch { case strings.HasPrefix(u, "ws://"): u = u[len("ws://"):] case strings.HasPrefix(u, "wss://"): u = u[len("wss://"):] useTLS = true default: return false, "", "", "", errMalformedURL } hostPort := u opaque = "/" if i := strings.Index(u, "/"); i >= 0 { hostPort = u[:i] opaque = u[i:] } host = hostPort port = ":80" if i := strings.LastIndex(hostPort, ":"); i > strings.LastIndex(hostPort, "]") { host = hostPort[:i] port = hostPort[i:] } else if useTLS { port = ":443" } return useTLS, host, port, opaque, nil } var DefaultDialer *Dialer // Dial creates a new client connection. Use requestHeader to specify the // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). // Use the response.Header to get the selected subprotocol // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). // // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, // etc. func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { useTLS, host, port, opaque, err := parseURL(urlStr) if err != nil { return nil, nil, err } if d == nil { d = &Dialer{} } var deadline time.Time if d.HandshakeTimeout != 0 { deadline = time.Now().Add(d.HandshakeTimeout) } netDial := d.NetDial if netDial == nil { netDialer := &net.Dialer{Deadline: deadline} netDial = netDialer.Dial } netConn, err := netDial("tcp", host+port) if err != nil { return nil, nil, err } defer func() { if netConn != nil { netConn.Close() } }() if err := netConn.SetDeadline(deadline); err != nil { return nil, nil, err } if useTLS { cfg := d.TLSClientConfig if cfg == nil { cfg = &tls.Config{ServerName: host} } else if cfg.ServerName == "" { shallowCopy := *cfg cfg = &shallowCopy cfg.ServerName = host } tlsConn := tls.Client(netConn, cfg) netConn = tlsConn if err := tlsConn.Handshake(); err != nil { return nil, nil, err } if !cfg.InsecureSkipVerify { if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { return nil, nil, err } } } readBufferSize := d.ReadBufferSize if readBufferSize == 0 { readBufferSize = 4096 } writeBufferSize := d.WriteBufferSize if writeBufferSize == 0 { writeBufferSize = 4096 } if len(d.Subprotocols) > 0 { h := http.Header{} for k, v := range requestHeader { h[k] = v } h.Set("Sec-Websocket-Protocol", strings.Join(d.Subprotocols, ", ")) requestHeader = h } conn, resp, err := NewClient( netConn, &url.URL{Host: host + port, Opaque: opaque}, requestHeader, readBufferSize, writeBufferSize) if err != nil { return nil, resp, err } netConn.SetDeadline(time.Time{}) netConn = nil // to avoid close in defer. return conn, resp, nil }