Use errors.As and errors.Is

Replace type assertions to error types with errors.As. Replace
comparisons with error values with errors.Is.
This commit is contained in:
Halo Arrow 2023-08-29 15:06:08 -07:00
parent 666c197fc9
commit 493d31ecb9
6 changed files with 25 additions and 19 deletions

View File

@ -131,7 +131,7 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) {
return 0, io.ErrClosedPipe
}
n, err := r.fr.Read(p)
if err == io.EOF {
if errors.Is(err, io.EOF) {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.

14
conn.go
View File

@ -149,7 +149,8 @@ func (e *CloseError) Error() string {
// IsCloseError returns boolean indicating whether the error is a *CloseError
// with one of the specified codes.
func IsCloseError(err error, codes ...int) bool {
if e, ok := err.(*CloseError); ok {
var e *CloseError
if errors.As(err, &e) {
for _, code := range codes {
if e.Code == code {
return true
@ -162,7 +163,8 @@ func IsCloseError(err error, codes ...int) bool {
// IsUnexpectedCloseError returns boolean indicating whether the error is a
// *CloseError with a code not in the list of expected codes.
func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
if e, ok := err.(*CloseError); ok {
var e *CloseError
if errors.As(err, &e) {
for _, code := range expectedCodes {
if e.Code == code {
return false
@ -376,7 +378,7 @@ func (c *Conn) writeFatal(err error) error {
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = errUnexpectedEOF
}
if _, err := c.br.Discard(len(p)); err != nil {
@ -730,7 +732,7 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
w.pos += n
nn += int64(n)
if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = nil
}
break
@ -1082,7 +1084,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
if err := c.setReadRemaining(rem); err != nil {
return 0, err
}
if c.readRemaining > 0 && c.readErr == io.EOF {
if c.readRemaining > 0 && errors.Is(c.readErr, io.EOF) {
c.readErr = errUnexpectedEOF
}
return n, c.readErr
@ -1103,7 +1105,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
}
err := c.readErr
if err == io.EOF && c.messageReader == r {
if errors.Is(err, io.EOF) && c.messageReader == r {
err = errUnexpectedEOF
}
return 0, err

View File

@ -404,18 +404,18 @@ func TestEOFWithinFrame(t *testing.T) {
b.Truncate(n)
op, r, err := rc.NextReader()
if err == errUnexpectedEOF {
if errors.Is(err, errUnexpectedEOF) {
continue
}
if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
}
_, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF {
if !errors.Is(err, errUnexpectedEOF) {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
}
_, _, err = rc.NextReader()
if err != errUnexpectedEOF {
if !errors.Is(err, errUnexpectedEOF) {
t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
}
}
@ -438,11 +438,11 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
t.Fatalf("NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF {
if !errors.Is(err, errUnexpectedEOF) {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
}
_, _, err = rc.NextReader()
if err != errUnexpectedEOF {
if !errors.Is(err, errUnexpectedEOF) {
t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
}
}
@ -514,7 +514,7 @@ func TestReadLimit(t *testing.T) {
t.Fatalf("2: NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(io.Discard, r)
if err != ErrReadLimit {
if !errors.Is(err, ErrReadLimit) {
t.Fatalf("io.Copy() returned %v", err)
}
})
@ -558,13 +558,13 @@ func TestReadLimit(t *testing.T) {
var buf [10]byte
var read int
n, err := r.Read(buf[:])
if err != nil && err != ErrReadLimit {
if err != nil && !errors.Is(err, ErrReadLimit) {
t.Fatalf("unexpected error testing read limit: %v", err)
}
read += n
n, err = r.Read(buf[:])
if err != nil && err != ErrReadLimit {
if err != nil && !errors.Is(err, ErrReadLimit) {
t.Fatalf("unexpected error testing read limit: %v", err)
}
read += n

View File

@ -5,6 +5,7 @@
package websocket
import (
"errors"
"io"
"strings"
)
@ -34,7 +35,7 @@ func (r *joinReader) Read(p []byte) (int, error) {
}
}
n, err := r.r.Read(p)
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = nil
r.r = nil
}

View File

@ -6,6 +6,7 @@ package websocket
import (
"encoding/json"
"errors"
"io"
)
@ -52,7 +53,7 @@ func (c *Conn) ReadJSON(v interface{}) error {
return err
}
err = json.NewDecoder(r).Decode(v)
if err == io.EOF {
if errors.Is(err, io.EOF) {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}

View File

@ -7,6 +7,7 @@ package websocket
import (
"bytes"
"encoding/json"
"errors"
"io"
"reflect"
"testing"
@ -79,13 +80,14 @@ func TestPartialJSONRead(t *testing.T) {
for i := 0; i < messageCount; i++ {
err := rc.ReadJSON(&v)
if err != io.ErrUnexpectedEOF {
if !errors.Is(err, io.ErrUnexpectedEOF) {
t.Error("read", i, err)
}
}
err = rc.ReadJSON(&v)
if _, ok := err.(*CloseError); !ok {
var ce *CloseError
if !errors.As(err, &ce) {
t.Error("final", err)
}
}