Simplified returning of errors.

This commit is contained in:
Joachim Bauch 2014-04-18 00:20:46 +02:00
parent 018944708b
commit 2b15a66741
1 changed files with 10 additions and 18 deletions

View File

@ -49,12 +49,14 @@ type Upgrader struct {
} }
// Return an error depending on settings on the Upgrader // Return an error depending on settings on the Upgrader
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason error) { func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
err := HandshakeError{reason}
if u.Error != nil { if u.Error != nil {
u.Error(w, r, status, reason) u.Error(w, r, status, err)
} else { } else {
http.Error(w, reason.Error(), status) http.Error(w, reason, status)
} }
return nil, err
} }
// Check if the passed subprotocol is supported by the server // Check if the passed subprotocol is supported by the server
@ -98,21 +100,15 @@ func (u *Upgrader) checkSameOrigin(r *http.Request) bool {
// an error message already has been returned to the caller. // an error message already has been returned to the caller.
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" { if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" {
err := HandshakeError{"websocket: version != 13"} return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13")
u.returnError(w, r, http.StatusBadRequest, err)
return nil, err
} }
if !tokenListContainsValue(r.Header, "Connection", "upgrade") { if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
err := HandshakeError{"websocket: connection header != upgrade"} return u.returnError(w, r, http.StatusBadRequest, "websocket: connection header != upgrade")
u.returnError(w, r, http.StatusBadRequest, err)
return nil, err
} }
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
err := HandshakeError{"websocket: upgrade != websocket"} return u.returnError(w, r, http.StatusBadRequest, "websocket: upgrade != websocket")
u.returnError(w, r, http.StatusBadRequest, err)
return nil, err
} }
checkOrigin := u.CheckOrigin checkOrigin := u.CheckOrigin
@ -120,17 +116,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
checkOrigin = u.checkSameOrigin checkOrigin = u.checkSameOrigin
} }
if !checkOrigin(r) { if !checkOrigin(r) {
err := HandshakeError{"websocket: origin not allowed"} return u.returnError(w, r, http.StatusForbidden, "websocket: origin not allowed")
u.returnError(w, r, http.StatusForbidden, err)
return nil, err
} }
var challengeKey string var challengeKey string
values := r.Header["Sec-Websocket-Key"] values := r.Header["Sec-Websocket-Key"]
if len(values) == 0 || values[0] == "" { if len(values) == 0 || values[0] == "" {
err := HandshakeError{"websocket: key missing or blank"} return u.returnError(w, r, http.StatusBadRequest, "websocket: key missing or blank")
u.returnError(w, r, http.StatusBadRequest, err)
return nil, err
} }
challengeKey = values[0] challengeKey = values[0]