mirror of https://github.com/gorilla/websocket.git
Cleanup EOF handling.
- Modify data message reader to return io.ErrUnexpectedEOF if a close message is received before the final frame of the message. - Modify NextReader to return io.ErrUnexpectedEOF if underlying connection returns io.EOF before a close message.
This commit is contained in:
parent
aef42a8ae6
commit
10afcadf69
40
conn.go
40
conn.go
|
@ -516,6 +516,22 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
|
|
||||||
// Read methods
|
// Read methods
|
||||||
|
|
||||||
|
// readFull is like io.ReadFull except that io.EOF is never returned.
|
||||||
|
func (c *Conn) readFull(p []byte) (err error) {
|
||||||
|
var n int
|
||||||
|
for n < len(p) && err == nil {
|
||||||
|
var nn int
|
||||||
|
nn, err = c.br.Read(p[n:])
|
||||||
|
n += nn
|
||||||
|
}
|
||||||
|
if n == len(p) {
|
||||||
|
err = nil
|
||||||
|
} else if err == io.EOF {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) advanceFrame() (int, error) {
|
func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
// 1. Skip remainder of previous frame.
|
// 1. Skip remainder of previous frame.
|
||||||
|
@ -529,7 +545,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
// 2. Read and parse first two bytes of frame header.
|
// 2. Read and parse first two bytes of frame header.
|
||||||
|
|
||||||
var b [8]byte
|
var b [8]byte
|
||||||
if _, err := io.ReadFull(c.br, b[:2]); err != nil {
|
if err := c.readFull(b[:2]); err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -569,12 +585,12 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
switch c.readRemaining {
|
switch c.readRemaining {
|
||||||
case 126:
|
case 126:
|
||||||
if _, err := io.ReadFull(c.br, b[:2]); err != nil {
|
if err := c.readFull(b[:2]); err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
|
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
|
||||||
case 127:
|
case 127:
|
||||||
if _, err := io.ReadFull(c.br, b[:8]); err != nil {
|
if err := c.readFull(b[:8]); err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
|
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
|
||||||
|
@ -588,7 +604,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
if mask {
|
if mask {
|
||||||
c.readMaskPos = 0
|
c.readMaskPos = 0
|
||||||
if _, err := io.ReadFull(c.br, c.readMaskKey[:]); err != nil {
|
if err := c.readFull(c.readMaskKey[:]); err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -612,7 +628,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
payload = make([]byte, c.readRemaining)
|
payload = make([]byte, c.readRemaining)
|
||||||
c.readRemaining = 0
|
c.readRemaining = 0
|
||||||
if _, err := io.ReadFull(c.br, payload); err != nil {
|
if err := c.readFull(payload); err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
if c.isServer {
|
if c.isServer {
|
||||||
|
@ -686,7 +702,7 @@ type messageReader struct {
|
||||||
seq int
|
seq int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r messageReader) Read(b []byte) (n int, err error) {
|
func (r messageReader) Read(b []byte) (int, error) {
|
||||||
|
|
||||||
if r.seq != r.c.readSeq {
|
if r.seq != r.c.readSeq {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
|
@ -713,13 +729,19 @@ func (r messageReader) Read(b []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
frameType, err := r.c.advanceFrame()
|
frameType, err := r.c.advanceFrame()
|
||||||
if err != nil {
|
switch {
|
||||||
|
case err != nil:
|
||||||
r.c.readErr = hideTempErr(err)
|
r.c.readErr = hideTempErr(err)
|
||||||
} else if frameType == TextMessage || frameType == BinaryMessage {
|
case frameType == TextMessage || frameType == BinaryMessage:
|
||||||
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0, r.c.readErr
|
|
||||||
|
err := r.c.readErr
|
||||||
|
if err == io.EOF && r.seq == r.c.readSeq {
|
||||||
|
err = io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadMessage is a helper method for getting a reader using NextReader and
|
// ReadMessage is a helper method for getting a reader using NextReader and
|
||||||
|
|
50
conn_test.go
50
conn_test.go
|
@ -143,6 +143,56 @@ func TestControl(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCloseBeforeFinalFrame(t *testing.T) {
|
||||||
|
const bufSize = 512
|
||||||
|
|
||||||
|
var b1, b2 bytes.Buffer
|
||||||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
|
||||||
|
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
||||||
|
|
||||||
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
|
w.Write(make([]byte, bufSize+bufSize/2))
|
||||||
|
wc.WriteControl(CloseMessage, []byte{}, time.Now().Add(10*time.Second))
|
||||||
|
w.Close()
|
||||||
|
|
||||||
|
op, r, err := rc.NextReader()
|
||||||
|
if op != BinaryMessage || err != nil {
|
||||||
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||||
|
}
|
||||||
|
_, err = io.Copy(ioutil.Discard, r)
|
||||||
|
if err != io.ErrUnexpectedEOF {
|
||||||
|
t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
|
||||||
|
}
|
||||||
|
_, _, err = rc.NextReader()
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("NextReader() returned %v, want %v", err, io.EOF)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
|
const bufSize = 512
|
||||||
|
|
||||||
|
var b1, b2 bytes.Buffer
|
||||||
|
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
|
||||||
|
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
|
||||||
|
|
||||||
|
w, _ := wc.NextWriter(BinaryMessage)
|
||||||
|
w.Write(make([]byte, bufSize+bufSize/2))
|
||||||
|
|
||||||
|
op, r, err := rc.NextReader()
|
||||||
|
if op != BinaryMessage || err != nil {
|
||||||
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||||
|
}
|
||||||
|
_, err = io.Copy(ioutil.Discard, r)
|
||||||
|
if err != io.ErrUnexpectedEOF {
|
||||||
|
t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF)
|
||||||
|
}
|
||||||
|
_, _, err = rc.NextReader()
|
||||||
|
if err != io.ErrUnexpectedEOF {
|
||||||
|
t.Fatalf("NextReader() returned %v, want %v", err, io.ErrUnexpectedEOF)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReadLimit(t *testing.T) {
|
func TestReadLimit(t *testing.T) {
|
||||||
|
|
||||||
const readLimit = 512
|
const readLimit = 512
|
||||||
|
|
Loading…
Reference in New Issue