From 10afcadf69098c003a97f23349c25ca16f0543e3 Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Thu, 10 Jul 2014 19:36:51 -0700 Subject: [PATCH] Cleanup EOF handling. - Modify data message reader to return io.ErrUnexpectedEOF if a close message is received before the final frame of the message. - Modify NextReader to return io.ErrUnexpectedEOF if underlying connection returns io.EOF before a close message. --- conn.go | 40 +++++++++++++++++++++++++++++++--------- conn_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 9 deletions(-) diff --git a/conn.go b/conn.go index d778d39..af41beb 100644 --- a/conn.go +++ b/conn.go @@ -516,6 +516,22 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { // Read methods +// readFull is like io.ReadFull except that io.EOF is never returned. +func (c *Conn) readFull(p []byte) (err error) { + var n int + for n < len(p) && err == nil { + var nn int + nn, err = c.br.Read(p[n:]) + n += nn + } + if n == len(p) { + err = nil + } else if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return +} + func (c *Conn) advanceFrame() (int, error) { // 1. Skip remainder of previous frame. @@ -529,7 +545,7 @@ func (c *Conn) advanceFrame() (int, error) { // 2. Read and parse first two bytes of frame header. var b [8]byte - if _, err := io.ReadFull(c.br, b[:2]); err != nil { + if err := c.readFull(b[:2]); err != nil { return noFrame, err } @@ -569,12 +585,12 @@ func (c *Conn) advanceFrame() (int, error) { switch c.readRemaining { case 126: - if _, err := io.ReadFull(c.br, b[:2]); err != nil { + if err := c.readFull(b[:2]); err != nil { return noFrame, err } c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) case 127: - if _, err := io.ReadFull(c.br, b[:8]); err != nil { + if err := c.readFull(b[:8]); err != nil { return noFrame, err } c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) @@ -588,7 +604,7 @@ func (c *Conn) advanceFrame() (int, error) { if mask { c.readMaskPos = 0 - if _, err := io.ReadFull(c.br, c.readMaskKey[:]); err != nil { + if err := c.readFull(c.readMaskKey[:]); err != nil { return noFrame, err } } @@ -612,7 +628,7 @@ func (c *Conn) advanceFrame() (int, error) { if c.readRemaining > 0 { payload = make([]byte, c.readRemaining) c.readRemaining = 0 - if _, err := io.ReadFull(c.br, payload); err != nil { + if err := c.readFull(payload); err != nil { return noFrame, err } if c.isServer { @@ -686,7 +702,7 @@ type messageReader struct { seq int } -func (r messageReader) Read(b []byte) (n int, err error) { +func (r messageReader) Read(b []byte) (int, error) { if r.seq != r.c.readSeq { return 0, io.EOF @@ -713,13 +729,19 @@ func (r messageReader) Read(b []byte) (n int, err error) { } frameType, err := r.c.advanceFrame() - if err != nil { + switch { + case err != nil: r.c.readErr = hideTempErr(err) - } else if frameType == TextMessage || frameType == BinaryMessage { + case frameType == TextMessage || frameType == BinaryMessage: r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } } - return 0, r.c.readErr + + err := r.c.readErr + if err == io.EOF && r.seq == r.c.readSeq { + err = io.ErrUnexpectedEOF + } + return 0, err } // ReadMessage is a helper method for getting a reader using NextReader and diff --git a/conn_test.go b/conn_test.go index 52bbede..632725a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -143,6 +143,56 @@ func TestControl(t *testing.T) { } } +func TestCloseBeforeFinalFrame(t *testing.T) { + const bufSize = 512 + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize+bufSize/2)) + wc.WriteControl(CloseMessage, []byte{}, time.Now().Add(10*time.Second)) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != io.ErrUnexpectedEOF { + t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != io.EOF { + t.Fatalf("NextReader() returned %v, want %v", err, io.EOF) + } +} + +func TestEOFBeforeFinalFrame(t *testing.T) { + const bufSize = 512 + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize+bufSize/2)) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != io.ErrUnexpectedEOF { + t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != io.ErrUnexpectedEOF { + t.Fatalf("NextReader() returned %v, want %v", err, io.ErrUnexpectedEOF) + } +} + func TestReadLimit(t *testing.T) { const readLimit = 512