forked from mirror/websocket
Add check for Sec-WebSocket-Key header (#752)
* add Sec-WebSocket-Key header verification * add testcase to Sec-WebSocket-Key header verification
This commit is contained in:
parent
9111bb834a
commit
69d0eb9187
|
@ -154,8 +154,8 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
|
|||
}
|
||||
|
||||
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||||
if challengeKey == "" {
|
||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
|
||||
if !isValidChallengeKey(challengeKey) {
|
||||
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
|
||||
}
|
||||
|
||||
subprotocol := u.selectSubprotocol(r, responseHeader)
|
||||
|
|
15
util.go
15
util.go
|
@ -281,3 +281,18 @@ headers:
|
|||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// isValidChallengeKey checks if the argument meets RFC6455 specification.
|
||||
func isValidChallengeKey(s string) bool {
|
||||
// From RFC6455:
|
||||
//
|
||||
// A |Sec-WebSocket-Key| header field with a base64-encoded (see
|
||||
// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
|
||||
// length.
|
||||
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(s)
|
||||
return err == nil && len(decoded) == 16
|
||||
}
|
||||
|
|
19
util_test.go
19
util_test.go
|
@ -53,6 +53,25 @@ func TestTokenListContainsValue(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
var isValidChallengeKeyTests = []struct {
|
||||
key string
|
||||
ok bool
|
||||
}{
|
||||
{"dGhlIHNhbXBsZSBub25jZQ==", true},
|
||||
{"", false},
|
||||
{"InvalidKey", false},
|
||||
{"WHQ4eXhscUtKYjBvOGN3WEdtOEQ=", false},
|
||||
}
|
||||
|
||||
func TestIsValidChallengeKey(t *testing.T) {
|
||||
for _, tt := range isValidChallengeKeyTests {
|
||||
ok := isValidChallengeKey(tt.key)
|
||||
if ok != tt.ok {
|
||||
t.Errorf("isValidChallengeKey returns %v, want %v", ok, tt.ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var parseExtensionTests = []struct {
|
||||
value string
|
||||
extensions []map[string]string
|
||||
|
|
Loading…
Reference in New Issue