From a86136ccd5094350d408511667506aac10b275fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joe=CC=88l=20Ga=CC=88hwiler?= Date: Sun, 5 Jun 2016 00:39:12 +0200 Subject: [PATCH] added Conn.BufferMessage() and Conn.FlushMessages() --- conn.go | 76 +++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 13 deletions(-) diff --git a/conn.go b/conn.go index 794c2ef..f61b0c2 100644 --- a/conn.go +++ b/conn.go @@ -235,6 +235,7 @@ type Conn struct { // Message writer fields. writeErr error + bw *bufio.Writer writeBuf []byte // frame is constructed in this buffer. writePos int // end of data in writeBuf. writeFrameType int // type of the current frame. @@ -273,6 +274,7 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) c := &Conn{ isServer: isServer, + bw: bufio.NewWriterSize(conn, writeBufferSize), br: bufio.NewReaderSize(conn, readBufferSize), conn: conn, mu: mu, @@ -308,7 +310,7 @@ func (c *Conn) RemoteAddr() net.Addr { // Write methods -func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { +func (c *Conn) write(frameType int, deadline time.Time, flush bool, bufs ...[]byte) error { <-c.mu defer func() { c.mu <- true }() @@ -321,7 +323,7 @@ func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { c.conn.SetWriteDeadline(deadline) for _, buf := range bufs { if len(buf) > 0 { - n, err := c.conn.Write(buf) + n, err := c.bw.Write(buf) if n != len(buf) { // Close on partial write. c.conn.Close() @@ -331,6 +333,14 @@ func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { } } } + + if flush { + err := c.bw.Flush() + if err != nil { + return err + } + } + return nil } @@ -386,10 +396,16 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er } c.conn.SetWriteDeadline(deadline) - n, err := c.conn.Write(buf) + n, err := c.bw.Write(buf) if n != 0 && n != len(buf) { c.conn.Close() } + if err != nil { + return hideTempErr(err) + } + + // flush + err = c.bw.Flush() return hideTempErr(err) } @@ -404,7 +420,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { } if c.writeFrameType != noFrame { - if err := c.flushFrame(true, nil); err != nil { + if err := c.flushFrame(true, nil, true); err != nil { return nil, err } } @@ -414,12 +430,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { } c.writeFrameType = messageType - w := &messageWriter{c} + w := &messageWriter{c: c, flush: true} c.messageWriter = w return w, nil } -func (c *Conn) flushFrame(final bool, extra []byte) error { +func (c *Conn) flushFrame(final bool, extra []byte, flush bool) error { length := c.writePos - maxFrameHeaderSize + len(extra) // Check for invalid control frames. @@ -482,7 +498,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { } c.isWriting = true - c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) + c.writeErr = c.write(c.writeFrameType, c.writeDeadline, flush, c.writeBuf[framePos:c.writePos], extra) if !c.isWriting { panic("concurrent write to websocket connection") @@ -499,7 +515,10 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { return c.writeErr } -type messageWriter struct{ c *Conn } +type messageWriter struct{ + c *Conn + flush bool +} func (w *messageWriter) err() error { c := w.c @@ -515,7 +534,7 @@ func (w *messageWriter) err() error { func (w *messageWriter) ncopy(max int) (int, error) { n := len(w.c.writeBuf) - w.c.writePos if n <= 0 { - if err := w.c.flushFrame(false, nil); err != nil { + if err := w.c.flushFrame(false, nil, w.flush); err != nil { return 0, err } n = len(w.c.writeBuf) - w.c.writePos @@ -533,7 +552,7 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) { if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { // Don't buffer large messages. - err := w.c.flushFrame(final, p) + err := w.c.flushFrame(final, p, w.flush) if err != nil { return 0, err } @@ -581,7 +600,7 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { } for { if w.c.writePos == len(w.c.writeBuf) { - err = w.c.flushFrame(false, nil) + err = w.c.flushFrame(false, nil, w.flush) if err != nil { break } @@ -604,7 +623,7 @@ func (w *messageWriter) Close() error { if err := w.err(); err != nil { return err } - return w.c.flushFrame(true, nil) + return w.c.flushFrame(true, nil, w.flush) } // WriteMessage is a helper method for getting a writer using NextWriter, @@ -619,7 +638,7 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { n := copy(c.writeBuf[c.writePos:], data) c.writePos += n data = data[n:] - err = c.flushFrame(true, data) + err = c.flushFrame(true, data, true) return err } if _, err = w.Write(data); err != nil { @@ -628,6 +647,37 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { return w.Close() } +// BufferMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer without flushing the internal +// buffer. This functions allows in conjunction with FlushMessages() are more +// fine grained control over writing performance. +func (c *Conn) BufferMessage(messageType int, data []byte) error { + w, err := c.NextWriter(messageType) + if err != nil { + return err + } + + ww := w.(*messageWriter) + ww.flush = false + + if _, ok := w.(*messageWriter); ok && c.isServer { + // Optimize write as a single frame. + n := copy(c.writeBuf[c.writePos:], data) + c.writePos += n + data = data[n:] + err = c.flushFrame(true, data, false) + return err + } + if _, err = w.Write(data); err != nil { + return err + } + return w.Close() +} + +func (c *Conn) FlushMessages() error { + return c.bw.Flush() +} + // SetWriteDeadline sets the write deadline on the underlying network // connection. After a write has timed out, the websocket state is corrupt and // all future writes will return an error. A zero value for t means writes will