diff --git a/server.go b/server.go index 9251734..3201358 100644 --- a/server.go +++ b/server.go @@ -49,12 +49,14 @@ type Upgrader struct { } // Return an error depending on settings on the Upgrader -func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason error) { +func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) { + err := HandshakeError{reason} if u.Error != nil { - u.Error(w, r, status, reason) + u.Error(w, r, status, err) } else { - http.Error(w, reason.Error(), status) + http.Error(w, reason, status) } + return nil, err } // Check if the passed subprotocol is supported by the server @@ -98,21 +100,15 @@ func (u *Upgrader) checkSameOrigin(r *http.Request) bool { // an error message already has been returned to the caller. func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" { - err := HandshakeError{"websocket: version != 13"} - u.returnError(w, r, http.StatusBadRequest, err) - return nil, err + return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13") } if !tokenListContainsValue(r.Header, "Connection", "upgrade") { - err := HandshakeError{"websocket: connection header != upgrade"} - u.returnError(w, r, http.StatusBadRequest, err) - return nil, err + return u.returnError(w, r, http.StatusBadRequest, "websocket: connection header != upgrade") } if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { - err := HandshakeError{"websocket: upgrade != websocket"} - u.returnError(w, r, http.StatusBadRequest, err) - return nil, err + return u.returnError(w, r, http.StatusBadRequest, "websocket: upgrade != websocket") } checkOrigin := u.CheckOrigin @@ -120,17 +116,13 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade checkOrigin = u.checkSameOrigin } if !checkOrigin(r) { - err := HandshakeError{"websocket: origin not allowed"} - u.returnError(w, r, http.StatusForbidden, err) - return nil, err + return u.returnError(w, r, http.StatusForbidden, "websocket: origin not allowed") } var challengeKey string values := r.Header["Sec-Websocket-Key"] if len(values) == 0 || values[0] == "" { - err := HandshakeError{"websocket: key missing or blank"} - u.returnError(w, r, http.StatusBadRequest, err) - return nil, err + return u.returnError(w, r, http.StatusBadRequest, "websocket: key missing or blank") } challengeKey = values[0]