mirror of https://github.com/gorilla/websocket.git
Compare request header tokens with ASCII case folding
Compare request header tokens with ASCII case folding per the WebSocket RFC.
This commit is contained in:
parent
aa5ed01c91
commit
447c2df769
28
util.go
28
util.go
|
@ -11,6 +11,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||||
|
@ -127,8 +128,31 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
|
||||||
return "", ""
|
return "", ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// equalASCIIFold returns true if s is equal to t with ASCII case folding.
|
||||||
|
func equalASCIIFold(s, t string) bool {
|
||||||
|
for s != "" && t != "" {
|
||||||
|
sr, size := utf8.DecodeRuneInString(s)
|
||||||
|
s = s[size:]
|
||||||
|
tr, size := utf8.DecodeRuneInString(t)
|
||||||
|
t = t[size:]
|
||||||
|
if sr == tr {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if 'A' <= sr && sr <= 'Z' {
|
||||||
|
sr = sr + 'a' - 'A'
|
||||||
|
}
|
||||||
|
if 'A' <= tr && tr <= 'Z' {
|
||||||
|
tr = tr + 'a' - 'A'
|
||||||
|
}
|
||||||
|
if sr != tr {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s == t
|
||||||
|
}
|
||||||
|
|
||||||
// tokenListContainsValue returns true if the 1#token header with the given
|
// tokenListContainsValue returns true if the 1#token header with the given
|
||||||
// name contains token.
|
// name contains a token equal to value with ASCII case folding.
|
||||||
func tokenListContainsValue(header http.Header, name string, value string) bool {
|
func tokenListContainsValue(header http.Header, name string, value string) bool {
|
||||||
headers:
|
headers:
|
||||||
for _, s := range header[name] {
|
for _, s := range header[name] {
|
||||||
|
@ -142,7 +166,7 @@ headers:
|
||||||
if s != "" && s[0] != ',' {
|
if s != "" && s[0] != ',' {
|
||||||
continue headers
|
continue headers
|
||||||
}
|
}
|
||||||
if strings.EqualFold(t, value) {
|
if equalASCIIFold(t, value) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if s == "" {
|
if s == "" {
|
||||||
|
|
18
util_test.go
18
util_test.go
|
@ -10,6 +10,24 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var equalASCIIFoldTests = []struct {
|
||||||
|
t, s string
|
||||||
|
eq bool
|
||||||
|
}{
|
||||||
|
{"WebSocket", "websocket", true},
|
||||||
|
{"websocket", "WebSocket", true},
|
||||||
|
{"Öyster", "öyster", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEqualASCIIFold(t *testing.T) {
|
||||||
|
for _, tt := range equalASCIIFoldTests {
|
||||||
|
eq := equalASCIIFold(tt.s, tt.t)
|
||||||
|
if eq != tt.eq {
|
||||||
|
t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var tokenListContainsValueTests = []struct {
|
var tokenListContainsValueTests = []struct {
|
||||||
value string
|
value string
|
||||||
ok bool
|
ok bool
|
||||||
|
|
Loading…
Reference in New Issue