timeout based shutdown

This commit is contained in:
Sanket Patel 2019-03-02 17:47:42 +05:30
parent 4b34454a14
commit 4fbd52ea52
1 changed files with 18 additions and 38 deletions

54
conn.go
View File

@ -12,7 +12,6 @@ import (
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"net" "net"
"reflect"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -243,6 +242,7 @@ type Conn struct {
conn net.Conn conn net.Conn
isServer bool isServer bool
subprotocol string subprotocol string
isClosed chan bool
// Write fields // Write fields
mu chan bool // used as mutex to protect write to conn mu chan bool // used as mutex to protect write to conn
@ -326,33 +326,17 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol return c.subprotocol
} }
// Close sends close frame and waits for one in response // Close closes the underlying network connection without sending or waiting
// it expects two args. `closeCode int` and `closeMessage string` in order // for a close message.
// it uses variadic args to maintain backwards compatibility func (c *Conn) Close() error {
func (c *Conn) Close(args ...interface{}) error { return c.conn.Close()
closeCode := CloseNoStatusReceived
message := ""
ok := false
if len(args) == 2 {
closeCode, ok = args[0].(int)
if !ok {
closeCode = CloseNoStatusReceived
}
message, ok = args[1].(string)
if !ok {
message = ""
}
}
err := c.Shutdown(closeCode, message)
if err != nil {
return err
}
c.conn.Close()
return nil
} }
// Shutdown sends a close frame and waits for one in response
func (c *Conn) Shutdown(closeCode int, closeMessage string) error { // Shutdown sends a close frame to the peer and waits for close frame in resopnse.
// Shutdown assumes that the application is reading the connection in another
// goroutine and hence it does not try to read close frame itself
func (c *Conn) Shutdown(closeCode int, closeMessage string, timeout time.Duration) error {
if !isValidCloseCode(closeCode) { if !isValidCloseCode(closeCode) {
// we do not shutdown connection // we do not shutdown connection
return errors.New("invalid close code received") return errors.New("invalid close code received")
@ -360,19 +344,14 @@ func (c *Conn) Shutdown(closeCode int, closeMessage string) error {
if !utf8.ValidString(closeMessage) { if !utf8.ValidString(closeMessage) {
return errors.New("invalid utf8 payload for shutdown message") return errors.New("invalid utf8 payload for shutdown message")
} }
message := FormatCloseMessage(closeCode, closeMessage) message := FormatCloseMessage(closeCode, closeMessage)
err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
if err != nil { select {
return err case <-time.After(timeout): // if nothing happens and we timeout
case <-c.isClosed: // if existing reader encounters close frame
} }
timeStart := time.Now() return c.Close()
c.conn.SetReadDeadline(time.Now().Add(time.Minute))
for _, _, err := c.ReadMessage(); reflect.TypeOf(err) != reflect.TypeOf(&CloseError{}) ; {
if timeStart.Sub(time.Now()) > time.Minute {
break
}
}
return nil
} }
// LocalAddr returns the local network address. // LocalAddr returns the local network address.
@ -943,6 +922,7 @@ func (c *Conn) advanceFrame() (int, error) {
return noFrame, err return noFrame, err
} }
case CloseMessage: case CloseMessage:
c.isClosed <- true
closeCode := CloseNoStatusReceived closeCode := CloseNoStatusReceived
closeText := "" closeText := ""
if len(payload) >= 2 { if len(payload) >= 2 {