diff --git a/compression.go b/compression.go index 9fed0ef..7e9c2ca 100644 --- a/compression.go +++ b/compression.go @@ -131,7 +131,7 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) { return 0, io.ErrClosedPipe } n, err := r.fr.Read(p) - if err == io.EOF { + if errors.Is(err, io.EOF) { // Preemptively place the reader back in the pool. This helps with // scenarios where the application does not call NextReader() soon after // this final read. diff --git a/conn.go b/conn.go index 221e6cf..6c36850 100644 --- a/conn.go +++ b/conn.go @@ -149,7 +149,8 @@ func (e *CloseError) Error() string { // IsCloseError returns boolean indicating whether the error is a *CloseError // with one of the specified codes. func IsCloseError(err error, codes ...int) bool { - if e, ok := err.(*CloseError); ok { + var e *CloseError + if errors.As(err, &e) { for _, code := range codes { if e.Code == code { return true @@ -162,7 +163,8 @@ func IsCloseError(err error, codes ...int) bool { // IsUnexpectedCloseError returns boolean indicating whether the error is a // *CloseError with a code not in the list of expected codes. func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { - if e, ok := err.(*CloseError); ok { + var e *CloseError + if errors.As(err, &e) { for _, code := range expectedCodes { if e.Code == code { return false @@ -376,7 +378,7 @@ func (c *Conn) writeFatal(err error) error { func (c *Conn) read(n int) ([]byte, error) { p, err := c.br.Peek(n) - if err == io.EOF { + if errors.Is(err, io.EOF) { err = errUnexpectedEOF } if _, err := c.br.Discard(len(p)); err != nil { @@ -730,7 +732,7 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { w.pos += n nn += int64(n) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { err = nil } break @@ -1082,7 +1084,7 @@ func (r *messageReader) Read(b []byte) (int, error) { if err := c.setReadRemaining(rem); err != nil { return 0, err } - if c.readRemaining > 0 && c.readErr == io.EOF { + if c.readRemaining > 0 && errors.Is(c.readErr, io.EOF) { c.readErr = errUnexpectedEOF } return n, c.readErr @@ -1103,7 +1105,7 @@ func (r *messageReader) Read(b []byte) (int, error) { } err := c.readErr - if err == io.EOF && c.messageReader == r { + if errors.Is(err, io.EOF) && c.messageReader == r { err = errUnexpectedEOF } return 0, err diff --git a/conn_test.go b/conn_test.go index 2b823dd..58201ed 100644 --- a/conn_test.go +++ b/conn_test.go @@ -404,18 +404,18 @@ func TestEOFWithinFrame(t *testing.T) { b.Truncate(n) op, r, err := rc.NextReader() - if err == errUnexpectedEOF { + if errors.Is(err, errUnexpectedEOF) { continue } if op != BinaryMessage || err != nil { t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) } _, err = io.Copy(io.Discard, r) - if err != errUnexpectedEOF { + if !errors.Is(err, errUnexpectedEOF) { t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) } _, _, err = rc.NextReader() - if err != errUnexpectedEOF { + if !errors.Is(err, errUnexpectedEOF) { t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) } } @@ -438,11 +438,11 @@ func TestEOFBeforeFinalFrame(t *testing.T) { t.Fatalf("NextReader() returned %d, %v", op, err) } _, err = io.Copy(io.Discard, r) - if err != errUnexpectedEOF { + if !errors.Is(err, errUnexpectedEOF) { t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) } _, _, err = rc.NextReader() - if err != errUnexpectedEOF { + if !errors.Is(err, errUnexpectedEOF) { t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) } } @@ -514,7 +514,7 @@ func TestReadLimit(t *testing.T) { t.Fatalf("2: NextReader() returned %d, %v", op, err) } _, err = io.Copy(io.Discard, r) - if err != ErrReadLimit { + if !errors.Is(err, ErrReadLimit) { t.Fatalf("io.Copy() returned %v", err) } }) @@ -558,13 +558,13 @@ func TestReadLimit(t *testing.T) { var buf [10]byte var read int n, err := r.Read(buf[:]) - if err != nil && err != ErrReadLimit { + if err != nil && !errors.Is(err, ErrReadLimit) { t.Fatalf("unexpected error testing read limit: %v", err) } read += n n, err = r.Read(buf[:]) - if err != nil && err != ErrReadLimit { + if err != nil && !errors.Is(err, ErrReadLimit) { t.Fatalf("unexpected error testing read limit: %v", err) } read += n diff --git a/join.go b/join.go index c64f8c8..d335e82 100644 --- a/join.go +++ b/join.go @@ -5,6 +5,7 @@ package websocket import ( + "errors" "io" "strings" ) @@ -34,7 +35,7 @@ func (r *joinReader) Read(p []byte) (int, error) { } } n, err := r.r.Read(p) - if err == io.EOF { + if errors.Is(err, io.EOF) { err = nil r.r = nil } diff --git a/json.go b/json.go index dc2c1f6..4cc6a01 100644 --- a/json.go +++ b/json.go @@ -6,6 +6,7 @@ package websocket import ( "encoding/json" + "errors" "io" ) @@ -52,7 +53,7 @@ func (c *Conn) ReadJSON(v interface{}) error { return err } err = json.NewDecoder(r).Decode(v) - if err == io.EOF { + if errors.Is(err, io.EOF) { // One value is expected in the message. err = io.ErrUnexpectedEOF } diff --git a/json_test.go b/json_test.go index e4c4bdf..27a5a86 100644 --- a/json_test.go +++ b/json_test.go @@ -7,6 +7,7 @@ package websocket import ( "bytes" "encoding/json" + "errors" "io" "reflect" "testing" @@ -79,13 +80,14 @@ func TestPartialJSONRead(t *testing.T) { for i := 0; i < messageCount; i++ { err := rc.ReadJSON(&v) - if err != io.ErrUnexpectedEOF { + if !errors.Is(err, io.ErrUnexpectedEOF) { t.Error("read", i, err) } } err = rc.ReadJSON(&v) - if _, ok := err.(*CloseError); !ok { + var ce *CloseError + if !errors.As(err, &ce) { t.Error("final", err) } }