merge: master

This commit is contained in:
misu 2019-02-02 13:48:39 +09:00
commit a8d45c1b96
9 changed files with 371 additions and 154 deletions

View File

@ -27,7 +27,7 @@ package API is stable.
### Protocol Compliance ### Protocol Compliance
The Gorilla WebSocket package passes the server tests in the [Autobahn Test The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn). subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
### Gorilla WebSocket compared with other packages ### Gorilla WebSocket compared with other packages

View File

@ -71,7 +71,7 @@ type Dialer struct {
// HandshakeTimeout specifies the duration for the handshake to complete. // HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then a useful default size is used. The I/O buffer sizes // size is zero, then a useful default size is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received. // do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int ReadBufferSize, WriteBufferSize int
@ -144,7 +144,7 @@ var nilDialer = *DefaultDialer
// Use the response.Header to get the selected subprotocol // Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
// //
// The context will be used in the request and in the Dialer // The context will be used in the request and in the Dialer.
// //
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication, // non-nil *http.Response so that callers can handle redirects, authentication,

View File

@ -11,8 +11,10 @@ import (
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log"
"net" "net"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
@ -51,15 +53,10 @@ type cstHandler struct {
cstHandlerConfig cstHandlerConfig
} }
var cstDialerWithoutHandshakeTimeout = Dialer{
Subprotocols: []string{"p1", "p2"},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
type cstServer struct { type cstServer struct {
*httptest.Server *httptest.Server
URL string URL string
t *testing.T
} }
const ( const (
@ -341,10 +338,7 @@ func TestDialCookieJar(t *testing.T) {
sendRecv(t, ws) sendRecv(t, ws)
} }
func TestDialTLS(t *testing.T) { func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
s := newTLSServer(t, cstHandlerConfig{})
defer s.Close()
certs := x509.NewCertPool() certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates { for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
@ -355,35 +349,15 @@ func TestDialTLS(t *testing.T) {
certs.AddCert(root) certs.AddCert(root)
} }
} }
return certs
d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs}
ws, _, err := d.Dial(s.URL, nil)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(t, ws)
} }
func xTestDialTLSBadCert(t *testing.T) { func TestDialTLS(t *testing.T) {
// This test is deactivated because of noisy logging from the net/http package.
s := newTLSServer(t, cstHandlerConfig{})
defer s.Close()
ws, _, err := cstDialer.Dial(s.URL, nil)
if err == nil {
ws.Close()
t.Fatalf("Dial: nil")
}
}
func TestDialTLSNoVerify(t *testing.T) {
s := newTLSServer(t, cstHandlerConfig{}) s := newTLSServer(t, cstHandlerConfig{})
defer s.Close() defer s.Close()
d := cstDialer d := cstDialer
d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.Dial(s.URL, nil) ws, _, err := d.Dial(s.URL, nil)
if err != nil { if err != nil {
t.Fatalf("Dial: %v", err) t.Fatalf("Dial: %v", err)
@ -468,7 +442,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
s := newServer(t, cstHandlerConfig{}) s := newServer(t, cstHandlerConfig{})
defer s.Close() defer s.Close()
d := cstDialerWithoutHandshakeTimeout d := cstDialer
d.HandshakeTimeout = 0
d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) { d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
netDialer := &net.Dialer{} netDialer := &net.Dialer{}
c, err := netDialer.DialContext(ctx, n, a) c, err := netDialer.DialContext(ctx, n, a)
@ -619,33 +594,195 @@ func TestRespOnBadHandshake(t *testing.T) {
} }
} }
// TestHostHeader confirms that the host header provided in the call to Dial is type testLogWriter struct {
// sent to the server. t *testing.T
func TestHostHeader(t *testing.T) { }
s := newServer(t, cstHandlerConfig{})
defer s.Close()
specifiedHost := make(chan string, 1) func (w testLogWriter) Write(p []byte) (int, error) {
origHandler := s.Server.Config.Handler w.t.Logf("%s", p)
return len(p), nil
}
// Capture the request Host header. // TestHost tests handling of host names and confirms that it matches net/http.
s.Server.Config.Handler = http.HandlerFunc( func TestHost(t *testing.T) {
func(w http.ResponseWriter, r *http.Request) {
specifiedHost <- r.Host
origHandler.ServeHTTP(w, r)
})
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) upgrader := Upgrader{}
if err != nil { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("Dial: %v", err) if IsWebSocketUpgrade(r) {
} c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
defer ws.Close() if err != nil {
t.Fatal(err)
}
c.Close()
} else {
w.Header().Set("X-Test-Host", r.Host)
}
})
if gotHost := <-specifiedHost; gotHost != "testhost" { server := httptest.NewServer(handler)
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost) defer server.Close()
tlsServer := httptest.NewTLSServer(handler)
defer tlsServer.Close()
addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
// Avoid log noise from net/http server by logging to testing.T
server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
tlsServer.Config.ErrorLog = server.Config.ErrorLog
cas := rootCAs(t, tlsServer)
tests := []struct {
fail bool // true if dial / get should fail
server *httptest.Server // server to use
url string // host for request URI
header string // optional request host header
tls string // optiona host for tls ServerName
wantAddr string // expected host for dial
wantHeader string // expected request header on server
insecureSkipVerify bool
}{
{
server: server,
url: addrs[server],
wantAddr: addrs[server],
wantHeader: addrs[server],
},
{
server: tlsServer,
url: addrs[tlsServer],
wantAddr: addrs[tlsServer],
wantHeader: addrs[tlsServer],
},
{
server: server,
url: addrs[server],
header: "badhost.com",
wantAddr: addrs[server],
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: addrs[tlsServer],
header: "badhost.com",
wantAddr: addrs[tlsServer],
wantHeader: "badhost.com",
},
{
server: server,
url: "example.com",
header: "badhost.com",
wantAddr: "example.com:80",
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: "example.com",
header: "badhost.com",
wantAddr: "example.com:443",
wantHeader: "badhost.com",
},
{
server: server,
url: "badhost.com",
header: "example.com",
wantAddr: "badhost.com:80",
wantHeader: "example.com",
},
{
fail: true,
server: tlsServer,
url: "badhost.com",
header: "example.com",
wantAddr: "badhost.com:443",
},
{
server: tlsServer,
url: "badhost.com",
insecureSkipVerify: true,
wantAddr: "badhost.com:443",
wantHeader: "badhost.com",
},
{
server: tlsServer,
url: "badhost.com",
tls: "example.com",
wantAddr: "badhost.com:443",
wantHeader: "badhost.com",
},
} }
sendRecv(t, ws) for i, tt := range tests {
tls := &tls.Config{
RootCAs: cas,
ServerName: tt.tls,
InsecureSkipVerify: tt.insecureSkipVerify,
}
var gotAddr string
dialer := Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
gotAddr = addr
return net.Dial(network, addrs[tt.server])
},
TLSClientConfig: tls,
}
// Test websocket dial
h := http.Header{}
if tt.header != "" {
h.Set("Host", tt.header)
}
c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
if err == nil {
c.Close()
}
check := func(protos map[*httptest.Server]string) {
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
if gotAddr != tt.wantAddr {
t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
}
switch {
case tt.fail && err == nil:
t.Errorf("%s: unexpected success", name)
case !tt.fail && err != nil:
t.Errorf("%s: unexpected error %v", name, err)
case !tt.fail && err == nil:
if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
}
}
}
check(wsProtos)
// Confirm that net/http has same result
transport := &http.Transport{
Dial: dialer.NetDial,
TLSClientConfig: dialer.TLSClientConfig,
}
req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
if tt.header != "" {
req.Host = tt.header
}
client := &http.Client{Transport: transport}
resp, err = client.Do(req)
if err == nil {
resp.Body.Close()
}
transport.CloseIdleConnections()
check(httpProtos)
}
} }
func TestDialCompression(t *testing.T) { func TestDialCompression(t *testing.T) {
@ -785,19 +922,8 @@ func TestTracingDialWithContext(t *testing.T) {
s := newTLSServer(t, cstHandlerConfig{}) s := newTLSServer(t, cstHandlerConfig{})
defer s.Close() defer s.Close()
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}
d := cstDialer d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs} d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.DialContext(ctx, s.URL, nil) ws, _, err := d.DialContext(ctx, s.URL, nil)
if err != nil { if err != nil {
@ -835,19 +961,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
s := newTLSServer(t, cstHandlerConfig{}) s := newTLSServer(t, cstHandlerConfig{})
defer s.Close() defer s.Close()
certs := x509.NewCertPool()
for _, c := range s.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}
d := cstDialer d := cstDialer
d.TLSClientConfig = &tls.Config{RootCAs: certs} d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
ws, _, err := d.DialContext(ctx, s.URL, nil) ws, _, err := d.DialContext(ctx, s.URL, nil)
if err != nil { if err != nil {

49
conn.go
View File

@ -453,7 +453,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return err return err
} }
func (c *Conn) prepWrite(messageType int) error { // beginMessage prepares a connection and message writer for a new message.
func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// Close previous writer if not already closed by the application. It's // Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot // probably better to return an error in this situation, but we cannot
// change this without breaking existing applications. // change this without breaking existing applications.
@ -473,6 +474,10 @@ func (c *Conn) prepWrite(messageType int) error {
return err return err
} }
mw.c = c
mw.frameType = messageType
mw.pos = maxFrameHeaderSize
if c.writeBuf == nil { if c.writeBuf == nil {
wpd, ok := c.writePool.Get().(writePoolData) wpd, ok := c.writePool.Get().(writePoolData)
if ok { if ok {
@ -493,16 +498,11 @@ func (c *Conn) prepWrite(messageType int) error {
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported. // PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
if err := c.prepWrite(messageType); err != nil { var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return nil, err return nil, err
} }
c.writer = &mw
mw := &messageWriter{
c: c,
frameType: messageType,
pos: maxFrameHeaderSize,
}
c.writer = mw
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w := c.newCompressionWriter(c.writer, c.compressionLevel) w := c.newCompressionWriter(c.writer, c.compressionLevel)
mw.compress = true mw.compress = true
@ -519,10 +519,16 @@ type messageWriter struct {
err error err error
} }
func (w *messageWriter) fatal(err error) error { func (w *messageWriter) endMessage(err error) error {
if w.err != nil { if w.err != nil {
w.err = err return err
w.c.writer = nil }
c := w.c
w.err = err
c.writer = nil
if c.writePool != nil {
c.writePool.Put(writePoolData{buf: c.writeBuf})
c.writeBuf = nil
} }
return err return err
} }
@ -536,7 +542,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
// Check for invalid control frames. // Check for invalid control frames.
if isControl(w.frameType) && if isControl(w.frameType) &&
(!final || length > maxControlFramePayloadSize) { (!final || length > maxControlFramePayloadSize) {
return w.fatal(errInvalidControlFrame) return w.endMessage(errInvalidControlFrame)
} }
b0 := byte(w.frameType) b0 := byte(w.frameType)
@ -581,7 +587,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
if len(extra) > 0 { if len(extra) > 0 {
return c.writeFatal(errors.New("websocket: internal error, extra used in client mode")) return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
} }
} }
@ -602,15 +608,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
c.isWriting = false c.isWriting = false
if err != nil { if err != nil {
return w.fatal(err) return w.endMessage(err)
} }
if final { if final {
c.writer = nil w.endMessage(errWriteClosed)
if c.writePool != nil {
c.writePool.Put(writePoolData{buf: c.writeBuf})
c.writeBuf = nil
}
return nil return nil
} }
@ -711,7 +713,6 @@ func (w *messageWriter) Close() error {
if err := w.flushFrame(true, nil); err != nil { if err := w.flushFrame(true, nil); err != nil {
return err return err
} }
w.err = errWriteClosed
return nil return nil
} }
@ -744,10 +745,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
// Fast path with no allocations and single frame. // Fast path with no allocations and single frame.
if err := c.prepWrite(messageType); err != nil { var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return err return err
} }
mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
n := copy(c.writeBuf[mw.pos:], data) n := copy(c.writeBuf[mw.pos:], data)
mw.pos += n mw.pos += n
data = data[n:] data = data[n:]
@ -1043,7 +1044,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t) return c.conn.SetReadDeadline(t)
} }
// SetReadLimit sets the maximum size for a message read from the peer. If a // SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
// message exceeds the limit, the connection sends a close message to the peer // message exceeds the limit, the connection sends a close message to the peer
// and returns ErrReadLimit to the application. // and returns ErrReadLimit to the application.
func (c *Conn) SetReadLimit(limit int64) { func (c *Conn) SetReadLimit(limit int64) {

View File

@ -227,11 +227,16 @@ func (p *simpleBufferPool) Put(v interface{}) {
} }
func TestWriteBufferPool(t *testing.T) { func TestWriteBufferPool(t *testing.T) {
const message = "Now is the time for all good people to come to the aid of the party."
var buf bytes.Buffer var buf bytes.Buffer
var pool simpleBufferPool var pool simpleBufferPool
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
rc := newTestConn(&buf, nil, false) rc := newTestConn(&buf, nil, false)
// Specify writeBufferSize smaller than message size to ensure that pooling
// works with fragmented messages.
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
if wc.writeBuf != nil { if wc.writeBuf != nil {
t.Fatal("writeBuf not nil after create") t.Fatal("writeBuf not nil after create")
} }
@ -249,8 +254,6 @@ func TestWriteBufferPool(t *testing.T) {
writeBufAddr := &wc.writeBuf[0] writeBufAddr := &wc.writeBuf[0]
const message = "Hello World!"
if _, err := io.WriteString(w, message); err != nil { if _, err := io.WriteString(w, message); err != nil {
t.Fatalf("io.WriteString(w, message) returned %v", err) t.Fatalf("io.WriteString(w, message) returned %v", err)
} }
@ -300,6 +303,7 @@ func TestWriteBufferPool(t *testing.T) {
} }
} }
// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
func TestWriteBufferPoolSync(t *testing.T) { func TestWriteBufferPoolSync(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
var pool sync.Pool var pool sync.Pool
@ -321,6 +325,56 @@ func TestWriteBufferPoolSync(t *testing.T) {
} }
} }
// errorWriter is an io.Writer than returns an error on all writes.
type errorWriter struct{}
func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("Error!") }
// TestWriteBufferPoolError ensures that buffer is returned to pool after error
// on write.
func TestWriteBufferPoolError(t *testing.T) {
// Part 1: Test NextWriter/Write/Close
var pool simpleBufferPool
wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
w, err := wc.NextWriter(TextMessage)
if err != nil {
t.Fatalf("wc.NextWriter() returned %v", err)
}
if wc.writeBuf == nil {
t.Fatal("writeBuf is nil after NextWriter")
}
writeBufAddr := &wc.writeBuf[0]
if _, err := io.WriteString(w, "Hello"); err != nil {
t.Fatalf("io.WriteString(w, message) returned %v", err)
}
if err := w.Close(); err == nil {
t.Fatalf("w.Close() did not return error")
}
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
t.Fatal("writeBuf not returned to pool")
}
// Part 2: Test WriteMessage
wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
t.Fatalf("wc.WriteMessage did not return error")
}
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
t.Fatal("writeBuf not returned to pool")
}
}
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
const bufSize = 512 const bufSize = 512

View File

@ -1,6 +1,6 @@
# Test Server # Test Server
This package contains a server for the [Autobahn WebSockets Test Suite](http://autobahn.ws/testsuite). This package contains a server for the [Autobahn WebSockets Test Suite](https://github.com/crossbario/autobahn-testsuite).
To test the server, run To test the server, run

View File

@ -28,7 +28,7 @@ type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete. // HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then buffers allocated by the HTTP server are used. The // size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent // I/O buffer sizes do not limit the size of the messages that can be sent
// or received. // or received.

132
util.go
View File

@ -31,68 +31,113 @@ func generateChallengeKey() (string, error) {
return base64.StdEncoding.EncodeToString(p), nil return base64.StdEncoding.EncodeToString(p), nil
} }
// Octet types from RFC 2616. // Token octets per RFC 2616.
var octetTypes [256]byte var isTokenOctet = [256]bool{
'!': true,
const ( '#': true,
isTokenOctet = 1 << iota '$': true,
isSpaceOctet '%': true,
) '&': true,
'\'': true,
func init() { '*': true,
// From RFC 2616 '+': true,
// '-': true,
// OCTET = <any 8-bit sequence of data> '.': true,
// CHAR = <any US-ASCII character (octets 0 - 127)> '0': true,
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)> '1': true,
// CR = <US-ASCII CR, carriage return (13)> '2': true,
// LF = <US-ASCII LF, linefeed (10)> '3': true,
// SP = <US-ASCII SP, space (32)> '4': true,
// HT = <US-ASCII HT, horizontal-tab (9)> '5': true,
// <"> = <US-ASCII double-quote mark (34)> '6': true,
// CRLF = CR LF '7': true,
// LWS = [CRLF] 1*( SP | HT ) '8': true,
// TEXT = <any OCTET except CTLs, but including LWS> '9': true,
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <"> 'A': true,
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT 'B': true,
// token = 1*<any CHAR except CTLs or separators> 'C': true,
// qdtext = <any TEXT except <">> 'D': true,
'E': true,
for c := 0; c < 256; c++ { 'F': true,
var t byte 'G': true,
isCtl := c <= 31 || c == 127 'H': true,
isChar := 0 <= c && c <= 127 'I': true,
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0 'J': true,
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 { 'K': true,
t |= isSpaceOctet 'L': true,
} 'M': true,
if isChar && !isCtl && !isSeparator { 'N': true,
t |= isTokenOctet 'O': true,
} 'P': true,
octetTypes[c] = t 'Q': true,
} 'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
} }
// skipSpace returns a slice of the string s with all leading RFC 2616 linear
// whitespace removed.
func skipSpace(s string) (rest string) { func skipSpace(s string) (rest string) {
i := 0 i := 0
for ; i < len(s); i++ { for ; i < len(s); i++ {
if octetTypes[s[i]]&isSpaceOctet == 0 { if b := s[i]; b != ' ' && b != '\t' {
break break
} }
} }
return s[i:] return s[i:]
} }
// nextToken returns the leading RFC 2616 token of s and the string following
// the token.
func nextToken(s string) (token, rest string) { func nextToken(s string) (token, rest string) {
i := 0 i := 0
for ; i < len(s); i++ { for ; i < len(s); i++ {
if octetTypes[s[i]]&isTokenOctet == 0 { if !isTokenOctet[s[i]] {
break break
} }
} }
return s[:i], s[i:] return s[:i], s[i:]
} }
// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616
// and the string following the token or quoted string.
func nextTokenOrQuoted(s string) (value string, rest string) { func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") { if !strings.HasPrefix(s, "\"") {
return nextToken(s) return nextToken(s)
@ -128,7 +173,8 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
return "", "" return "", ""
} }
// equalASCIIFold returns true if s is equal to t with ASCII case folding. // equalASCIIFold returns true if s is equal to t with ASCII case folding as
// defined in RFC 4790.
func equalASCIIFold(s, t string) bool { func equalASCIIFold(s, t string) bool {
for s != "" && t != "" { for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s) sr, size := utf8.DecodeRuneInString(s)

View File

@ -17,6 +17,7 @@ var equalASCIIFoldTests = []struct {
{"WebSocket", "websocket", true}, {"WebSocket", "websocket", true},
{"websocket", "WebSocket", true}, {"websocket", "WebSocket", true},
{"Öyster", "öyster", false}, {"Öyster", "öyster", false},
{"WebSocket", "WetSocket", false},
} }
func TestEqualASCIIFold(t *testing.T) { func TestEqualASCIIFold(t *testing.T) {