Remove hideTempErr to allow downstream users to check for errors like net.ErrClosed (#894)

Since this change
https://github.com/gorilla/websocket/pull/840/files#diff-4f427d2b022907c552328e63f137561f6de92396d7a6e8f6c2ea1bcf0db52654L190-R197
we can no longer determinate if the errors coming from ReadMessage() are
net.ErrClosed for example
Hardcoding the error message is not great option because it may vary
from OS to OS and system locale
This commit is contained in:
Rumen Nikiforov 2024-02-15 04:59:16 +02:00 committed by GitHub
parent d293aa53e1
commit 695e9095ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 11 deletions

14
conn.go
View File

@ -192,13 +192,6 @@ func newMaskKey() [4]byte {
return k
}
func hideTempErr(err error) error {
if e, ok := err.(net.Error); ok {
err = &netError{msg: e.Error(), timeout: e.Timeout()}
}
return err
}
func isControl(frameType int) bool {
return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage
}
@ -364,7 +357,6 @@ func (c *Conn) RemoteAddr() net.Addr {
// Write methods
func (c *Conn) writeFatal(err error) error {
err = hideTempErr(err)
c.writeErrMu.Lock()
if c.writeErr == nil {
c.writeErr = err
@ -1033,7 +1025,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
for c.readErr == nil {
frameType, err := c.advanceFrame()
if err != nil {
c.readErr = hideTempErr(err)
c.readErr = err
break
}
@ -1073,7 +1065,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
b = b[:c.readRemaining]
}
n, err := c.br.Read(b)
c.readErr = hideTempErr(err)
c.readErr = err
if c.isServer {
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
}
@ -1096,7 +1088,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
frameType, err := c.advanceFrame()
switch {
case err != nil:
c.readErr = hideTempErr(err)
c.readErr = err
case frameType == TextMessage || frameType == BinaryMessage:
c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
}

View File

@ -814,3 +814,25 @@ func TestFormatMessageType(t *testing.T) {
t.Error("failed to format message type")
}
}
type fakeNetClosedReader struct {
}
func (r fakeNetClosedReader) Read([]byte) (int, error) {
return 0, net.ErrClosed
}
func TestConnectionClosed(t *testing.T) {
var b1, b2 bytes.Buffer
client := newTestConn(fakeNetClosedReader{}, &b1, false)
server := newTestConn(fakeNetClosedReader{}, &b2, true)
if _, _, err := server.NextReader(); !errors.Is(err, net.ErrClosed) {
t.Fatalf("server expects a net.ErrClosed error, %v returned", err)
}
if _, _, err := client.NextReader(); !errors.Is(err, net.ErrClosed) {
t.Fatalf("client expects a net.ErrClosed error, %v returned", err)
}
}