diff --git a/conn.go b/conn.go index 221e6cf..2a5ff76 100644 --- a/conn.go +++ b/conn.go @@ -1163,7 +1163,8 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) { if h == nil { h = func(code int, text string) error { message := FormatCloseMessage(code, "") - if err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)); err != nil { + err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + if err != nil && err != ErrCloseSent { return err } return nil diff --git a/conn_test.go b/conn_test.go index 2b823dd..5b27f2f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -477,6 +477,26 @@ func TestWriteAfterMessageWriterClose(t *testing.T) { } } +func TestWriteHandlerDoesNotReturnErrCloseSent(t *testing.T) { + var b1, b2 bytes.Buffer + + client := newTestConn(&b2, &b1, false) + server := newTestConn(&b1, &b2, true) + + msg := FormatCloseMessage(CloseNormalClosure, "") + if err := client.WriteMessage(CloseMessage, msg); err != nil { + t.Fatalf("unexpected error when writing close message, %v", err) + } + + if _, _, err := server.NextReader(); !IsCloseError(err, 1000) { + t.Fatalf("server expects a close message, %v returned", err) + } + + if _, _, err := client.NextReader(); !IsCloseError(err, 1000) { + t.Fatalf("client expects a close message, %v returned", err) + } +} + func TestReadLimit(t *testing.T) { t.Run("Test ReadLimit is enforced", func(t *testing.T) { const readLimit = 512