From 804cb600d06b10672f2fbc0a336a7bee507a428e Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Tue, 14 Feb 2017 09:41:18 -0800 Subject: [PATCH] Prepared Messages (#211) --- conn.go | 23 ++++- conn_broadcast_test.go | 134 +++++++++++++++++++++++++++ examples/autobahn/fuzzingclient.json | 1 + examples/autobahn/server.go | 29 ++++-- prepared.go | 103 ++++++++++++++++++++ prepared_test.go | 74 +++++++++++++++ 6 files changed, 357 insertions(+), 7 deletions(-) create mode 100644 conn_broadcast_test.go create mode 100644 prepared.go create mode 100644 prepared_test.go diff --git a/conn.go b/conn.go index c8aee1c..4c0933b 100644 --- a/conn.go +++ b/conn.go @@ -659,12 +659,33 @@ func (w *messageWriter) Close() error { return nil } +// WritePreparedMessage writes prepared message into connection. +func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { + frameType, frameData, err := pm.frame(prepareKey{ + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + }) + if err != nil { + return err + } + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + err = c.write(frameType, c.writeDeadline, frameData, nil) + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + return err +} + // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. func (c *Conn) WriteMessage(messageType int, data []byte) error { if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { - // Fast path with no allocations and single frame. if err := c.prepWrite(messageType); err != nil { diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go new file mode 100644 index 0000000..45038e4 --- /dev/null +++ b/conn_broadcast_test.go @@ -0,0 +1,134 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.7 + +package websocket + +import ( + "io" + "io/ioutil" + "sync/atomic" + "testing" +) + +// broadcastBench allows to run broadcast benchmarks. +// In every broadcast benchmark we create many connections, then send the same +// message into every connection and wait for all writes complete. This emulates +// an application where many connections listen to the same data - i.e. PUB/SUB +// scenarios with many subscribers in one channel. +type broadcastBench struct { + w io.Writer + message *broadcastMessage + closeCh chan struct{} + doneCh chan struct{} + count int32 + conns []*broadcastConn + compression bool + usePrepared bool +} + +type broadcastMessage struct { + payload []byte + prepared *PreparedMessage +} + +type broadcastConn struct { + conn *Conn + msgCh chan *broadcastMessage +} + +func newBroadcastConn(c *Conn) *broadcastConn { + return &broadcastConn{ + conn: c, + msgCh: make(chan *broadcastMessage, 1), + } +} + +func newBroadcastBench(usePrepared, compression bool) *broadcastBench { + bench := &broadcastBench{ + w: ioutil.Discard, + doneCh: make(chan struct{}), + closeCh: make(chan struct{}), + usePrepared: usePrepared, + compression: compression, + } + msg := &broadcastMessage{ + payload: textMessages(1)[0], + } + if usePrepared { + pm, _ := NewPreparedMessage(TextMessage, msg.payload) + msg.prepared = pm + } + bench.message = msg + bench.makeConns(10000) + return bench +} + +func (b *broadcastBench) makeConns(numConns int) { + conns := make([]*broadcastConn, numConns) + + for i := 0; i < numConns; i++ { + c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024) + if b.compression { + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + } + conns[i] = newBroadcastConn(c) + go func(c *broadcastConn) { + for { + select { + case msg := <-c.msgCh: + if b.usePrepared { + c.conn.WritePreparedMessage(msg.prepared) + } else { + c.conn.WriteMessage(TextMessage, msg.payload) + } + val := atomic.AddInt32(&b.count, 1) + if val%int32(numConns) == 0 { + b.doneCh <- struct{}{} + } + case <-b.closeCh: + return + } + } + }(conns[i]) + } + b.conns = conns +} + +func (b *broadcastBench) close() { + close(b.closeCh) +} + +func (b *broadcastBench) runOnce() { + for _, c := range b.conns { + c.msgCh <- b.message + } + <-b.doneCh +} + +func BenchmarkBroadcast(b *testing.B) { + benchmarks := []struct { + name string + usePrepared bool + compression bool + }{ + {"NoCompression", false, false}, + {"WithCompression", false, true}, + {"NoCompressionPrepared", true, false}, + {"WithCompressionPrepared", true, true}, + } + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + bench := newBroadcastBench(bm.usePrepared, bm.compression) + defer bench.close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bench.runOnce() + } + b.ReportAllocs() + }) + } +} diff --git a/examples/autobahn/fuzzingclient.json b/examples/autobahn/fuzzingclient.json index 27d5a5b..aa3a0bc 100644 --- a/examples/autobahn/fuzzingclient.json +++ b/examples/autobahn/fuzzingclient.json @@ -4,6 +4,7 @@ "outdir": "./reports/clients", "servers": [ {"agent": "ReadAllWriteMessage", "url": "ws://localhost:9000/m", "options": {"version": 18}}, + {"agent": "ReadAllWritePreparedMessage", "url": "ws://localhost:9000/p", "options": {"version": 18}}, {"agent": "ReadAllWrite", "url": "ws://localhost:9000/r", "options": {"version": 18}}, {"agent": "CopyFull", "url": "ws://localhost:9000/f", "options": {"version": 18}}, {"agent": "CopyWriterOnly", "url": "ws://localhost:9000/c", "options": {"version": 18}} diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go index e98563b..3db880f 100644 --- a/examples/autobahn/server.go +++ b/examples/autobahn/server.go @@ -85,7 +85,7 @@ func echoCopyFull(w http.ResponseWriter, r *http.Request) { // echoReadAll echoes messages from the client by reading the entire message // with ioutil.ReadAll. -func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) { +func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage, writePrepared bool) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println("Upgrade:", err) @@ -109,9 +109,21 @@ func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) { } } if writeMessage { - err = conn.WriteMessage(mt, b) - if err != nil { - log.Println("WriteMessage:", err) + if !writePrepared { + err = conn.WriteMessage(mt, b) + if err != nil { + log.Println("WriteMessage:", err) + } + } else { + pm, err := websocket.NewPreparedMessage(mt, b) + if err != nil { + log.Println("NewPreparedMessage:", err) + return + } + err = conn.WritePreparedMessage(pm) + if err != nil { + log.Println("WritePreparedMessage:", err) + } } } else { w, err := conn.NextWriter(mt) @@ -132,11 +144,15 @@ func echoReadAll(w http.ResponseWriter, r *http.Request, writeMessage bool) { } func echoReadAllWriter(w http.ResponseWriter, r *http.Request) { - echoReadAll(w, r, false) + echoReadAll(w, r, false, false) } func echoReadAllWriteMessage(w http.ResponseWriter, r *http.Request) { - echoReadAll(w, r, true) + echoReadAll(w, r, true, false) +} + +func echoReadAllWritePreparedMessage(w http.ResponseWriter, r *http.Request) { + echoReadAll(w, r, true, true) } func serveHome(w http.ResponseWriter, r *http.Request) { @@ -161,6 +177,7 @@ func main() { http.HandleFunc("/f", echoCopyFull) http.HandleFunc("/r", echoReadAllWriter) http.HandleFunc("/m", echoReadAllWriteMessage) + http.HandleFunc("/p", echoReadAllWritePreparedMessage) err := http.ListenAndServe(*addr, nil) if err != nil { log.Fatal("ListenAndServe: ", err) diff --git a/prepared.go b/prepared.go new file mode 100644 index 0000000..1efffbd --- /dev/null +++ b/prepared.go @@ -0,0 +1,103 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "net" + "sync" + "time" +) + +// PreparedMessage caches on the wire representations of a message payload. +// Use PreparedMessage to efficiently send a message payload to multiple +// connections. PreparedMessage is especially useful when compression is used +// because the CPU and memory expensive compression operation can be executed +// once for a given set of compression options. +type PreparedMessage struct { + messageType int + data []byte + err error + mu sync.Mutex + frames map[prepareKey]*preparedFrame +} + +// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. +type prepareKey struct { + isServer bool + compress bool + compressionLevel int +} + +// preparedFrame contains data in wire representation. +type preparedFrame struct { + once sync.Once + data []byte +} + +// NewPreparedMessage returns an initialized PreparedMessage. You can then send +// it to connection using WritePreparedMessage method. Valid wire +// representation will be calculated lazily only once for a set of current +// connection options. +func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { + pm := &PreparedMessage{ + messageType: messageType, + frames: make(map[prepareKey]*preparedFrame), + data: data, + } + + // Prepare a plain server frame. + _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) + if err != nil { + return nil, err + } + + // To protect against caller modifying the data argument, remember the data + // copied to the plain server frame. + pm.data = frameData[len(frameData)-len(data):] + return pm, nil +} + +func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { + pm.mu.Lock() + frame, ok := pm.frames[key] + if !ok { + frame = &preparedFrame{} + pm.frames[key] = frame + } + pm.mu.Unlock() + + var err error + frame.once.Do(func() { + // Prepare a frame using a 'fake' connection. + // TODO: Refactor code in conn.go to allow more direct construction of + // the frame. + mu := make(chan bool, 1) + mu <- true + var nc prepareConn + c := &Conn{ + conn: &nc, + mu: mu, + isServer: key.isServer, + compressionLevel: key.compressionLevel, + enableWriteCompression: true, + writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + } + if key.compress { + c.newCompressionWriter = compressNoContextTakeover + } + err = c.WriteMessage(pm.messageType, pm.data) + frame.data = nc.buf.Bytes() + }) + return pm.messageType, frame.data, err +} + +type prepareConn struct { + buf bytes.Buffer + net.Conn +} + +func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } +func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/prepared_test.go b/prepared_test.go new file mode 100644 index 0000000..cf98c6c --- /dev/null +++ b/prepared_test.go @@ -0,0 +1,74 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package websocket + +import ( + "bytes" + "compress/flate" + "math/rand" + "testing" +) + +var preparedMessageTests = []struct { + messageType int + isServer bool + enableWriteCompression bool + compressionLevel int +}{ + // Server + {TextMessage, true, false, flate.BestSpeed}, + {TextMessage, true, true, flate.BestSpeed}, + {TextMessage, true, true, flate.BestCompression}, + {PingMessage, true, false, flate.BestSpeed}, + {PingMessage, true, true, flate.BestSpeed}, + + // Client + {TextMessage, false, false, flate.BestSpeed}, + {TextMessage, false, true, flate.BestSpeed}, + {TextMessage, false, true, flate.BestCompression}, + {PingMessage, false, false, flate.BestSpeed}, + {PingMessage, false, true, flate.BestSpeed}, +} + +func TestPreparedMessage(t *testing.T) { + for _, tt := range preparedMessageTests { + var data = []byte("this is a test") + var buf bytes.Buffer + c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024) + if tt.enableWriteCompression { + c.newCompressionWriter = compressNoContextTakeover + } + c.SetCompressionLevel(tt.compressionLevel) + + // Seed random number generator for consistent frame mask. + rand.Seed(1234) + + if err := c.WriteMessage(tt.messageType, data); err != nil { + t.Fatal(err) + } + want := buf.String() + + pm, err := NewPreparedMessage(tt.messageType, data) + if err != nil { + t.Fatal(err) + } + + // Scribble on data to ensure that NewPreparedMessage takes a snapshot. + copy(data, "hello world") + + // Seed random number generator for consistent frame mask. + rand.Seed(1234) + + buf.Reset() + if err := c.WritePreparedMessage(pm); err != nil { + t.Fatal(err) + } + got := buf.String() + + if got != want { + t.Errorf("write message != prepared message for %+v", tt) + } + } +}