diff --git a/server.go b/server.go index 3201358..fb25f2e 100644 --- a/server.go +++ b/server.go @@ -34,9 +34,10 @@ type Upgrader struct { // default values will be used. ReadBufferSize, WriteBufferSize int - // Subprotocols specifies the server's supported protocols. If Subprotocols - // is nil, then Upgrade does not negotiate a subprotocol. - Subprotocols []string + // NegotiateSubprotocol specifies the function to negotiate a subprotocol + // based on a request. If NegotiateSubprotocol is nil, then no subprotocol + // will be used. + NegotiateSubprotocol func(r *http.Request) (string, error) // Error specifies the function for generating HTTP error responses. If Error // is nil, then http.Error is used to generate the HTTP response. @@ -59,21 +60,6 @@ func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status in return nil, err } -// Check if the passed subprotocol is supported by the server -func (u *Upgrader) hasSubprotocol(subprotocol string) bool { - if u.Subprotocols == nil { - return false - } - - for _, s := range u.Subprotocols { - if s == subprotocol { - return true - } - } - - return false -} - // Check if host in Origin header matches host of request func (u *Upgrader) checkSameOrigin(r *http.Request) bool { origin := r.Header.Get("Origin") @@ -155,12 +141,11 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade } c := newConn(netConn, true, readBufSize, writeBufSize) - if u.Subprotocols != nil { - for _, proto := range Subprotocols(r) { - if u.hasSubprotocol(proto) { - c.subprotocol = proto - break - } + if u.NegotiateSubprotocol != nil { + c.subprotocol, err = u.NegotiateSubprotocol(r) + if err != nil { + netConn.Close() + return nil, err } } else if responseHeader != nil { c.subprotocol = responseHeader.Get("Sec-Websocket-Protocol")