Provide all close frame data to application

- Export closeError.
- Do not convert normal closure and going away to io.EOF.
This commit is contained in:
Gary Burd 2015-08-11 10:14:32 -07:00
parent 6eb6ad425a
commit b6ab76f1fe
4 changed files with 23 additions and 23 deletions

25
conn.go
View File

@ -88,19 +88,23 @@ func (e *netError) Error() string { return e.msg }
func (e *netError) Temporary() bool { return e.temporary } func (e *netError) Temporary() bool { return e.temporary }
func (e *netError) Timeout() bool { return e.timeout } func (e *netError) Timeout() bool { return e.timeout }
// closeError represents close frame. // CloseError represents close frame.
type closeError struct { type CloseError struct {
code int
text string // Code is defined in RFC 6455, section 11.7.
Code int
// Text is the optional text payload.
Text string
} }
func (e *closeError) Error() string { func (e *CloseError) Error() string {
return "websocket: close " + strconv.Itoa(e.code) + " " + e.text return "websocket: close " + strconv.Itoa(e.Code) + " " + e.Text
} }
var ( var (
errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true} errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true}
errUnexpectedEOF = &closeError{code: CloseAbnormalClosure, text: io.ErrUnexpectedEOF.Error()} errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()}
errBadWriteOpCode = errors.New("websocket: bad write message type") errBadWriteOpCode = errors.New("websocket: bad write message type")
errWriteClosed = errors.New("websocket: write closed") errWriteClosed = errors.New("websocket: write closed")
errInvalidControlFrame = errors.New("websocket: invalid control frame") errInvalidControlFrame = errors.New("websocket: invalid control frame")
@ -673,12 +677,7 @@ func (c *Conn) advanceFrame() (int, error) {
closeCode = int(binary.BigEndian.Uint16(payload)) closeCode = int(binary.BigEndian.Uint16(payload))
closeText = string(payload[2:]) closeText = string(payload[2:])
} }
switch closeCode { return noFrame, &CloseError{Code: closeCode, Text: closeText}
case CloseNormalClosure, CloseGoingAway:
return noFrame, io.EOF
default:
return noFrame, &closeError{code: closeCode, text: closeText}
}
} }
return frameType, nil return frameType, nil

View File

@ -10,6 +10,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"reflect"
"testing" "testing"
"testing/iotest" "testing/iotest"
"time" "time"
@ -146,13 +147,15 @@ func TestControl(t *testing.T) {
func TestCloseBeforeFinalFrame(t *testing.T) { func TestCloseBeforeFinalFrame(t *testing.T) {
const bufSize = 512 const bufSize = 512
expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
var b1, b2 bytes.Buffer var b1, b2 bytes.Buffer
wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
w, _ := wc.NextWriter(BinaryMessage) w, _ := wc.NextWriter(BinaryMessage)
w.Write(make([]byte, bufSize+bufSize/2)) w.Write(make([]byte, bufSize+bufSize/2))
wc.WriteControl(CloseMessage, FormatCloseMessage(CloseNormalClosure, ""), time.Now().Add(10*time.Second)) wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
w.Close() w.Close()
op, r, err := rc.NextReader() op, r, err := rc.NextReader()
@ -160,12 +163,12 @@ func TestCloseBeforeFinalFrame(t *testing.T) {
t.Fatalf("NextReader() returned %d, %v", op, err) t.Fatalf("NextReader() returned %d, %v", op, err)
} }
_, err = io.Copy(ioutil.Discard, r) _, err = io.Copy(ioutil.Discard, r)
if err != errUnexpectedEOF { if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
} }
_, _, err = rc.NextReader() _, _, err = rc.NextReader()
if err != io.EOF { if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("NextReader() returned %v, want %v", err, io.EOF) t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
} }
} }

View File

@ -48,9 +48,7 @@ func (c *Conn) ReadJSON(v interface{}) error {
} }
err = json.NewDecoder(r).Decode(v) err = json.NewDecoder(r).Decode(v)
if err == io.EOF { if err == io.EOF {
// Decode returns io.EOF when the message is empty or all whitespace. // One value is expected in the message.
// Convert to io.ErrUnexpectedEOF so that application can distinguish
// between an error reading the JSON value and the connection closing.
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }
return err return err

View File

@ -38,7 +38,7 @@ func TestJSON(t *testing.T) {
} }
} }
func TestPartialJsonRead(t *testing.T) { func TestPartialJSONRead(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
c := fakeNetConn{&buf, &buf} c := fakeNetConn{&buf, &buf}
wc := newConn(c, true, 1024, 1024) wc := newConn(c, true, 1024, 1024)
@ -87,7 +87,7 @@ func TestPartialJsonRead(t *testing.T) {
} }
err = rc.ReadJSON(&v) err = rc.ReadJSON(&v)
if err != io.EOF { if _, ok := err.(*CloseError); !ok {
t.Error("final", err) t.Error("final", err)
} }
} }