diff --git a/json.go b/json.go index e0668f2..18e62f2 100644 --- a/json.go +++ b/json.go @@ -6,6 +6,7 @@ package websocket import ( "encoding/json" + "io" ) // WriteJSON is deprecated, use c.WriteJSON instead. @@ -45,5 +46,12 @@ func (c *Conn) ReadJSON(v interface{}) error { if err != nil { return err } - return json.NewDecoder(r).Decode(v) + err = json.NewDecoder(r).Decode(v) + if err == io.EOF { + // Decode returns io.EOF when the message is empty or all whitespace. + // Convert to io.ErrUnexpectedEOF so that application can distinguish + // between an error reading the JSON value and the connection closing. + err = io.ErrUnexpectedEOF + } + return err } diff --git a/json_test.go b/json_test.go index 2edb28d..1b7a5ec 100644 --- a/json_test.go +++ b/json_test.go @@ -6,6 +6,8 @@ package websocket import ( "bytes" + "encoding/json" + "io" "reflect" "testing" ) @@ -36,6 +38,60 @@ func TestJSON(t *testing.T) { } } +func TestPartialJsonRead(t *testing.T) { + var buf bytes.Buffer + c := fakeNetConn{&buf, &buf} + wc := newConn(c, true, 1024, 1024) + rc := newConn(c, false, 1024, 1024) + + var v struct { + A int + B string + } + v.A = 1 + v.B = "hello" + + messageCount := 0 + + // Partial JSON values. + + data, err := json.Marshal(v) + if err != nil { + t.Fatal(err) + } + for i := len(data) - 1; i >= 0; i-- { + if err := wc.WriteMessage(TextMessage, data[:i]); err != nil { + t.Fatal(err) + } + messageCount++ + } + + // Whitespace. + + if err := wc.WriteMessage(TextMessage, []byte(" ")); err != nil { + t.Fatal(err) + } + messageCount++ + + // Close. + + if err := wc.WriteMessage(CloseMessage, FormatCloseMessage(CloseNormalClosure, "")); err != nil { + t.Fatal(err) + } + + for i := 0; i < messageCount; i++ { + err := rc.ReadJSON(&v) + if err != io.ErrUnexpectedEOF { + t.Error("read", i, err) + } + } + + err = rc.ReadJSON(&v) + if err != io.EOF { + t.Error("final", err) + } +} + func TestDeprecatedJSON(t *testing.T) { var buf bytes.Buffer c := fakeNetConn{&buf, &buf}