Reduce memory allocations in NextReader, NextWriter

Redo 8b209f6317 with support for old
versions of Go.
This commit is contained in:
Gary Burd 2016-05-31 05:14:41 -07:00
parent 50d660d6ac
commit be01041b66
3 changed files with 116 additions and 86 deletions

163
conn.go
View File

@ -238,16 +238,15 @@ type Conn struct {
writeBuf []byte // frame is constructed in this buffer. writeBuf []byte // frame is constructed in this buffer.
writePos int // end of data in writeBuf. writePos int // end of data in writeBuf.
writeFrameType int // type of the current frame. writeFrameType int // type of the current frame.
writeSeq int // incremented to invalidate message writers.
writeDeadline time.Time writeDeadline time.Time
isWriting bool // for best-effort concurrent write detection isWriting bool // for best-effort concurrent write detection
messageWriter *messageWriter // the current writer
// Read fields // Read fields
readErr error readErr error
br *bufio.Reader br *bufio.Reader
readRemaining int64 // bytes remaining in current frame. readRemaining int64 // bytes remaining in current frame.
readFinal bool // true the current message has more frames. readFinal bool // true the current message has more frames.
readSeq int // incremented to invalidate message readers.
readLength int64 // Message size. readLength int64 // Message size.
readLimit int64 // Maximum message size. readLimit int64 // Maximum message size.
readMaskPos int readMaskPos int
@ -255,6 +254,7 @@ type Conn struct {
handlePong func(string) error handlePong func(string) error
handlePing func(string) error handlePing func(string) error
readErrCount int readErrCount int
messageReader *messageReader // the current reader
} }
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
@ -264,6 +264,9 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
if readBufferSize == 0 { if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize readBufferSize = defaultReadBufferSize
} }
if readBufferSize < maxControlFramePayloadSize {
readBufferSize = maxControlFramePayloadSize
}
if writeBufferSize == 0 { if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize writeBufferSize = defaultWriteBufferSize
} }
@ -390,8 +393,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return hideTempErr(err) return hideTempErr(err)
} }
// NextWriter returns a writer for the next message to send. The writer's // NextWriter returns a writer for the next message to send. The writer's Close
// Close method flushes the complete message to the network. // method flushes the complete message to the network.
// //
// There can be at most one open writer on a connection. NextWriter closes the // There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so. // previous writer if the application has not already done so.
@ -411,7 +414,9 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
} }
c.writeFrameType = messageType c.writeFrameType = messageType
return messageWriter{c, c.writeSeq}, nil w := &messageWriter{c}
c.messageWriter = w
return w, nil
} }
func (c *Conn) flushFrame(final bool, extra []byte) error { func (c *Conn) flushFrame(final bool, extra []byte) error {
@ -420,7 +425,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
// Check for invalid control frames. // Check for invalid control frames.
if isControl(c.writeFrameType) && if isControl(c.writeFrameType) &&
(!final || length > maxControlFramePayloadSize) { (!final || length > maxControlFramePayloadSize) {
c.writeSeq++ c.messageWriter = nil
c.writeFrameType = noFrame c.writeFrameType = noFrame
c.writePos = maxFrameHeaderSize c.writePos = maxFrameHeaderSize
return errInvalidControlFrame return errInvalidControlFrame
@ -488,20 +493,17 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
c.writePos = maxFrameHeaderSize c.writePos = maxFrameHeaderSize
c.writeFrameType = continuationFrame c.writeFrameType = continuationFrame
if final { if final {
c.writeSeq++ c.messageWriter = nil
c.writeFrameType = noFrame c.writeFrameType = noFrame
} }
return c.writeErr return c.writeErr
} }
type messageWriter struct { type messageWriter struct{ c *Conn }
c *Conn
seq int
}
func (w messageWriter) err() error { func (w *messageWriter) err() error {
c := w.c c := w.c
if c.writeSeq != w.seq { if c.messageWriter != w {
return errWriteClosed return errWriteClosed
} }
if c.writeErr != nil { if c.writeErr != nil {
@ -510,7 +512,7 @@ func (w messageWriter) err() error {
return nil return nil
} }
func (w messageWriter) ncopy(max int) (int, error) { func (w *messageWriter) ncopy(max int) (int, error) {
n := len(w.c.writeBuf) - w.c.writePos n := len(w.c.writeBuf) - w.c.writePos
if n <= 0 { if n <= 0 {
if err := w.c.flushFrame(false, nil); err != nil { if err := w.c.flushFrame(false, nil); err != nil {
@ -524,7 +526,7 @@ func (w messageWriter) ncopy(max int) (int, error) {
return n, nil return n, nil
} }
func (w messageWriter) write(final bool, p []byte) (int, error) { func (w *messageWriter) write(final bool, p []byte) (int, error) {
if err := w.err(); err != nil { if err := w.err(); err != nil {
return 0, err return 0, err
} }
@ -551,11 +553,11 @@ func (w messageWriter) write(final bool, p []byte) (int, error) {
return nn, nil return nn, nil
} }
func (w messageWriter) Write(p []byte) (int, error) { func (w *messageWriter) Write(p []byte) (int, error) {
return w.write(false, p) return w.write(false, p)
} }
func (w messageWriter) WriteString(p string) (int, error) { func (w *messageWriter) WriteString(p string) (int, error) {
if err := w.err(); err != nil { if err := w.err(); err != nil {
return 0, err return 0, err
} }
@ -573,7 +575,7 @@ func (w messageWriter) WriteString(p string) (int, error) {
return nn, nil return nn, nil
} }
func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
if err := w.err(); err != nil { if err := w.err(); err != nil {
return 0, err return 0, err
} }
@ -598,7 +600,7 @@ func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
return nn, err return nn, err
} }
func (w messageWriter) Close() error { func (w *messageWriter) Close() error {
if err := w.err(); err != nil { if err := w.err(); err != nil {
return err return err
} }
@ -608,20 +610,22 @@ func (w messageWriter) Close() error {
// WriteMessage is a helper method for getting a writer using NextWriter, // WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer. // writing the message and closing the writer.
func (c *Conn) WriteMessage(messageType int, data []byte) error { func (c *Conn) WriteMessage(messageType int, data []byte) error {
wr, err := c.NextWriter(messageType) w, err := c.NextWriter(messageType)
if err != nil { if err != nil {
return err return err
} }
w := wr.(messageWriter) if _, ok := w.(*messageWriter); ok && c.isServer {
if _, err := w.write(true, data); err != nil { // Optimize write as a single frame.
n := copy(c.writeBuf[c.writePos:], data)
c.writePos += n
data = data[n:]
err = c.flushFrame(true, data)
return err return err
} }
if c.writeSeq == w.seq { if _, err = w.Write(data); err != nil {
if err := c.flushFrame(true, nil); err != nil { return err
return err
}
} }
return nil return w.Close()
} }
// SetWriteDeadline sets the write deadline on the underlying network // SetWriteDeadline sets the write deadline on the underlying network
@ -635,22 +639,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
// Read methods // Read methods
// readFull is like io.ReadFull except that io.EOF is never returned.
func (c *Conn) readFull(p []byte) (err error) {
var n int
for n < len(p) && err == nil {
var nn int
nn, err = c.br.Read(p[n:])
n += nn
}
if n == len(p) {
err = nil
} else if err == io.EOF {
err = errUnexpectedEOF
}
return
}
func (c *Conn) advanceFrame() (int, error) { func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
@ -663,16 +651,16 @@ func (c *Conn) advanceFrame() (int, error) {
// 2. Read and parse first two bytes of frame header. // 2. Read and parse first two bytes of frame header.
var b [8]byte p, err := c.read(2)
if err := c.readFull(b[:2]); err != nil { if err != nil {
return noFrame, err return noFrame, err
} }
final := b[0]&finalBit != 0 final := p[0]&finalBit != 0
frameType := int(b[0] & 0xf) frameType := int(p[0] & 0xf)
reserved := int((b[0] >> 4) & 0x7) reserved := int((p[0] >> 4) & 0x7)
mask := b[1]&maskBit != 0 mask := p[1]&maskBit != 0
c.readRemaining = int64(b[1] & 0x7f) c.readRemaining = int64(p[1] & 0x7f)
if reserved != 0 { if reserved != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved)) return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
@ -704,15 +692,17 @@ func (c *Conn) advanceFrame() (int, error) {
switch c.readRemaining { switch c.readRemaining {
case 126: case 126:
if err := c.readFull(b[:2]); err != nil { p, err := c.read(2)
if err != nil {
return noFrame, err return noFrame, err
} }
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) c.readRemaining = int64(binary.BigEndian.Uint16(p))
case 127: case 127:
if err := c.readFull(b[:8]); err != nil { p, err := c.read(8)
if err != nil {
return noFrame, err return noFrame, err
} }
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) c.readRemaining = int64(binary.BigEndian.Uint64(p))
} }
// 4. Handle frame masking. // 4. Handle frame masking.
@ -723,9 +713,11 @@ func (c *Conn) advanceFrame() (int, error) {
if mask { if mask {
c.readMaskPos = 0 c.readMaskPos = 0
if err := c.readFull(c.readMaskKey[:]); err != nil { p, err := c.read(len(c.readMaskKey))
if err != nil {
return noFrame, err return noFrame, err
} }
copy(c.readMaskKey[:], p)
} }
// 5. For text and binary messages, enforce read limit and return. // 5. For text and binary messages, enforce read limit and return.
@ -745,9 +737,9 @@ func (c *Conn) advanceFrame() (int, error) {
var payload []byte var payload []byte
if c.readRemaining > 0 { if c.readRemaining > 0 {
payload = make([]byte, c.readRemaining) payload, err = c.read(int(c.readRemaining))
c.readRemaining = 0 c.readRemaining = 0
if err := c.readFull(payload); err != nil { if err != nil {
return noFrame, err return noFrame, err
} }
if c.isServer { if c.isServer {
@ -805,7 +797,7 @@ func (c *Conn) handleProtocolError(message string) error {
// this method return the same error. // this method return the same error.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.readSeq++ c.messageReader = nil
c.readLength = 0 c.readLength = 0
for c.readErr == nil { for c.readErr == nil {
@ -815,7 +807,9 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
break break
} }
if frameType == TextMessage || frameType == BinaryMessage { if frameType == TextMessage || frameType == BinaryMessage {
return frameType, messageReader{c, c.readSeq}, nil r := &messageReader{c}
c.messageReader = r
return frameType, r, nil
} }
} }
@ -830,51 +824,48 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
return noFrame, nil, c.readErr return noFrame, nil, c.readErr
} }
type messageReader struct { type messageReader struct{ c *Conn }
c *Conn
seq int
}
func (r messageReader) Read(b []byte) (int, error) { func (r *messageReader) Read(b []byte) (int, error) {
c := r.c
if r.seq != r.c.readSeq { if c.messageReader != r {
return 0, io.EOF return 0, io.EOF
} }
for r.c.readErr == nil { for c.readErr == nil {
if r.c.readRemaining > 0 { if c.readRemaining > 0 {
if int64(len(b)) > r.c.readRemaining { if int64(len(b)) > c.readRemaining {
b = b[:r.c.readRemaining] b = b[:c.readRemaining]
} }
n, err := r.c.br.Read(b) n, err := c.br.Read(b)
r.c.readErr = hideTempErr(err) c.readErr = hideTempErr(err)
if r.c.isServer { if c.isServer {
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n]) c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
} }
r.c.readRemaining -= int64(n) c.readRemaining -= int64(n)
if r.c.readRemaining > 0 && r.c.readErr == io.EOF { if c.readRemaining > 0 && c.readErr == io.EOF {
r.c.readErr = errUnexpectedEOF c.readErr = errUnexpectedEOF
} }
return n, r.c.readErr return n, c.readErr
} }
if r.c.readFinal { if c.readFinal {
r.c.readSeq++ c.messageReader = nil
return 0, io.EOF return 0, io.EOF
} }
frameType, err := r.c.advanceFrame() frameType, err := c.advanceFrame()
switch { switch {
case err != nil: case err != nil:
r.c.readErr = hideTempErr(err) c.readErr = hideTempErr(err)
case frameType == TextMessage || frameType == BinaryMessage: case frameType == TextMessage || frameType == BinaryMessage:
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
} }
} }
err := r.c.readErr err := c.readErr
if err == io.EOF && r.seq == r.c.readSeq { if err == io.EOF && c.messageReader == r {
err = errUnexpectedEOF err = errUnexpectedEOF
} }
return 0, err return 0, err

18
conn_read.go Normal file
View File

@ -0,0 +1,18 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
return p, err
}

21
conn_read_legacy.go Normal file
View File

@ -0,0 +1,21 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
if len(p) > 0 {
// advance over the bytes just read
io.ReadFull(c.br, p)
}
return p, err
}