forked from mirror/websocket
Add check for Sec-WebSocket-Key header (#752)
* add Sec-WebSocket-Key header verification * add testcase to Sec-WebSocket-Key header verification
This commit is contained in:
parent
9111bb834a
commit
69d0eb9187
|
@ -154,8 +154,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
||||||
}
|
}
|
||||||
|
|
||||||
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||||||
if challengeKey == "" {
|
if !isValidChallengeKey(challengeKey) {
|
||||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
|
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
|
||||||
}
|
}
|
||||||
|
|
||||||
subprotocol := u.selectSubprotocol(r, responseHeader)
|
subprotocol := u.selectSubprotocol(r, responseHeader)
|
||||||
|
|
15
util.go
15
util.go
|
@ -281,3 +281,18 @@ headers:
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isValidChallengeKey checks if the argument meets RFC6455 specification.
|
||||||
|
func isValidChallengeKey(s string) bool {
|
||||||
|
// From RFC6455:
|
||||||
|
//
|
||||||
|
// A |Sec-WebSocket-Key| header field with a base64-encoded (see
|
||||||
|
// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
|
||||||
|
// length.
|
||||||
|
|
||||||
|
if s == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(s)
|
||||||
|
return err == nil && len(decoded) == 16
|
||||||
|
}
|
||||||
|
|
19
util_test.go
19
util_test.go
|
@ -53,6 +53,25 @@ func TestTokenListContainsValue(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var isValidChallengeKeyTests = []struct {
|
||||||
|
key string
|
||||||
|
ok bool
|
||||||
|
}{
|
||||||
|
{"dGhlIHNhbXBsZSBub25jZQ==", true},
|
||||||
|
{"", false},
|
||||||
|
{"InvalidKey", false},
|
||||||
|
{"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidChallengeKey(t *testing.T) {
|
||||||
|
for _, tt := range isValidChallengeKeyTests {
|
||||||
|
ok := isValidChallengeKey(tt.key)
|
||||||
|
if ok != tt.ok {
|
||||||
|
t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var parseExtensionTests = []struct {
|
var parseExtensionTests = []struct {
|
||||||
value string
|
value string
|
||||||
extensions []map[string]string
|
extensions []map[string]string
|
||||||
|
|
Loading…
Reference in New Issue