From 8a3691e53c290ed66c0e1806248155c6495e4e46 Mon Sep 17 00:00:00 2001 From: Jernej Jakob Date: Tue, 24 Jul 2018 19:10:44 +0200 Subject: [PATCH] Implement subprotocol selection according to RFC6455 section 4.2.2 (TODO: update documentation) --- server.go | 76 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 49 insertions(+), 27 deletions(-) diff --git a/server.go b/server.go index aee2705..9ac4f00 100644 --- a/server.go +++ b/server.go @@ -83,20 +83,53 @@ func checkSameOrigin(r *http.Request) bool { return equalASCIIFold(u.Host, r.Host) } -func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { - if u.Subprotocols != nil { - clientProtocols := Subprotocols(r) - for _, serverProtocol := range u.Subprotocols { - for _, clientProtocol := range clientProtocols { - if clientProtocol == serverProtocol { - return clientProtocol - } +// firstMatching returns the first matching element present in both slices and true/false whether a match has been found. +func firstMatching(as []string, bs []string) (string, bool) { + for _, a := range as { + for _, b := range bs { + if a == b { + return a, true } } - } else if responseHeader != nil { - return responseHeader.Get("Sec-Websocket-Protocol") } - return "" + return "", false +} + +// Subprotocols returns the subprotocols requested by the client in the +// Sec-WebSocket-Protocol header. +func Subprotocols(r *http.Request) []string { + h := strings.TrimSpace(r.Header.Get("Sec-WebSocket-Protocol")) + if h == "" { + return nil + } + + protocols := strings.Split(h, ",") + for i := range protocols { + protocols[i] = strings.TrimSpace(protocols[i]) + } + return protocols +} + +// selectSubprotocol returns the first matching subprotocol found, in the following way: +// - if Subprotocols in the Upgrader struct is unset and the client's subprotocol is unset (or empty), +// it returns "" +// - if Subprotocols in the Upgrader struct is set and responseHeader is unset, +// it returns the first matching subprotocol from Subprotocols and the r *http.Request +// - if responseHeader is set, it returns the first matching subprotocol from the ResponseHeader (ignoring Subprotocols) +// In any other case, e.g. no matching subprotocols are found, it returns "" and false. +// The second return value is of type bool, true = match found, false = no match found. +func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) (string, bool) { + clientProtocols := Subprotocols(r) + + if responseProtocols, ok := responseHeader["Sec-WebSocket-Protocol"]; ok { + return firstMatching(responseProtocols, clientProtocols) + } else if u.Subprotocols != nil { + return firstMatching(u.Subprotocols, clientProtocols) + } else if clientProtocols == nil { + return "", true + } + + return "", false } // Upgrade upgrades the HTTP server connection to the WebSocket protocol. @@ -140,10 +173,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade challengeKey := r.Header.Get("Sec-Websocket-Key") if challengeKey == "" { - return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank") + return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank") } - subprotocol := u.selectSubprotocol(r, responseHeader) + subprotocol, ok := u.selectSubprotocol(r, responseHeader) + if !ok { + return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported client subprotocol") + } // Negotiate PMCE var compress bool @@ -276,20 +312,6 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, return u.Upgrade(w, r, responseHeader) } -// Subprotocols returns the subprotocols requested by the client in the -// Sec-Websocket-Protocol header. -func Subprotocols(r *http.Request) []string { - h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol")) - if h == "" { - return nil - } - protocols := strings.Split(h, ",") - for i := range protocols { - protocols[i] = strings.TrimSpace(protocols[i]) - } - return protocols -} - // IsWebSocketUpgrade returns true if the client requested upgrade to the // WebSocket protocol. func IsWebSocketUpgrade(r *http.Request) bool {