From be01041b667965201246c5a50911bf3ba3545e5f Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Tue, 31 May 2016 05:14:41 -0700 Subject: [PATCH] Reduce memory allocations in NextReader, NextWriter Redo 8b209f63177a963547dc3cee89350a327ead0412 with support for old versions of Go. --- conn.go | 163 +++++++++++++++++++++----------------------- conn_read.go | 18 +++++ conn_read_legacy.go | 21 ++++++ 3 files changed, 116 insertions(+), 86 deletions(-) create mode 100644 conn_read.go create mode 100644 conn_read_legacy.go diff --git a/conn.go b/conn.go index ed7736c..0bb6597 100644 --- a/conn.go +++ b/conn.go @@ -238,16 +238,15 @@ type Conn struct { writeBuf []byte // frame is constructed in this buffer. writePos int // end of data in writeBuf. writeFrameType int // type of the current frame. - writeSeq int // incremented to invalidate message writers. writeDeadline time.Time - isWriting bool // for best-effort concurrent write detection + isWriting bool // for best-effort concurrent write detection + messageWriter *messageWriter // the current writer // Read fields readErr error br *bufio.Reader readRemaining int64 // bytes remaining in current frame. readFinal bool // true the current message has more frames. - readSeq int // incremented to invalidate message readers. readLength int64 // Message size. readLimit int64 // Maximum message size. readMaskPos int @@ -255,6 +254,7 @@ type Conn struct { handlePong func(string) error handlePing func(string) error readErrCount int + messageReader *messageReader // the current reader } func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { @@ -264,6 +264,9 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) if readBufferSize == 0 { readBufferSize = defaultReadBufferSize } + if readBufferSize < maxControlFramePayloadSize { + readBufferSize = maxControlFramePayloadSize + } if writeBufferSize == 0 { writeBufferSize = defaultWriteBufferSize } @@ -390,8 +393,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er return hideTempErr(err) } -// NextWriter returns a writer for the next message to send. The writer's -// Close method flushes the complete message to the network. +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. // // There can be at most one open writer on a connection. NextWriter closes the // previous writer if the application has not already done so. @@ -411,7 +414,9 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { } c.writeFrameType = messageType - return messageWriter{c, c.writeSeq}, nil + w := &messageWriter{c} + c.messageWriter = w + return w, nil } func (c *Conn) flushFrame(final bool, extra []byte) error { @@ -420,7 +425,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { // Check for invalid control frames. if isControl(c.writeFrameType) && (!final || length > maxControlFramePayloadSize) { - c.writeSeq++ + c.messageWriter = nil c.writeFrameType = noFrame c.writePos = maxFrameHeaderSize return errInvalidControlFrame @@ -488,20 +493,17 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { c.writePos = maxFrameHeaderSize c.writeFrameType = continuationFrame if final { - c.writeSeq++ + c.messageWriter = nil c.writeFrameType = noFrame } return c.writeErr } -type messageWriter struct { - c *Conn - seq int -} +type messageWriter struct{ c *Conn } -func (w messageWriter) err() error { +func (w *messageWriter) err() error { c := w.c - if c.writeSeq != w.seq { + if c.messageWriter != w { return errWriteClosed } if c.writeErr != nil { @@ -510,7 +512,7 @@ func (w messageWriter) err() error { return nil } -func (w messageWriter) ncopy(max int) (int, error) { +func (w *messageWriter) ncopy(max int) (int, error) { n := len(w.c.writeBuf) - w.c.writePos if n <= 0 { if err := w.c.flushFrame(false, nil); err != nil { @@ -524,7 +526,7 @@ func (w messageWriter) ncopy(max int) (int, error) { return n, nil } -func (w messageWriter) write(final bool, p []byte) (int, error) { +func (w *messageWriter) write(final bool, p []byte) (int, error) { if err := w.err(); err != nil { return 0, err } @@ -551,11 +553,11 @@ func (w messageWriter) write(final bool, p []byte) (int, error) { return nn, nil } -func (w messageWriter) Write(p []byte) (int, error) { +func (w *messageWriter) Write(p []byte) (int, error) { return w.write(false, p) } -func (w messageWriter) WriteString(p string) (int, error) { +func (w *messageWriter) WriteString(p string) (int, error) { if err := w.err(); err != nil { return 0, err } @@ -573,7 +575,7 @@ func (w messageWriter) WriteString(p string) (int, error) { return nn, nil } -func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { +func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { if err := w.err(); err != nil { return 0, err } @@ -598,7 +600,7 @@ func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { return nn, err } -func (w messageWriter) Close() error { +func (w *messageWriter) Close() error { if err := w.err(); err != nil { return err } @@ -608,20 +610,22 @@ func (w messageWriter) Close() error { // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. func (c *Conn) WriteMessage(messageType int, data []byte) error { - wr, err := c.NextWriter(messageType) + w, err := c.NextWriter(messageType) if err != nil { return err } - w := wr.(messageWriter) - if _, err := w.write(true, data); err != nil { + if _, ok := w.(*messageWriter); ok && c.isServer { + // Optimize write as a single frame. + n := copy(c.writeBuf[c.writePos:], data) + c.writePos += n + data = data[n:] + err = c.flushFrame(true, data) return err } - if c.writeSeq == w.seq { - if err := c.flushFrame(true, nil); err != nil { - return err - } + if _, err = w.Write(data); err != nil { + return err } - return nil + return w.Close() } // SetWriteDeadline sets the write deadline on the underlying network @@ -635,22 +639,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { // Read methods -// readFull is like io.ReadFull except that io.EOF is never returned. -func (c *Conn) readFull(p []byte) (err error) { - var n int - for n < len(p) && err == nil { - var nn int - nn, err = c.br.Read(p[n:]) - n += nn - } - if n == len(p) { - err = nil - } else if err == io.EOF { - err = errUnexpectedEOF - } - return -} - func (c *Conn) advanceFrame() (int, error) { // 1. Skip remainder of previous frame. @@ -663,16 +651,16 @@ func (c *Conn) advanceFrame() (int, error) { // 2. Read and parse first two bytes of frame header. - var b [8]byte - if err := c.readFull(b[:2]); err != nil { + p, err := c.read(2) + if err != nil { return noFrame, err } - final := b[0]&finalBit != 0 - frameType := int(b[0] & 0xf) - reserved := int((b[0] >> 4) & 0x7) - mask := b[1]&maskBit != 0 - c.readRemaining = int64(b[1] & 0x7f) + final := p[0]&finalBit != 0 + frameType := int(p[0] & 0xf) + reserved := int((p[0] >> 4) & 0x7) + mask := p[1]&maskBit != 0 + c.readRemaining = int64(p[1] & 0x7f) if reserved != 0 { return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) @@ -704,15 +692,17 @@ func (c *Conn) advanceFrame() (int, error) { switch c.readRemaining { case 126: - if err := c.readFull(b[:2]); err != nil { + p, err := c.read(2) + if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) + c.readRemaining = int64(binary.BigEndian.Uint16(p)) case 127: - if err := c.readFull(b[:8]); err != nil { + p, err := c.read(8) + if err != nil { return noFrame, err } - c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) + c.readRemaining = int64(binary.BigEndian.Uint64(p)) } // 4. Handle frame masking. @@ -723,9 +713,11 @@ func (c *Conn) advanceFrame() (int, error) { if mask { c.readMaskPos = 0 - if err := c.readFull(c.readMaskKey[:]); err != nil { + p, err := c.read(len(c.readMaskKey)) + if err != nil { return noFrame, err } + copy(c.readMaskKey[:], p) } // 5. For text and binary messages, enforce read limit and return. @@ -745,9 +737,9 @@ func (c *Conn) advanceFrame() (int, error) { var payload []byte if c.readRemaining > 0 { - payload = make([]byte, c.readRemaining) + payload, err = c.read(int(c.readRemaining)) c.readRemaining = 0 - if err := c.readFull(payload); err != nil { + if err != nil { return noFrame, err } if c.isServer { @@ -805,7 +797,7 @@ func (c *Conn) handleProtocolError(message string) error { // this method return the same error. func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { - c.readSeq++ + c.messageReader = nil c.readLength = 0 for c.readErr == nil { @@ -815,7 +807,9 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { break } if frameType == TextMessage || frameType == BinaryMessage { - return frameType, messageReader{c, c.readSeq}, nil + r := &messageReader{c} + c.messageReader = r + return frameType, r, nil } } @@ -830,51 +824,48 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { return noFrame, nil, c.readErr } -type messageReader struct { - c *Conn - seq int -} +type messageReader struct{ c *Conn } -func (r messageReader) Read(b []byte) (int, error) { - - if r.seq != r.c.readSeq { +func (r *messageReader) Read(b []byte) (int, error) { + c := r.c + if c.messageReader != r { return 0, io.EOF } - for r.c.readErr == nil { + for c.readErr == nil { - if r.c.readRemaining > 0 { - if int64(len(b)) > r.c.readRemaining { - b = b[:r.c.readRemaining] + if c.readRemaining > 0 { + if int64(len(b)) > c.readRemaining { + b = b[:c.readRemaining] } - n, err := r.c.br.Read(b) - r.c.readErr = hideTempErr(err) - if r.c.isServer { - r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n]) + n, err := c.br.Read(b) + c.readErr = hideTempErr(err) + if c.isServer { + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) } - r.c.readRemaining -= int64(n) - if r.c.readRemaining > 0 && r.c.readErr == io.EOF { - r.c.readErr = errUnexpectedEOF + c.readRemaining -= int64(n) + if c.readRemaining > 0 && c.readErr == io.EOF { + c.readErr = errUnexpectedEOF } - return n, r.c.readErr + return n, c.readErr } - if r.c.readFinal { - r.c.readSeq++ + if c.readFinal { + c.messageReader = nil return 0, io.EOF } - frameType, err := r.c.advanceFrame() + frameType, err := c.advanceFrame() switch { case err != nil: - r.c.readErr = hideTempErr(err) + c.readErr = hideTempErr(err) case frameType == TextMessage || frameType == BinaryMessage: - r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } } - err := r.c.readErr - if err == io.EOF && r.seq == r.c.readSeq { + err := c.readErr + if err == io.EOF && c.messageReader == r { err = errUnexpectedEOF } return 0, err diff --git a/conn_read.go b/conn_read.go new file mode 100644 index 0000000..1ea1505 --- /dev/null +++ b/conn_read.go @@ -0,0 +1,18 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.5 + +package websocket + +import "io" + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + c.br.Discard(len(p)) + return p, err +} diff --git a/conn_read_legacy.go b/conn_read_legacy.go new file mode 100644 index 0000000..018541c --- /dev/null +++ b/conn_read_legacy.go @@ -0,0 +1,21 @@ +// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.5 + +package websocket + +import "io" + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + if len(p) > 0 { + // advance over the bytes just read + io.ReadFull(c.br, p) + } + return p, err +}