mirror of https://github.com/gorilla/websocket.git
Use errors.As and errors.Is
Replace type assertions to error types with errors.As. Replace comparisons with error values with errors.Is.
This commit is contained in:
parent
666c197fc9
commit
493d31ecb9
|
@ -131,7 +131,7 @@ func (r *flateReadWrapper) Read(p []byte) (int, error) {
|
||||||
return 0, io.ErrClosedPipe
|
return 0, io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
n, err := r.fr.Read(p)
|
n, err := r.fr.Read(p)
|
||||||
if err == io.EOF {
|
if errors.Is(err, io.EOF) {
|
||||||
// Preemptively place the reader back in the pool. This helps with
|
// Preemptively place the reader back in the pool. This helps with
|
||||||
// scenarios where the application does not call NextReader() soon after
|
// scenarios where the application does not call NextReader() soon after
|
||||||
// this final read.
|
// this final read.
|
||||||
|
|
14
conn.go
14
conn.go
|
@ -149,7 +149,8 @@ func (e *CloseError) Error() string {
|
||||||
// IsCloseError returns boolean indicating whether the error is a *CloseError
|
// IsCloseError returns boolean indicating whether the error is a *CloseError
|
||||||
// with one of the specified codes.
|
// with one of the specified codes.
|
||||||
func IsCloseError(err error, codes ...int) bool {
|
func IsCloseError(err error, codes ...int) bool {
|
||||||
if e, ok := err.(*CloseError); ok {
|
var e *CloseError
|
||||||
|
if errors.As(err, &e) {
|
||||||
for _, code := range codes {
|
for _, code := range codes {
|
||||||
if e.Code == code {
|
if e.Code == code {
|
||||||
return true
|
return true
|
||||||
|
@ -162,7 +163,8 @@ func IsCloseError(err error, codes ...int) bool {
|
||||||
// IsUnexpectedCloseError returns boolean indicating whether the error is a
|
// IsUnexpectedCloseError returns boolean indicating whether the error is a
|
||||||
// *CloseError with a code not in the list of expected codes.
|
// *CloseError with a code not in the list of expected codes.
|
||||||
func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
|
func IsUnexpectedCloseError(err error, expectedCodes ...int) bool {
|
||||||
if e, ok := err.(*CloseError); ok {
|
var e *CloseError
|
||||||
|
if errors.As(err, &e) {
|
||||||
for _, code := range expectedCodes {
|
for _, code := range expectedCodes {
|
||||||
if e.Code == code {
|
if e.Code == code {
|
||||||
return false
|
return false
|
||||||
|
@ -376,7 +378,7 @@ func (c *Conn) writeFatal(err error) error {
|
||||||
|
|
||||||
func (c *Conn) read(n int) ([]byte, error) {
|
func (c *Conn) read(n int) ([]byte, error) {
|
||||||
p, err := c.br.Peek(n)
|
p, err := c.br.Peek(n)
|
||||||
if err == io.EOF {
|
if errors.Is(err, io.EOF) {
|
||||||
err = errUnexpectedEOF
|
err = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
if _, err := c.br.Discard(len(p)); err != nil {
|
if _, err := c.br.Discard(len(p)); err != nil {
|
||||||
|
@ -730,7 +732,7 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) {
|
||||||
w.pos += n
|
w.pos += n
|
||||||
nn += int64(n)
|
nn += int64(n)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF {
|
if errors.Is(err, io.EOF) {
|
||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
@ -1082,7 +1084,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
if err := c.setReadRemaining(rem); err != nil {
|
if err := c.setReadRemaining(rem); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if c.readRemaining > 0 && c.readErr == io.EOF {
|
if c.readRemaining > 0 && errors.Is(c.readErr, io.EOF) {
|
||||||
c.readErr = errUnexpectedEOF
|
c.readErr = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
return n, c.readErr
|
return n, c.readErr
|
||||||
|
@ -1103,7 +1105,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
err := c.readErr
|
err := c.readErr
|
||||||
if err == io.EOF && c.messageReader == r {
|
if errors.Is(err, io.EOF) && c.messageReader == r {
|
||||||
err = errUnexpectedEOF
|
err = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|
16
conn_test.go
16
conn_test.go
|
@ -404,18 +404,18 @@ func TestEOFWithinFrame(t *testing.T) {
|
||||||
b.Truncate(n)
|
b.Truncate(n)
|
||||||
|
|
||||||
op, r, err := rc.NextReader()
|
op, r, err := rc.NextReader()
|
||||||
if err == errUnexpectedEOF {
|
if errors.Is(err, errUnexpectedEOF) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if op != BinaryMessage || err != nil {
|
if op != BinaryMessage || err != nil {
|
||||||
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(io.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if err != errUnexpectedEOF {
|
if !errors.Is(err, errUnexpectedEOF) {
|
||||||
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||||
}
|
}
|
||||||
_, _, err = rc.NextReader()
|
_, _, err = rc.NextReader()
|
||||||
if err != errUnexpectedEOF {
|
if !errors.Is(err, errUnexpectedEOF) {
|
||||||
t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
|
t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -438,11 +438,11 @@ func TestEOFBeforeFinalFrame(t *testing.T) {
|
||||||
t.Fatalf("NextReader() returned %d, %v", op, err)
|
t.Fatalf("NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(io.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if err != errUnexpectedEOF {
|
if !errors.Is(err, errUnexpectedEOF) {
|
||||||
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
|
||||||
}
|
}
|
||||||
_, _, err = rc.NextReader()
|
_, _, err = rc.NextReader()
|
||||||
if err != errUnexpectedEOF {
|
if !errors.Is(err, errUnexpectedEOF) {
|
||||||
t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
|
t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -514,7 +514,7 @@ func TestReadLimit(t *testing.T) {
|
||||||
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
t.Fatalf("2: NextReader() returned %d, %v", op, err)
|
||||||
}
|
}
|
||||||
_, err = io.Copy(io.Discard, r)
|
_, err = io.Copy(io.Discard, r)
|
||||||
if err != ErrReadLimit {
|
if !errors.Is(err, ErrReadLimit) {
|
||||||
t.Fatalf("io.Copy() returned %v", err)
|
t.Fatalf("io.Copy() returned %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -558,13 +558,13 @@ func TestReadLimit(t *testing.T) {
|
||||||
var buf [10]byte
|
var buf [10]byte
|
||||||
var read int
|
var read int
|
||||||
n, err := r.Read(buf[:])
|
n, err := r.Read(buf[:])
|
||||||
if err != nil && err != ErrReadLimit {
|
if err != nil && !errors.Is(err, ErrReadLimit) {
|
||||||
t.Fatalf("unexpected error testing read limit: %v", err)
|
t.Fatalf("unexpected error testing read limit: %v", err)
|
||||||
}
|
}
|
||||||
read += n
|
read += n
|
||||||
|
|
||||||
n, err = r.Read(buf[:])
|
n, err = r.Read(buf[:])
|
||||||
if err != nil && err != ErrReadLimit {
|
if err != nil && !errors.Is(err, ErrReadLimit) {
|
||||||
t.Fatalf("unexpected error testing read limit: %v", err)
|
t.Fatalf("unexpected error testing read limit: %v", err)
|
||||||
}
|
}
|
||||||
read += n
|
read += n
|
||||||
|
|
3
join.go
3
join.go
|
@ -5,6 +5,7 @@
|
||||||
package websocket
|
package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
@ -34,7 +35,7 @@ func (r *joinReader) Read(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
n, err := r.r.Read(p)
|
n, err := r.r.Read(p)
|
||||||
if err == io.EOF {
|
if errors.Is(err, io.EOF) {
|
||||||
err = nil
|
err = nil
|
||||||
r.r = nil
|
r.r = nil
|
||||||
}
|
}
|
||||||
|
|
3
json.go
3
json.go
|
@ -6,6 +6,7 @@ package websocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,7 +53,7 @@ func (c *Conn) ReadJSON(v interface{}) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = json.NewDecoder(r).Decode(v)
|
err = json.NewDecoder(r).Decode(v)
|
||||||
if err == io.EOF {
|
if errors.Is(err, io.EOF) {
|
||||||
// One value is expected in the message.
|
// One value is expected in the message.
|
||||||
err = io.ErrUnexpectedEOF
|
err = io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ package websocket
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -79,13 +80,14 @@ func TestPartialJSONRead(t *testing.T) {
|
||||||
|
|
||||||
for i := 0; i < messageCount; i++ {
|
for i := 0; i < messageCount; i++ {
|
||||||
err := rc.ReadJSON(&v)
|
err := rc.ReadJSON(&v)
|
||||||
if err != io.ErrUnexpectedEOF {
|
if !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||||
t.Error("read", i, err)
|
t.Error("read", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rc.ReadJSON(&v)
|
err = rc.ReadJSON(&v)
|
||||||
if _, ok := err.(*CloseError); !ok {
|
var ce *CloseError
|
||||||
|
if !errors.As(err, &ce) {
|
||||||
t.Error("final", err)
|
t.Error("final", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue