diff --git a/conn.go b/conn.go index 2701142..73c64a4 100644 --- a/conn.go +++ b/conn.go @@ -70,18 +70,30 @@ var ( ErrReadLimit = errors.New("websocket: read limit exceeded") ) -type websocketError struct { +// netError satisfies the net Error interface. +type netError struct { msg string temporary bool timeout bool } -func (e *websocketError) Error() string { return e.msg } -func (e *websocketError) Temporary() bool { return e.temporary } -func (e *websocketError) Timeout() bool { return e.timeout } +func (e *netError) Error() string { return e.msg } +func (e *netError) Temporary() bool { return e.temporary } +func (e *netError) Timeout() bool { return e.timeout } + +// closeError represents close frame. +type closeError struct { + code int + text string +} + +func (e *closeError) Error() string { + return "websocket: close " + strconv.Itoa(e.code) + " " + e.text +} var ( - errWriteTimeout = &websocketError{msg: "websocket: write timeout", timeout: true} + errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true} + errUnexpectedEOF = &closeError{code: CloseAbnormalClosure, text: io.ErrUnexpectedEOF.Error()} errBadWriteOpCode = errors.New("websocket: bad write message type") errWriteClosed = errors.New("websocket: write closed") errInvalidControlFrame = errors.New("websocket: invalid control frame") @@ -527,7 +539,7 @@ func (c *Conn) readFull(p []byte) (err error) { if n == len(p) { err = nil } else if err == io.EOF { - err = io.ErrUnexpectedEOF + err = errUnexpectedEOF } return } @@ -649,17 +661,17 @@ func (c *Conn) advanceFrame() (int, error) { } case CloseMessage: c.WriteControl(CloseMessage, []byte{}, time.Now().Add(writeWait)) - if len(payload) < 2 { - return noFrame, io.EOF + closeCode := CloseNoStatusReceived + closeText := "" + if len(payload) >= 2 { + closeCode = int(binary.BigEndian.Uint16(payload)) + closeText = string(payload[2:]) } - closeCode := binary.BigEndian.Uint16(payload) switch closeCode { case CloseNormalClosure, CloseGoingAway: return noFrame, io.EOF default: - return noFrame, errors.New("websocket: close " + - strconv.Itoa(int(closeCode)) + " " + - string(payload[2:])) + return noFrame, &closeError{code: closeCode, text: closeText} } } @@ -739,7 +751,7 @@ func (r messageReader) Read(b []byte) (int, error) { err := r.c.readErr if err == io.EOF && r.seq == r.c.readSeq { - err = io.ErrUnexpectedEOF + err = errUnexpectedEOF } return 0, err } diff --git a/conn_test.go b/conn_test.go index 632725a..1f1197e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -152,7 +152,7 @@ func TestCloseBeforeFinalFrame(t *testing.T) { w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize+bufSize/2)) - wc.WriteControl(CloseMessage, []byte{}, time.Now().Add(10*time.Second)) + wc.WriteControl(CloseMessage, FormatCloseMessage(CloseNormalClosure, ""), time.Now().Add(10*time.Second)) w.Close() op, r, err := rc.NextReader() @@ -160,8 +160,8 @@ func TestCloseBeforeFinalFrame(t *testing.T) { t.Fatalf("NextReader() returned %d, %v", op, err) } _, err = io.Copy(ioutil.Discard, r) - if err != io.ErrUnexpectedEOF { - t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF) + if err != errUnexpectedEOF { + t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) } _, _, err = rc.NextReader() if err != io.EOF { @@ -184,12 +184,12 @@ func TestEOFBeforeFinalFrame(t *testing.T) { t.Fatalf("NextReader() returned %d, %v", op, err) } _, err = io.Copy(ioutil.Discard, r) - if err != io.ErrUnexpectedEOF { - t.Fatalf("io.Copy() returned %v, want %v", err, io.ErrUnexpectedEOF) + if err != errUnexpectedEOF { + t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) } _, _, err = rc.NextReader() - if err != io.ErrUnexpectedEOF { - t.Fatalf("NextReader() returned %v, want %v", err, io.ErrUnexpectedEOF) + if err != errUnexpectedEOF { + t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) } }