diff --git a/conn.go b/conn.go index eff26c6..e7180c8 100644 --- a/conn.go +++ b/conn.go @@ -821,6 +821,9 @@ func (r messageReader) Read(b []byte) (int, error) { r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n]) } r.c.readRemaining -= int64(n) + if r.c.readRemaining > 0 && r.c.readErr == io.EOF { + r.c.readErr = errUnexpectedEOF + } return n, r.c.readErr } diff --git a/conn_test.go b/conn_test.go index 04c8dd8..2e37ece 100644 --- a/conn_test.go +++ b/conn_test.go @@ -174,6 +174,33 @@ func TestCloseBeforeFinalFrame(t *testing.T) { } } +func TestEOFWithinFrame(t *testing.T) { + const bufSize = 512 + + var b bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024) + rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize)) + w.Close() + + b.Truncate(bufSize / 2) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) + } +} + func TestEOFBeforeFinalFrame(t *testing.T) { const bufSize = 512