Implement subprotocol selection according to RFC6455 section 4.2.2 (TODO: update documentation)

This commit is contained in:
Jernej Jakob 2018-07-24 19:10:44 +02:00
parent 5ed622c449
commit 8a3691e53c
1 changed files with 49 additions and 27 deletions

View File

@ -83,20 +83,53 @@ func checkSameOrigin(r *http.Request) bool {
return equalASCIIFold(u.Host, r.Host)
}
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil {
clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol {
return clientProtocol
}
// firstMatching returns the first matching element present in both slices and true/false whether a match has been found.
func firstMatching(as []string, bs []string) (string, bool) {
for _, a := range as {
for _, b := range bs {
if a == b {
return a, true
}
}
} else if responseHeader != nil {
return responseHeader.Get("Sec-Websocket-Protocol")
}
return ""
return "", false
}
// Subprotocols returns the subprotocols requested by the client in the
// Sec-WebSocket-Protocol header.
func Subprotocols(r *http.Request) []string {
h := strings.TrimSpace(r.Header.Get("Sec-WebSocket-Protocol"))
if h == "" {
return nil
}
protocols := strings.Split(h, ",")
for i := range protocols {
protocols[i] = strings.TrimSpace(protocols[i])
}
return protocols
}
// selectSubprotocol returns the first matching subprotocol found, in the following way:
// - if Subprotocols in the Upgrader struct is unset and the client's subprotocol is unset (or empty),
// it returns ""
// - if Subprotocols in the Upgrader struct is set and responseHeader is unset,
// it returns the first matching subprotocol from Subprotocols and the r *http.Request
// - if responseHeader is set, it returns the first matching subprotocol from the ResponseHeader (ignoring Subprotocols)
// In any other case, e.g. no matching subprotocols are found, it returns "" and false.
// The second return value is of type bool, true = match found, false = no match found.
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) (string, bool) {
clientProtocols := Subprotocols(r)
if responseProtocols, ok := responseHeader["Sec-WebSocket-Protocol"]; ok {
return firstMatching(responseProtocols, clientProtocols)
} else if u.Subprotocols != nil {
return firstMatching(u.Subprotocols, clientProtocols)
} else if clientProtocols == nil {
return "", true
}
return "", false
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
@ -140,10 +173,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank")
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
}
subprotocol := u.selectSubprotocol(r, responseHeader)
subprotocol, ok := u.selectSubprotocol(r, responseHeader)
if !ok {
return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported client subprotocol")
}
// Negotiate PMCE
var compress bool
@ -276,20 +312,6 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
return u.Upgrade(w, r, responseHeader)
}
// Subprotocols returns the subprotocols requested by the client in the
// Sec-Websocket-Protocol header.
func Subprotocols(r *http.Request) []string {
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
if h == "" {
return nil
}
protocols := strings.Split(h, ",")
for i := range protocols {
protocols[i] = strings.TrimSpace(protocols[i])
}
return protocols
}
// IsWebSocketUpgrade returns true if the client requested upgrade to the
// WebSocket protocol.
func IsWebSocketUpgrade(r *http.Request) bool {