Cleanup EOF handling.

- Modify data message reader to return io.ErrUnexpectedEOF if a close
  message is received before the final frame of the message.
- Modify NextReader to return io.ErrUnexpectedEOF if underlying
  connection returns io.EOF before a close message.
This commit is contained in:
Gary Burd 2014-07-10 19:36:51 -07:00
parent aef42a8ae6
commit 10afcadf69
2 changed files with 81 additions and 9 deletions

40
conn.go
View File

@ -516,6 +516,22 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
// Read methods
// readFull is like io.ReadFull except that io.EOF is never returned.
func (c *Conn) readFull(p []byte) (err error) {
var n int
for n < len(p) && err == nil {
var nn int
nn, err = c.br.Read(p[n:])
n += nn
}
if n == len(p) {
err = nil
} else if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return
}
func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame.
@ -529,7 +545,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 2. Read and parse first two bytes of frame header.
var b [8]byte
if _, err := io.ReadFull(c.br, b[:2]); err != nil {
if err := c.readFull(b[:2]); err != nil {
return noFrame, err
}
@ -569,12 +585,12 @@ func (c *Conn) advanceFrame() (int, error) {
switch c.readRemaining {
case 126:
if _, err := io.ReadFull(c.br, b[:2]); err != nil {
if err := c.readFull(b[:2]); err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
case 127:
if _, err := io.ReadFull(c.br, b[:8]); err != nil {
if err := c.readFull(b[:8]); err != nil {
return noFrame, err
}
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
@ -588,7 +604,7 @@ func (c *Conn) advanceFrame() (int, error) {
if mask {
c.readMaskPos = 0
if _, err := io.ReadFull(c.br, c.readMaskKey[:]); err != nil {
if err := c.readFull(c.readMaskKey[:]); err != nil {
return noFrame, err
}
}
@ -612,7 +628,7 @@ func (c *Conn) advanceFrame() (int, error) {
if c.readRemaining > 0 {
payload = make([]byte, c.readRemaining)
c.readRemaining = 0
if _, err := io.ReadFull(c.br, payload); err != nil {
if err := c.readFull(payload); err != nil {
return noFrame, err
}
if c.isServer {
@ -686,7 +702,7 @@ type messageReader struct {
seq int
}
func (r messageReader) Read(b []byte) (n int, err error) {
func (r messageReader) Read(b []byte) (int, error) {
if r.seq != r.c.readSeq {
return 0, io.EOF
@ -713,13 +729,19 @@ func (r messageReader) Read(b []byte) (n int, err error) {
}
frameType, err := r.c.advanceFrame()
if err != nil {
switch {
case err != nil:
r.c.readErr = hideTempErr(err)
} else if frameType == TextMessage || frameType == BinaryMessage {
case frameType == TextMessage || frameType == BinaryMessage:
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
}
}
return 0, r.c.readErr
err := r.c.readErr
if err == io.EOF && r.seq == r.c.readSeq {
err = io.ErrUnexpectedEOF
}
return 0, err
}
// ReadMessage is a helper method for getting a reader using NextReader and

View File

@ -143,6 +143,56 @@ func TestControl(t *testing.T) {
}
}
func TestCloseBeforeFinalFrame(t *testing.T) {
const bufSize = 512
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
wc.WriteControl(CloseMessage, []byte{}, time.Now().Add(10*time.Second))
w.Close()
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != io.ErrUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
}
_, _, err = rc.NextReader()
if err != io.EOF {
t.Fatalf("NextReader() returned %v, want %v", err, io.EOF)
}
}
func TestEOFBeforeFinalFrame(t *testing.T) {
const bufSize = 512
var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2))
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != io.ErrUnexpectedEOF {
t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
}
_, _, err = rc.NextReader()
if err != io.ErrUnexpectedEOF {
t.Fatalf("NextReader() returned %v, want %v", err, io.ErrUnexpectedEOF)
}
}
func TestReadLimit(t *testing.T) {
const readLimit = 512