Compare request header tokens with ASCII case folding

Compare request header tokens with ASCII case folding per the WebSocket
RFC.
This commit is contained in:
Gary Burd 2017-11-22 21:55:10 -08:00
parent aa5ed01c91
commit 447c2df769
2 changed files with 44 additions and 2 deletions

28
util.go
View File

@ -11,6 +11,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"unicode/utf8"
) )
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
@ -127,8 +128,31 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
return "", "" return "", ""
} }
// equalASCIIFold returns true if s is equal to t with ASCII case folding.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
// tokenListContainsValue returns true if the 1#token header with the given // tokenListContainsValue returns true if the 1#token header with the given
// name contains token. // name contains a token equal to value with ASCII case folding.
func tokenListContainsValue(header http.Header, name string, value string) bool { func tokenListContainsValue(header http.Header, name string, value string) bool {
headers: headers:
for _, s := range header[name] { for _, s := range header[name] {
@ -142,7 +166,7 @@ headers:
if s != "" && s[0] != ',' { if s != "" && s[0] != ',' {
continue headers continue headers
} }
if strings.EqualFold(t, value) { if equalASCIIFold(t, value) {
return true return true
} }
if s == "" { if s == "" {

View File

@ -10,6 +10,24 @@ import (
"testing" "testing"
) )
var equalASCIIFoldTests = []struct {
t, s string
eq bool
}{
{"WebSocket", "websocket", true},
{"websocket", "WebSocket", true},
{"Öyster", "öyster", false},
}
func TestEqualASCIIFold(t *testing.T) {
for _, tt := range equalASCIIFoldTests {
eq := equalASCIIFold(tt.s, tt.t)
if eq != tt.eq {
t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq)
}
}
}
var tokenListContainsValueTests = []struct { var tokenListContainsValueTests = []struct {
value string value string
ok bool ok bool