mirror of https://github.com/gorilla/websocket.git
Return write buffer to pool on write error (#427)
Fix bug where connection did not return the write buffer to the pool after a write error. Add test for the same. Rename messsageWriter.fatal method to endMessage and consolidate all message cleanup code there. This ensures that the buffer is returned to pool on all code paths. Rename Conn.prepMessage to beginMessage for symmetry with endMessage. Move some duplicated code at calls to prepMessage to beginMessage. Bonus improvement: Adjust message and buffer size in TestWriteBufferPool to test that pool works with fragmented messages.
This commit is contained in:
parent
cdd40f587d
commit
3130e8d3f1
45
conn.go
45
conn.go
|
@ -451,7 +451,8 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) prepWrite(messageType int) error {
|
// beginMessage prepares a connection and message writer for a new message.
|
||||||
|
func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
|
||||||
// Close previous writer if not already closed by the application. It's
|
// Close previous writer if not already closed by the application. It's
|
||||||
// probably better to return an error in this situation, but we cannot
|
// probably better to return an error in this situation, but we cannot
|
||||||
// change this without breaking existing applications.
|
// change this without breaking existing applications.
|
||||||
|
@ -471,6 +472,10 @@ func (c *Conn) prepWrite(messageType int) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mw.c = c
|
||||||
|
mw.frameType = messageType
|
||||||
|
mw.pos = maxFrameHeaderSize
|
||||||
|
|
||||||
if c.writeBuf == nil {
|
if c.writeBuf == nil {
|
||||||
wpd, ok := c.writePool.Get().(writePoolData)
|
wpd, ok := c.writePool.Get().(writePoolData)
|
||||||
if ok {
|
if ok {
|
||||||
|
@ -491,16 +496,11 @@ func (c *Conn) prepWrite(messageType int) error {
|
||||||
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
|
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
|
||||||
// PongMessage) are supported.
|
// PongMessage) are supported.
|
||||||
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||||
if err := c.prepWrite(messageType); err != nil {
|
var mw messageWriter
|
||||||
|
if err := c.beginMessage(&mw, messageType); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
c.writer = &mw
|
||||||
mw := &messageWriter{
|
|
||||||
c: c,
|
|
||||||
frameType: messageType,
|
|
||||||
pos: maxFrameHeaderSize,
|
|
||||||
}
|
|
||||||
c.writer = mw
|
|
||||||
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
||||||
w := c.newCompressionWriter(c.writer, c.compressionLevel)
|
w := c.newCompressionWriter(c.writer, c.compressionLevel)
|
||||||
mw.compress = true
|
mw.compress = true
|
||||||
|
@ -517,10 +517,16 @@ type messageWriter struct {
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *messageWriter) fatal(err error) error {
|
func (w *messageWriter) endMessage(err error) error {
|
||||||
if w.err != nil {
|
if w.err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c := w.c
|
||||||
w.err = err
|
w.err = err
|
||||||
w.c.writer = nil
|
c.writer = nil
|
||||||
|
if c.writePool != nil {
|
||||||
|
c.writePool.Put(writePoolData{buf: c.writeBuf})
|
||||||
|
c.writeBuf = nil
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -534,7 +540,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
||||||
// Check for invalid control frames.
|
// Check for invalid control frames.
|
||||||
if isControl(w.frameType) &&
|
if isControl(w.frameType) &&
|
||||||
(!final || length > maxControlFramePayloadSize) {
|
(!final || length > maxControlFramePayloadSize) {
|
||||||
return w.fatal(errInvalidControlFrame)
|
return w.endMessage(errInvalidControlFrame)
|
||||||
}
|
}
|
||||||
|
|
||||||
b0 := byte(w.frameType)
|
b0 := byte(w.frameType)
|
||||||
|
@ -579,7 +585,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
||||||
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
|
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
|
||||||
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
|
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
|
||||||
if len(extra) > 0 {
|
if len(extra) > 0 {
|
||||||
return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
|
return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -600,15 +606,11 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
||||||
c.isWriting = false
|
c.isWriting = false
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return w.fatal(err)
|
return w.endMessage(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if final {
|
if final {
|
||||||
c.writer = nil
|
w.endMessage(errWriteClosed)
|
||||||
if c.writePool != nil {
|
|
||||||
c.writePool.Put(writePoolData{buf: c.writeBuf})
|
|
||||||
c.writeBuf = nil
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -709,7 +711,6 @@ func (w *messageWriter) Close() error {
|
||||||
if err := w.flushFrame(true, nil); err != nil {
|
if err := w.flushFrame(true, nil); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
w.err = errWriteClosed
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -742,10 +743,10 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
||||||
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
|
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
|
||||||
// Fast path with no allocations and single frame.
|
// Fast path with no allocations and single frame.
|
||||||
|
|
||||||
if err := c.prepWrite(messageType); err != nil {
|
var mw messageWriter
|
||||||
|
if err := c.beginMessage(&mw, messageType); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
mw := messageWriter{c: c, frameType: messageType, pos: maxFrameHeaderSize}
|
|
||||||
n := copy(c.writeBuf[mw.pos:], data)
|
n := copy(c.writeBuf[mw.pos:], data)
|
||||||
mw.pos += n
|
mw.pos += n
|
||||||
data = data[n:]
|
data = data[n:]
|
||||||
|
|
60
conn_test.go
60
conn_test.go
|
@ -196,11 +196,16 @@ func (p *simpleBufferPool) Put(v interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteBufferPool(t *testing.T) {
|
func TestWriteBufferPool(t *testing.T) {
|
||||||
|
const message = "Now is the time for all good people to come to the aid of the party."
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
var pool simpleBufferPool
|
var pool simpleBufferPool
|
||||||
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
|
|
||||||
rc := newTestConn(&buf, nil, false)
|
rc := newTestConn(&buf, nil, false)
|
||||||
|
|
||||||
|
// Specify writeBufferSize smaller than message size to ensure that pooling
|
||||||
|
// works with fragmented messages.
|
||||||
|
wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil)
|
||||||
|
|
||||||
if wc.writeBuf != nil {
|
if wc.writeBuf != nil {
|
||||||
t.Fatal("writeBuf not nil after create")
|
t.Fatal("writeBuf not nil after create")
|
||||||
}
|
}
|
||||||
|
@ -218,8 +223,6 @@ func TestWriteBufferPool(t *testing.T) {
|
||||||
|
|
||||||
writeBufAddr := &wc.writeBuf[0]
|
writeBufAddr := &wc.writeBuf[0]
|
||||||
|
|
||||||
const message = "Hello World!"
|
|
||||||
|
|
||||||
if _, err := io.WriteString(w, message); err != nil {
|
if _, err := io.WriteString(w, message); err != nil {
|
||||||
t.Fatalf("io.WriteString(w, message) returned %v", err)
|
t.Fatalf("io.WriteString(w, message) returned %v", err)
|
||||||
}
|
}
|
||||||
|
@ -269,6 +272,7 @@ func TestWriteBufferPool(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool.
|
||||||
func TestWriteBufferPoolSync(t *testing.T) {
|
func TestWriteBufferPoolSync(t *testing.T) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
var pool sync.Pool
|
var pool sync.Pool
|
||||||
|
@ -290,6 +294,56 @@ func TestWriteBufferPoolSync(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// errorWriter is an io.Writer than returns an error on all writes.
|
||||||
|
type errorWriter struct{}
|
||||||
|
|
||||||
|
func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("Error!") }
|
||||||
|
|
||||||
|
// TestWriteBufferPoolError ensures that buffer is returned to pool after error
|
||||||
|
// on write.
|
||||||
|
func TestWriteBufferPoolError(t *testing.T) {
|
||||||
|
|
||||||
|
// Part 1: Test NextWriter/Write/Close
|
||||||
|
|
||||||
|
var pool simpleBufferPool
|
||||||
|
wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
|
||||||
|
|
||||||
|
w, err := wc.NextWriter(TextMessage)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("wc.NextWriter() returned %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if wc.writeBuf == nil {
|
||||||
|
t.Fatal("writeBuf is nil after NextWriter")
|
||||||
|
}
|
||||||
|
|
||||||
|
writeBufAddr := &wc.writeBuf[0]
|
||||||
|
|
||||||
|
if _, err := io.WriteString(w, "Hello"); err != nil {
|
||||||
|
t.Fatalf("io.WriteString(w, message) returned %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.Close(); err == nil {
|
||||||
|
t.Fatalf("w.Close() did not return error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
|
||||||
|
t.Fatal("writeBuf not returned to pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Part 2: Test WriteMessage
|
||||||
|
|
||||||
|
wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil)
|
||||||
|
|
||||||
|
if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil {
|
||||||
|
t.Fatalf("wc.WriteMessage did not return error")
|
||||||
|
}
|
||||||
|
|
||||||
|
if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
|
||||||
|
t.Fatal("writeBuf not returned to pool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||||
const bufSize = 512
|
const bufSize = 512
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue