added Conn.BufferMessage() and Conn.FlushMessages()

This commit is contained in:
Joël Gähwiler 2016-06-05 00:39:12 +02:00
parent 3ddc984058
commit a86136ccd5
1 changed files with 63 additions and 13 deletions

76
conn.go
View File

@ -235,6 +235,7 @@ type Conn struct {
// Message writer fields. // Message writer fields.
writeErr error writeErr error
bw *bufio.Writer
writeBuf []byte // frame is constructed in this buffer. writeBuf []byte // frame is constructed in this buffer.
writePos int // end of data in writeBuf. writePos int // end of data in writeBuf.
writeFrameType int // type of the current frame. writeFrameType int // type of the current frame.
@ -273,6 +274,7 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
c := &Conn{ c := &Conn{
isServer: isServer, isServer: isServer,
bw: bufio.NewWriterSize(conn, writeBufferSize),
br: bufio.NewReaderSize(conn, readBufferSize), br: bufio.NewReaderSize(conn, readBufferSize),
conn: conn, conn: conn,
mu: mu, mu: mu,
@ -308,7 +310,7 @@ func (c *Conn) RemoteAddr() net.Addr {
// Write methods // Write methods
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { func (c *Conn) write(frameType int, deadline time.Time, flush bool, bufs ...[]byte) error {
<-c.mu <-c.mu
defer func() { c.mu <- true }() defer func() { c.mu <- true }()
@ -321,7 +323,7 @@ func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
c.conn.SetWriteDeadline(deadline) c.conn.SetWriteDeadline(deadline)
for _, buf := range bufs { for _, buf := range bufs {
if len(buf) > 0 { if len(buf) > 0 {
n, err := c.conn.Write(buf) n, err := c.bw.Write(buf)
if n != len(buf) { if n != len(buf) {
// Close on partial write. // Close on partial write.
c.conn.Close() c.conn.Close()
@ -331,6 +333,14 @@ func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
} }
} }
} }
if flush {
err := c.bw.Flush()
if err != nil {
return err
}
}
return nil return nil
} }
@ -386,10 +396,16 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
} }
c.conn.SetWriteDeadline(deadline) c.conn.SetWriteDeadline(deadline)
n, err := c.conn.Write(buf) n, err := c.bw.Write(buf)
if n != 0 && n != len(buf) { if n != 0 && n != len(buf) {
c.conn.Close() c.conn.Close()
} }
if err != nil {
return hideTempErr(err)
}
// flush
err = c.bw.Flush()
return hideTempErr(err) return hideTempErr(err)
} }
@ -404,7 +420,7 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
} }
if c.writeFrameType != noFrame { if c.writeFrameType != noFrame {
if err := c.flushFrame(true, nil); err != nil { if err := c.flushFrame(true, nil, true); err != nil {
return nil, err return nil, err
} }
} }
@ -414,12 +430,12 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
} }
c.writeFrameType = messageType c.writeFrameType = messageType
w := &messageWriter{c} w := &messageWriter{c: c, flush: true}
c.messageWriter = w c.messageWriter = w
return w, nil return w, nil
} }
func (c *Conn) flushFrame(final bool, extra []byte) error { func (c *Conn) flushFrame(final bool, extra []byte, flush bool) error {
length := c.writePos - maxFrameHeaderSize + len(extra) length := c.writePos - maxFrameHeaderSize + len(extra)
// Check for invalid control frames. // Check for invalid control frames.
@ -482,7 +498,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
} }
c.isWriting = true c.isWriting = true
c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) c.writeErr = c.write(c.writeFrameType, c.writeDeadline, flush, c.writeBuf[framePos:c.writePos], extra)
if !c.isWriting { if !c.isWriting {
panic("concurrent write to websocket connection") panic("concurrent write to websocket connection")
@ -499,7 +515,10 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
return c.writeErr return c.writeErr
} }
type messageWriter struct{ c *Conn } type messageWriter struct{
c *Conn
flush bool
}
func (w *messageWriter) err() error { func (w *messageWriter) err() error {
c := w.c c := w.c
@ -515,7 +534,7 @@ func (w *messageWriter) err() error {
func (w *messageWriter) ncopy(max int) (int, error) { func (w *messageWriter) ncopy(max int) (int, error) {
n := len(w.c.writeBuf) - w.c.writePos n := len(w.c.writeBuf) - w.c.writePos
if n <= 0 { if n <= 0 {
if err := w.c.flushFrame(false, nil); err != nil { if err := w.c.flushFrame(false, nil, w.flush); err != nil {
return 0, err return 0, err
} }
n = len(w.c.writeBuf) - w.c.writePos n = len(w.c.writeBuf) - w.c.writePos
@ -533,7 +552,7 @@ func (w *messageWriter) write(final bool, p []byte) (int, error) {
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
// Don't buffer large messages. // Don't buffer large messages.
err := w.c.flushFrame(final, p) err := w.c.flushFrame(final, p, w.flush)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -581,7 +600,7 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
} }
for { for {
if w.c.writePos == len(w.c.writeBuf) { if w.c.writePos == len(w.c.writeBuf) {
err = w.c.flushFrame(false, nil) err = w.c.flushFrame(false, nil, w.flush)
if err != nil { if err != nil {
break break
} }
@ -604,7 +623,7 @@ func (w *messageWriter) Close() error {
if err := w.err(); err != nil { if err := w.err(); err != nil {
return err return err
} }
return w.c.flushFrame(true, nil) return w.c.flushFrame(true, nil, w.flush)
} }
// WriteMessage is a helper method for getting a writer using NextWriter, // WriteMessage is a helper method for getting a writer using NextWriter,
@ -619,7 +638,7 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
n := copy(c.writeBuf[c.writePos:], data) n := copy(c.writeBuf[c.writePos:], data)
c.writePos += n c.writePos += n
data = data[n:] data = data[n:]
err = c.flushFrame(true, data) err = c.flushFrame(true, data, true)
return err return err
} }
if _, err = w.Write(data); err != nil { if _, err = w.Write(data); err != nil {
@ -628,6 +647,37 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
return w.Close() return w.Close()
} }
// BufferMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer without flushing the internal
// buffer. This functions allows in conjunction with FlushMessages() are more
// fine grained control over writing performance.
func (c *Conn) BufferMessage(messageType int, data []byte) error {
w, err := c.NextWriter(messageType)
if err != nil {
return err
}
ww := w.(*messageWriter)
ww.flush = false
if _, ok := w.(*messageWriter); ok && c.isServer {
// Optimize write as a single frame.
n := copy(c.writeBuf[c.writePos:], data)
c.writePos += n
data = data[n:]
err = c.flushFrame(true, data, false)
return err
}
if _, err = w.Write(data); err != nil {
return err
}
return w.Close()
}
func (c *Conn) FlushMessages() error {
return c.bw.Flush()
}
// SetWriteDeadline sets the write deadline on the underlying network // SetWriteDeadline sets the write deadline on the underlying network
// connection. After a write has timed out, the websocket state is corrupt and // connection. After a write has timed out, the websocket state is corrupt and
// all future writes will return an error. A zero value for t means writes will // all future writes will return an error. A zero value for t means writes will