diff --git a/conn.go b/conn.go index d485949..819ab0d 100644 --- a/conn.go +++ b/conn.go @@ -95,6 +95,13 @@ const ( writeWait = time.Second ) +func hideTempErr(err error) error { + if e, ok := err.(net.Error); ok && e.Temporary() { + err = struct{ error }{err} + } + return err +} + func isControl(frameType int) bool { return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage } @@ -501,7 +508,7 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { // SetWriteDeadline sets the write deadline on the underlying network // connection. After a write has timed out, the websocket state is corrupt and // all future writes will return an error. A zero value for t means writes will -// not time out +// not time out func (c *Conn) SetWriteDeadline(t time.Time) error { c.writeDeadline = t return nil @@ -522,7 +529,7 @@ func (c *Conn) advanceFrame() (int, error) { // 2. Read and parse first two bytes of frame header. var b [8]byte - if err := c.read(b[:2]); err != nil { + if _, err := io.ReadFull(c.br, b[:2]); err != nil { return noFrame, err } @@ -562,12 +569,12 @@ func (c *Conn) advanceFrame() (int, error) { switch c.readRemaining { case 126: - if err := c.read(b[:2]); err != nil { + if _, err := io.ReadFull(c.br, b[:2]); err != nil { return noFrame, err } c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) case 127: - if err := c.read(b[:8]); err != nil { + if _, err := io.ReadFull(c.br, b[:8]); err != nil { return noFrame, err } c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) @@ -581,7 +588,7 @@ func (c *Conn) advanceFrame() (int, error) { if mask { c.readMaskPos = 0 - if err := c.read(c.readMaskKey[:]); err != nil { + if _, err := io.ReadFull(c.br, c.readMaskKey[:]); err != nil { return noFrame, err } } @@ -601,12 +608,15 @@ func (c *Conn) advanceFrame() (int, error) { // 6. Read control frame payload. - payload := make([]byte, c.readRemaining) - c.readRemaining = 0 - if err := c.read(payload); err != nil { - return noFrame, err + var payload []byte + if c.readRemaining > 0 { + payload = make([]byte, c.readRemaining) + c.readRemaining = 0 + if _, err := io.ReadFull(c.br, payload); err != nil { + return noFrame, err + } + maskBytes(c.readMaskKey, 0, payload) } - maskBytes(c.readMaskKey, 0, payload) // 7. Process control frame payload. @@ -643,23 +653,6 @@ func (c *Conn) handleProtocolError(message string) error { return errors.New("websocket: " + message) } -func (c *Conn) read(buf []byte) error { - var err error - for len(buf) > 0 && err == nil { - var nn int - nn, err = c.br.Read(buf) - buf = buf[nn:] - } - if err == io.EOF { - if len(buf) == 0 { - err = nil - } else { - err = io.ErrUnexpectedEOF - } - } - return err -} - // NextReader returns the next data message received from the peer. The // returned messageType is either TextMessage or BinaryMessage. // @@ -674,8 +667,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { c.readLength = 0 for c.readErr == nil { - var frameType int - frameType, c.readErr = c.advanceFrame() + frameType, err := c.advanceFrame() + if err != nil { + c.readErr = hideTempErr(err) + break + } if frameType == TextMessage || frameType == BinaryMessage { return frameType, messageReader{c, c.readSeq}, nil } @@ -700,10 +696,11 @@ func (r messageReader) Read(b []byte) (n int, err error) { if int64(len(b)) > r.c.readRemaining { b = b[:r.c.readRemaining] } - r.c.readErr = r.c.read(b) - r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b) - r.c.readRemaining -= int64(len(b)) - return len(b), r.c.readErr + n, err := r.c.br.Read(b) + r.c.readErr = hideTempErr(err) + r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n]) + r.c.readRemaining -= int64(n) + return n, r.c.readErr } if r.c.readFinal { @@ -711,10 +708,10 @@ func (r messageReader) Read(b []byte) (n int, err error) { return 0, io.EOF } - var frameType int - frameType, r.c.readErr = r.c.advanceFrame() - - if frameType == TextMessage || frameType == BinaryMessage { + frameType, err := r.c.advanceFrame() + if err != nil { + r.c.readErr = hideTempErr(err) + } else if frameType == TextMessage || frameType == BinaryMessage { r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } }