Improve protocol error messages

To aid protocol error debugging, report all errors found in the first two bytes of a message header.
This commit is contained in:
Gary Burd 2022-01-02 12:16:08 -08:00 committed by GitHub
parent 2d6ee4c55c
commit f0643a3a18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 40 additions and 17 deletions

55
conn.go
View File

@ -13,6 +13,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
"unicode/utf8" "unicode/utf8"
@ -794,47 +795,69 @@ func (c *Conn) advanceFrame() (int, error) {
} }
// 2. Read and parse first two bytes of frame header. // 2. Read and parse first two bytes of frame header.
// To aid debugging, collect and report all errors in the first two bytes
// of the header.
var errors []string
p, err := c.read(2) p, err := c.read(2)
if err != nil { if err != nil {
return noFrame, err return noFrame, err
} }
final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf) frameType := int(p[0] & 0xf)
final := p[0]&finalBit != 0
rsv1 := p[0]&rsv1Bit != 0
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0 mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f)) c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { if rsv1 {
if c.newDecompressionReader != nil {
c.readDecompress = true c.readDecompress = true
p[0] &^= rsv1Bit } else {
errors = append(errors, "RSV1 set")
}
} }
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 { if rsv2 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16)) errors = append(errors, "RSV2 set")
}
if rsv3 {
errors = append(errors, "RSV3 set")
} }
switch frameType { switch frameType {
case CloseMessage, PingMessage, PongMessage: case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize { if c.readRemaining > maxControlFramePayloadSize {
return noFrame, c.handleProtocolError("control frame length > 125") errors = append(errors, "len > 125 for control")
} }
if !final { if !final {
return noFrame, c.handleProtocolError("control frame not final") errors = append(errors, "FIN not set on control")
} }
case TextMessage, BinaryMessage: case TextMessage, BinaryMessage:
if !c.readFinal { if !c.readFinal {
return noFrame, c.handleProtocolError("message start before final message frame") errors = append(errors, "data before FIN")
} }
c.readFinal = final c.readFinal = final
case continuationFrame: case continuationFrame:
if c.readFinal { if c.readFinal {
return noFrame, c.handleProtocolError("continuation after final message frame") errors = append(errors, "continuation after FIN")
} }
c.readFinal = final c.readFinal = final
default: default:
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
}
if mask != c.isServer {
errors = append(errors, "bad MASK")
}
if len(errors) > 0 {
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
} }
// 3. Read and parse frame length as per // 3. Read and parse frame length as per
@ -872,10 +895,6 @@ func (c *Conn) advanceFrame() (int, error) {
// 4. Handle frame masking. // 4. Handle frame masking.
if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}
if mask { if mask {
c.readMaskPos = 0 c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey)) p, err := c.read(len(c.readMaskKey))
@ -935,7 +954,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 { if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload)) closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) { if !isValidReceivedCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code") return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
} }
closeText = string(payload[2:]) closeText = string(payload[2:])
if !utf8.ValidString(closeText) { if !utf8.ValidString(closeText) {
@ -952,7 +971,11 @@ func (c *Conn) advanceFrame() (int, error) {
} }
func (c *Conn) handleProtocolError(message string) error { func (c *Conn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait)) data := FormatCloseMessage(CloseProtocolError, message)
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message) return errors.New("websocket: " + message)
} }