diff --git a/conn.go b/conn.go index 5b26b53..cc04ed9 100644 --- a/conn.go +++ b/conn.go @@ -12,6 +12,7 @@ import ( "io/ioutil" "net" "strconv" + "sync" "time" "unicode/utf8" ) @@ -223,19 +224,16 @@ type Conn struct { subprotocol string // Write fields - mu chan bool // used as mutex to protect write to conn and closeSent - closeSent bool // whether close message was sent - writeErr error - writeBuf []byte // frame is constructed in this buffer. - writePos int // end of data in writeBuf. - writeFrameType int // type of the current frame. - writeDeadline time.Time - messageWriter *messageWriter // the current low-level message writer - writer io.WriteCloser // the current writer returned to the application - isWriting bool // for best-effort concurrent write detection + mu chan bool // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writeDeadline time.Time + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + writeErrMu sync.Mutex + writeErr error enableWriteCompression bool - writeCompress bool // whether next call to flushFrame should set RSV1 newCompressionWriter func(io.WriteCloser) (io.WriteCloser, error) // Read fields @@ -277,8 +275,6 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) mu: mu, readFinal: true, writeBuf: make([]byte, writeBufferSize+maxFrameHeaderSize), - writeFrameType: noFrame, - writePos: maxFrameHeaderSize, enableWriteCompression: true, } c.SetPingHandler(nil) @@ -308,29 +304,40 @@ func (c *Conn) RemoteAddr() net.Addr { // Write methods +func (c *Conn) writeFatal(err error) error { + err = hideTempErr(err) + c.writeErrMu.Lock() + if c.writeErr == nil { + c.writeErr = err + } + c.writeErrMu.Unlock() + return err +} + func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { <-c.mu defer func() { c.mu <- true }() - if c.closeSent { - return ErrCloseSent - } else if frameType == CloseMessage { - c.closeSent = true + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err } c.conn.SetWriteDeadline(deadline) for _, buf := range bufs { if len(buf) > 0 { - n, err := c.conn.Write(buf) - if n != len(buf) { - // Close on partial write. - c.conn.Close() - } + _, err := c.conn.Write(buf) if err != nil { - return err + return c.writeFatal(err) } } } + + if frameType == CloseMessage { + c.writeFatal(ErrCloseSent) + } return nil } @@ -379,18 +386,22 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er } defer func() { c.mu <- true }() - if c.closeSent { - return ErrCloseSent - } else if messageType == CloseMessage { - c.closeSent = true + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err } c.conn.SetWriteDeadline(deadline) - n, err := c.conn.Write(buf) - if n != 0 && n != len(buf) { - c.conn.Close() + _, err = c.conn.Write(buf) + if err != nil { + return c.writeFatal(err) } - return hideTempErr(err) + if messageType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return err } // NextWriter returns a writer for the next message to send. The writer's Close @@ -399,64 +410,79 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er // There can be at most one open writer on a connection. NextWriter closes the // previous writer if the application has not already done so. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { - if c.writeErr != nil { - return nil, c.writeErr - } - // Close previous writer if not already closed by the application. It's // probably better to return an error in this situation, but we cannot // change this without breaking existing applications. if c.writer != nil { - err := c.writer.Close() - if err != nil { - return nil, err - } + c.writer.Close() + c.writer = nil } if !isControl(messageType) && !isData(messageType) { return nil, errBadWriteOpCode } - c.writeFrameType = messageType - c.messageWriter = &messageWriter{c} - - var w io.WriteCloser = c.messageWriter - if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { - c.writeCompress = true - var err error - w, err = c.newCompressionWriter(w) - if err != nil { - c.writer.Close() - return nil, err - } + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return nil, err } - return w, nil + mw := &messageWriter{ + c: c, + frameType: messageType, + pos: maxFrameHeaderSize, + } + c.writer = mw + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w, err := c.newCompressionWriter(c.writer) + if err != nil { + c.writer = nil + return nil, err + } + mw.compress = true + c.writer = w + } + return c.writer, nil +} + +type messageWriter struct { + c *Conn + compress bool // whether next call to flushFrame should set RSV1 + pos int // end of data in writeBuf. + frameType int // type of the current frame. + err error +} + +func (w *messageWriter) fatal(err error) error { + if w.err != nil { + w.err = err + w.c.writer = nil + } + return err } // flushFrame writes buffered data and extra as a frame to the network. The // final argument indicates that this is the last frame in the message. -func (c *Conn) flushFrame(final bool, extra []byte) error { - length := c.writePos - maxFrameHeaderSize + len(extra) +func (w *messageWriter) flushFrame(final bool, extra []byte) error { + c := w.c + length := w.pos - maxFrameHeaderSize + len(extra) // Check for invalid control frames. - if isControl(c.writeFrameType) && + if isControl(w.frameType) && (!final || length > maxControlFramePayloadSize) { - c.messageWriter = nil - c.writer = nil - c.writeFrameType = noFrame - c.writePos = maxFrameHeaderSize - return errInvalidControlFrame + return w.fatal(errInvalidControlFrame) } - b0 := byte(c.writeFrameType) + b0 := byte(w.frameType) if final { b0 |= finalBit } - if c.writeCompress { + if w.compress { b0 |= rsv1Bit } - c.writeCompress = false + w.compress = false b1 := byte(0) if !c.isServer { @@ -489,10 +515,9 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { if !c.isServer { key := newMaskKey() copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) - maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:c.writePos]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) if len(extra) > 0 { - c.writeErr = errors.New("websocket: internal error, extra used in client mode") - return c.writeErr + return c.writeFatal(errors.New("websocket: internal error, extra used in client mode")) } } @@ -505,44 +530,35 @@ func (c *Conn) flushFrame(final bool, extra []byte) error { } c.isWriting = true - c.writeErr = c.write(c.writeFrameType, c.writeDeadline, c.writeBuf[framePos:c.writePos], extra) + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) if !c.isWriting { panic("concurrent write to websocket connection") } c.isWriting = false - // Setup for next frame. - c.writePos = maxFrameHeaderSize - c.writeFrameType = continuationFrame + if err != nil { + return w.fatal(err) + } + if final { - c.messageWriter = nil c.writer = nil - c.writeFrameType = noFrame + return nil } - return c.writeErr -} -type messageWriter struct{ c *Conn } - -func (w *messageWriter) err() error { - c := w.c - if c.messageWriter != w { - return errWriteClosed - } - if c.writeErr != nil { - return c.writeErr - } + // Setup for next frame. + w.pos = maxFrameHeaderSize + w.frameType = continuationFrame return nil } func (w *messageWriter) ncopy(max int) (int, error) { - n := len(w.c.writeBuf) - w.c.writePos + n := len(w.c.writeBuf) - w.pos if n <= 0 { - if err := w.c.flushFrame(false, nil); err != nil { + if err := w.flushFrame(false, nil); err != nil { return 0, err } - n = len(w.c.writeBuf) - w.c.writePos + n = len(w.c.writeBuf) - w.pos } if n > max { n = max @@ -551,13 +567,13 @@ func (w *messageWriter) ncopy(max int) (int, error) { } func (w *messageWriter) Write(p []byte) (int, error) { - if err := w.err(); err != nil { - return 0, err + if w.err != nil { + return 0, w.err } if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { // Don't buffer large messages. - err := w.c.flushFrame(false, p) + err := w.flushFrame(false, p) if err != nil { return 0, err } @@ -570,16 +586,16 @@ func (w *messageWriter) Write(p []byte) (int, error) { if err != nil { return 0, err } - copy(w.c.writeBuf[w.c.writePos:], p[:n]) - w.c.writePos += n + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n p = p[n:] } return nn, nil } func (w *messageWriter) WriteString(p string) (int, error) { - if err := w.err(); err != nil { - return 0, err + if w.err != nil { + return 0, w.err } nn := len(p) @@ -588,27 +604,27 @@ func (w *messageWriter) WriteString(p string) (int, error) { if err != nil { return 0, err } - copy(w.c.writeBuf[w.c.writePos:], p[:n]) - w.c.writePos += n + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n p = p[n:] } return nn, nil } func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { - if err := w.err(); err != nil { - return 0, err + if w.err != nil { + return 0, w.err } for { - if w.c.writePos == len(w.c.writeBuf) { - err = w.c.flushFrame(false, nil) + if w.pos == len(w.c.writeBuf) { + err = w.flushFrame(false, nil) if err != nil { break } } var n int - n, err = r.Read(w.c.writeBuf[w.c.writePos:]) - w.c.writePos += n + n, err = r.Read(w.c.writeBuf[w.pos:]) + w.pos += n nn += int64(n) if err != nil { if err == io.EOF { @@ -621,10 +637,14 @@ func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { } func (w *messageWriter) Close() error { - if err := w.err(); err != nil { + if w.err != nil { + return w.err + } + if err := w.flushFrame(true, nil); err != nil { return err } - return w.c.flushFrame(true, nil) + w.err = errWriteClosed + return nil } // WriteMessage is a helper method for getting a writer using NextWriter, @@ -634,12 +654,12 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { if err != nil { return err } - if _, ok := w.(*messageWriter); ok && c.isServer { + if mw, ok := w.(*messageWriter); ok && c.isServer { // Optimize write as a single frame. - n := copy(c.writeBuf[c.writePos:], data) - c.writePos += n + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n data = data[n:] - err = c.flushFrame(true, data) + err = mw.flushFrame(true, data) return err } if _, err = w.Write(data); err != nil { diff --git a/conn_test.go b/conn_test.go index 3c938cb..7431383 100644 --- a/conn_test.go +++ b/conn_test.go @@ -26,12 +26,27 @@ type fakeNetConn struct { } func (c fakeNetConn) Close() error { return nil } -func (c fakeNetConn) LocalAddr() net.Addr { return nil } -func (c fakeNetConn) RemoteAddr() net.Addr { return nil } +func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } +func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } +type fakeAddr int + +var ( + localAddr = fakeAddr(1) + remoteAddr = fakeAddr(2) +) + +func (a fakeAddr) Network() string { + return "net" +} + +func (a fakeAddr) String() string { + return "str" +} + func TestFraming(t *testing.T) { frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537} var readChunkers = []struct { @@ -42,11 +57,25 @@ func TestFraming(t *testing.T) { {"one", iotest.OneByteReader}, {"asis", func(r io.Reader) io.Reader { return r }}, } - writeBuf := make([]byte, 65537) for i := range writeBuf { writeBuf[i] = byte(i) } + var writers = []struct { + name string + f func(w io.Writer, n int) (int, error) + }{ + {"iocopy", func(w io.Writer, n int) (int, error) { + nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) + return int(nn), err + }}, + {"write", func(w io.Writer, n int) (int, error) { + return w.Write(writeBuf[:n]) + }}, + {"string", func(w io.Writer, n int) (int, error) { + return io.WriteString(w, string(writeBuf[:n])) + }}, + } for _, compress := range []bool{false, true} { for _, isServer := range []bool{true, false} { @@ -60,22 +89,15 @@ func TestFraming(t *testing.T) { rc.newDecompressionReader = decompressNoContextTakeover } for _, n := range frameSizes { - for _, iocopy := range []bool{true, false} { - name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d c:%v", compress, isServer, chunker.name, n, iocopy) + for _, writer := range writers { + name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) w, err := wc.NextWriter(TextMessage) if err != nil { t.Errorf("%s: wc.NextWriter() returned %v", name, err) continue } - var nn int - if iocopy { - var n64 int64 - n64, err = io.Copy(w, bytes.NewReader(writeBuf[:n])) - nn = int(n64) - } else { - nn, err = w.Write(writeBuf[:n]) - } + nn, err := writer.f(w, n) if err != nil || nn != n { t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) continue @@ -151,7 +173,7 @@ func TestControl(t *testing.T) { } } -func TestCloseBeforeFinalFrame(t *testing.T) { +func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { const bufSize = 512 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} @@ -238,6 +260,32 @@ func TestEOFBeforeFinalFrame(t *testing.T) { } } +func TestWriteAfterMessageWriterClose(t *testing.T) { + wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024) + w, _ := wc.NextWriter(BinaryMessage) + io.WriteString(w, "hello") + if err := w.Close(); err != nil { + t.Fatalf("unxpected error closing message writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } + + w, _ = wc.NextWriter(BinaryMessage) + io.WriteString(w, "hello") + + // close w by getting next writer + _, err := wc.NextWriter(BinaryMessage) + if err != nil { + t.Fatalf("unexpected error getting next writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } +} + func TestReadLimit(t *testing.T) { const readLimit = 512 @@ -272,6 +320,16 @@ func TestReadLimit(t *testing.T) { } } +func TestAddrs(t *testing.T) { + c := newConn(&fakeNetConn{}, true, 1024, 1024) + if c.LocalAddr() != localAddr { + t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) + } + if c.RemoteAddr() != remoteAddr { + t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) + } +} + func TestUnderlyingConn(t *testing.T) { var b1, b2 bytes.Buffer fc := fakeNetConn{Reader: &b1, Writer: &b2}