forked from mirror/websocket
Improve protocol error messages
To aid protocol error debugging, report all errors found in the first two bytes of a message header.
This commit is contained in:
parent
2d6ee4c55c
commit
f0643a3a18
57
conn.go
57
conn.go
|
@ -13,6 +13,7 @@ import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
@ -794,47 +795,69 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Read and parse first two bytes of frame header.
|
// 2. Read and parse first two bytes of frame header.
|
||||||
|
// To aid debugging, collect and report all errors in the first two bytes
|
||||||
|
// of the header.
|
||||||
|
|
||||||
|
var errors []string
|
||||||
|
|
||||||
p, err := c.read(2)
|
p, err := c.read(2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
|
|
||||||
final := p[0]&finalBit != 0
|
|
||||||
frameType := int(p[0] & 0xf)
|
frameType := int(p[0] & 0xf)
|
||||||
|
final := p[0]&finalBit != 0
|
||||||
|
rsv1 := p[0]&rsv1Bit != 0
|
||||||
|
rsv2 := p[0]&rsv2Bit != 0
|
||||||
|
rsv3 := p[0]&rsv3Bit != 0
|
||||||
mask := p[1]&maskBit != 0
|
mask := p[1]&maskBit != 0
|
||||||
c.setReadRemaining(int64(p[1] & 0x7f))
|
c.setReadRemaining(int64(p[1] & 0x7f))
|
||||||
|
|
||||||
c.readDecompress = false
|
c.readDecompress = false
|
||||||
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
|
if rsv1 {
|
||||||
c.readDecompress = true
|
if c.newDecompressionReader != nil {
|
||||||
p[0] &^= rsv1Bit
|
c.readDecompress = true
|
||||||
|
} else {
|
||||||
|
errors = append(errors, "RSV1 set")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
|
if rsv2 {
|
||||||
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
|
errors = append(errors, "RSV2 set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rsv3 {
|
||||||
|
errors = append(errors, "RSV3 set")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch frameType {
|
switch frameType {
|
||||||
case CloseMessage, PingMessage, PongMessage:
|
case CloseMessage, PingMessage, PongMessage:
|
||||||
if c.readRemaining > maxControlFramePayloadSize {
|
if c.readRemaining > maxControlFramePayloadSize {
|
||||||
return noFrame, c.handleProtocolError("control frame length > 125")
|
errors = append(errors, "len > 125 for control")
|
||||||
}
|
}
|
||||||
if !final {
|
if !final {
|
||||||
return noFrame, c.handleProtocolError("control frame not final")
|
errors = append(errors, "FIN not set on control")
|
||||||
}
|
}
|
||||||
case TextMessage, BinaryMessage:
|
case TextMessage, BinaryMessage:
|
||||||
if !c.readFinal {
|
if !c.readFinal {
|
||||||
return noFrame, c.handleProtocolError("message start before final message frame")
|
errors = append(errors, "data before FIN")
|
||||||
}
|
}
|
||||||
c.readFinal = final
|
c.readFinal = final
|
||||||
case continuationFrame:
|
case continuationFrame:
|
||||||
if c.readFinal {
|
if c.readFinal {
|
||||||
return noFrame, c.handleProtocolError("continuation after final message frame")
|
errors = append(errors, "continuation after FIN")
|
||||||
}
|
}
|
||||||
c.readFinal = final
|
c.readFinal = final
|
||||||
default:
|
default:
|
||||||
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
|
errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
|
||||||
|
}
|
||||||
|
|
||||||
|
if mask != c.isServer {
|
||||||
|
errors = append(errors, "bad MASK")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Read and parse frame length as per
|
// 3. Read and parse frame length as per
|
||||||
|
@ -872,10 +895,6 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
|
|
||||||
// 4. Handle frame masking.
|
// 4. Handle frame masking.
|
||||||
|
|
||||||
if mask != c.isServer {
|
|
||||||
return noFrame, c.handleProtocolError("incorrect mask flag")
|
|
||||||
}
|
|
||||||
|
|
||||||
if mask {
|
if mask {
|
||||||
c.readMaskPos = 0
|
c.readMaskPos = 0
|
||||||
p, err := c.read(len(c.readMaskKey))
|
p, err := c.read(len(c.readMaskKey))
|
||||||
|
@ -935,7 +954,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
if len(payload) >= 2 {
|
if len(payload) >= 2 {
|
||||||
closeCode = int(binary.BigEndian.Uint16(payload))
|
closeCode = int(binary.BigEndian.Uint16(payload))
|
||||||
if !isValidReceivedCloseCode(closeCode) {
|
if !isValidReceivedCloseCode(closeCode) {
|
||||||
return noFrame, c.handleProtocolError("invalid close code")
|
return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
|
||||||
}
|
}
|
||||||
closeText = string(payload[2:])
|
closeText = string(payload[2:])
|
||||||
if !utf8.ValidString(closeText) {
|
if !utf8.ValidString(closeText) {
|
||||||
|
@ -952,7 +971,11 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) handleProtocolError(message string) error {
|
func (c *Conn) handleProtocolError(message string) error {
|
||||||
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
|
data := FormatCloseMessage(CloseProtocolError, message)
|
||||||
|
if len(data) > maxControlFramePayloadSize {
|
||||||
|
data = data[:maxControlFramePayloadSize]
|
||||||
|
}
|
||||||
|
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
|
||||||
return errors.New("websocket: " + message)
|
return errors.New("websocket: " + message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue