mirror of https://github.com/gorilla/websocket.git
Improve write error handling
- Do not fail NextWriter when close of previous writer fails. - Replace closeSent field with mutex protected writeErr. Set writeErr on any error writing to underlying network connection. Check and return writeErr before attempting to write to network connection. Check writeErr in NextWriter so application can detect failed connection before attempting to write. - Do not close underlying network connection on error. - Move message writing state and method flushFrame from Conn to messageWriter. This makes error code paths (and the code in general) easier to understand. - Add messageWriter field err to latch errors in messageWriter. Bonus: Improve test coverage.
This commit is contained in:
parent
343fff4c5c
commit
80a0029a65
222
conn.go
222
conn.go
|
@ -12,6 +12,7 @@ import (
|
|||
"io/ioutil"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
@ -223,19 +224,16 @@ type Conn struct {
|
|||
subprotocol string
|
||||
|
||||
// Write fields
|
||||
mu chan bool // used as mutex to protect write to conn and closeSent
|
||||
closeSent bool // whether close message was sent
|
||||
writeErr error
|
||||
mu chan bool // used as mutex to protect write to conn
|
||||
writeBuf []byte // frame is constructed in this buffer.
|
||||
writePos int // end of data in writeBuf.
|
||||
writeFrameType int // type of the current frame.
|
||||
writeDeadline time.Time
|
||||
messageWriter *messageWriter // the current low-level message writer
|
||||
writer io.WriteCloser // the current writer returned to the application
|
||||
isWriting bool // for best-effort concurrent write detection
|
||||
|
||||
writeErrMu sync.Mutex
|
||||
writeErr error
|
||||
|
||||
enableWriteCompression bool
|
||||
writeCompress bool // whether next call to flushFrame should set RSV1
|
||||
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
|
||||
|
||||
// Read fields
|
||||
|
@ -277,8 +275,6 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
|
|||
mu: mu,
|
||||
readFinal: true,
|
||||
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
|
||||
writeFrameType: noFrame,
|
||||
writePos: maxFrameHeaderSize,
|
||||
enableWriteCompression: true,
|
||||
}
|
||||
c.SetPingHandler(nil)
|
||||
|
@ -308,29 +304,40 @@ func (c *Conn) RemoteAddr() net.Addr {
|
|||
|
||||
// Write methods
|
||||
|
||||
func (c *Conn) writeFatal(err error) error {
|
||||
err = hideTempErr(err)
|
||||
c.writeErrMu.Lock()
|
||||
if c.writeErr == nil {
|
||||
c.writeErr = err
|
||||
}
|
||||
c.writeErrMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
|
||||
<-c.mu
|
||||
defer func() { c.mu <- true }()
|
||||
|
||||
if c.closeSent {
|
||||
return ErrCloseSent
|
||||
} else if frameType == CloseMessage {
|
||||
c.closeSent = true
|
||||
c.writeErrMu.Lock()
|
||||
err := c.writeErr
|
||||
c.writeErrMu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.conn.SetWriteDeadline(deadline)
|
||||
for _, buf := range bufs {
|
||||
if len(buf) > 0 {
|
||||
n, err := c.conn.Write(buf)
|
||||
if n != len(buf) {
|
||||
// Close on partial write.
|
||||
c.conn.Close()
|
||||
}
|
||||
_, err := c.conn.Write(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
return c.writeFatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if frameType == CloseMessage {
|
||||
c.writeFatal(ErrCloseSent)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -379,18 +386,22 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
|||
}
|
||||
defer func() { c.mu <- true }()
|
||||
|
||||
if c.closeSent {
|
||||
return ErrCloseSent
|
||||
} else if messageType == CloseMessage {
|
||||
c.closeSent = true
|
||||
c.writeErrMu.Lock()
|
||||
err := c.writeErr
|
||||
c.writeErrMu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.conn.SetWriteDeadline(deadline)
|
||||
n, err := c.conn.Write(buf)
|
||||
if n != 0 && n != len(buf) {
|
||||
c.conn.Close()
|
||||
_, err = c.conn.Write(buf)
|
||||
if err != nil {
|
||||
return c.writeFatal(err)
|
||||
}
|
||||
return hideTempErr(err)
|
||||
if messageType == CloseMessage {
|
||||
c.writeFatal(ErrCloseSent)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// NextWriter returns a writer for the next message to send. The writer's Close
|
||||
|
@ -399,64 +410,79 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
|
|||
// There can be at most one open writer on a connection. NextWriter closes the
|
||||
// previous writer if the application has not already done so.
|
||||
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||
if c.writeErr != nil {
|
||||
return nil, c.writeErr
|
||||
}
|
||||
|
||||
// Close previous writer if not already closed by the application. It's
|
||||
// probably better to return an error in this situation, but we cannot
|
||||
// change this without breaking existing applications.
|
||||
if c.writer != nil {
|
||||
err := c.writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.writer.Close()
|
||||
c.writer = nil
|
||||
}
|
||||
|
||||
if !isControl(messageType) && !isData(messageType) {
|
||||
return nil, errBadWriteOpCode
|
||||
}
|
||||
|
||||
c.writeFrameType = messageType
|
||||
c.messageWriter = &messageWriter{c}
|
||||
|
||||
var w io.WriteCloser = c.messageWriter
|
||||
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
||||
c.writeCompress = true
|
||||
var err error
|
||||
w, err = c.newCompressionWriter(w)
|
||||
c.writeErrMu.Lock()
|
||||
err := c.writeErr
|
||||
c.writeErrMu.Unlock()
|
||||
if err != nil {
|
||||
c.writer.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mw := &messageWriter{
|
||||
c: c,
|
||||
frameType: messageType,
|
||||
pos: maxFrameHeaderSize,
|
||||
}
|
||||
c.writer = mw
|
||||
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
|
||||
w, err := c.newCompressionWriter(c.writer)
|
||||
if err != nil {
|
||||
c.writer = nil
|
||||
return nil, err
|
||||
}
|
||||
mw.compress = true
|
||||
c.writer = w
|
||||
}
|
||||
return c.writer, nil
|
||||
}
|
||||
|
||||
return w, nil
|
||||
type messageWriter struct {
|
||||
c *Conn
|
||||
compress bool // whether next call to flushFrame should set RSV1
|
||||
pos int // end of data in writeBuf.
|
||||
frameType int // type of the current frame.
|
||||
err error
|
||||
}
|
||||
|
||||
func (w *messageWriter) fatal(err error) error {
|
||||
if w.err != nil {
|
||||
w.err = err
|
||||
w.c.writer = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// flushFrame writes buffered data and extra as a frame to the network. The
|
||||
// final argument indicates that this is the last frame in the message.
|
||||
func (c *Conn) flushFrame(final bool, extra []byte) error {
|
||||
length := c.writePos - maxFrameHeaderSize + len(extra)
|
||||
func (w *messageWriter) flushFrame(final bool, extra []byte) error {
|
||||
c := w.c
|
||||
length := w.pos - maxFrameHeaderSize + len(extra)
|
||||
|
||||
// Check for invalid control frames.
|
||||
if isControl(c.writeFrameType) &&
|
||||
if isControl(w.frameType) &&
|
||||
(!final || length > maxControlFramePayloadSize) {
|
||||
c.messageWriter = nil
|
||||
c.writer = nil
|
||||
c.writeFrameType = noFrame
|
||||
c.writePos = maxFrameHeaderSize
|
||||
return errInvalidControlFrame
|
||||
return w.fatal(errInvalidControlFrame)
|
||||
}
|
||||
|
||||
b0 := byte(c.writeFrameType)
|
||||
b0 := byte(w.frameType)
|
||||
if final {
|
||||
b0 |= finalBit
|
||||
}
|
||||
if c.writeCompress {
|
||||
if w.compress {
|
||||
b0 |= rsv1Bit
|
||||
}
|
||||
c.writeCompress = false
|
||||
w.compress = false
|
||||
|
||||
b1 := byte(0)
|
||||
if !c.isServer {
|
||||
|
@ -489,10 +515,9 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|||
if !c.isServer {
|
||||
key := newMaskKey()
|
||||
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
|
||||
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos])
|
||||
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
|
||||
if len(extra) > 0 {
|
||||
c.writeErr = errors.New("websocket: internal error, extra used in client mode")
|
||||
return c.writeErr
|
||||
return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -505,44 +530,35 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
|
|||
}
|
||||
c.isWriting = true
|
||||
|
||||
c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra)
|
||||
err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
|
||||
|
||||
if !c.isWriting {
|
||||
panic("concurrent write to websocket connection")
|
||||
}
|
||||
c.isWriting = false
|
||||
|
||||
// Setup for next frame.
|
||||
c.writePos = maxFrameHeaderSize
|
||||
c.writeFrameType = continuationFrame
|
||||
if err != nil {
|
||||
return w.fatal(err)
|
||||
}
|
||||
|
||||
if final {
|
||||
c.messageWriter = nil
|
||||
c.writer = nil
|
||||
c.writeFrameType = noFrame
|
||||
}
|
||||
return c.writeErr
|
||||
return nil
|
||||
}
|
||||
|
||||
type messageWriter struct{ c *Conn }
|
||||
|
||||
func (w *messageWriter) err() error {
|
||||
c := w.c
|
||||
if c.messageWriter != w {
|
||||
return errWriteClosed
|
||||
}
|
||||
if c.writeErr != nil {
|
||||
return c.writeErr
|
||||
}
|
||||
// Setup for next frame.
|
||||
w.pos = maxFrameHeaderSize
|
||||
w.frameType = continuationFrame
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *messageWriter) ncopy(max int) (int, error) {
|
||||
n := len(w.c.writeBuf) - w.c.writePos
|
||||
n := len(w.c.writeBuf) - w.pos
|
||||
if n <= 0 {
|
||||
if err := w.c.flushFrame(false, nil); err != nil {
|
||||
if err := w.flushFrame(false, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n = len(w.c.writeBuf) - w.c.writePos
|
||||
n = len(w.c.writeBuf) - w.pos
|
||||
}
|
||||
if n > max {
|
||||
n = max
|
||||
|
@ -551,13 +567,13 @@ func (w *messageWriter) ncopy(max int) (int, error) {
|
|||
}
|
||||
|
||||
func (w *messageWriter) Write(p []byte) (int, error) {
|
||||
if err := w.err(); err != nil {
|
||||
return 0, err
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
|
||||
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
|
||||
// Don't buffer large messages.
|
||||
err := w.c.flushFrame(false, p)
|
||||
err := w.flushFrame(false, p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -570,16 +586,16 @@ func (w *messageWriter) Write(p []byte) (int, error) {
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
copy(w.c.writeBuf[w.c.writePos:], p[:n])
|
||||
w.c.writePos += n
|
||||
copy(w.c.writeBuf[w.pos:], p[:n])
|
||||
w.pos += n
|
||||
p = p[n:]
|
||||
}
|
||||
return nn, nil
|
||||
}
|
||||
|
||||
func (w *messageWriter) WriteString(p string) (int, error) {
|
||||
if err := w.err(); err != nil {
|
||||
return 0, err
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
|
||||
nn := len(p)
|
||||
|
@ -588,27 +604,27 @@ func (w *messageWriter) WriteString(p string) (int, error) {
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
copy(w.c.writeBuf[w.c.writePos:], p[:n])
|
||||
w.c.writePos += n
|
||||
copy(w.c.writeBuf[w.pos:], p[:n])
|
||||
w.pos += n
|
||||
p = p[n:]
|
||||
}
|
||||
return nn, nil
|
||||
}
|
||||
|
||||
func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
||||
if err := w.err(); err != nil {
|
||||
return 0, err
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
for {
|
||||
if w.c.writePos == len(w.c.writeBuf) {
|
||||
err = w.c.flushFrame(false, nil)
|
||||
if w.pos == len(w.c.writeBuf) {
|
||||
err = w.flushFrame(false, nil)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
var n int
|
||||
n, err = r.Read(w.c.writeBuf[w.c.writePos:])
|
||||
w.c.writePos += n
|
||||
n, err = r.Read(w.c.writeBuf[w.pos:])
|
||||
w.pos += n
|
||||
nn += int64(n)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
|
@ -621,10 +637,14 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
|||
}
|
||||
|
||||
func (w *messageWriter) Close() error {
|
||||
if err := w.err(); err != nil {
|
||||
if w.err != nil {
|
||||
return w.err
|
||||
}
|
||||
if err := w.flushFrame(true, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.c.flushFrame(true, nil)
|
||||
w.err = errWriteClosed
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteMessage is a helper method for getting a writer using NextWriter,
|
||||
|
@ -634,12 +654,12 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, ok := w.(*messageWriter); ok && c.isServer {
|
||||
if mw, ok := w.(*messageWriter); ok && c.isServer {
|
||||
// Optimize write as a single frame.
|
||||
n := copy(c.writeBuf[c.writePos:], data)
|
||||
c.writePos += n
|
||||
n := copy(c.writeBuf[mw.pos:], data)
|
||||
mw.pos += n
|
||||
data = data[n:]
|
||||
err = c.flushFrame(true, data)
|
||||
err = mw.flushFrame(true, data)
|
||||
return err
|
||||
}
|
||||
if _, err = w.Write(data); err != nil {
|
||||
|
|
86
conn_test.go
86
conn_test.go
|
@ -26,12 +26,27 @@ type fakeNetConn struct {
|
|||
}
|
||||
|
||||
func (c fakeNetConn) Close() error { return nil }
|
||||
func (c fakeNetConn) LocalAddr() net.Addr { return nil }
|
||||
func (c fakeNetConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
|
||||
func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
|
||||
func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
type fakeAddr int
|
||||
|
||||
var (
|
||||
localAddr = fakeAddr(1)
|
||||
remoteAddr = fakeAddr(2)
|
||||
)
|
||||
|
||||
func (a fakeAddr) Network() string {
|
||||
return "net"
|
||||
}
|
||||
|
||||
func (a fakeAddr) String() string {
|
||||
return "str"
|
||||
}
|
||||
|
||||
func TestFraming(t *testing.T) {
|
||||
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
|
||||
var readChunkers = []struct {
|
||||
|
@ -42,11 +57,25 @@ func TestFraming(t *testing.T) {
|
|||
{"one", iotest.OneByteReader},
|
||||
{"asis", func(r io.Reader) io.Reader { return r }},
|
||||
}
|
||||
|
||||
writeBuf := make([]byte, 65537)
|
||||
for i := range writeBuf {
|
||||
writeBuf[i] = byte(i)
|
||||
}
|
||||
var writers = []struct {
|
||||
name string
|
||||
f func(w io.Writer, n int) (int, error)
|
||||
}{
|
||||
{"iocopy", func(w io.Writer, n int) (int, error) {
|
||||
nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
|
||||
return int(nn), err
|
||||
}},
|
||||
{"write", func(w io.Writer, n int) (int, error) {
|
||||
return w.Write(writeBuf[:n])
|
||||
}},
|
||||
{"string", func(w io.Writer, n int) (int, error) {
|
||||
return io.WriteString(w, string(writeBuf[:n]))
|
||||
}},
|
||||
}
|
||||
|
||||
for _, compress := range []bool{false, true} {
|
||||
for _, isServer := range []bool{true, false} {
|
||||
|
@ -60,22 +89,15 @@ func TestFraming(t *testing.T) {
|
|||
rc.newDecompressionReader = decompressNoContextTakeover
|
||||
}
|
||||
for _, n := range frameSizes {
|
||||
for _, iocopy := range []bool{true, false} {
|
||||
name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d c:%v", compress, isServer, chunker.name, n, iocopy)
|
||||
for _, writer := range writers {
|
||||
name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
|
||||
|
||||
w, err := wc.NextWriter(TextMessage)
|
||||
if err != nil {
|
||||
t.Errorf("%s: wc.NextWriter() returned %v", name, err)
|
||||
continue
|
||||
}
|
||||
var nn int
|
||||
if iocopy {
|
||||
var n64 int64
|
||||
n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
|
||||
nn = int(n64)
|
||||
} else {
|
||||
nn, err = w.Write(writeBuf[:n])
|
||||
}
|
||||
nn, err := writer.f(w, n)
|
||||
if err != nil || nn != n {
|
||||
t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
|
||||
continue
|
||||
|
@ -151,7 +173,7 @@ func TestControl(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCloseBeforeFinalFrame(t *testing.T) {
|
||||
func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
|
||||
const bufSize = 512
|
||||
|
||||
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
|
||||
|
@ -238,6 +260,32 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
|
||||
w, _ := wc.NextWriter(BinaryMessage)
|
||||
io.WriteString(w, "hello")
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatalf("unxpected error closing message writer, %v", err)
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(w, "world"); err == nil {
|
||||
t.Fatalf("no error writing after close")
|
||||
}
|
||||
|
||||
w, _ = wc.NextWriter(BinaryMessage)
|
||||
io.WriteString(w, "hello")
|
||||
|
||||
// close w by getting next writer
|
||||
_, err := wc.NextWriter(BinaryMessage)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error getting next writer, %v", err)
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(w, "world"); err == nil {
|
||||
t.Fatalf("no error writing after close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadLimit(t *testing.T) {
|
||||
|
||||
const readLimit = 512
|
||||
|
@ -272,6 +320,16 @@ func TestReadLimit(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAddrs(t *testing.T) {
|
||||
c := newConn(&fakeNetConn{}, true, 1024, 1024)
|
||||
if c.LocalAddr() != localAddr {
|
||||
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
|
||||
}
|
||||
if c.RemoteAddr() != remoteAddr {
|
||||
t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnderlyingConn(t *testing.T) {
|
||||
var b1, b2 bytes.Buffer
|
||||
fc := fakeNetConn{Reader: &b1, Writer: &b2}
|
||||
|
|
Loading…
Reference in New Issue