gracefully close connection fixes: #448

This commit is contained in:
Sanket Patel 2019-03-02 03:03:53 +05:30
parent 7c8e298727
commit 4b34454a14
1 changed files with 51 additions and 6 deletions

57
conn.go
View File

@ -12,6 +12,7 @@ import (
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"net" "net"
"reflect"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -219,7 +220,7 @@ var validReceivedCloseCodes = map[int]bool{
CloseTLSHandshake: false, CloseTLSHandshake: false,
} }
func isValidReceivedCloseCode(code int) bool { func isValidCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
} }
@ -325,10 +326,53 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol return c.subprotocol
} }
// Close closes the underlying network connection without sending or waiting // Close sends close frame and waits for one in response
// for a close message. // it expects two args. `closeCode int` and `closeMessage string` in order
func (c *Conn) Close() error { // it uses variadic args to maintain backwards compatibility
return c.conn.Close() func (c *Conn) Close(args ...interface{}) error {
closeCode := CloseNoStatusReceived
message := ""
ok := false
if len(args) == 2 {
closeCode, ok = args[0].(int)
if !ok {
closeCode = CloseNoStatusReceived
}
message, ok = args[1].(string)
if !ok {
message = ""
}
}
err := c.Shutdown(closeCode, message)
if err != nil {
return err
}
c.conn.Close()
return nil
}
// Shutdown sends a close frame and waits for one in response
func (c *Conn) Shutdown(closeCode int, closeMessage string) error {
if !isValidCloseCode(closeCode) {
// we do not shutdown connection
return errors.New("invalid close code received")
}
if !utf8.ValidString(closeMessage) {
return errors.New("invalid utf8 payload for shutdown message")
}
message := FormatCloseMessage(closeCode, closeMessage)
err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
if err != nil {
return err
}
timeStart := time.Now()
c.conn.SetReadDeadline(time.Now().Add(time.Minute))
for _, _, err := c.ReadMessage(); reflect.TypeOf(err) != reflect.TypeOf(&CloseError{}) ; {
if timeStart.Sub(time.Now()) > time.Minute {
break
}
}
return nil
} }
// LocalAddr returns the local network address. // LocalAddr returns the local network address.
@ -496,6 +540,7 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported. // PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
var mw messageWriter var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil { if err := c.beginMessage(&mw, messageType); err != nil {
return nil, err return nil, err
@ -902,7 +947,7 @@ func (c *Conn) advanceFrame() (int, error) {
closeText := "" closeText := ""
if len(payload) >= 2 { if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload)) closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) { if !isValidCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code") return noFrame, c.handleProtocolError("invalid close code")
} }
closeText = string(payload[2:]) closeText = string(payload[2:])