diff --git a/conn.go b/conn.go index 8179ecb..af4d1b1 100644 --- a/conn.go +++ b/conn.go @@ -360,6 +360,23 @@ func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } +// +func (c *Conn) WriteLock() { + c.writeMu.Lock() + c.isWriting = true +} + +// +func (c *Conn) WriteUnlock() { + c.isWriting = false + c.writeMu.Unlock() +} + +// Return the conn writing status +func (c *Conn) IsWriting() bool { + return c.isWriting +} + // Write methods func (c *Conn) writeFatal(err error) error { @@ -579,22 +596,8 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { } } - // Write the buffers to the connection with best-effort detection of - // concurrent writes. See the concurrency section in the package - // documentation for more info. - - if c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = true - err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) - if !c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = false - if err != nil { return w.fatal(err) } @@ -715,23 +718,17 @@ func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { if err != nil { return err } - if c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = true + err = c.write(frameType, c.writeDeadline, frameData, nil) - if !c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = false + return err } // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. func (c *Conn) WriteMessage(messageType int, data []byte) error { - c.writeMu.Lock() - defer c.writeMu.Unlock() + c.WriteLock() + defer c.WriteUnlock() if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { // Fast path with no allocations and single frame. @@ -765,11 +762,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { return nil } -// Return the conn writing status -func (c *Conn) IsWriting() bool { - return c.isWriting -} - // Read methods func (c *Conn) advanceFrame() (int, error) { diff --git a/json.go b/json.go index 7f78d2a..b2c6608 100644 --- a/json.go +++ b/json.go @@ -19,8 +19,8 @@ func WriteJSON(c *Conn, v interface{}) error { // See the documentation for encoding/json Marshal for details about the // conversion of Go values to JSON. func (c *Conn) WriteJSON(v interface{}) error { - c.writeMu.Lock() - defer c.writeMu.Unlock() + c.WriteLock() + defer c.WriteUnlock() w, err := c.NextWriter(TextMessage) if err != nil { return err