mirror of https://github.com/gorilla/websocket.git
Implement subprotocol selection according to RFC6455 section 4.2.2 (TODO: update documentation)
This commit is contained in:
parent
5ed622c449
commit
8a3691e53c
78
server.go
78
server.go
|
@ -83,20 +83,53 @@ func checkSameOrigin(r *http.Request) bool {
|
||||||
return equalASCIIFold(u.Host, r.Host)
|
return equalASCIIFold(u.Host, r.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
|
// firstMatching returns the first matching element present in both slices and true/false whether a match has been found.
|
||||||
if u.Subprotocols != nil {
|
func firstMatching(as []string, bs []string) (string, bool) {
|
||||||
|
for _, a := range as {
|
||||||
|
for _, b := range bs {
|
||||||
|
if a == b {
|
||||||
|
return a, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subprotocols returns the subprotocols requested by the client in the
|
||||||
|
// Sec-WebSocket-Protocol header.
|
||||||
|
func Subprotocols(r *http.Request) []string {
|
||||||
|
h := strings.TrimSpace(r.Header.Get("Sec-WebSocket-Protocol"))
|
||||||
|
if h == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
protocols := strings.Split(h, ",")
|
||||||
|
for i := range protocols {
|
||||||
|
protocols[i] = strings.TrimSpace(protocols[i])
|
||||||
|
}
|
||||||
|
return protocols
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectSubprotocol returns the first matching subprotocol found, in the following way:
|
||||||
|
// - if Subprotocols in the Upgrader struct is unset and the client's subprotocol is unset (or empty),
|
||||||
|
// it returns ""
|
||||||
|
// - if Subprotocols in the Upgrader struct is set and responseHeader is unset,
|
||||||
|
// it returns the first matching subprotocol from Subprotocols and the r *http.Request
|
||||||
|
// - if responseHeader is set, it returns the first matching subprotocol from the ResponseHeader (ignoring Subprotocols)
|
||||||
|
// In any other case, e.g. no matching subprotocols are found, it returns "" and false.
|
||||||
|
// The second return value is of type bool, true = match found, false = no match found.
|
||||||
|
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) (string, bool) {
|
||||||
clientProtocols := Subprotocols(r)
|
clientProtocols := Subprotocols(r)
|
||||||
for _, serverProtocol := range u.Subprotocols {
|
|
||||||
for _, clientProtocol := range clientProtocols {
|
if responseProtocols, ok := responseHeader["Sec-WebSocket-Protocol"]; ok {
|
||||||
if clientProtocol == serverProtocol {
|
return firstMatching(responseProtocols, clientProtocols)
|
||||||
return clientProtocol
|
} else if u.Subprotocols != nil {
|
||||||
|
return firstMatching(u.Subprotocols, clientProtocols)
|
||||||
|
} else if clientProtocols == nil {
|
||||||
|
return "", true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
return "", false
|
||||||
} else if responseHeader != nil {
|
|
||||||
return responseHeader.Get("Sec-Websocket-Protocol")
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
|
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
|
||||||
|
@ -140,10 +173,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
|
|
||||||
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||||||
if challengeKey == "" {
|
if challengeKey == "" {
|
||||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank")
|
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
|
||||||
}
|
}
|
||||||
|
|
||||||
subprotocol := u.selectSubprotocol(r, responseHeader)
|
subprotocol, ok := u.selectSubprotocol(r, responseHeader)
|
||||||
|
if !ok {
|
||||||
|
return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported client subprotocol")
|
||||||
|
}
|
||||||
|
|
||||||
// Negotiate PMCE
|
// Negotiate PMCE
|
||||||
var compress bool
|
var compress bool
|
||||||
|
@ -276,20 +312,6 @@ func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header,
|
||||||
return u.Upgrade(w, r, responseHeader)
|
return u.Upgrade(w, r, responseHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subprotocols returns the subprotocols requested by the client in the
|
|
||||||
// Sec-Websocket-Protocol header.
|
|
||||||
func Subprotocols(r *http.Request) []string {
|
|
||||||
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
|
|
||||||
if h == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
protocols := strings.Split(h, ",")
|
|
||||||
for i := range protocols {
|
|
||||||
protocols[i] = strings.TrimSpace(protocols[i])
|
|
||||||
}
|
|
||||||
return protocols
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsWebSocketUpgrade returns true if the client requested upgrade to the
|
// IsWebSocketUpgrade returns true if the client requested upgrade to the
|
||||||
// WebSocket protocol.
|
// WebSocket protocol.
|
||||||
func IsWebSocketUpgrade(r *http.Request) bool {
|
func IsWebSocketUpgrade(r *http.Request) bool {
|
||||||
|
|
Loading…
Reference in New Issue