Improve write error handling

- Do not fail NextWriter when close of previous writer fails.
- Replace closeSent field with mutex protected writeErr. Set writeErr on
  any error writing to underlying network connection. Check and return
  writeErr before attempting to write to network connection. Check
  writeErr in NextWriter so application can detect failed connection
  before attempting to write.
- Do not close underlying network connection on error.
- Move message writing state and method flushFrame from Conn to
  messageWriter. This makes error code paths (and the code in general)
  easier to understand.
- Add messageWriter field err to latch errors in messageWriter.

Bonus: Improve test coverage.
This commit is contained in:
Gary Burd 2016-11-02 09:41:43 -07:00
parent 343fff4c5c
commit 80a0029a65
2 changed files with 200 additions and 122 deletions

224
conn.go
View File

@ -12,6 +12,7 @@ import (
"io/ioutil" "io/ioutil"
"net" "net"
"strconv" "strconv"
"sync"
"time" "time"
"unicode/utf8" "unicode/utf8"
) )
@ -223,19 +224,16 @@ type Conn struct {
subprotocol string subprotocol string
// Write fields // Write fields
mu chan bool // used as mutex to protect write to conn and closeSent mu chan bool // used as mutex to protect write to conn
closeSent bool // whether close message was sent
writeErr error
writeBuf []byte // frame is constructed in this buffer. writeBuf []byte // frame is constructed in this buffer.
writePos int // end of data in writeBuf.
writeFrameType int // type of the current frame.
writeDeadline time.Time writeDeadline time.Time
messageWriter *messageWriter // the current low-level message writer
writer io.WriteCloser // the current writer returned to the application writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection isWriting bool // for best-effort concurrent write detection
writeErrMu sync.Mutex
writeErr error
enableWriteCompression bool enableWriteCompression bool
writeCompress bool // whether next call to flushFrame should set RSV1
newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error) newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error)
// Read fields // Read fields
@ -277,8 +275,6 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
mu: mu, mu: mu,
readFinal: true, readFinal: true,
writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize),
writeFrameType: noFrame,
writePos: maxFrameHeaderSize,
enableWriteCompression: true, enableWriteCompression: true,
} }
c.SetPingHandler(nil) c.SetPingHandler(nil)
@ -308,29 +304,40 @@ func (c *Conn) RemoteAddr() net.Addr {
// Write methods // Write methods
func (c *Conn) writeFatal(err error) error {
err = hideTempErr(err)
c.writeErrMu.Lock()
if c.writeErr == nil {
c.writeErr = err
}
c.writeErrMu.Unlock()
return err
}
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
<-c.mu <-c.mu
defer func() { c.mu <- true }() defer func() { c.mu <- true }()
if c.closeSent { c.writeErrMu.Lock()
return ErrCloseSent err := c.writeErr
} else if frameType == CloseMessage { c.writeErrMu.Unlock()
c.closeSent = true if err != nil {
return err
} }
c.conn.SetWriteDeadline(deadline) c.conn.SetWriteDeadline(deadline)
for _, buf := range bufs { for _, buf := range bufs {
if len(buf) > 0 { if len(buf) > 0 {
n, err := c.conn.Write(buf) _, err := c.conn.Write(buf)
if n != len(buf) {
// Close on partial write.
c.conn.Close()
}
if err != nil { if err != nil {
return err return c.writeFatal(err)
} }
} }
} }
if frameType == CloseMessage {
c.writeFatal(ErrCloseSent)
}
return nil return nil
} }
@ -379,18 +386,22 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
} }
defer func() { c.mu <- true }() defer func() { c.mu <- true }()
if c.closeSent { c.writeErrMu.Lock()
return ErrCloseSent err := c.writeErr
} else if messageType == CloseMessage { c.writeErrMu.Unlock()
c.closeSent = true if err != nil {
return err
} }
c.conn.SetWriteDeadline(deadline) c.conn.SetWriteDeadline(deadline)
n, err := c.conn.Write(buf) _, err = c.conn.Write(buf)
if n != 0 && n != len(buf) { if err != nil {
c.conn.Close() return c.writeFatal(err)
} }
return hideTempErr(err) if messageType == CloseMessage {
c.writeFatal(ErrCloseSent)
}
return err
} }
// NextWriter returns a writer for the next message to send. The writer's Close // NextWriter returns a writer for the next message to send. The writer's Close
@ -399,64 +410,79 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
// There can be at most one open writer on a connection. NextWriter closes the // There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so. // previous writer if the application has not already done so.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
if c.writeErr != nil {
return nil, c.writeErr
}
// Close previous writer if not already closed by the application. It's // Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot // probably better to return an error in this situation, but we cannot
// change this without breaking existing applications. // change this without breaking existing applications.
if c.writer != nil { if c.writer != nil {
err := c.writer.Close() c.writer.Close()
if err != nil { c.writer = nil
return nil, err
}
} }
if !isControl(messageType) && !isData(messageType) { if !isControl(messageType) && !isData(messageType) {
return nil, errBadWriteOpCode return nil, errBadWriteOpCode
} }
c.writeFrameType = messageType c.writeErrMu.Lock()
c.messageWriter = &messageWriter{c} err := c.writeErr
c.writeErrMu.Unlock()
var w io.WriteCloser = c.messageWriter
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
c.writeCompress = true
var err error
w, err = c.newCompressionWriter(w)
if err != nil { if err != nil {
c.writer.Close()
return nil, err return nil, err
} }
}
return w, nil mw := &messageWriter{
c: c,
frameType: messageType,
pos: maxFrameHeaderSize,
}
c.writer = mw
if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
w, err := c.newCompressionWriter(c.writer)
if err != nil {
c.writer = nil
return nil, err
}
mw.compress = true
c.writer = w
}
return c.writer, nil
}
type messageWriter struct {
c *Conn
compress bool // whether next call to flushFrame should set RSV1
pos int // end of data in writeBuf.
frameType int // type of the current frame.
err error
}
func (w *messageWriter) fatal(err error) error {
if w.err != nil {
w.err = err
w.c.writer = nil
}
return err
} }
// flushFrame writes buffered data and extra as a frame to the network. The // flushFrame writes buffered data and extra as a frame to the network. The
// final argument indicates that this is the last frame in the message. // final argument indicates that this is the last frame in the message.
func (c *Conn) flushFrame(final bool, extra []byte) error { func (w *messageWriter) flushFrame(final bool, extra []byte) error {
length := c.writePos - maxFrameHeaderSize + len(extra) c := w.c
length := w.pos - maxFrameHeaderSize + len(extra)
// Check for invalid control frames. // Check for invalid control frames.
if isControl(c.writeFrameType) && if isControl(w.frameType) &&
(!final || length > maxControlFramePayloadSize) { (!final || length > maxControlFramePayloadSize) {
c.messageWriter = nil return w.fatal(errInvalidControlFrame)
c.writer = nil
c.writeFrameType = noFrame
c.writePos = maxFrameHeaderSize
return errInvalidControlFrame
} }
b0 := byte(c.writeFrameType) b0 := byte(w.frameType)
if final { if final {
b0 |= finalBit b0 |= finalBit
} }
if c.writeCompress { if w.compress {
b0 |= rsv1Bit b0 |= rsv1Bit
} }
c.writeCompress = false w.compress = false
b1 := byte(0) b1 := byte(0)
if !c.isServer { if !c.isServer {
@ -489,10 +515,9 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
if !c.isServer { if !c.isServer {
key := newMaskKey() key := newMaskKey()
copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos]) maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
if len(extra) > 0 { if len(extra) > 0 {
c.writeErr = errors.New("websocket: internal error, extra used in client mode") return c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))
return c.writeErr
} }
} }
@ -505,44 +530,35 @@ func (c *Conn) flushFrame(final bool, extra []byte) error {
} }
c.isWriting = true c.isWriting = true
c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)
if !c.isWriting { if !c.isWriting {
panic("concurrent write to websocket connection") panic("concurrent write to websocket connection")
} }
c.isWriting = false c.isWriting = false
// Setup for next frame. if err != nil {
c.writePos = maxFrameHeaderSize return w.fatal(err)
c.writeFrameType = continuationFrame }
if final { if final {
c.messageWriter = nil
c.writer = nil c.writer = nil
c.writeFrameType = noFrame return nil
} }
return c.writeErr
}
type messageWriter struct{ c *Conn } // Setup for next frame.
w.pos = maxFrameHeaderSize
func (w *messageWriter) err() error { w.frameType = continuationFrame
c := w.c
if c.messageWriter != w {
return errWriteClosed
}
if c.writeErr != nil {
return c.writeErr
}
return nil return nil
} }
func (w *messageWriter) ncopy(max int) (int, error) { func (w *messageWriter) ncopy(max int) (int, error) {
n := len(w.c.writeBuf) - w.c.writePos n := len(w.c.writeBuf) - w.pos
if n <= 0 { if n <= 0 {
if err := w.c.flushFrame(false, nil); err != nil { if err := w.flushFrame(false, nil); err != nil {
return 0, err return 0, err
} }
n = len(w.c.writeBuf) - w.c.writePos n = len(w.c.writeBuf) - w.pos
} }
if n > max { if n > max {
n = max n = max
@ -551,13 +567,13 @@ func (w *messageWriter) ncopy(max int) (int, error) {
} }
func (w *messageWriter) Write(p []byte) (int, error) { func (w *messageWriter) Write(p []byte) (int, error) {
if err := w.err(); err != nil { if w.err != nil {
return 0, err return 0, w.err
} }
if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
// Don't buffer large messages. // Don't buffer large messages.
err := w.c.flushFrame(false, p) err := w.flushFrame(false, p)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -570,16 +586,16 @@ func (w *messageWriter) Write(p []byte) (int, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
copy(w.c.writeBuf[w.c.writePos:], p[:n]) copy(w.c.writeBuf[w.pos:], p[:n])
w.c.writePos += n w.pos += n
p = p[n:] p = p[n:]
} }
return nn, nil return nn, nil
} }
func (w *messageWriter) WriteString(p string) (int, error) { func (w *messageWriter) WriteString(p string) (int, error) {
if err := w.err(); err != nil { if w.err != nil {
return 0, err return 0, w.err
} }
nn := len(p) nn := len(p)
@ -588,27 +604,27 @@ func (w *messageWriter) WriteString(p string) (int, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
copy(w.c.writeBuf[w.c.writePos:], p[:n]) copy(w.c.writeBuf[w.pos:], p[:n])
w.c.writePos += n w.pos += n
p = p[n:] p = p[n:]
} }
return nn, nil return nn, nil
} }
func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
if err := w.err(); err != nil { if w.err != nil {
return 0, err return 0, w.err
} }
for { for {
if w.c.writePos == len(w.c.writeBuf) { if w.pos == len(w.c.writeBuf) {
err = w.c.flushFrame(false, nil) err = w.flushFrame(false, nil)
if err != nil { if err != nil {
break break
} }
} }
var n int var n int
n, err = r.Read(w.c.writeBuf[w.c.writePos:]) n, err = r.Read(w.c.writeBuf[w.pos:])
w.c.writePos += n w.pos += n
nn += int64(n) nn += int64(n)
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
@ -621,10 +637,14 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
} }
func (w *messageWriter) Close() error { func (w *messageWriter) Close() error {
if err := w.err(); err != nil { if w.err != nil {
return w.err
}
if err := w.flushFrame(true, nil); err != nil {
return err return err
} }
return w.c.flushFrame(true, nil) w.err = errWriteClosed
return nil
} }
// WriteMessage is a helper method for getting a writer using NextWriter, // WriteMessage is a helper method for getting a writer using NextWriter,
@ -634,12 +654,12 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
if err != nil { if err != nil {
return err return err
} }
if _, ok := w.(*messageWriter); ok && c.isServer { if mw, ok := w.(*messageWriter); ok && c.isServer {
// Optimize write as a single frame. // Optimize write as a single frame.
n := copy(c.writeBuf[c.writePos:], data) n := copy(c.writeBuf[mw.pos:], data)
c.writePos += n mw.pos += n
data = data[n:] data = data[n:]
err = c.flushFrame(true, data) err = mw.flushFrame(true, data)
return err return err
} }
if _, err = w.Write(data); err != nil { if _, err = w.Write(data); err != nil {

View File

@ -26,12 +26,27 @@ type fakeNetConn struct {
} }
func (c fakeNetConn) Close() error { return nil } func (c fakeNetConn) Close() error { return nil }
func (c fakeNetConn) LocalAddr() net.Addr { return nil } func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
func (c fakeNetConn) RemoteAddr() net.Addr { return nil } func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
type fakeAddr int
var (
localAddr = fakeAddr(1)
remoteAddr = fakeAddr(2)
)
func (a fakeAddr) Network() string {
return "net"
}
func (a fakeAddr) String() string {
return "str"
}
func TestFraming(t *testing.T) { func TestFraming(t *testing.T) {
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537} frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
var readChunkers = []struct { var readChunkers = []struct {
@ -42,11 +57,25 @@ func TestFraming(t *testing.T) {
{"one", iotest.OneByteReader}, {"one", iotest.OneByteReader},
{"asis", func(r io.Reader) io.Reader { return r }}, {"asis", func(r io.Reader) io.Reader { return r }},
} }
writeBuf := make([]byte, 65537) writeBuf := make([]byte, 65537)
for i := range writeBuf { for i := range writeBuf {
writeBuf[i] = byte(i) writeBuf[i] = byte(i)
} }
var writers = []struct {
name string
f func(w io.Writer, n int) (int, error)
}{
{"iocopy", func(w io.Writer, n int) (int, error) {
nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
return int(nn), err
}},
{"write", func(w io.Writer, n int) (int, error) {
return w.Write(writeBuf[:n])
}},
{"string", func(w io.Writer, n int) (int, error) {
return io.WriteString(w, string(writeBuf[:n]))
}},
}
for _, compress := range []bool{false, true} { for _, compress := range []bool{false, true} {
for _, isServer := range []bool{true, false} { for _, isServer := range []bool{true, false} {
@ -60,22 +89,15 @@ func TestFraming(t *testing.T) {
rc.newDecompressionReader = decompressNoContextTakeover rc.newDecompressionReader = decompressNoContextTakeover
} }
for _, n := range frameSizes { for _, n := range frameSizes {
for _, iocopy := range []bool{true, false} { for _, writer := range writers {
name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d c:%v", compress, isServer, chunker.name, n, iocopy) name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
w, err := wc.NextWriter(TextMessage) w, err := wc.NextWriter(TextMessage)
if err != nil { if err != nil {
t.Errorf("%s: wc.NextWriter() returned %v", name, err) t.Errorf("%s: wc.NextWriter() returned %v", name, err)
continue continue
} }
var nn int nn, err := writer.f(w, n)
if iocopy {
var n64 int64
n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n]))
nn = int(n64)
} else {
nn, err = w.Write(writeBuf[:n])
}
if err != nil || nn != n { if err != nil || nn != n {
t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
continue continue
@ -151,7 +173,7 @@ func TestControl(t *testing.T) {
} }
} }
func TestCloseBeforeFinalFrame(t *testing.T) { func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
const bufSize = 512 const bufSize = 512
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
@ -238,6 +260,32 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
} }
} }
func TestWriteAfterMessageWriterClose(t *testing.T) {
wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello")
if err := w.Close(); err != nil {
t.Fatalf("unxpected error closing message writer, %v", err)
}
if _, err := io.WriteString(w, "world"); err == nil {
t.Fatalf("no error writing after close")
}
w, _ = wc.NextWriter(BinaryMessage)
io.WriteString(w, "hello")
// close w by getting next writer
_, err := wc.NextWriter(BinaryMessage)
if err != nil {
t.Fatalf("unexpected error getting next writer, %v", err)
}
if _, err := io.WriteString(w, "world"); err == nil {
t.Fatalf("no error writing after close")
}
}
func TestReadLimit(t *testing.T) { func TestReadLimit(t *testing.T) {
const readLimit = 512 const readLimit = 512
@ -272,6 +320,16 @@ func TestReadLimit(t *testing.T) {
} }
} }
func TestAddrs(t *testing.T) {
c := newConn(&fakeNetConn{}, true, 1024, 1024)
if c.LocalAddr() != localAddr {
t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
}
if c.RemoteAddr() != remoteAddr {
t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
}
}
func TestUnderlyingConn(t *testing.T) { func TestUnderlyingConn(t *testing.T) {
var b1, b2 bytes.Buffer var b1, b2 bytes.Buffer
fc := fakeNetConn{Reader: &b1, Writer: &b2} fc := fakeNetConn{Reader: &b1, Writer: &b2}