mirror of https://github.com/gorilla/websocket.git
Cleanup EOF handling.
- Modify data message reader to return io.ErrUnexpectedEOF if a close message is received before the final frame of the message. - Modify NextReader to return io.ErrUnexpectedEOF if underlying connection returns io.EOF before a close message.
This commit is contained in:
parent
aef42a8ae6
commit
10afcadf69
40
conn.go
40
conn.go
|
@ -516,6 +516,22 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
|
|||
|
||||
// Read methods
|
||||
|
||||
// readFull is like io.ReadFull except that io.EOF is never returned.
|
||||
func (c *Conn) readFull(p []byte) (err error) {
|
||||
var n int
|
||||
for n < len(p) && err == nil {
|
||||
var nn int
|
||||
nn, err = c.br.Read(p[n:])
|
||||
n += nn
|
||||
}
|
||||
if n == len(p) {
|
||||
err = nil
|
||||
} else if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) advanceFrame() (int, error) {
|
||||
|
||||
// 1. Skip remainder of previous frame.
|
||||
|
@ -529,7 +545,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
// 2. Read and parse first two bytes of frame header.
|
||||
|
||||
var b [8]byte
|
||||
if _, err := io.ReadFull(c.br, b[:2]); err != nil {
|
||||
if err := c.readFull(b[:2]); err != nil {
|
||||
return noFrame, err
|
||||
}
|
||||
|
||||
|
@ -569,12 +585,12 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
|
||||
switch c.readRemaining {
|
||||
case 126:
|
||||
if _, err := io.ReadFull(c.br, b[:2]); err != nil {
|
||||
if err := c.readFull(b[:2]); err != nil {
|
||||
return noFrame, err
|
||||
}
|
||||
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
|
||||
case 127:
|
||||
if _, err := io.ReadFull(c.br, b[:8]); err != nil {
|
||||
if err := c.readFull(b[:8]); err != nil {
|
||||
return noFrame, err
|
||||
}
|
||||
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
|
||||
|
@ -588,7 +604,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
|
||||
if mask {
|
||||
c.readMaskPos = 0
|
||||
if _, err := io.ReadFull(c.br, c.readMaskKey[:]); err != nil {
|
||||
if err := c.readFull(c.readMaskKey[:]); err != nil {
|
||||
return noFrame, err
|
||||
}
|
||||
}
|
||||
|
@ -612,7 +628,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
if c.readRemaining > 0 {
|
||||
payload = make([]byte, c.readRemaining)
|
||||
c.readRemaining = 0
|
||||
if _, err := io.ReadFull(c.br, payload); err != nil {
|
||||
if err := c.readFull(payload); err != nil {
|
||||
return noFrame, err
|
||||
}
|
||||
if c.isServer {
|
||||
|
@ -686,7 +702,7 @@ type messageReader struct {
|
|||
seq int
|
||||
}
|
||||
|
||||
func (r messageReader) Read(b []byte) (n int, err error) {
|
||||
func (r messageReader) Read(b []byte) (int, error) {
|
||||
|
||||
if r.seq != r.c.readSeq {
|
||||
return 0, io.EOF
|
||||
|
@ -713,13 +729,19 @@ func (r messageReader) Read(b []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
frameType, err := r.c.advanceFrame()
|
||||
if err != nil {
|
||||
switch {
|
||||
case err != nil:
|
||||
r.c.readErr = hideTempErr(err)
|
||||
} else if frameType == TextMessage || frameType == BinaryMessage {
|
||||
case frameType == TextMessage || frameType == BinaryMessage:
|
||||
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
||||
}
|
||||
}
|
||||
return 0, r.c.readErr
|
||||
|
||||
err := r.c.readErr
|
||||
if err == io.EOF && r.seq == r.c.readSeq {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// ReadMessage is a helper method for getting a reader using NextReader and
|
||||
|
|
50
conn_test.go
50
conn_test.go
|
@ -143,6 +143,56 @@ func TestControl(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCloseBeforeFinalFrame(t *testing.T) {
|
||||
const bufSize = 512
|
||||
|
||||
var b1, b2 bytes.Buffer
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
|
||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
||||
|
||||
w, _ := wc.NextWriter(BinaryMessage)
|
||||
w.Write(make([]byte, bufSize+bufSize/2))
|
||||
wc.WriteControl(CloseMessage, []byte{}, time.Now().Add(10*time.Second))
|
||||
w.Close()
|
||||
|
||||
op, r, err := rc.NextReader()
|
||||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||
}
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
if err != io.ErrUnexpectedEOF {
|
||||
t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
_, _, err = rc.NextReader()
|
||||
if err != io.EOF {
|
||||
t.Fatalf("NextReader() returned %v, want %v", err, io.EOF)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||
const bufSize = 512
|
||||
|
||||
var b1, b2 bytes.Buffer
|
||||
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
|
||||
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
||||
|
||||
w, _ := wc.NextWriter(BinaryMessage)
|
||||
w.Write(make([]byte, bufSize+bufSize/2))
|
||||
|
||||
op, r, err := rc.NextReader()
|
||||
if op != BinaryMessage || err != nil {
|
||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||
}
|
||||
_, err = io.Copy(ioutil.Discard, r)
|
||||
if err != io.ErrUnexpectedEOF {
|
||||
t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
_, _, err = rc.NextReader()
|
||||
if err != io.ErrUnexpectedEOF {
|
||||
t.Fatalf("NextReader() returned %v, want %v", err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadLimit(t *testing.T) {
|
||||
|
||||
const readLimit = 512
|
||||
|
|
Loading…
Reference in New Issue