Make subprotocol negotiation more flexible.

This commit is contained in:
Joachim Bauch 2014-04-18 00:25:32 +02:00
parent 2b15a66741
commit 0a7cd15dd1
1 changed files with 9 additions and 24 deletions

View File

@ -34,9 +34,10 @@ type Upgrader struct {
// default values will be used. // default values will be used.
ReadBufferSize, WriteBufferSize int ReadBufferSize, WriteBufferSize int
// Subprotocols specifies the server's supported protocols. If Subprotocols // NegotiateSubprotocol specifies the function to negotiate a subprotocol
// is nil, then Upgrade does not negotiate a subprotocol. // based on a request. If NegotiateSubprotocol is nil, then no subprotocol
Subprotocols []string // will be used.
NegotiateSubprotocol func(r *http.Request) (string, error)
// Error specifies the function for generating HTTP error responses. If Error // Error specifies the function for generating HTTP error responses. If Error
// is nil, then http.Error is used to generate the HTTP response. // is nil, then http.Error is used to generate the HTTP response.
@ -59,21 +60,6 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in
return nil, err return nil, err
} }
// Check if the passed subprotocol is supported by the server
func (u *Upgrader) hasSubprotocol(subprotocol string) bool {
if u.Subprotocols == nil {
return false
}
for _, s := range u.Subprotocols {
if s == subprotocol {
return true
}
}
return false
}
// Check if host in Origin header matches host of request // Check if host in Origin header matches host of request
func (u *Upgrader) checkSameOrigin(r *http.Request) bool { func (u *Upgrader) checkSameOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
@ -155,12 +141,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
c := newConn(netConn, true, readBufSize, writeBufSize) c := newConn(netConn, true, readBufSize, writeBufSize)
if u.Subprotocols != nil { if u.NegotiateSubprotocol != nil {
for _, proto := range Subprotocols(r) { c.subprotocol, err = u.NegotiateSubprotocol(r)
if u.hasSubprotocol(proto) { if err != nil {
c.subprotocol = proto netConn.Close()
break return nil, err
}
} }
} else if responseHeader != nil { } else if responseHeader != nil {
c.subprotocol = responseHeader.Get("Sec-Websocket-Protocol") c.subprotocol = responseHeader.Get("Sec-Websocket-Protocol")