// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package websocket import ( "bufio" "errors" "fmt" "net" "net/http" "net/url" "strings" "time" ) // HandshakeError describes an error with the handshake from the peer. type HandshakeError struct { message string } func (e HandshakeError) Error() string { return e.message } // Upgrader specifies parameters for upgrading an HTTP connection to a // WebSocket connection. type Upgrader struct { // HandshakeTimeout specifies the duration for the handshake to complete. HandshakeTimeout time.Duration // ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer // size is zero, then a default value of 4096 is used. The I/O buffer sizes // do not limit the size of the messages that can be sent or received. ReadBufferSize, WriteBufferSize int // Subprotocols specifies the server's supported protocols in order of // preference. If this field is set, then the Upgrade method negotiates a // subprotocol by selecting the first match in this list with a protocol // requested by the client. Subprotocols []string // Server supported extensions (e.g. permessage-deflate) // Currently only 'permessage-deflate; server_no_context_takeover; client_no_context_takeover' // is supported. i.e. no deflate options at this time. Extensions []string // Error specifies the function for generating HTTP error responses. If Error // is nil, then http.Error is used to generate the HTTP response. Error func(w http.ResponseWriter, r *http.Request, status int, reason error) // CheckOrigin returns true if the request Origin header is acceptable. If // CheckOrigin is nil, the host in the Origin header must not be set or // must match the host of the request. CheckOrigin func(r *http.Request) bool } func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { err := HandshakeError{reason} if u.Error != nil { u.Error(w, r, status, err) } else { http.Error(w, http.StatusText(status), status) } return nil, err } // checkSameOrigin returns true if the origin is not set or is equal to the request host. func checkSameOrigin(r *http.Request) bool { origin := r.Header["Origin"] if len(origin) == 0 { return true } u, err := url.Parse(origin[0]) if err != nil { return false } return u.Host == r.Host } func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { if u.Subprotocols != nil { clientProtocols := Subprotocols(r) for _, serverProtocol := range u.Subprotocols { for _, clientProtocol := range clientProtocols { if clientProtocol == serverProtocol { return clientProtocol } } } } else if responseHeader != nil { return responseHeader.Get("Sec-Websocket-Protocol") } return "" } // Check for compression support, select compression type and // return extension with server supported options along with // whether valid compression headers were found. func (u *Upgrader) selectCompressionExtension(r *http.Request) (string, bool, error) { extensions := r.Header.Get("Sec-WebSocket-Extensions") if u.Extensions != nil && len(extensions) > 0 { extOpts := strings.Split(extensions, " ") if len(extOpts) > 0 && strings.HasPrefix(extOpts[0], "permessage-deflate") { ext := strings.TrimSuffix(extOpts[0], ";") // Find and return extension with supported options. for _, e := range u.Extensions { // Check if server supports supplied extension if strings.HasPrefix(e, ext) { if e != CompressPermessageDeflate { return "", false, fmt.Errorf("Compression options not supported: %s", e) } return e, true, nil } } } } return "", false, nil } // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // The responseHeader is included in the response to the client's upgrade // request. Use the responseHeader to specify cookies (Set-Cookie) and the // application negotiated subprotocol (Sec-Websocket-Protocol). func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { if r.Method != "GET" { return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET") } if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" { return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13") } if !tokenListContainsValue(r.Header, "Connection", "upgrade") { return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find connection header with token 'upgrade'") } if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { return u.returnError(w, r, http.StatusBadRequest, "websocket: could not find upgrade header with token 'websocket'") } checkOrigin := u.CheckOrigin if checkOrigin == nil { checkOrigin = checkSameOrigin } if !checkOrigin(r) { return u.returnError(w, r, http.StatusForbidden, "websocket: origin not allowed") } challengeKey := r.Header.Get("Sec-Websocket-Key") if challengeKey == "" { return u.returnError(w, r, http.StatusBadRequest, "websocket: key missing or blank") } subprotocol := u.selectSubprotocol(r, responseHeader) var ( netConn net.Conn br *bufio.Reader err error ) h, ok := w.(http.Hijacker) if !ok { return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") } var rw *bufio.ReadWriter netConn, rw, err = h.Hijack() if err != nil { return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } br = rw.Reader if br.Buffered() > 0 { netConn.Close() return nil, errors.New("websocket: client sent data before handshake is complete") } c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize) c.subprotocol = subprotocol _, c.compressionNegotiated, err = u.selectCompressionExtension(r) if err != nil { return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } p := c.writeBuf[:0] p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, computeAcceptKey(challengeKey)...) p = append(p, "\r\n"...) if c.subprotocol != "" { p = append(p, "Sec-Websocket-Protocol: "...) p = append(p, c.subprotocol...) p = append(p, "\r\n"...) } for k, vs := range responseHeader { if k == "Sec-Websocket-Protocol" { continue } for _, v := range vs { p = append(p, k...) p = append(p, ": "...) for i := 0; i < len(v); i++ { b := v[i] if b <= 31 { // prevent response splitting. b = ' ' } p = append(p, b) } p = append(p, "\r\n"...) } } // Set the selected compression header if enabled. if c.compressionNegotiated { // Turn compression on by default c.writeCompressionEnabled = true p = append(p, "Sec-WebSocket-Extensions: "+CompressPermessageDeflate+"\r\n"...) } p = append(p, "\r\n"...) // Clear deadlines set by HTTP server. netConn.SetDeadline(time.Time{}) if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) } if _, err = netConn.Write(p); err != nil { netConn.Close() return nil, err } if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Time{}) } return c, nil } // Upgrade upgrades the HTTP server connection to the WebSocket protocol. // // This function is deprecated, use websocket.Upgrader instead. // // The application is responsible for checking the request origin before // calling Upgrade. An example implementation of the same origin policy is: // // if req.Header.Get("Origin") != "http://"+req.Host { // http.Error(w, "Origin not allowed", 403) // return // } // // If the endpoint supports subprotocols, then the application is responsible // for negotiating the protocol used on the connection. Use the Subprotocols() // function to get the subprotocols requested by the client. Use the // Sec-Websocket-Protocol response header to specify the subprotocol selected // by the application. // // The responseHeader is included in the response to the client's upgrade // request. Use the responseHeader to specify cookies (Set-Cookie) and the // negotiated subprotocol (Sec-Websocket-Protocol). // // The connection buffers IO to the underlying network connection. The // readBufSize and writeBufSize parameters specify the size of the buffers to // use. Messages can be larger than the buffers. // // If the request is not a valid WebSocket handshake, then Upgrade returns an // error of type HandshakeError. Applications should handle this error by // replying to the client with an HTTP error response. func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { // don't return errors to maintain backwards compatibility } u.CheckOrigin = func(r *http.Request) bool { // allow all connections by default return true } return u.Upgrade(w, r, responseHeader) } // Subprotocols returns the subprotocols requested by the client in the // Sec-Websocket-Protocol header. func Subprotocols(r *http.Request) []string { h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) if h == "" { return nil } protocols := strings.Split(h, ",") for i := range protocols { protocols[i] = strings.TrimSpace(protocols[i]) } return protocols }