From 4b34454a146e17e68bd317157e8059be5a9a2e0e Mon Sep 17 00:00:00 2001 From: Sanket Patel Date: Sat, 2 Mar 2019 03:03:53 +0530 Subject: [PATCH] gracefully close connection fixes: #448 --- conn.go | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/conn.go b/conn.go index 3848ab4..48f21cd 100644 --- a/conn.go +++ b/conn.go @@ -12,6 +12,7 @@ import ( "io/ioutil" "math/rand" "net" + "reflect" "strconv" "sync" "time" @@ -219,7 +220,7 @@ var validReceivedCloseCodes = map[int]bool{ CloseTLSHandshake: false, } -func isValidReceivedCloseCode(code int) bool { +func isValidCloseCode(code int) bool { return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) } @@ -325,10 +326,53 @@ func (c *Conn) Subprotocol() string { return c.subprotocol } -// Close closes the underlying network connection without sending or waiting -// for a close message. -func (c *Conn) Close() error { - return c.conn.Close() +// Close sends close frame and waits for one in response +// it expects two args. `closeCode int` and `closeMessage string` in order +// it uses variadic args to maintain backwards compatibility +func (c *Conn) Close(args ...interface{}) error { + closeCode := CloseNoStatusReceived + message := "" + ok := false + if len(args) == 2 { + closeCode, ok = args[0].(int) + if !ok { + closeCode = CloseNoStatusReceived + } + message, ok = args[1].(string) + if !ok { + message = "" + } + } + err := c.Shutdown(closeCode, message) + if err != nil { + return err + } + c.conn.Close() + return nil +} + +// Shutdown sends a close frame and waits for one in response +func (c *Conn) Shutdown(closeCode int, closeMessage string) error { + if !isValidCloseCode(closeCode) { + // we do not shutdown connection + return errors.New("invalid close code received") + } + if !utf8.ValidString(closeMessage) { + return errors.New("invalid utf8 payload for shutdown message") + } + message := FormatCloseMessage(closeCode, closeMessage) + err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + if err != nil { + return err + } + timeStart := time.Now() + c.conn.SetReadDeadline(time.Now().Add(time.Minute)) + for _, _, err := c.ReadMessage(); reflect.TypeOf(err) != reflect.TypeOf(&CloseError{}) ; { + if timeStart.Sub(time.Now()) > time.Minute { + break + } + } + return nil } // LocalAddr returns the local network address. @@ -496,6 +540,7 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and // PongMessage) are supported. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + var mw messageWriter if err := c.beginMessage(&mw, messageType); err != nil { return nil, err @@ -902,7 +947,7 @@ func (c *Conn) advanceFrame() (int, error) { closeText := "" if len(payload) >= 2 { closeCode = int(binary.BigEndian.Uint16(payload)) - if !isValidReceivedCloseCode(closeCode) { + if !isValidCloseCode(closeCode) { return noFrame, c.handleProtocolError("invalid close code") } closeText = string(payload[2:])