Add Sec-WebSocket-Extensions header parser

Also, improve token list header parser.
This commit is contained in:
Gary Burd 2016-05-31 09:32:45 -07:00
parent be01041b66
commit 8b29b78138
2 changed files with 223 additions and 13 deletions

196
util.go
View File

@ -13,19 +13,6 @@ import (
"strings"
)
// tokenListContainsValue returns true if the 1#token header with the given
// name contains token.
func tokenListContainsValue(header http.Header, name string, value string) bool {
for _, v := range header[name] {
for _, s := range strings.Split(v, ",") {
if strings.EqualFold(value, strings.TrimSpace(s)) {
return true
}
}
}
return false
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
@ -42,3 +29,186 @@ func generateChallengeKey() (string, error) {
}
return base64.StdEncoding.EncodeToString(p), nil
}
// Octet types from RFC 2616.
var octetTypes [256]byte
const (
isTokenOctet = 1 << iota
isSpaceOctet
)
func init() {
// From RFC 2616
//
// OCTET = <any 8-bit sequence of data>
// CHAR = <any US-ASCII character (octets 0 - 127)>
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
// CR = <US-ASCII CR, carriage return (13)>
// LF = <US-ASCII LF, linefeed (10)>
// SP = <US-ASCII SP, space (32)>
// HT = <US-ASCII HT, horizontal-tab (9)>
// <"> = <US-ASCII double-quote mark (34)>
// CRLF = CR LF
// LWS = [CRLF] 1*( SP | HT )
// TEXT = <any OCTET except CTLs, but including LWS>
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
// token = 1*<any CHAR except CTLs or separators>
// qdtext = <any TEXT except <">>
for c := 0; c < 256; c++ {
var t byte
isCtl := c <= 31 || c == 127
isChar := 0 <= c && c <= 127
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
t |= isSpaceOctet
}
if isChar && !isCtl && !isSeparator {
t |= isTokenOctet
}
octetTypes[c] = t
}
}
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isSpaceOctet == 0 {
break
}
}
return s[i:]
}
func nextToken(s string) (token, rest string) {
i := 0
for ; i < len(s); i++ {
if octetTypes[s[i]]&isTokenOctet == 0 {
break
}
}
return s[:i], s[i:]
}
func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") {
return nextToken(s)
}
s = s[1:]
for i := 0; i < len(s); i++ {
switch s[i] {
case '"':
return s[:i], s[i+1:]
case '\\':
p := make([]byte, len(s)-1)
j := copy(p, s[:i])
escape := true
for i = i + 1; i < len(s); i++ {
b := s[i]
switch {
case escape:
escape = false
p[j] = b
j += 1
case b == '\\':
escape = true
case b == '"':
return string(p[:j]), s[i+1:]
default:
p[j] = b
j += 1
}
}
return "", ""
}
}
return "", ""
}
// tokenListContainsValue returns true if the 1#token header with the given
// name contains token.
func tokenListContainsValue(header http.Header, name string, value string) bool {
headers:
for _, s := range header[name] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
s = skipSpace(s)
if s != "" && s[0] != ',' {
continue headers
}
if strings.EqualFold(t, value) {
return true
}
if s == "" {
continue headers
}
s = s[1:]
}
}
return false
}
// parseExtensiosn parses WebSocket extensions from a header.
func parseExtensions(header http.Header) []map[string]string {
// From RFC 6455:
//
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
var result []map[string]string
headers:
for _, s := range header["Sec-Websocket-Extensions"] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
continue headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
continue headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
continue headers
}
result = append(result, ext)
if s == "" {
continue headers
}
s = s[1:]
}
}
return result
}

View File

@ -6,6 +6,7 @@ package websocket
import (
"net/http"
"reflect"
"testing"
)
@ -32,3 +33,42 @@ func TestTokenListContainsValue(t *testing.T) {
}
}
}
var parseExtensionTests = []struct {
value string
extensions []map[string]string
}{
{`foo`, []map[string]string{map[string]string{"": "foo"}}},
{`foo, bar; baz=2`, []map[string]string{
map[string]string{"": "foo"},
map[string]string{"": "bar", "baz": "2"}}},
{`foo; bar="b,a;z"`, []map[string]string{
map[string]string{"": "foo", "bar": "b,a;z"}}},
{`foo , bar; baz = 2`, []map[string]string{
map[string]string{"": "foo"},
map[string]string{"": "bar", "baz": "2"}}},
{`foo, bar; baz=2 junk`, []map[string]string{
map[string]string{"": "foo"}}},
{`foo junk, bar; baz=2 junk`, nil},
{`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{
map[string]string{"": "mux", "max-channels": "4", "flow-control": ""},
map[string]string{"": "deflate-stream"}}},
{`permessage-foo; x="10"`, []map[string]string{
map[string]string{"": "permessage-foo", "x": "10"}}},
{`permessage-foo; use_y, permessage-foo`, []map[string]string{
map[string]string{"": "permessage-foo", "use_y": ""},
map[string]string{"": "permessage-foo"}}},
{`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{
map[string]string{"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"},
map[string]string{"": "permessage-deflate", "client_max_window_bits": ""}}},
}
func TestParseExtensions(t *testing.T) {
for _, tt := range parseExtensionTests {
h := http.Header{http.CanonicalHeaderKey("Sec-WebSocket-Extensions"): {tt.value}}
extensions := parseExtensions(h)
if !reflect.DeepEqual(extensions, tt.extensions) {
t.Errorf("parseExtensions(%q)\n = %v,\nwant %v", tt.value, extensions, tt.extensions)
}
}
}