From b6ab76f1fe9803ee1d59e7e5b2a797c1fe897ce5 Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Tue, 11 Aug 2015 10:14:32 -0700 Subject: [PATCH] Provide all close frame data to application - Export closeError. - Do not convert normal closure and going away to io.EOF. --- conn.go | 25 ++++++++++++------------- conn_test.go | 13 ++++++++----- json.go | 4 +--- json_test.go | 4 ++-- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/conn.go b/conn.go index e719f1c..a2374a8 100644 --- a/conn.go +++ b/conn.go @@ -88,19 +88,23 @@ func (e *netError) Error() string { return e.msg } func (e *netError) Temporary() bool { return e.temporary } func (e *netError) Timeout() bool { return e.timeout } -// closeError represents close frame. -type closeError struct { - code int - text string +// CloseError represents close frame. +type CloseError struct { + + // Code is defined in RFC 6455, section 11.7. + Code int + + // Text is the optional text payload. + Text string } -func (e *closeError) Error() string { - return "websocket: close " + strconv.Itoa(e.code) + " " + e.text +func (e *CloseError) Error() string { + return "websocket: close " + strconv.Itoa(e.Code) + " " + e.Text } var ( errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true} - errUnexpectedEOF = &closeError{code: CloseAbnormalClosure, text: io.ErrUnexpectedEOF.Error()} + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} errBadWriteOpCode = errors.New("websocket: bad write message type") errWriteClosed = errors.New("websocket: write closed") errInvalidControlFrame = errors.New("websocket: invalid control frame") @@ -673,12 +677,7 @@ func (c *Conn) advanceFrame() (int, error) { closeCode = int(binary.BigEndian.Uint16(payload)) closeText = string(payload[2:]) } - switch closeCode { - case CloseNormalClosure, CloseGoingAway: - return noFrame, io.EOF - default: - return noFrame, &closeError{code: closeCode, text: closeText} - } + return noFrame, &CloseError{Code: closeCode, Text: closeText} } return frameType, nil diff --git a/conn_test.go b/conn_test.go index 1f1197e..929be0e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,6 +10,7 @@ import ( "io" "io/ioutil" "net" + "reflect" "testing" "testing/iotest" "time" @@ -146,13 +147,15 @@ func TestControl(t *testing.T) { func TestCloseBeforeFinalFrame(t *testing.T) { const bufSize = 512 + expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} + var b1, b2 bytes.Buffer wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize+bufSize/2)) - wc.WriteControl(CloseMessage, FormatCloseMessage(CloseNormalClosure, ""), time.Now().Add(10*time.Second)) + wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) w.Close() op, r, err := rc.NextReader() @@ -160,12 +163,12 @@ func TestCloseBeforeFinalFrame(t *testing.T) { t.Fatalf("NextReader() returned %d, %v", op, err) } _, err = io.Copy(ioutil.Discard, r) - if err != errUnexpectedEOF { - t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) } _, _, err = rc.NextReader() - if err != io.EOF { - t.Fatalf("NextReader() returned %v, want %v", err, io.EOF) + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) } } diff --git a/json.go b/json.go index 18e62f2..4f0e368 100644 --- a/json.go +++ b/json.go @@ -48,9 +48,7 @@ func (c *Conn) ReadJSON(v interface{}) error { } err = json.NewDecoder(r).Decode(v) if err == io.EOF { - // Decode returns io.EOF when the message is empty or all whitespace. - // Convert to io.ErrUnexpectedEOF so that application can distinguish - // between an error reading the JSON value and the connection closing. + // One value is expected in the message. err = io.ErrUnexpectedEOF } return err diff --git a/json_test.go b/json_test.go index 1b7a5ec..61100e4 100644 --- a/json_test.go +++ b/json_test.go @@ -38,7 +38,7 @@ func TestJSON(t *testing.T) { } } -func TestPartialJsonRead(t *testing.T) { +func TestPartialJSONRead(t *testing.T) { var buf bytes.Buffer c := fakeNetConn{&buf, &buf} wc := newConn(c, true, 1024, 1024) @@ -87,7 +87,7 @@ func TestPartialJsonRead(t *testing.T) { } err = rc.ReadJSON(&v) - if err != io.EOF { + if _, ok := err.(*CloseError); !ok { t.Error("final", err) } }