mirror of https://github.com/gorilla/websocket.git
gracefully close connection fixes: #448
This commit is contained in:
parent
7c8e298727
commit
4b34454a14
57
conn.go
57
conn.go
|
@ -12,6 +12,7 @@ import (
|
|||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -219,7 +220,7 @@ var validReceivedCloseCodes = map[int]bool{
|
|||
CloseTLSHandshake: false,
|
||||
}
|
||||
|
||||
func isValidReceivedCloseCode(code int) bool {
|
||||
func isValidCloseCode(code int) bool {
|
||||
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
|
||||
}
|
||||
|
||||
|
@ -325,10 +326,53 @@ func (c *Conn) Subprotocol() string {
|
|||
return c.subprotocol
|
||||
}
|
||||
|
||||
// Close closes the underlying network connection without sending or waiting
|
||||
// for a close message.
|
||||
func (c *Conn) Close() error {
|
||||
return c.conn.Close()
|
||||
// Close sends close frame and waits for one in response
|
||||
// it expects two args. `closeCode int` and `closeMessage string` in order
|
||||
// it uses variadic args to maintain backwards compatibility
|
||||
func (c *Conn) Close(args ...interface{}) error {
|
||||
closeCode := CloseNoStatusReceived
|
||||
message := ""
|
||||
ok := false
|
||||
if len(args) == 2 {
|
||||
closeCode, ok = args[0].(int)
|
||||
if !ok {
|
||||
closeCode = CloseNoStatusReceived
|
||||
}
|
||||
message, ok = args[1].(string)
|
||||
if !ok {
|
||||
message = ""
|
||||
}
|
||||
}
|
||||
err := c.Shutdown(closeCode, message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown sends a close frame and waits for one in response
|
||||
func (c *Conn) Shutdown(closeCode int, closeMessage string) error {
|
||||
if !isValidCloseCode(closeCode) {
|
||||
// we do not shutdown connection
|
||||
return errors.New("invalid close code received")
|
||||
}
|
||||
if !utf8.ValidString(closeMessage) {
|
||||
return errors.New("invalid utf8 payload for shutdown message")
|
||||
}
|
||||
message := FormatCloseMessage(closeCode, closeMessage)
|
||||
err := c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
timeStart := time.Now()
|
||||
c.conn.SetReadDeadline(time.Now().Add(time.Minute))
|
||||
for _, _, err := c.ReadMessage(); reflect.TypeOf(err) != reflect.TypeOf(&CloseError{}) ; {
|
||||
if timeStart.Sub(time.Now()) > time.Minute {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
|
@ -496,6 +540,7 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
|
|||
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
|
||||
// PongMessage) are supported.
|
||||
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
|
||||
|
||||
var mw messageWriter
|
||||
if err := c.beginMessage(&mw, messageType); err != nil {
|
||||
return nil, err
|
||||
|
@ -902,7 +947,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|||
closeText := ""
|
||||
if len(payload) >= 2 {
|
||||
closeCode = int(binary.BigEndian.Uint16(payload))
|
||||
if !isValidReceivedCloseCode(closeCode) {
|
||||
if !isValidCloseCode(closeCode) {
|
||||
return noFrame, c.handleProtocolError("invalid close code")
|
||||
}
|
||||
closeText = string(payload[2:])
|
||||
|
|
Loading…
Reference in New Issue