Add IsCloseError, improve doc about errors

This commit is contained in:
Gary Burd 2016-01-19 09:20:21 -08:00
parent 3986be78bf
commit a2d85bcbfc
2 changed files with 36 additions and 2 deletions

17
conn.go
View File

@ -102,6 +102,19 @@ func (e *CloseError) Error() string {
return "websocket: close " + strconv.Itoa(e.Code) + " " + e.Text return "websocket: close " + strconv.Itoa(e.Code) + " " + e.Text
} }
// IsCloseError returns boolean indicating whether the error is a *CloseError
// with one of the specified codes.
func IsCloseError(err error, codes ...int) bool {
if e, ok := err.(*CloseError); ok {
for _, code := range codes {
if e.Code == code {
return true
}
}
}
return false
}
var ( var (
errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true}
errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()}
@ -694,8 +707,8 @@ func (c *Conn) handleProtocolError(message string) error {
// There can be at most one open reader on a connection. NextReader discards // There can be at most one open reader on a connection. NextReader discards
// the previous message if the application has not already consumed it. // the previous message if the application has not already consumed it.
// //
// The NextReader method and the readers returned from the method cannot be // Errors returned from NextReader are permanent. If NextReader returns a
// accessed by more than one goroutine at a time. // non-nil error, then all subsequent calls to NextReader will the same error.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.readSeq++ c.readSeq++

View File

@ -7,6 +7,7 @@ package websocket
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -270,3 +271,23 @@ func TestBufioReadBytes(t *testing.T) {
t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m)) t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
} }
} }
var closeErrorTests = []struct {
err error
codes []int
ok bool
}{
{&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
{&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
{errors.New("hello"), []int{CloseNormalClosure}, false},
}
func TestCloseError(t *testing.T) {
for _, tt := range closeErrorTests {
ok := IsCloseError(tt.err, tt.codes...)
if ok != tt.ok {
t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
}
}
}