From 695e9095ce8736ac99c83939ca6b0fe93768f680 Mon Sep 17 00:00:00 2001 From: Rumen Nikiforov Date: Thu, 15 Feb 2024 04:59:16 +0200 Subject: [PATCH] Remove hideTempErr to allow downstream users to check for errors like net.ErrClosed (#894) Since this change https://github.com/gorilla/websocket/pull/840/files#diff-4f427d2b022907c552328e63f137561f6de92396d7a6e8f6c2ea1bcf0db52654L190-R197 we can no longer determinate if the errors coming from ReadMessage() are net.ErrClosed for example Hardcoding the error message is not great option because it may vary from OS to OS and system locale --- conn.go | 14 +++----------- conn_test.go | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index 2ac19e9..72f5df8 100644 --- a/conn.go +++ b/conn.go @@ -192,13 +192,6 @@ func newMaskKey() [4]byte { return k } -func hideTempErr(err error) error { - if e, ok := err.(net.Error); ok { - err = &netError{msg: e.Error(), timeout: e.Timeout()} - } - return err -} - func isControl(frameType int) bool { return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage } @@ -364,7 +357,6 @@ func (c *Conn) RemoteAddr() net.Addr { // Write methods func (c *Conn) writeFatal(err error) error { - err = hideTempErr(err) c.writeErrMu.Lock() if c.writeErr == nil { c.writeErr = err @@ -1033,7 +1025,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { for c.readErr == nil { frameType, err := c.advanceFrame() if err != nil { - c.readErr = hideTempErr(err) + c.readErr = err break } @@ -1073,7 +1065,7 @@ func (r *messageReader) Read(b []byte) (int, error) { b = b[:c.readRemaining] } n, err := c.br.Read(b) - c.readErr = hideTempErr(err) + c.readErr = err if c.isServer { c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) } @@ -1096,7 +1088,7 @@ func (r *messageReader) Read(b []byte) (int, error) { frameType, err := c.advanceFrame() switch { case err != nil: - c.readErr = hideTempErr(err) + c.readErr = err case frameType == TextMessage || frameType == BinaryMessage: c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") } diff --git a/conn_test.go b/conn_test.go index f0c29c3..3bd4f61 100644 --- a/conn_test.go +++ b/conn_test.go @@ -814,3 +814,25 @@ func TestFormatMessageType(t *testing.T) { t.Error("failed to format message type") } } + +type fakeNetClosedReader struct { +} + +func (r fakeNetClosedReader) Read([]byte) (int, error) { + return 0, net.ErrClosed +} + +func TestConnectionClosed(t *testing.T) { + var b1, b2 bytes.Buffer + + client := newTestConn(fakeNetClosedReader{}, &b1, false) + server := newTestConn(fakeNetClosedReader{}, &b2, true) + + if _, _, err := server.NextReader(); !errors.Is(err, net.ErrClosed) { + t.Fatalf("server expects a net.ErrClosed error, %v returned", err) + } + + if _, _, err := client.NextReader(); !errors.Is(err, net.ErrClosed) { + t.Fatalf("client expects a net.ErrClosed error, %v returned", err) + } +}