diff --git a/conn.go b/conn.go index 624ac8a..6bd84cc 100644 --- a/conn.go +++ b/conn.go @@ -660,27 +660,13 @@ func (w *messageWriter) Close() error { return nil } -// PreparedMessage allows to prepare message to be sent into connections -// using WritePreparedMessage method. By doing so, you can avoid the overhead -// of framing the same payload into WebSocket messages multiple times when -// that same payload is to be sent out on multiple connections - i.e. PUB/SUB -// scenarios with many active subscribers. -// This is especially useful when compression is used as permessage compression -// is pretty CPU and memory expensive. -type PreparedMessage struct { - messageType int - compression bool - compressionLevel int - payload []byte - compressedPayload []byte +// netConn is a fake network connection used to get PreparedMessage +// prebuilt payloads. +type netConn struct { + bytes.Buffer } -// netConn is a fake connection to be used to get PreparedMessage prebuilt payloads. -// TODO: this is a simplest solution I've found. Is it hacky? Better to refactor a package in some way? -type netConn struct { - io.Reader - io.Writer -} +func (netConn) Read(p []byte) (int, error) { return 0, nil } // netAddr is a fake net.Addr implementation to be used in netConn. type netAddr int @@ -698,92 +684,121 @@ func (c netConn) SetWriteDeadline(t time.Time) error { return nil } var ( preparingServerConnPool = sync.Pool{New: func() interface{} { var buf bytes.Buffer - return newConn(&netConn{Reader: nil, Writer: &buf}, true, 0, 0) + return newConn(&netConn{Buffer: buf}, true, 0, 0) }} preparingClientConnPool = sync.Pool{New: func() interface{} { var buf bytes.Buffer - return newConn(&netConn{Reader: nil, Writer: &buf}, false, 0, 0) + return newConn(&netConn{Buffer: buf}, false, 0, 0) }} ) -// NewPreparedMessage returns ready to use PreparedMessage with uncompressed (always) -// and compressed (only if compression flag is true) prebuilt payloads. -// TODO: client or server message? Options as last argument (with compression level only at moment). -func NewPreparedMessage(messageType int, data []byte, compression bool, compressionLevel int) (*PreparedMessage, error) { - m := &PreparedMessage{messageType: messageType} +// PreparedMessage caches on the wire representations of a message payload. +// Use PreparedMessage to efficiently send a message payload to multiple +// connections. PreparedMessage is especially useful when compression +// is used because the CPU and memory expensive compression operation +// can be executed once for a given set of compression options. +type PreparedMessage struct { + frameType int + data []byte + mu sync.Mutex + frames map[frameKey]*preparedFrame +} - c := preparingServerConnPool.Get().(*Conn) - defer func() { - c.conn.(*netConn).Writer.(*bytes.Buffer).Reset() - c.enableWriteCompression = false - c.newCompressionWriter = nil - preparingServerConnPool.Put(c) - }() +// frameKey defines a unique set of options to cache prepared frames in PreparedMessage. +type frameKey struct { + isServer bool + compress bool + compressionLevel int +} - w, err := c.NextWriter(messageType) - if err != nil { - return nil, err +// preparedFrame contains data in wire representation. +type preparedFrame struct { + once sync.Once + data []byte +} + +// NewPreparedMessage returns initialized PreparedMessage. You can then send +// it to connection using WritePreparedMessage method. Valid wire representation +// will be calculated lazily only once for a set of current connection options. +func NewPreparedMessage(messageType int, data []byte) *PreparedMessage { + if !isData(messageType) { + panic("Prepared message type can only be TextMessage or BinaryMessage") } - if _, err = w.Write(data); err != nil { - return nil, err + return &PreparedMessage{ + frameType: messageType, + data: data, + frames: make(map[frameKey]*preparedFrame), } - err = w.Close() - if err != nil { - return nil, err +} + +func (pm *PreparedMessage) frame(key frameKey) (int, []byte, error) { + pm.mu.Lock() + frame, ok := pm.frames[key] + if !ok { + frame = &preparedFrame{} + pm.frames[key] = frame } + pm.mu.Unlock() - // We always need uncompressed payload because even if application enables - // compression we can't guarantee it will be negotiated with client. - m.payload = c.conn.(*netConn).Writer.(*bytes.Buffer).Bytes() + var writeErr error - if compression { - // Create compressed payload only if application uses compression. - - m.compression = true - m.compressionLevel = compressionLevel - - c.conn.(*netConn).Writer.(*bytes.Buffer).Reset() - c.enableWriteCompression = true - c.newCompressionWriter = compressNoContextTakeover - c.SetCompressionLevel(compressionLevel) - - w, err = c.NextWriter(messageType) - if err != nil { - return nil, err + frame.once.Do(func() { + // Create frame data once for a given frameKey. + var c *Conn + if key.isServer { + c = preparingServerConnPool.Get().(*Conn) + } else { + c = preparingClientConnPool.Get().(*Conn) } - if _, err = w.Write(data); err != nil { - return nil, err - } - err = w.Close() - if err != nil { - return nil, err - } - m.compressedPayload = c.conn.(*netConn).Writer.(*bytes.Buffer).Bytes() - } - return m, nil + defer func() { + c.conn.(*netConn).Buffer.Reset() + c.enableWriteCompression = false + c.newCompressionWriter = nil + c.SetCompressionLevel(0) + if key.isServer { + preparingServerConnPool.Put(c) + } else { + preparingClientConnPool.Put(c) + } + }() + + if key.compress { + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + c.SetCompressionLevel(key.compressionLevel) + } + writeErr := c.WriteMessage(pm.frameType, pm.data) + if writeErr == nil { + preparedData := c.conn.(*netConn).Buffer.Bytes() + data := make([]byte, len(preparedData)) + copy(data, preparedData) + frame.data = data + } + }) + + return pm.frameType, frame.data, writeErr } // WritePreparedMessage writes prepared message into connection. func (c *Conn) WritePreparedMessage(msg *PreparedMessage) error { - + frameType, frameData, err := msg.frame(frameKey{ + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression, + compressionLevel: c.compressionLevel, + }) + if err != nil { + return err + } if c.isWriting { panic("concurrent write to websocket connection") } c.isWriting = true - - var err error - if c.newCompressionWriter != nil && c.enableWriteCompression && isData(msg.messageType) { - err = c.write(msg.messageType, c.writeDeadline, msg.compressedPayload) - } else { - err = c.write(msg.messageType, c.writeDeadline, msg.payload) - } - + err = c.write(frameType, c.writeDeadline, frameData, nil) if !c.isWriting { panic("concurrent write to websocket connection") } c.isWriting = false - return err } diff --git a/conn_test.go b/conn_test.go index 04883be..943b59d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -620,7 +620,7 @@ func BenchmarkBroadcastNoCompressionPrepared(b *testing.B) { for j := 0; j < b.N; j++ { for i := 0; i < bench.numMessages; i++ { msg := bench.messages[i%len(bench.messages)] - preparedMsg, _ := NewPreparedMessage(TextMessage, msg, false, 1) + preparedMsg := NewPreparedMessage(TextMessage, msg) for _, c := range conns { c.messages <- preparedMsg } @@ -638,7 +638,7 @@ func BenchmarkBroadcastWithCompressionPrepared(b *testing.B) { for j := 0; j < b.N; j++ { for i := 0; i < bench.numMessages; i++ { msg := bench.messages[i%len(bench.messages)] - preparedMsg, _ := NewPreparedMessage(TextMessage, msg, true, 1) + preparedMsg := NewPreparedMessage(TextMessage, msg) for _, c := range conns { c.messages <- preparedMsg } @@ -655,7 +655,7 @@ func TestPreparedMessageBytesStreamUncompressed(t *testing.T) { var b1 bytes.Buffer c := newConn(fakeNetConn{Reader: nil, Writer: &b1}, true, 1024, 1024) for _, msg := range messages { - preparedMsg, _ := NewPreparedMessage(TextMessage, msg, false, 1) + preparedMsg := NewPreparedMessage(TextMessage, msg) c.WritePreparedMessage(preparedMsg) } out1 := b1.Bytes() @@ -679,8 +679,10 @@ func TestPreparedMessageBytesStreamCompressed(t *testing.T) { c := newConn(fakeNetConn{Reader: nil, Writer: &b1}, true, 1024, 1024) c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover - for _, msg := range messages { - preparedMsg, _ := NewPreparedMessage(TextMessage, msg, true, 1) + for i, msg := range messages { + preparedMsg := NewPreparedMessage(TextMessage, msg) + level := i%(maxCompressionLevel-minCompressionLevel+1) - 2 + c.SetCompressionLevel(level) c.WritePreparedMessage(preparedMsg) } out1 := b1.Bytes() @@ -689,7 +691,9 @@ func TestPreparedMessageBytesStreamCompressed(t *testing.T) { c = newConn(fakeNetConn{Reader: nil, Writer: &b2}, true, 1024, 1024) c.enableWriteCompression = true c.newCompressionWriter = compressNoContextTakeover - for _, msg := range messages { + for i, msg := range messages { + level := i%(maxCompressionLevel-minCompressionLevel+1) - 2 + c.SetCompressionLevel(level) c.WriteMessage(TextMessage, msg) } out2 := b2.Bytes()