forked from mirror/websocket
Reduce memory allocations in NextReader, NextWriter
Redo 8b209f6317
with support for old
versions of Go.
This commit is contained in:
parent
50d660d6ac
commit
be01041b66
159
conn.go
159
conn.go
|
@ -238,16 +238,15 @@ type Conn struct {
|
||||||
writeBuf []byte // frame is constructed in this buffer.
|
writeBuf []byte // frame is constructed in this buffer.
|
||||||
writePos int // end of data in writeBuf.
|
writePos int // end of data in writeBuf.
|
||||||
writeFrameType int // type of the current frame.
|
writeFrameType int // type of the current frame.
|
||||||
writeSeq int // incremented to invalidate message writers.
|
|
||||||
writeDeadline time.Time
|
writeDeadline time.Time
|
||||||
isWriting bool // for best-effort concurrent write detection
|
isWriting bool // for best-effort concurrent write detection
|
||||||
|
messageWriter *messageWriter // the current writer
|
||||||
|
|
||||||
// Read fields
|
// Read fields
|
||||||
readErr error
|
readErr error
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
readRemaining int64 // bytes remaining in current frame.
|
readRemaining int64 // bytes remaining in current frame.
|
||||||
readFinal bool // true the current message has more frames.
|
readFinal bool // true the current message has more frames.
|
||||||
readSeq int // incremented to invalidate message readers.
|
|
||||||
readLength int64 // Message size.
|
readLength int64 // Message size.
|
||||||
readLimit int64 // Maximum message size.
|
readLimit int64 // Maximum message size.
|
||||||
readMaskPos int
|
readMaskPos int
|
||||||
|
@ -255,6 +254,7 @@ type Conn struct {
|
||||||
handlePong func(string) error
|
handlePong func(string) error
|
||||||
handlePing func(string) error
|
handlePing func(string) error
|
||||||
readErrCount int
|
readErrCount int
|
||||||
|
messageReader *messageReader // the current reader
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
|
||||||
|
@ -264,6 +264,9 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
|
||||||
if readBufferSize == 0 {
|
if readBufferSize == 0 {
|
||||||
readBufferSize = defaultReadBufferSize
|
readBufferSize = defaultReadBufferSize
|
||||||
}
|
}
|
||||||
|
if readBufferSize < maxControlFramePayloadSize {
|
||||||
|
readBufferSize = maxControlFramePayloadSize
|
||||||
|
}
|
||||||
if writeBufferSize == 0 {
|
if writeBufferSize == 0 {
|
||||||
writeBufferSize = defaultWriteBufferSize
|
writeBufferSize = defaultWriteBufferSize
|
||||||
}
|
}
|
||||||
|
@ -390,8 +393,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
||||||
return hideTempErr(err)
|
return hideTempErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NextWriter returns a writer for the next message to send. The writer's
|
// NextWriter returns a writer for the next message to send. The writer's Close
|
||||||
// Close method flushes the complete message to the network.
|
// method flushes the complete message to the network.
|
||||||
//
|
//
|
||||||
// There can be at most one open writer on a connection. NextWriter closes the
|
// There can be at most one open writer on a connection. NextWriter closes the
|
||||||
// previous writer if the application has not already done so.
|
// previous writer if the application has not already done so.
|
||||||
|
@ -411,7 +414,9 @@ func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
c.writeFrameType = messageType
|
c.writeFrameType = messageType
|
||||||
return messageWriter{c, c.writeSeq}, nil
|
w := &messageWriter{c}
|
||||||
|
c.messageWriter = w
|
||||||
|
return w, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) flushFrame(final bool, extra []byte) error {
|
func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||||
|
@ -420,7 +425,7 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||||
// Check for invalid control frames.
|
// Check for invalid control frames.
|
||||||
if isControl(c.writeFrameType) &&
|
if isControl(c.writeFrameType) &&
|
||||||
(!final || length > maxControlFramePayloadSize) {
|
(!final || length > maxControlFramePayloadSize) {
|
||||||
c.writeSeq++
|
c.messageWriter = nil
|
||||||
c.writeFrameType = noFrame
|
c.writeFrameType = noFrame
|
||||||
c.writePos = maxFrameHeaderSize
|
c.writePos = maxFrameHeaderSize
|
||||||
return errInvalidControlFrame
|
return errInvalidControlFrame
|
||||||
|
@ -488,20 +493,17 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||||
c.writePos = maxFrameHeaderSize
|
c.writePos = maxFrameHeaderSize
|
||||||
c.writeFrameType = continuationFrame
|
c.writeFrameType = continuationFrame
|
||||||
if final {
|
if final {
|
||||||
c.writeSeq++
|
c.messageWriter = nil
|
||||||
c.writeFrameType = noFrame
|
c.writeFrameType = noFrame
|
||||||
}
|
}
|
||||||
return c.writeErr
|
return c.writeErr
|
||||||
}
|
}
|
||||||
|
|
||||||
type messageWriter struct {
|
type messageWriter struct{ c *Conn }
|
||||||
c *Conn
|
|
||||||
seq int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w messageWriter) err() error {
|
func (w *messageWriter) err() error {
|
||||||
c := w.c
|
c := w.c
|
||||||
if c.writeSeq != w.seq {
|
if c.messageWriter != w {
|
||||||
return errWriteClosed
|
return errWriteClosed
|
||||||
}
|
}
|
||||||
if c.writeErr != nil {
|
if c.writeErr != nil {
|
||||||
|
@ -510,7 +512,7 @@ func (w messageWriter) err() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w messageWriter) ncopy(max int) (int, error) {
|
func (w *messageWriter) ncopy(max int) (int, error) {
|
||||||
n := len(w.c.writeBuf) - w.c.writePos
|
n := len(w.c.writeBuf) - w.c.writePos
|
||||||
if n <= 0 {
|
if n <= 0 {
|
||||||
if err := w.c.flushFrame(false, nil); err != nil {
|
if err := w.c.flushFrame(false, nil); err != nil {
|
||||||
|
@ -524,7 +526,7 @@ func (w messageWriter) ncopy(max int) (int, error) {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w messageWriter) write(final bool, p []byte) (int, error) {
|
func (w *messageWriter) write(final bool, p []byte) (int, error) {
|
||||||
if err := w.err(); err != nil {
|
if err := w.err(); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -551,11 +553,11 @@ func (w messageWriter) write(final bool, p []byte) (int, error) {
|
||||||
return nn, nil
|
return nn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w messageWriter) Write(p []byte) (int, error) {
|
func (w *messageWriter) Write(p []byte) (int, error) {
|
||||||
return w.write(false, p)
|
return w.write(false, p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w messageWriter) WriteString(p string) (int, error) {
|
func (w *messageWriter) WriteString(p string) (int, error) {
|
||||||
if err := w.err(); err != nil {
|
if err := w.err(); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -573,7 +575,7 @@ func (w messageWriter) WriteString(p string) (int, error) {
|
||||||
return nn, nil
|
return nn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
||||||
if err := w.err(); err != nil {
|
if err := w.err(); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -598,7 +600,7 @@ func (w messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
||||||
return nn, err
|
return nn, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w messageWriter) Close() error {
|
func (w *messageWriter) Close() error {
|
||||||
if err := w.err(); err != nil {
|
if err := w.err(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -608,20 +610,22 @@ func (w messageWriter) Close() error {
|
||||||
// WriteMessage is a helper method for getting a writer using NextWriter,
|
// WriteMessage is a helper method for getting a writer using NextWriter,
|
||||||
// writing the message and closing the writer.
|
// writing the message and closing the writer.
|
||||||
func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
||||||
wr, err := c.NextWriter(messageType)
|
w, err := c.NextWriter(messageType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w := wr.(messageWriter)
|
if _, ok := w.(*messageWriter); ok && c.isServer {
|
||||||
if _, err := w.write(true, data); err != nil {
|
// Optimize write as a single frame.
|
||||||
|
n := copy(c.writeBuf[c.writePos:], data)
|
||||||
|
c.writePos += n
|
||||||
|
data = data[n:]
|
||||||
|
err = c.flushFrame(true, data)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if c.writeSeq == w.seq {
|
if _, err = w.Write(data); err != nil {
|
||||||
if err := c.flushFrame(true, nil); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
return w.Close()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWriteDeadline sets the write deadline on the underlying network
|
// SetWriteDeadline sets the write deadline on the underlying network
|
||||||
|
@ -635,22 +639,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
|
|
||||||
// Read methods
|
// Read methods
|
||||||
|
|
||||||
// readFull is like io.ReadFull except that io.EOF is never returned.
|
|
||||||
func (c *Conn) readFull(p []byte) (err error) {
|
|
||||||
var n int
|
|
||||||
for n < len(p) && err == nil {
|
|
||||||
var nn int
|
|
||||||
nn, err = c.br.Read(p[n:])
|
|
||||||
n += nn
|
|
||||||
}
|
|
||||||
if n == len(p) {
|
|
||||||
err = nil
|
|
||||||
} else if err == io.EOF {
|
|
||||||
err = errUnexpectedEOF
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) advanceFrame() (int, error) {
|
func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
// 1. Skip remainder of previous frame.
|
// 1. Skip remainder of previous frame.
|
||||||
|
@ -663,16 +651,16 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
// 2. Read and parse first two bytes of frame header.
|
// 2. Read and parse first two bytes of frame header.
|
||||||
|
|
||||||
var b [8]byte
|
p, err := c.read(2)
|
||||||
if err := c.readFull(b[:2]); err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
|
|
||||||
final := b[0]&finalBit != 0
|
final := p[0]&finalBit != 0
|
||||||
frameType := int(b[0] & 0xf)
|
frameType := int(p[0] & 0xf)
|
||||||
reserved := int((b[0] >> 4) & 0x7)
|
reserved := int((p[0] >> 4) & 0x7)
|
||||||
mask := b[1]&maskBit != 0
|
mask := p[1]&maskBit != 0
|
||||||
c.readRemaining = int64(b[1] & 0x7f)
|
c.readRemaining = int64(p[1] & 0x7f)
|
||||||
|
|
||||||
if reserved != 0 {
|
if reserved != 0 {
|
||||||
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
return noFrame, c.handleProtocolError("unexpected reserved bits " + strconv.Itoa(reserved))
|
||||||
|
@ -704,15 +692,17 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
switch c.readRemaining {
|
switch c.readRemaining {
|
||||||
case 126:
|
case 126:
|
||||||
if err := c.readFull(b[:2]); err != nil {
|
p, err := c.read(2)
|
||||||
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
|
c.readRemaining = int64(binary.BigEndian.Uint16(p))
|
||||||
case 127:
|
case 127:
|
||||||
if err := c.readFull(b[:8]); err != nil {
|
p, err := c.read(8)
|
||||||
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
|
c.readRemaining = int64(binary.BigEndian.Uint64(p))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Handle frame masking.
|
// 4. Handle frame masking.
|
||||||
|
@ -723,9 +713,11 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
if mask {
|
if mask {
|
||||||
c.readMaskPos = 0
|
c.readMaskPos = 0
|
||||||
if err := c.readFull(c.readMaskKey[:]); err != nil {
|
p, err := c.read(len(c.readMaskKey))
|
||||||
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
|
copy(c.readMaskKey[:], p)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. For text and binary messages, enforce read limit and return.
|
// 5. For text and binary messages, enforce read limit and return.
|
||||||
|
@ -745,9 +737,9 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
var payload []byte
|
var payload []byte
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
payload = make([]byte, c.readRemaining)
|
payload, err = c.read(int(c.readRemaining))
|
||||||
c.readRemaining = 0
|
c.readRemaining = 0
|
||||||
if err := c.readFull(payload); err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
if c.isServer {
|
if c.isServer {
|
||||||
|
@ -805,7 +797,7 @@ func (c *Conn) handleProtocolError(message string) error {
|
||||||
// this method return the same error.
|
// this method return the same error.
|
||||||
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
|
|
||||||
c.readSeq++
|
c.messageReader = nil
|
||||||
c.readLength = 0
|
c.readLength = 0
|
||||||
|
|
||||||
for c.readErr == nil {
|
for c.readErr == nil {
|
||||||
|
@ -815,7 +807,9 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if frameType == TextMessage || frameType == BinaryMessage {
|
if frameType == TextMessage || frameType == BinaryMessage {
|
||||||
return frameType, messageReader{c, c.readSeq}, nil
|
r := &messageReader{c}
|
||||||
|
c.messageReader = r
|
||||||
|
return frameType, r, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -830,51 +824,48 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
return noFrame, nil, c.readErr
|
return noFrame, nil, c.readErr
|
||||||
}
|
}
|
||||||
|
|
||||||
type messageReader struct {
|
type messageReader struct{ c *Conn }
|
||||||
c *Conn
|
|
||||||
seq int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r messageReader) Read(b []byte) (int, error) {
|
func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
|
c := r.c
|
||||||
if r.seq != r.c.readSeq {
|
if c.messageReader != r {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
for r.c.readErr == nil {
|
for c.readErr == nil {
|
||||||
|
|
||||||
if r.c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
if int64(len(b)) > r.c.readRemaining {
|
if int64(len(b)) > c.readRemaining {
|
||||||
b = b[:r.c.readRemaining]
|
b = b[:c.readRemaining]
|
||||||
}
|
}
|
||||||
n, err := r.c.br.Read(b)
|
n, err := c.br.Read(b)
|
||||||
r.c.readErr = hideTempErr(err)
|
c.readErr = hideTempErr(err)
|
||||||
if r.c.isServer {
|
if c.isServer {
|
||||||
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
|
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
|
||||||
}
|
}
|
||||||
r.c.readRemaining -= int64(n)
|
c.readRemaining -= int64(n)
|
||||||
if r.c.readRemaining > 0 && r.c.readErr == io.EOF {
|
if c.readRemaining > 0 && c.readErr == io.EOF {
|
||||||
r.c.readErr = errUnexpectedEOF
|
c.readErr = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
return n, r.c.readErr
|
return n, c.readErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.c.readFinal {
|
if c.readFinal {
|
||||||
r.c.readSeq++
|
c.messageReader = nil
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
frameType, err := r.c.advanceFrame()
|
frameType, err := c.advanceFrame()
|
||||||
switch {
|
switch {
|
||||||
case err != nil:
|
case err != nil:
|
||||||
r.c.readErr = hideTempErr(err)
|
c.readErr = hideTempErr(err)
|
||||||
case frameType == TextMessage || frameType == BinaryMessage:
|
case frameType == TextMessage || frameType == BinaryMessage:
|
||||||
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := r.c.readErr
|
err := c.readErr
|
||||||
if err == io.EOF && r.seq == r.c.readSeq {
|
if err == io.EOF && c.messageReader == r {
|
||||||
err = errUnexpectedEOF
|
err = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build go1.5
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
func (c *Conn) read(n int) ([]byte, error) {
|
||||||
|
p, err := c.br.Peek(n)
|
||||||
|
if err == io.EOF {
|
||||||
|
err = errUnexpectedEOF
|
||||||
|
}
|
||||||
|
c.br.Discard(len(p))
|
||||||
|
return p, err
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build !go1.5
|
||||||
|
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
func (c *Conn) read(n int) ([]byte, error) {
|
||||||
|
p, err := c.br.Peek(n)
|
||||||
|
if err == io.EOF {
|
||||||
|
err = errUnexpectedEOF
|
||||||
|
}
|
||||||
|
if len(p) > 0 {
|
||||||
|
// advance over the bytes just read
|
||||||
|
io.ReadFull(c.br, p)
|
||||||
|
}
|
||||||
|
return p, err
|
||||||
|
}
|
Loading…
Reference in New Issue