diff --git a/conn.go b/conn.go index c8aee1c..624ac8a 100644 --- a/conn.go +++ b/conn.go @@ -6,6 +6,7 @@ package websocket import ( "bufio" + "bytes" "encoding/binary" "errors" "io" @@ -659,12 +660,138 @@ func (w *messageWriter) Close() error { return nil } +// PreparedMessage allows to prepare message to be sent into connections +// using WritePreparedMessage method. By doing so, you can avoid the overhead +// of framing the same payload into WebSocket messages multiple times when +// that same payload is to be sent out on multiple connections - i.e. PUB/SUB +// scenarios with many active subscribers. +// This is especially useful when compression is used as permessage compression +// is pretty CPU and memory expensive. +type PreparedMessage struct { + messageType int + compression bool + compressionLevel int + payload []byte + compressedPayload []byte +} + +// netConn is a fake connection to be used to get PreparedMessage prebuilt payloads. +// TODO: this is a simplest solution I've found. Is it hacky? Better to refactor a package in some way? +type netConn struct { + io.Reader + io.Writer +} + +// netAddr is a fake net.Addr implementation to be used in netConn. +type netAddr int + +func (a netAddr) Network() string { return "" } +func (a netAddr) String() string { return "" } + +func (c netConn) Close() error { return nil } +func (c netConn) LocalAddr() net.Addr { return netAddr(0) } +func (c netConn) RemoteAddr() net.Addr { return netAddr(0) } +func (c netConn) SetDeadline(t time.Time) error { return nil } +func (c netConn) SetReadDeadline(t time.Time) error { return nil } +func (c netConn) SetWriteDeadline(t time.Time) error { return nil } + +var ( + preparingServerConnPool = sync.Pool{New: func() interface{} { + var buf bytes.Buffer + return newConn(&netConn{Reader: nil, Writer: &buf}, true, 0, 0) + }} + preparingClientConnPool = sync.Pool{New: func() interface{} { + var buf bytes.Buffer + return newConn(&netConn{Reader: nil, Writer: &buf}, false, 0, 0) + }} +) + +// NewPreparedMessage returns ready to use PreparedMessage with uncompressed (always) +// and compressed (only if compression flag is true) prebuilt payloads. +// TODO: client or server message? Options as last argument (with compression level only at moment). +func NewPreparedMessage(messageType int, data []byte, compression bool, compressionLevel int) (*PreparedMessage, error) { + m := &PreparedMessage{messageType: messageType} + + c := preparingServerConnPool.Get().(*Conn) + defer func() { + c.conn.(*netConn).Writer.(*bytes.Buffer).Reset() + c.enableWriteCompression = false + c.newCompressionWriter = nil + preparingServerConnPool.Put(c) + }() + + w, err := c.NextWriter(messageType) + if err != nil { + return nil, err + } + if _, err = w.Write(data); err != nil { + return nil, err + } + err = w.Close() + if err != nil { + return nil, err + } + + // We always need uncompressed payload because even if application enables + // compression we can't guarantee it will be negotiated with client. + m.payload = c.conn.(*netConn).Writer.(*bytes.Buffer).Bytes() + + if compression { + // Create compressed payload only if application uses compression. + + m.compression = true + m.compressionLevel = compressionLevel + + c.conn.(*netConn).Writer.(*bytes.Buffer).Reset() + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + c.SetCompressionLevel(compressionLevel) + + w, err = c.NextWriter(messageType) + if err != nil { + return nil, err + } + if _, err = w.Write(data); err != nil { + return nil, err + } + err = w.Close() + if err != nil { + return nil, err + } + m.compressedPayload = c.conn.(*netConn).Writer.(*bytes.Buffer).Bytes() + } + + return m, nil +} + +// WritePreparedMessage writes prepared message into connection. +func (c *Conn) WritePreparedMessage(msg *PreparedMessage) error { + + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + + var err error + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(msg.messageType) { + err = c.write(msg.messageType, c.writeDeadline, msg.compressedPayload) + } else { + err = c.write(msg.messageType, c.writeDeadline, msg.payload) + } + + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + + return err +} + // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. func (c *Conn) WriteMessage(messageType int, data []byte) error { if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { - // Fast path with no allocations and single frame. if err := c.prepWrite(messageType); err != nil { diff --git a/conn_test.go b/conn_test.go index 7431383..04883be 100644 --- a/conn_test.go +++ b/conn_test.go @@ -13,6 +13,7 @@ import ( "io/ioutil" "net" "reflect" + "sync/atomic" "testing" "testing/iotest" "time" @@ -463,3 +464,237 @@ func TestFailedConnectionReadPanic(t *testing.T) { } t.Fatal("should not get here") } + +type testConn struct { + conn *Conn + messages chan []byte +} + +func newTestConn(c *Conn, bufferSize int) *testConn { + return &testConn{ + conn: c, + messages: make(chan []byte, bufferSize), + } +} + +type testPreparedConn struct { + conn *Conn + messages chan *PreparedMessage +} + +func newTestPreparedConn(c *Conn, bufferSize int) *testPreparedConn { + return &testPreparedConn{ + conn: c, + messages: make(chan *PreparedMessage, bufferSize), + } +} + +const ( + testBroadcastNumConns = 10000 + testBroadcastNumMessages = 1 + testBroadcastConnBufferSize = 256 + testBroadcastNumDifferentMessages = 100 +) + +// broadcastBench contains all common fields and methods to run broadcast +// benchmarks below. In every broadcast benchmark we start many connections +// (testBroadcastNumConns) and then broadcast testBroadcastNumMessages +// messages to every connection. This simulates an application where many +// connections listen to the same data - i.e. PUB/SUB scenarios with many +// subscribers. +type broadcastBench struct { + w io.Writer + numConns int + numMessages int + messages [][]byte + done chan struct{} + tick chan struct{} + count int32 +} + +func newBroadcastBench() *broadcastBench { + return &broadcastBench{ + w: ioutil.Discard, + numConns: testBroadcastNumConns, + numMessages: testBroadcastNumMessages, + messages: textMessages(testBroadcastNumDifferentMessages), + done: make(chan struct{}), + tick: make(chan struct{}), + } +} + +func (b *broadcastBench) makeConns(withCompression bool) []*testConn { + conns := make([]*testConn, b.numConns) + + for i := 0; i < b.numConns; i++ { + c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024) + if withCompression { + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + } + conns[i] = newTestConn(c, b.numMessages) + go func(c *testConn) { + for { + select { + case msg := <-c.messages: + c.conn.WriteMessage(TextMessage, msg) + val := atomic.AddInt32(&b.count, 1) + if val%int32(b.numConns*b.numMessages) == 0 { + b.tick <- struct{}{} + } + case <-b.done: + return + } + } + }(conns[i]) + } + return conns +} + +func (b *broadcastBench) makePreparedConns(withCompression bool) []*testPreparedConn { + conns := make([]*testPreparedConn, b.numConns) + + for i := 0; i < b.numConns; i++ { + c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024) + if withCompression { + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + } + conns[i] = newTestPreparedConn(c, b.numMessages) + go func(c *testPreparedConn) { + for { + select { + case msg := <-c.messages: + c.conn.WritePreparedMessage(msg) + val := atomic.AddInt32(&b.count, 1) + if val%int32(b.numConns*b.numMessages) == 0 { + b.tick <- struct{}{} + } + case <-b.done: + return + } + } + }(conns[i]) + } + return conns +} + +func BenchmarkBroadcastNoCompression(b *testing.B) { + bench := newBroadcastBench() + conns := bench.makeConns(false) + b.ResetTimer() + for j := 0; j < b.N; j++ { + for i := 0; i < bench.numMessages; i++ { + msg := bench.messages[i%len(bench.messages)] + for _, c := range conns { + c.messages <- msg + } + } + <-bench.tick + } + b.ReportAllocs() + close(bench.done) +} + +func BenchmarkBroadcastWithCompression(b *testing.B) { + bench := newBroadcastBench() + conns := bench.makeConns(true) + b.ResetTimer() + for j := 0; j < b.N; j++ { + for i := 0; i < bench.numMessages; i++ { + msg := bench.messages[i%len(bench.messages)] + for _, c := range conns { + c.messages <- msg + } + } + <-bench.tick + } + b.ReportAllocs() + close(bench.done) +} + +func BenchmarkBroadcastNoCompressionPrepared(b *testing.B) { + bench := newBroadcastBench() + conns := bench.makePreparedConns(false) + b.ResetTimer() + for j := 0; j < b.N; j++ { + for i := 0; i < bench.numMessages; i++ { + msg := bench.messages[i%len(bench.messages)] + preparedMsg, _ := NewPreparedMessage(TextMessage, msg, false, 1) + for _, c := range conns { + c.messages <- preparedMsg + } + } + <-bench.tick + } + b.ReportAllocs() + close(bench.done) +} + +func BenchmarkBroadcastWithCompressionPrepared(b *testing.B) { + bench := newBroadcastBench() + conns := bench.makePreparedConns(false) + b.ResetTimer() + for j := 0; j < b.N; j++ { + for i := 0; i < bench.numMessages; i++ { + msg := bench.messages[i%len(bench.messages)] + preparedMsg, _ := NewPreparedMessage(TextMessage, msg, true, 1) + for _, c := range conns { + c.messages <- preparedMsg + } + } + <-bench.tick + } + b.ReportAllocs() + close(bench.done) +} + +func TestPreparedMessageBytesStreamUncompressed(t *testing.T) { + messages := textMessages(100) + + var b1 bytes.Buffer + c := newConn(fakeNetConn{Reader: nil, Writer: &b1}, true, 1024, 1024) + for _, msg := range messages { + preparedMsg, _ := NewPreparedMessage(TextMessage, msg, false, 1) + c.WritePreparedMessage(preparedMsg) + } + out1 := b1.Bytes() + + var b2 bytes.Buffer + c = newConn(fakeNetConn{Reader: nil, Writer: &b2}, true, 1024, 1024) + for _, msg := range messages { + c.WriteMessage(TextMessage, msg) + } + out2 := b2.Bytes() + + if !reflect.DeepEqual(out1, out2) { + t.Errorf("Connection bytes stream must be equal when using preparing message and not") + } +} + +func TestPreparedMessageBytesStreamCompressed(t *testing.T) { + messages := textMessages(100) + + var b1 bytes.Buffer + c := newConn(fakeNetConn{Reader: nil, Writer: &b1}, true, 1024, 1024) + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + for _, msg := range messages { + preparedMsg, _ := NewPreparedMessage(TextMessage, msg, true, 1) + c.WritePreparedMessage(preparedMsg) + } + out1 := b1.Bytes() + + var b2 bytes.Buffer + c = newConn(fakeNetConn{Reader: nil, Writer: &b2}, true, 1024, 1024) + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + for _, msg := range messages { + c.WriteMessage(TextMessage, msg) + } + out2 := b2.Bytes() + + if !reflect.DeepEqual(out1, out2) { + t.Errorf("Connection bytes stream must be equal when using preparing message and not") + } +}