Use errors.As and errors.Is

Replace type assertions to error types with errors.As. Replace
comparisons with error values with errors.Is.
This commit is contained in:
Halo Arrow 2023-08-29 15:06:08 -07:00
parent 666c197fc9
commit 493d31ecb9
6 changed files with 25 additions and 19 deletions

View File

@ -131,7 +131,7 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
n, err := r.fr.Read(p) n, err := r.fr.Read(p)
if err == io.EOF { if errors.Is(err, io.EOF) {
// Preemptively place the reader back in the pool. This helps with // Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after // scenarios where the application does not call NextReader() soon after
// this final read. // this final read.

14
conn.go
View File

@ -149,7 +149,8 @@ func (e *CloseError) Error() string {
// IsCloseError returns boolean indicating whether the error is a *CloseError // IsCloseError returns boolean indicating whether the error is a *CloseError
// with one of the specified codes. // with one of the specified codes.
func IsCloseError(err error, codes ...int) bool { func IsCloseError(err error, codes ...int) bool {
if e, ok := err.(*CloseError); ok { var e *CloseError
if errors.As(err, &e) {
for _, code := range codes { for _, code := range codes {
if e.Code == code { if e.Code == code {
return true return true
@ -162,7 +163,8 @@ func IsCloseError(err error, codes ...int) bool {
// IsUnexpectedCloseError returns boolean indicating whether the error is a // IsUnexpectedCloseError returns boolean indicating whether the error is a
// *CloseError with a code not in the list of expected codes. // *CloseError with a code not in the list of expected codes.
func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
if e, ok := err.(*CloseError); ok { var e *CloseError
if errors.As(err, &e) {
for _, code := range expectedCodes { for _, code := range expectedCodes {
if e.Code == code { if e.Code == code {
return false return false
@ -376,7 +378,7 @@ func (c *Conn) writeFatal(err error) error {
func (c *Conn) read(n int) ([]byte, error) { func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n) p, err := c.br.Peek(n)
if err == io.EOF { if errors.Is(err, io.EOF) {
err = errUnexpectedEOF err = errUnexpectedEOF
} }
if _, err := c.br.Discard(len(p)); err != nil { if _, err := c.br.Discard(len(p)); err != nil {
@ -730,7 +732,7 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
w.pos += n w.pos += n
nn += int64(n) nn += int64(n)
if err != nil { if err != nil {
if err == io.EOF { if errors.Is(err, io.EOF) {
err = nil err = nil
} }
break break
@ -1082,7 +1084,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
if err := c.setReadRemaining(rem); err != nil { if err := c.setReadRemaining(rem); err != nil {
return 0, err return 0, err
} }
if c.readRemaining > 0 && c.readErr == io.EOF { if c.readRemaining > 0 && errors.Is(c.readErr, io.EOF) {
c.readErr = errUnexpectedEOF c.readErr = errUnexpectedEOF
} }
return n, c.readErr return n, c.readErr
@ -1103,7 +1105,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
} }
err := c.readErr err := c.readErr
if err == io.EOF && c.messageReader == r { if errors.Is(err, io.EOF) && c.messageReader == r {
err = errUnexpectedEOF err = errUnexpectedEOF
} }
return 0, err return 0, err

View File

@ -404,18 +404,18 @@ func TestEOFWithinFrame(t *testing.T) {
b.Truncate(n) b.Truncate(n)
op, r, err := rc.NextReader() op, r, err := rc.NextReader()
if err == errUnexpectedEOF { if errors.Is(err, errUnexpectedEOF) {
continue continue
} }
if op != BinaryMessage || err != nil { if op != BinaryMessage || err != nil {
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
} }
_, err = io.Copy(io.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if !errors.Is(err, errUnexpectedEOF) {
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
} }
_, _, err = rc.NextReader() _, _, err = rc.NextReader()
if err != errUnexpectedEOF { if !errors.Is(err, errUnexpectedEOF) {
t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
} }
} }
@ -438,11 +438,11 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(io.Discard, r) _, err = io.Copy(io.Discard, r)
if err != errUnexpectedEOF { if !errors.Is(err, errUnexpectedEOF) {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
} }
_, _, err = rc.NextReader() _, _, err = rc.NextReader()
if err != errUnexpectedEOF { if !errors.Is(err, errUnexpectedEOF) {
t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
} }
} }
@ -514,7 +514,7 @@ func TestReadLimit(t *testing.T) {
t.Fatalf("2: NextReader() returned %d, %v", op, err) t.Fatalf("2: NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(io.Discard, r) _, err = io.Copy(io.Discard, r)
if err != ErrReadLimit { if !errors.Is(err, ErrReadLimit) {
t.Fatalf("io.Copy() returned %v", err) t.Fatalf("io.Copy() returned %v", err)
} }
}) })
@ -558,13 +558,13 @@ func TestReadLimit(t *testing.T) {
var buf [10]byte var buf [10]byte
var read int var read int
n, err := r.Read(buf[:]) n, err := r.Read(buf[:])
if err != nil && err != ErrReadLimit { if err != nil && !errors.Is(err, ErrReadLimit) {
t.Fatalf("unexpected error testing read limit: %v", err) t.Fatalf("unexpected error testing read limit: %v", err)
} }
read += n read += n
n, err = r.Read(buf[:]) n, err = r.Read(buf[:])
if err != nil && err != ErrReadLimit { if err != nil && !errors.Is(err, ErrReadLimit) {
t.Fatalf("unexpected error testing read limit: %v", err) t.Fatalf("unexpected error testing read limit: %v", err)
} }
read += n read += n

View File

@ -5,6 +5,7 @@
package websocket package websocket
import ( import (
"errors"
"io" "io"
"strings" "strings"
) )
@ -34,7 +35,7 @@ func (r *joinReader) Read(p []byte) (int, error) {
} }
} }
n, err := r.r.Read(p) n, err := r.r.Read(p)
if err == io.EOF { if errors.Is(err, io.EOF) {
err = nil err = nil
r.r = nil r.r = nil
} }

View File

@ -6,6 +6,7 @@ package websocket
import ( import (
"encoding/json" "encoding/json"
"errors"
"io" "io"
) )
@ -52,7 +53,7 @@ func (c *Conn) ReadJSON(v interface{}) error {
return err return err
} }
err = json.NewDecoder(r).Decode(v) err = json.NewDecoder(r).Decode(v)
if err == io.EOF { if errors.Is(err, io.EOF) {
// One value is expected in the message. // One value is expected in the message.
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }

View File

@ -7,6 +7,7 @@ package websocket
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"reflect" "reflect"
"testing" "testing"
@ -79,13 +80,14 @@ func TestPartialJSONRead(t *testing.T) {
for i := 0; i < messageCount; i++ { for i := 0; i < messageCount; i++ {
err := rc.ReadJSON(&v) err := rc.ReadJSON(&v)
if err != io.ErrUnexpectedEOF { if !errors.Is(err, io.ErrUnexpectedEOF) {
t.Error("read", i, err) t.Error("read", i, err)
} }
} }
err = rc.ReadJSON(&v) err = rc.ReadJSON(&v)
if _, ok := err.(*CloseError); !ok { var ce *CloseError
if !errors.As(err, &ce) {
t.Error("final", err) t.Error("final", err)
} }
} }