Cleanup read operations.

- Use io.ReadFull instead of similar function in package.
- Return from Read with partial data. Don't attempt to fill buffer.
- Do not return net.Error with Temporary() == true
This commit is contained in:
Gary Burd 2014-06-06 09:12:15 -07:00
parent f4076986b6
commit efd7f76a14
1 changed files with 34 additions and 37 deletions

71
conn.go
View File

@ -95,6 +95,13 @@ const (
writeWait = time.Second writeWait = time.Second
) )
func hideTempErr(err error) error {
if e, ok := err.(net.Error); ok && e.Temporary() {
err = struct{ error }{err}
}
return err
}
func isControl(frameType int) bool { func isControl(frameType int) bool {
return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
} }
@ -501,7 +508,7 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
// SetWriteDeadline sets the write deadline on the underlying network // SetWriteDeadline sets the write deadline on the underlying network
// connection. After a write has timed out, the websocket state is corrupt and // connection. After a write has timed out, the websocket state is corrupt and
// all future writes will return an error. A zero value for t means writes will // all future writes will return an error. A zero value for t means writes will
// not time out // not time out
func (c *Conn) SetWriteDeadline(t time.Time) error { func (c *Conn) SetWriteDeadline(t time.Time) error {
c.writeDeadline = t c.writeDeadline = t
return nil return nil
@ -522,7 +529,7 @@ func (c *Conn) advanceFrame() (int, error) {
// 2. Read and parse first two bytes of frame header. // 2. Read and parse first two bytes of frame header.
var b [8]byte var b [8]byte
if err := c.read(b[:2]); err != nil { if _, err := io.ReadFull(c.br, b[:2]); err != nil {
return noFrame, err return noFrame, err
} }
@ -562,12 +569,12 @@ func (c *Conn) advanceFrame() (int, error) {
switch c.readRemaining { switch c.readRemaining {
case 126: case 126:
if err := c.read(b[:2]); err != nil { if _, err := io.ReadFull(c.br, b[:2]); err != nil {
return noFrame, err return noFrame, err
} }
c.readRemaining = int64(binary.BigEndian.Uint16(b[:2])) c.readRemaining = int64(binary.BigEndian.Uint16(b[:2]))
case 127: case 127:
if err := c.read(b[:8]); err != nil { if _, err := io.ReadFull(c.br, b[:8]); err != nil {
return noFrame, err return noFrame, err
} }
c.readRemaining = int64(binary.BigEndian.Uint64(b[:8])) c.readRemaining = int64(binary.BigEndian.Uint64(b[:8]))
@ -581,7 +588,7 @@ func (c *Conn) advanceFrame() (int, error) {
if mask { if mask {
c.readMaskPos = 0 c.readMaskPos = 0
if err := c.read(c.readMaskKey[:]); err != nil { if _, err := io.ReadFull(c.br, c.readMaskKey[:]); err != nil {
return noFrame, err return noFrame, err
} }
} }
@ -601,12 +608,15 @@ func (c *Conn) advanceFrame() (int, error) {
// 6. Read control frame payload. // 6. Read control frame payload.
payload := make([]byte, c.readRemaining) var payload []byte
c.readRemaining = 0 if c.readRemaining > 0 {
if err := c.read(payload); err != nil { payload = make([]byte, c.readRemaining)
return noFrame, err c.readRemaining = 0
if _, err := io.ReadFull(c.br, payload); err != nil {
return noFrame, err
}
maskBytes(c.readMaskKey, 0, payload)
} }
maskBytes(c.readMaskKey, 0, payload)
// 7. Process control frame payload. // 7. Process control frame payload.
@ -643,23 +653,6 @@ func (c *Conn) handleProtocolError(message string) error {
return errors.New("websocket: " + message) return errors.New("websocket: " + message)
} }
func (c *Conn) read(buf []byte) error {
var err error
for len(buf) > 0 && err == nil {
var nn int
nn, err = c.br.Read(buf)
buf = buf[nn:]
}
if err == io.EOF {
if len(buf) == 0 {
err = nil
} else {
err = io.ErrUnexpectedEOF
}
}
return err
}
// NextReader returns the next data message received from the peer. The // NextReader returns the next data message received from the peer. The
// returned messageType is either TextMessage or BinaryMessage. // returned messageType is either TextMessage or BinaryMessage.
// //
@ -674,8 +667,11 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.readLength = 0 c.readLength = 0
for c.readErr == nil { for c.readErr == nil {
var frameType int frameType, err := c.advanceFrame()
frameType, c.readErr = c.advanceFrame() if err != nil {
c.readErr = hideTempErr(err)
break
}
if frameType == TextMessage || frameType == BinaryMessage { if frameType == TextMessage || frameType == BinaryMessage {
return frameType, messageReader{c, c.readSeq}, nil return frameType, messageReader{c, c.readSeq}, nil
} }
@ -700,10 +696,11 @@ func (r messageReader) Read(b []byte) (n int, err error) {
if int64(len(b)) > r.c.readRemaining { if int64(len(b)) > r.c.readRemaining {
b = b[:r.c.readRemaining] b = b[:r.c.readRemaining]
} }
r.c.readErr = r.c.read(b) n, err := r.c.br.Read(b)
r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b) r.c.readErr = hideTempErr(err)
r.c.readRemaining -= int64(len(b)) r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n])
return len(b), r.c.readErr r.c.readRemaining -= int64(n)
return n, r.c.readErr
} }
if r.c.readFinal { if r.c.readFinal {
@ -711,10 +708,10 @@ func (r messageReader) Read(b []byte) (n int, err error) {
return 0, io.EOF return 0, io.EOF
} }
var frameType int frameType, err := r.c.advanceFrame()
frameType, r.c.readErr = r.c.advanceFrame() if err != nil {
r.c.readErr = hideTempErr(err)
if frameType == TextMessage || frameType == BinaryMessage { } else if frameType == TextMessage || frameType == BinaryMessage {
r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") r.c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
} }
} }