From 0e7b5f878f841d76632b20fd79f53ca0efd06fae Mon Sep 17 00:00:00 2001 From: Gary Burd Date: Sat, 5 Jul 2014 13:56:34 -0700 Subject: [PATCH] Do not mask bytes when reading on the client. - The bytes were masked with zero, a nop. - Add test for control messages. --- conn.go | 8 ++++++-- conn_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 819ab0d..d778d39 100644 --- a/conn.go +++ b/conn.go @@ -615,7 +615,9 @@ func (c *Conn) advanceFrame() (int, error) { if _, err := io.ReadFull(c.br, payload); err != nil { return noFrame, err } - maskBytes(c.readMaskKey, 0, payload) + if c.isServer { + maskBytes(c.readMaskKey, 0, payload) + } } // 7. Process control frame payload. @@ -698,7 +700,9 @@ func (r messageReader) Read(b []byte) (n int, err error) { } n, err := r.c.br.Read(b) r.c.readErr = hideTempErr(err) - r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n]) + if r.c.isServer { + r.c.readMaskPos = maskBytes(r.c.readMaskKey, r.c.readMaskPos, b[:n]) + } r.c.readRemaining -= int64(n) return n, r.c.readErr } diff --git a/conn_test.go b/conn_test.go index fcba850..52bbede 100644 --- a/conn_test.go +++ b/conn_test.go @@ -107,6 +107,42 @@ func TestFraming(t *testing.T) { } } +func TestControl(t *testing.T) { + const message = "this is a ping/pong messsage" + for _, isServer := range []bool{true, false} { + for _, isWriteControl := range []bool{true, false} { + name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) + var connBuf bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) + rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024) + if isWriteControl { + wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + } else { + w, err := wc.NextWriter(PongMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + if _, err := w.Write([]byte(message)); err != nil { + t.Errorf("%s: w.Write() returned %v", name, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + var actualMessage string + rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) + rc.NextReader() + if actualMessage != message { + t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) + continue + } + } + } + } +} + func TestReadLimit(t *testing.T) { const readLimit = 512