websocket/conn_test.go

701 lines
18 KiB
Go

// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"reflect"
"sync/atomic"
"testing"
"testing/iotest"
"time"
)
var _ net.Error = errWriteTimeout
type fakeNetConn struct {
io.Reader
io.Writer
}
func (c fakeNetConn) Close() error { return nil }
func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
type fakeAddr int
var (
localAddr = fakeAddr(1)
remoteAddr = fakeAddr(2)
)
func (a fakeAddr) Network() string {
return "net"
}
func (a fakeAddr) String() string {
return "str"
}
func TestFraming(t *testing.T) {
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
var readChunkers = []struct {
name string
f func(io.Reader) io.Reader
}{
{"half", iotest.HalfReader},
{"one", iotest.OneByteReader},
{"asis", func(r io.Reader) io.Reader { return r }},
}
writeBuf := make([]byte, 65537)
for i := range writeBuf {
writeBuf[i] = byte(i)
}
var writers = []struct {
name string
f func(w io.Writer, n int) (int, error)
}{
{"iocopy", func(w io.Writer, n int) (int, error) {
nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
return int(nn), err
}},
{"write", func(w io.Writer, n int) (int, error) {
return w.Write(writeBuf[:n])
}},
{"string", func(w io.Writer, n int) (int, error) {
return io.WriteString(w, string(writeBuf[:n]))
}},
}
for _, compress := range []bool{false, true} {
for _, isServer := range []bool{true, false} {
for _, chunker := range readChunkers {
var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
if compress {
wc.newCompressionWriter = compressNoContextTakeover
rc.newDecompressionReader = decompressNoContextTakeover
}
for _, n := range frameSizes {
for _, writer := range writers {
name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
w, err := wc.NextWriter(TextMessage)
if err != nil {
t.Errorf("%s: wc.NextWriter() returned %v", name, err)
continue
}
nn, err := writer.f(w, n)
if err != nil || nn != n {
t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
continue
}
err = w.Close()
if err != nil {
t.Errorf("%s: w.Close() returned %v", name, err)
continue
}
opCode, r, err := rc.NextReader()
if err != nil || opCode != TextMessage {
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
continue
}
rbuf, err := ioutil.ReadAll(r)
if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
continue
}
if len(rbuf) != n {
t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
continue
}
for i, b := range rbuf {
if byte(i) != b {
t.Errorf("%s: bad byte at offset %d", name, i)
break
}
}
}
}
}
}
}
}
func TestControl(t *testing.T) {
const message = "this is a ping/pong messsage"
for _, isServer := range []bool{true, false} {
for _, isWriteControl := range []bool{true, false} {
name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
var connBuf bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
if isWriteControl {
wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
} else {
w, err := wc.NextWriter(PongMessage)
if err != nil {
t.Errorf("%s: wc.NextWriter() returned %v", name, err)
continue
}
if _, err := w.Write([]byte(message)); err != nil {
t.Errorf("%s: w.Write() returned %v", name, err)
continue
}
if err := w.Close(); err != nil {
t.Errorf("%s: w.Close() returned %v", name, err)
continue
}
var actualMessage string
rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
rc.NextReader()
if actualMessage != message {
t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
continue
}
}
}
}
}
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
const bufSize = 512
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
w.Close()
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
}
_, _, err = rc.NextReader()
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
}
}
func TestEOFWithinFrame(t *testing.T) {
const bufSize = 64
for n := 0; ; n++ {
var b bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize))
w.Close()
if n >= b.Len() {
break
}
b.Truncate(n)
op, r, err := rc.NextReader()
if err == errUnexpectedEOF {
continue
}
if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != errUnexpectedEOF {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
}
_, _, err = rc.NextReader()
if err != errUnexpectedEOF {
t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
}
}
}
func TestEOFBeforeFinalFrame(t *testing.T) {
const bufSize = 512
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != errUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
}
_, _, err = rc.NextReader()
if err != errUnexpectedEOF {
t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
}
}
func TestWriteAfterMessageWriterClose(t *testing.T) {
wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello")
if err := w.Close(); err != nil {
t.Fatalf("unxpected error closing message writer, %v", err)
}
if _, err := io.WriteString(w, "world"); err == nil {
t.Fatalf("no error writing after close")
}
w, _ = wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello")
// close w by getting next writer
_, err := wc.NextWriter(BinaryMessage)
if err != nil {
t.Fatalf("unexpected error getting next writer, %v", err)
}
if _, err := io.WriteString(w, "world"); err == nil {
t.Fatalf("no error writing after close")
}
}
func TestReadLimit(t *testing.T) {
const readLimit = 512
message := make([]byte, readLimit+1)
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
rc.SetReadLimit(readLimit)
// Send message at the limit with interleaved pong.
w, _ := wc.NextWriter(BinaryMessage)
w.Write(message[:readLimit-1])
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
w.Write(message[:1])
w.Close()
// Send message larger than the limit.
wc.WriteMessage(BinaryMessage, message[:readLimit+1])
op, _, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("1: NextReader() returned %d, %v", op, err)
}
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err)
}
}
func TestAddrs(t *testing.T) {
c := newConn(&fakeNetConn{}, true, 1024, 1024)
if c.LocalAddr() != localAddr {
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
}
if c.RemoteAddr() != remoteAddr {
t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
}
}
func TestUnderlyingConn(t *testing.T) {
var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2}
c := newConn(fc, true, 1024, 1024)
ul := c.UnderlyingConn()
if ul != fc {
t.Fatalf("Underlying conn is not what it should be.")
}
}
func TestBufioReadBytes(t *testing.T) {
// Test calling bufio.ReadBytes for value longer than read buffer size.
m := make([]byte, 512)
m[len(m)-1] = '\n'
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(m)
w.Close()
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err)
}
br := bufio.NewReader(r)
p, err := br.ReadBytes('\n')
if err != nil {
t.Fatalf("ReadBytes() returned %v", err)
}
if len(p) != len(m) {
t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
}
}
var closeErrorTests = []struct {
err error
codes []int
ok bool
}{
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
{errors.New("hello"), []int{CloseNormalClosure}, false},
}
func TestCloseError(t *testing.T) {
for _, tt := range closeErrorTests {
ok := IsCloseError(tt.err, tt.codes...)
if ok != tt.ok {
t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
}
}
}
var unexpectedCloseErrorTests = []struct {
err error
codes []int
ok bool
}{
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
{errors.New("hello"), []int{CloseNormalClosure}, false},
}
func TestUnexpectedCloseErrors(t *testing.T) {
for _, tt := range unexpectedCloseErrorTests {
ok := IsUnexpectedCloseError(tt.err, tt.codes...)
if ok != tt.ok {
t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
}
}
}
type blockingWriter struct {
c1, c2 chan struct{}
}
func (w blockingWriter) Write(p []byte) (int, error) {
// Allow main to continue
close(w.c1)
// Wait for panic in main
<-w.c2
return len(p), nil
}
func TestConcurrentWritePanic(t *testing.T) {
w := blockingWriter{make(chan struct{}), make(chan struct{})}
c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
go func() {
c.WriteMessage(TextMessage, []byte{})
}()
// wait for goroutine to block in write.
<-w.c1
defer func() {
close(w.c2)
if v := recover(); v != nil {
return
}
}()
c.WriteMessage(TextMessage, []byte{})
t.Fatal("should not get here")
}
type failingReader struct{}
func (r failingReader) Read(p []byte) (int, error) {
return 0, io.EOF
}
func TestFailedConnectionReadPanic(t *testing.T) {
c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
defer func() {
if v := recover(); v != nil {
return
}
}()
for i := 0; i < 20000; i++ {
c.ReadMessage()
}
t.Fatal("should not get here")
}
type testConn struct {
conn *Conn
messages chan []byte
}
func newTestConn(c *Conn, bufferSize int) *testConn {
return &testConn{
conn: c,
messages: make(chan []byte, bufferSize),
}
}
type testPreparedConn struct {
conn *Conn
messages chan *PreparedMessage
}
func newTestPreparedConn(c *Conn, bufferSize int) *testPreparedConn {
return &testPreparedConn{
conn: c,
messages: make(chan *PreparedMessage, bufferSize),
}
}
const (
testBroadcastNumConns = 10000
testBroadcastNumMessages = 1
testBroadcastConnBufferSize = 256
testBroadcastNumDifferentMessages = 100
)
// broadcastBench contains all common fields and methods to run broadcast
// benchmarks below. In every broadcast benchmark we start many connections
// (testBroadcastNumConns) and then broadcast testBroadcastNumMessages
// messages to every connection. This simulates an application where many
// connections listen to the same data - i.e. PUB/SUB scenarios with many
// subscribers.
type broadcastBench struct {
w io.Writer
numConns int
numMessages int
messages [][]byte
done chan struct{}
tick chan struct{}
count int32
}
func newBroadcastBench() *broadcastBench {
return &broadcastBench{
w: ioutil.Discard,
numConns: testBroadcastNumConns,
numMessages: testBroadcastNumMessages,
messages: textMessages(testBroadcastNumDifferentMessages),
done: make(chan struct{}),
tick: make(chan struct{}),
}
}
func (b *broadcastBench) makeConns(withCompression bool) []*testConn {
conns := make([]*testConn, b.numConns)
for i := 0; i < b.numConns; i++ {
c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024)
if withCompression {
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover
}
conns[i] = newTestConn(c, b.numMessages)
go func(c *testConn) {
for {
select {
case msg := <-c.messages:
c.conn.WriteMessage(TextMessage, msg)
val := atomic.AddInt32(&b.count, 1)
if val%int32(b.numConns*b.numMessages) == 0 {
b.tick <- struct{}{}
}
case <-b.done:
return
}
}
}(conns[i])
}
return conns
}
func (b *broadcastBench) makePreparedConns(withCompression bool) []*testPreparedConn {
conns := make([]*testPreparedConn, b.numConns)
for i := 0; i < b.numConns; i++ {
c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024)
if withCompression {
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover
}
conns[i] = newTestPreparedConn(c, b.numMessages)
go func(c *testPreparedConn) {
for {
select {
case msg := <-c.messages:
c.conn.WritePreparedMessage(msg)
val := atomic.AddInt32(&b.count, 1)
if val%int32(b.numConns*b.numMessages) == 0 {
b.tick <- struct{}{}
}
case <-b.done:
return
}
}
}(conns[i])
}
return conns
}
func BenchmarkBroadcastNoCompression(b *testing.B) {
bench := newBroadcastBench()
conns := bench.makeConns(false)
b.ResetTimer()
for j := 0; j < b.N; j++ {
for i := 0; i < bench.numMessages; i++ {
msg := bench.messages[i%len(bench.messages)]
for _, c := range conns {
c.messages <- msg
}
}
<-bench.tick
}
b.ReportAllocs()
close(bench.done)
}
func BenchmarkBroadcastWithCompression(b *testing.B) {
bench := newBroadcastBench()
conns := bench.makeConns(true)
b.ResetTimer()
for j := 0; j < b.N; j++ {
for i := 0; i < bench.numMessages; i++ {
msg := bench.messages[i%len(bench.messages)]
for _, c := range conns {
c.messages <- msg
}
}
<-bench.tick
}
b.ReportAllocs()
close(bench.done)
}
func BenchmarkBroadcastNoCompressionPrepared(b *testing.B) {
bench := newBroadcastBench()
conns := bench.makePreparedConns(false)
b.ResetTimer()
for j := 0; j < b.N; j++ {
for i := 0; i < bench.numMessages; i++ {
msg := bench.messages[i%len(bench.messages)]
preparedMsg, _ := NewPreparedMessage(TextMessage, msg, false, 1)
for _, c := range conns {
c.messages <- preparedMsg
}
}
<-bench.tick
}
b.ReportAllocs()
close(bench.done)
}
func BenchmarkBroadcastWithCompressionPrepared(b *testing.B) {
bench := newBroadcastBench()
conns := bench.makePreparedConns(false)
b.ResetTimer()
for j := 0; j < b.N; j++ {
for i := 0; i < bench.numMessages; i++ {
msg := bench.messages[i%len(bench.messages)]
preparedMsg, _ := NewPreparedMessage(TextMessage, msg, true, 1)
for _, c := range conns {
c.messages <- preparedMsg
}
}
<-bench.tick
}
b.ReportAllocs()
close(bench.done)
}
func TestPreparedMessageBytesStreamUncompressed(t *testing.T) {
messages := textMessages(100)
var b1 bytes.Buffer
c := newConn(fakeNetConn{Reader: nil, Writer: &b1}, true, 1024, 1024)
for _, msg := range messages {
preparedMsg, _ := NewPreparedMessage(TextMessage, msg, false, 1)
c.WritePreparedMessage(preparedMsg)
}
out1 := b1.Bytes()
var b2 bytes.Buffer
c = newConn(fakeNetConn{Reader: nil, Writer: &b2}, true, 1024, 1024)
for _, msg := range messages {
c.WriteMessage(TextMessage, msg)
}
out2 := b2.Bytes()
if !reflect.DeepEqual(out1, out2) {
t.Errorf("Connection bytes stream must be equal when using preparing message and not")
}
}
func TestPreparedMessageBytesStreamCompressed(t *testing.T) {
messages := textMessages(100)
var b1 bytes.Buffer
c := newConn(fakeNetConn{Reader: nil, Writer: &b1}, true, 1024, 1024)
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover
for _, msg := range messages {
preparedMsg, _ := NewPreparedMessage(TextMessage, msg, true, 1)
c.WritePreparedMessage(preparedMsg)
}
out1 := b1.Bytes()
var b2 bytes.Buffer
c = newConn(fakeNetConn{Reader: nil, Writer: &b2}, true, 1024, 1024)
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover
for _, msg := range messages {
c.WriteMessage(TextMessage, msg)
}
out2 := b2.Bytes()
if !reflect.DeepEqual(out1, out2) {
t.Errorf("Connection bytes stream must be equal when using preparing message and not")
}
}