mirror of https://github.com/gorilla/websocket.git
merge: master
This commit is contained in:
commit
a8d45c1b96
|
@ -27,7 +27,7 @@ package API is stable.
|
|||
### Protocol Compliance
|
||||
|
||||
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
|
||||
Suite](http://autobahn.ws/testsuite) using the application in the [examples/autobahn
|
||||
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
|
||||
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
|
||||
|
||||
### Gorilla WebSocket compared with other packages
|
||||
|
|
|
@ -71,7 +71,7 @@ type Dialer struct {
|
|||
// HandshakeTimeout specifies the duration for the handshake to complete.
|
||||
HandshakeTimeout time.Duration
|
||||
|
||||
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
|
||||
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
|
||||
// size is zero, then a useful default size is used. The I/O buffer sizes
|
||||
// do not limit the size of the messages that can be sent or received.
|
||||
ReadBufferSize, WriteBufferSize int
|
||||
|
@ -144,7 +144,7 @@ var nilDialer = *DefaultDialer
|
|||
// Use the response.Header to get the selected subprotocol
|
||||
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
|
||||
//
|
||||
// The context will be used in the request and in the Dialer
|
||||
// The context will be used in the request and in the Dialer.
|
||||
//
|
||||
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
|
||||
// non-nil *http.Response so that callers can handle redirects, authentication,
|
||||
|
|
|
@ -11,8 +11,10 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
|
@ -51,15 +53,10 @@ type cstHandler struct {
|
|||
cstHandlerConfig
|
||||
}
|
||||
|
||||
var cstDialerWithoutHandshakeTimeout = Dialer{
|
||||
Subprotocols: []string{"p1", "p2"},
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
type cstServer struct {
|
||||
*httptest.Server
|
||||
URL string
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -341,10 +338,7 @@ func TestDialCookieJar(t *testing.T) {
|
|||
sendRecv(t, ws)
|
||||
}
|
||||
|
||||
func TestDialTLS(t *testing.T) {
|
||||
s := newTLSServer(t, cstHandlerConfig{})
|
||||
defer s.Close()
|
||||
|
||||
func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
|
||||
certs := x509.NewCertPool()
|
||||
for _, c := range s.TLS.Certificates {
|
||||
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
|
||||
|
@ -355,35 +349,15 @@ func TestDialTLS(t *testing.T) {
|
|||
certs.AddCert(root)
|
||||
}
|
||||
}
|
||||
|
||||
d := cstDialer
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
||||
ws, _, err := d.Dial(s.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
sendRecv(t, ws)
|
||||
return certs
|
||||
}
|
||||
|
||||
func xTestDialTLSBadCert(t *testing.T) {
|
||||
// This test is deactivated because of noisy logging from the net/http package.
|
||||
s := newTLSServer(t, cstHandlerConfig{})
|
||||
defer s.Close()
|
||||
|
||||
ws, _, err := cstDialer.Dial(s.URL, nil)
|
||||
if err == nil {
|
||||
ws.Close()
|
||||
t.Fatalf("Dial: nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialTLSNoVerify(t *testing.T) {
|
||||
func TestDialTLS(t *testing.T) {
|
||||
s := newTLSServer(t, cstHandlerConfig{})
|
||||
defer s.Close()
|
||||
|
||||
d := cstDialer
|
||||
d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
||||
ws, _, err := d.Dial(s.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
|
@ -468,7 +442,8 @@ func TestHandshakeTimeoutInContext(t *testing.T) {
|
|||
s := newServer(t, cstHandlerConfig{})
|
||||
defer s.Close()
|
||||
|
||||
d := cstDialerWithoutHandshakeTimeout
|
||||
d := cstDialer
|
||||
d.HandshakeTimeout = 0
|
||||
d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
|
||||
netDialer := &net.Dialer{}
|
||||
c, err := netDialer.DialContext(ctx, n, a)
|
||||
|
@ -619,33 +594,195 @@ func TestRespOnBadHandshake(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestHostHeader confirms that the host header provided in the call to Dial is
|
||||
// sent to the server.
|
||||
func TestHostHeader(t *testing.T) {
|
||||
s := newServer(t, cstHandlerConfig{})
|
||||
defer s.Close()
|
||||
type testLogWriter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
specifiedHost := make(chan string, 1)
|
||||
origHandler := s.Server.Config.Handler
|
||||
func (w testLogWriter) Write(p []byte) (int, error) {
|
||||
w.t.Logf("%s", p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Capture the request Host header.
|
||||
s.Server.Config.Handler = http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
specifiedHost <- r.Host
|
||||
origHandler.ServeHTTP(w, r)
|
||||
})
|
||||
// TestHost tests handling of host names and confirms that it matches net/http.
|
||||
func TestHost(t *testing.T) {
|
||||
|
||||
ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
defer ws.Close()
|
||||
upgrader := Upgrader{}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if IsWebSocketUpgrade(r) {
|
||||
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c.Close()
|
||||
} else {
|
||||
w.Header().Set("X-Test-Host", r.Host)
|
||||
}
|
||||
})
|
||||
|
||||
if gotHost := <-specifiedHost; gotHost != "testhost" {
|
||||
t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
|
||||
tlsServer := httptest.NewTLSServer(handler)
|
||||
defer tlsServer.Close()
|
||||
|
||||
addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
|
||||
wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
|
||||
httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
|
||||
|
||||
// Avoid log noise from net/http server by logging to testing.T
|
||||
server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
|
||||
tlsServer.Config.ErrorLog = server.Config.ErrorLog
|
||||
|
||||
cas := rootCAs(t, tlsServer)
|
||||
|
||||
tests := []struct {
|
||||
fail bool // true if dial / get should fail
|
||||
server *httptest.Server // server to use
|
||||
url string // host for request URI
|
||||
header string // optional request host header
|
||||
tls string // optiona host for tls ServerName
|
||||
wantAddr string // expected host for dial
|
||||
wantHeader string // expected request header on server
|
||||
insecureSkipVerify bool
|
||||
}{
|
||||
{
|
||||
server: server,
|
||||
url: addrs[server],
|
||||
wantAddr: addrs[server],
|
||||
wantHeader: addrs[server],
|
||||
},
|
||||
{
|
||||
server: tlsServer,
|
||||
url: addrs[tlsServer],
|
||||
wantAddr: addrs[tlsServer],
|
||||
wantHeader: addrs[tlsServer],
|
||||
},
|
||||
|
||||
{
|
||||
server: server,
|
||||
url: addrs[server],
|
||||
header: "badhost.com",
|
||||
wantAddr: addrs[server],
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
{
|
||||
server: tlsServer,
|
||||
url: addrs[tlsServer],
|
||||
header: "badhost.com",
|
||||
wantAddr: addrs[tlsServer],
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
|
||||
{
|
||||
server: server,
|
||||
url: "example.com",
|
||||
header: "badhost.com",
|
||||
wantAddr: "example.com:80",
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
{
|
||||
server: tlsServer,
|
||||
url: "example.com",
|
||||
header: "badhost.com",
|
||||
wantAddr: "example.com:443",
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
|
||||
{
|
||||
server: server,
|
||||
url: "badhost.com",
|
||||
header: "example.com",
|
||||
wantAddr: "badhost.com:80",
|
||||
wantHeader: "example.com",
|
||||
},
|
||||
{
|
||||
fail: true,
|
||||
server: tlsServer,
|
||||
url: "badhost.com",
|
||||
header: "example.com",
|
||||
wantAddr: "badhost.com:443",
|
||||
},
|
||||
{
|
||||
server: tlsServer,
|
||||
url: "badhost.com",
|
||||
insecureSkipVerify: true,
|
||||
wantAddr: "badhost.com:443",
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
{
|
||||
server: tlsServer,
|
||||
url: "badhost.com",
|
||||
tls: "example.com",
|
||||
wantAddr: "badhost.com:443",
|
||||
wantHeader: "badhost.com",
|
||||
},
|
||||
}
|
||||
|
||||
sendRecv(t, ws)
|
||||
for i, tt := range tests {
|
||||
|
||||
tls := &tls.Config{
|
||||
RootCAs: cas,
|
||||
ServerName: tt.tls,
|
||||
InsecureSkipVerify: tt.insecureSkipVerify,
|
||||
}
|
||||
|
||||
var gotAddr string
|
||||
dialer := Dialer{
|
||||
NetDial: func(network, addr string) (net.Conn, error) {
|
||||
gotAddr = addr
|
||||
return net.Dial(network, addrs[tt.server])
|
||||
},
|
||||
TLSClientConfig: tls,
|
||||
}
|
||||
|
||||
// Test websocket dial
|
||||
|
||||
h := http.Header{}
|
||||
if tt.header != "" {
|
||||
h.Set("Host", tt.header)
|
||||
}
|
||||
c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
|
||||
if err == nil {
|
||||
c.Close()
|
||||
}
|
||||
|
||||
check := func(protos map[*httptest.Server]string) {
|
||||
name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
|
||||
if gotAddr != tt.wantAddr {
|
||||
t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
|
||||
}
|
||||
switch {
|
||||
case tt.fail && err == nil:
|
||||
t.Errorf("%s: unexpected success", name)
|
||||
case !tt.fail && err != nil:
|
||||
t.Errorf("%s: unexpected error %v", name, err)
|
||||
case !tt.fail && err == nil:
|
||||
if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
|
||||
t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
check(wsProtos)
|
||||
|
||||
// Confirm that net/http has same result
|
||||
|
||||
transport := &http.Transport{
|
||||
Dial: dialer.NetDial,
|
||||
TLSClientConfig: dialer.TLSClientConfig,
|
||||
}
|
||||
req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
|
||||
if tt.header != "" {
|
||||
req.Host = tt.header
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
resp, err = client.Do(req)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
transport.CloseIdleConnections()
|
||||
check(httpProtos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialCompression(t *testing.T) {
|
||||
|
@ -785,19 +922,8 @@ func TestTracingDialWithContext(t *testing.T) {
|
|||
s := newTLSServer(t, cstHandlerConfig{})
|
||||
defer s.Close()
|
||||
|
||||
certs := x509.NewCertPool()
|
||||
for _, c := range s.TLS.Certificates {
|
||||
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing server's root cert: %v", err)
|
||||
}
|
||||
for _, root := range roots {
|
||||
certs.AddCert(root)
|
||||
}
|
||||
}
|
||||
|
||||
d := cstDialer
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
||||
|
||||
ws, _, err := d.DialContext(ctx, s.URL, nil)
|
||||
if err != nil {
|
||||
|
@ -835,19 +961,8 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
|
|||
s := newTLSServer(t, cstHandlerConfig{})
|
||||
defer s.Close()
|
||||
|
||||
certs := x509.NewCertPool()
|
||||
for _, c := range s.TLS.Certificates {
|
||||
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing server's root cert: %v", err)
|
||||
}
|
||||
for _, root := range roots {
|
||||
certs.AddCert(root)
|
||||
}
|
||||
}
|
||||
|
||||
d := cstDialer
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: certs}
|
||||
d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
|
||||
|
||||
ws, _, err := d.DialContext(ctx, s.URL, nil)
|
||||
if err != nil {
|
||||
|
|
49
conn.go
49
conn.go
|
@ -453,7 +453,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
|||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) prepWrite(messageType int) error {
|
||||
// beginMessage prepares a connection and message writer for a new message.
|
||||
func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
|
||||
// Close previous writer if not already closed by the application. It's
|
||||
// probably better to return an error in this situation, but we cannot
|
||||
// change this without breaking existing applications.
|
||||
|
@ -473,6 +474,10 @@ func (c *Conn) prepWrite(messageType int) error {
|
|||
return err
|
||||
}
|
||||
|
||||
mw.c = c
|
||||
mw.frameType = messageType
|
||||
mw.pos = maxFrameHeaderSize
|
||||
|
||||
if c.writeBuf == nil {
|
||||
wpd, ok := c.writePool.Get().(writePoolData)
|
||||
if ok {
|
||||
|
@ -493,16 +498,11 @@ func (c *Conn) prepWrite(messageType int) error {
|
|||
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
|
||||
// PongMessage) are supported.
|
||||
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||
if err := c.prepWrite(messageType); err != nil {
|
||||
var mw messageWriter
|
||||
if err := c.beginMessage(&mw, messageType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mw := &messageWriter{
|
||||
c: c,
|
||||
frameType: messageType,
|
||||
pos: maxFrameHeaderSize,
|
||||
}
|
||||
c.writer = mw
|
||||
c.writer = &mw
|
||||
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
||||
w := c.newCompressionWriter(c.writer, c.compressionLevel)
|
||||
mw.compress = true
|
||||
|
@ -519,10 +519,16 @@ type messageWriter struct {
|
|||
err error
|
||||
}
|
||||
|
||||
func (w *messageWriter) fatal(err error) error {
|
||||
func (w *messageWriter) endMessage(err error) error {
|
||||
if w.err != nil {
|
||||
w.err = err
|
||||
w.c.writer = nil
|
||||
return err
|
||||
}
|
||||
c := w.c
|
||||
w.err = err
|
||||
c.writer = nil
|
||||
if c.writePool != nil {
|
||||
c.writePool.Put(writePoolData{buf: c.writeBuf})
|
||||
c.writeBuf = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@ -536,7 +542,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
|||
// Check for invalid control frames.
|
||||
if isControl(w.frameType) &&
|
||||
(!final || length > maxControlFramePayloadSize) {
|
||||
return w.fatal(errInvalidControlFrame)
|
||||
return w.endMessage(errInvalidControlFrame)
|
||||
}
|
||||
|
||||
b0 := byte(w.frameType)
|
||||
|
@ -581,7 +587,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
|||
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
|
||||
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
|
||||
if len(extra) > 0 {
|
||||
return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
|
||||
return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -602,15 +608,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
|||
c.isWriting = false
|
||||
|
||||
if err != nil {
|
||||
return w.fatal(err)
|
||||
return w.endMessage(err)
|
||||
}
|
||||
|
||||
if final {
|
||||
c.writer = nil
|
||||
if c.writePool != nil {
|
||||
c.writePool.Put(writePoolData{buf: c.writeBuf})
|
||||
c.writeBuf = nil
|
||||
}
|
||||
w.endMessage(errWriteClosed)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -711,7 +713,6 @@ func (w *messageWriter) Close() error {
|
|||
if err := w.flushFrame(true, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
w.err = errWriteClosed
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -744,10 +745,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
|||
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
|
||||
// Fast path with no allocations and single frame.
|
||||
|
||||
if err := c.prepWrite(messageType); err != nil {
|
||||
var mw messageWriter
|
||||
if err := c.beginMessage(&mw, messageType); err != nil {
|
||||
return err
|
||||
}
|
||||
mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
|
||||
n := copy(c.writeBuf[mw.pos:], data)
|
||||
mw.pos += n
|
||||
data = data[n:]
|
||||
|
@ -1043,7 +1044,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
|
|||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadLimit sets the maximum size for a message read from the peer. If a
|
||||
// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
|
||||
// message exceeds the limit, the connection sends a close message to the peer
|
||||
// and returns ErrReadLimit to the application.
|
||||
func (c *Conn) SetReadLimit(limit int64) {
|
||||
|
|
60
conn_test.go
60
conn_test.go
|
@ -227,11 +227,16 @@ func (p *simpleBufferPool) Put(v interface{}) {
|
|||
}
|
||||
|
||||
func TestWriteBufferPool(t *testing.T) {
|
||||
const message = "Now is the time for all good people to come to the aid of the party."
|
||||
|
||||
var buf bytes.Buffer
|
||||
var pool simpleBufferPool
|
||||
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
|
||||
rc := newTestConn(&buf, nil, false)
|
||||
|
||||
// Specify writeBufferSize smaller than message size to ensure that pooling
|
||||
// works with fragmented messages.
|
||||
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
|
||||
|
||||
if wc.writeBuf != nil {
|
||||
t.Fatal("writeBuf not nil after create")
|
||||
}
|
||||
|
@ -249,8 +254,6 @@ func TestWriteBufferPool(t *testing.T) {
|
|||
|
||||
writeBufAddr := &wc.writeBuf[0]
|
||||
|
||||
const message = "Hello World!"
|
||||
|
||||
if _, err := io.WriteString(w, message); err != nil {
|
||||
t.Fatalf("io.WriteString(w, message) returned %v", err)
|
||||
}
|
||||
|
@ -300,6 +303,7 @@ func TestWriteBufferPool(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
|
||||
func TestWriteBufferPoolSync(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
var pool sync.Pool
|
||||
|
@ -321,6 +325,56 @@ func TestWriteBufferPoolSync(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// errorWriter is an io.Writer than returns an error on all writes.
|
||||
type errorWriter struct{}
|
||||
|
||||
func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("Error!") }
|
||||
|
||||
// TestWriteBufferPoolError ensures that buffer is returned to pool after error
|
||||
// on write.
|
||||
func TestWriteBufferPoolError(t *testing.T) {
|
||||
|
||||
// Part 1: Test NextWriter/Write/Close
|
||||
|
||||
var pool simpleBufferPool
|
||||
wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
|
||||
|
||||
w, err := wc.NextWriter(TextMessage)
|
||||
if err != nil {
|
||||
t.Fatalf("wc.NextWriter() returned %v", err)
|
||||
}
|
||||
|
||||
if wc.writeBuf == nil {
|
||||
t.Fatal("writeBuf is nil after NextWriter")
|
||||
}
|
||||
|
||||
writeBufAddr := &wc.writeBuf[0]
|
||||
|
||||
if _, err := io.WriteString(w, "Hello"); err != nil {
|
||||
t.Fatalf("io.WriteString(w, message) returned %v", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err == nil {
|
||||
t.Fatalf("w.Close() did not return error")
|
||||
}
|
||||
|
||||
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
|
||||
t.Fatal("writeBuf not returned to pool")
|
||||
}
|
||||
|
||||
// Part 2: Test WriteMessage
|
||||
|
||||
wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
|
||||
|
||||
if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
|
||||
t.Fatalf("wc.WriteMessage did not return error")
|
||||
}
|
||||
|
||||
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
|
||||
t.Fatal("writeBuf not returned to pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||
const bufSize = 512
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Test Server
|
||||
|
||||
This package contains a server for the [Autobahn WebSockets Test Suite](http://autobahn.ws/testsuite).
|
||||
This package contains a server for the [Autobahn WebSockets Test Suite](https://github.com/crossbario/autobahn-testsuite).
|
||||
|
||||
To test the server, run
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ type Upgrader struct {
|
|||
// HandshakeTimeout specifies the duration for the handshake to complete.
|
||||
HandshakeTimeout time.Duration
|
||||
|
||||
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes. If a buffer
|
||||
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
|
||||
// size is zero, then buffers allocated by the HTTP server are used. The
|
||||
// I/O buffer sizes do not limit the size of the messages that can be sent
|
||||
// or received.
|
||||
|
|
132
util.go
132
util.go
|
@ -31,68 +31,113 @@ func generateChallengeKey() (string, error) {
|
|||
return base64.StdEncoding.EncodeToString(p), nil
|
||||
}
|
||||
|
||||
// Octet types from RFC 2616.
|
||||
var octetTypes [256]byte
|
||||
|
||||
const (
|
||||
isTokenOctet = 1 << iota
|
||||
isSpaceOctet
|
||||
)
|
||||
|
||||
func init() {
|
||||
// From RFC 2616
|
||||
//
|
||||
// OCTET = <any 8-bit sequence of data>
|
||||
// CHAR = <any US-ASCII character (octets 0 - 127)>
|
||||
// CTL = <any US-ASCII control character (octets 0 - 31) and DEL (127)>
|
||||
// CR = <US-ASCII CR, carriage return (13)>
|
||||
// LF = <US-ASCII LF, linefeed (10)>
|
||||
// SP = <US-ASCII SP, space (32)>
|
||||
// HT = <US-ASCII HT, horizontal-tab (9)>
|
||||
// <"> = <US-ASCII double-quote mark (34)>
|
||||
// CRLF = CR LF
|
||||
// LWS = [CRLF] 1*( SP | HT )
|
||||
// TEXT = <any OCTET except CTLs, but including LWS>
|
||||
// separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <">
|
||||
// | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT
|
||||
// token = 1*<any CHAR except CTLs or separators>
|
||||
// qdtext = <any TEXT except <">>
|
||||
|
||||
for c := 0; c < 256; c++ {
|
||||
var t byte
|
||||
isCtl := c <= 31 || c == 127
|
||||
isChar := 0 <= c && c <= 127
|
||||
isSeparator := strings.IndexRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) >= 0
|
||||
if strings.IndexRune(" \t\r\n", rune(c)) >= 0 {
|
||||
t |= isSpaceOctet
|
||||
}
|
||||
if isChar && !isCtl && !isSeparator {
|
||||
t |= isTokenOctet
|
||||
}
|
||||
octetTypes[c] = t
|
||||
}
|
||||
// Token octets per RFC 2616.
|
||||
var isTokenOctet = [256]bool{
|
||||
'!': true,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': true,
|
||||
'\'': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'W': true,
|
||||
'V': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'|': true,
|
||||
'~': true,
|
||||
}
|
||||
|
||||
// skipSpace returns a slice of the string s with all leading RFC 2616 linear
|
||||
// whitespace removed.
|
||||
func skipSpace(s string) (rest string) {
|
||||
i := 0
|
||||
for ; i < len(s); i++ {
|
||||
if octetTypes[s[i]]&isSpaceOctet == 0 {
|
||||
if b := s[i]; b != ' ' && b != '\t' {
|
||||
break
|
||||
}
|
||||
}
|
||||
return s[i:]
|
||||
}
|
||||
|
||||
// nextToken returns the leading RFC 2616 token of s and the string following
|
||||
// the token.
|
||||
func nextToken(s string) (token, rest string) {
|
||||
i := 0
|
||||
for ; i < len(s); i++ {
|
||||
if octetTypes[s[i]]&isTokenOctet == 0 {
|
||||
if !isTokenOctet[s[i]] {
|
||||
break
|
||||
}
|
||||
}
|
||||
return s[:i], s[i:]
|
||||
}
|
||||
|
||||
// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616
|
||||
// and the string following the token or quoted string.
|
||||
func nextTokenOrQuoted(s string) (value string, rest string) {
|
||||
if !strings.HasPrefix(s, "\"") {
|
||||
return nextToken(s)
|
||||
|
@ -128,7 +173,8 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
|
|||
return "", ""
|
||||
}
|
||||
|
||||
// equalASCIIFold returns true if s is equal to t with ASCII case folding.
|
||||
// equalASCIIFold returns true if s is equal to t with ASCII case folding as
|
||||
// defined in RFC 4790.
|
||||
func equalASCIIFold(s, t string) bool {
|
||||
for s != "" && t != "" {
|
||||
sr, size := utf8.DecodeRuneInString(s)
|
||||
|
|
|
@ -17,6 +17,7 @@ var equalASCIIFoldTests = []struct {
|
|||
{"WebSocket", "websocket", true},
|
||||
{"websocket", "WebSocket", true},
|
||||
{"Öyster", "öyster", false},
|
||||
{"WebSocket", "WetSocket", false},
|
||||
}
|
||||
|
||||
func TestEqualASCIIFold(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue