forked from mirror/websocket
Read Limit Fix (#537)
This fix addresses a potential denial-of-service (DoS) vector that can cause an integer overflow in the presence of malicious WebSocket frames. The fix adds additional checks against the remaining bytes on a connection, as well as a test to prevent regression. Credit to Max Justicz (https://justi.cz/) for discovering and reporting this, as well as providing a robust PoC and review. * build: go.mod to go1.12 * bugfix: fix DoS vector caused by readLimit bypass * test: update TestReadLimit sub-test * bugfix: payload length 127 should read bytes as uint64 * bugfix: defend against readLength overflows
This commit is contained in:
parent
7e9819d926
commit
5b740c2926
52
conn.go
52
conn.go
|
@ -263,7 +263,9 @@ type Conn struct {
|
||||||
reader io.ReadCloser // the current reader returned to the application
|
reader io.ReadCloser // the current reader returned to the application
|
||||||
readErr error
|
readErr error
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
readRemaining int64 // bytes remaining in current frame.
|
// bytes remaining in current frame.
|
||||||
|
// set setReadRemaining to safely update this value and prevent overflow
|
||||||
|
readRemaining int64
|
||||||
readFinal bool // true the current message has more frames.
|
readFinal bool // true the current message has more frames.
|
||||||
readLength int64 // Message size.
|
readLength int64 // Message size.
|
||||||
readLimit int64 // Maximum message size.
|
readLimit int64 // Maximum message size.
|
||||||
|
@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setReadRemaining tracks the number of bytes remaining on the connection. If n
|
||||||
|
// overflows, an ErrReadLimit is returned.
|
||||||
|
func (c *Conn) setReadRemaining(n int64) error {
|
||||||
|
if n < 0 {
|
||||||
|
return ErrReadLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
c.readRemaining = n
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Subprotocol returns the negotiated protocol for the connection.
|
// Subprotocol returns the negotiated protocol for the connection.
|
||||||
func (c *Conn) Subprotocol() string {
|
func (c *Conn) Subprotocol() string {
|
||||||
return c.subprotocol
|
return c.subprotocol
|
||||||
|
@ -790,7 +803,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
final := p[0]&finalBit != 0
|
final := p[0]&finalBit != 0
|
||||||
frameType := int(p[0] & 0xf)
|
frameType := int(p[0] & 0xf)
|
||||||
mask := p[1]&maskBit != 0
|
mask := p[1]&maskBit != 0
|
||||||
c.readRemaining = int64(p[1] & 0x7f)
|
c.setReadRemaining(int64(p[1] & 0x7f))
|
||||||
|
|
||||||
c.readDecompress = false
|
c.readDecompress = false
|
||||||
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
|
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
|
||||||
|
@ -824,7 +837,17 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
|
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Read and parse frame length.
|
// 3. Read and parse frame length as per
|
||||||
|
// https://tools.ietf.org/html/rfc6455#section-5.2
|
||||||
|
//
|
||||||
|
// The length of the "Payload data", in bytes: if 0-125, that is the payload
|
||||||
|
// length.
|
||||||
|
// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
|
||||||
|
// integer are the payload length.
|
||||||
|
// - If 127, the following 8 bytes interpreted as
|
||||||
|
// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
|
||||||
|
// payload length. Multibyte length quantities are expressed in network byte
|
||||||
|
// order.
|
||||||
|
|
||||||
switch c.readRemaining {
|
switch c.readRemaining {
|
||||||
case 126:
|
case 126:
|
||||||
|
@ -832,13 +855,19 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
c.readRemaining = int64(binary.BigEndian.Uint16(p))
|
|
||||||
|
if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
|
||||||
|
return noFrame, err
|
||||||
|
}
|
||||||
case 127:
|
case 127:
|
||||||
p, err := c.read(8)
|
p, err := c.read(8)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
c.readRemaining = int64(binary.BigEndian.Uint64(p))
|
|
||||||
|
if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
|
||||||
|
return noFrame, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Handle frame masking.
|
// 4. Handle frame masking.
|
||||||
|
@ -861,6 +890,12 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
|
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
|
||||||
|
|
||||||
c.readLength += c.readRemaining
|
c.readLength += c.readRemaining
|
||||||
|
// Don't allow readLength to overflow in the presence of a large readRemaining
|
||||||
|
// counter.
|
||||||
|
if c.readLength < 0 {
|
||||||
|
return noFrame, ErrReadLimit
|
||||||
|
}
|
||||||
|
|
||||||
if c.readLimit > 0 && c.readLength > c.readLimit {
|
if c.readLimit > 0 && c.readLength > c.readLimit {
|
||||||
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
|
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
|
||||||
return noFrame, ErrReadLimit
|
return noFrame, ErrReadLimit
|
||||||
|
@ -874,7 +909,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
||||||
var payload []byte
|
var payload []byte
|
||||||
if c.readRemaining > 0 {
|
if c.readRemaining > 0 {
|
||||||
payload, err = c.read(int(c.readRemaining))
|
payload, err = c.read(int(c.readRemaining))
|
||||||
c.readRemaining = 0
|
c.setReadRemaining(0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noFrame, err
|
return noFrame, err
|
||||||
}
|
}
|
||||||
|
@ -947,6 +982,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
||||||
c.readErr = hideTempErr(err)
|
c.readErr = hideTempErr(err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if frameType == TextMessage || frameType == BinaryMessage {
|
if frameType == TextMessage || frameType == BinaryMessage {
|
||||||
c.messageReader = &messageReader{c}
|
c.messageReader = &messageReader{c}
|
||||||
c.reader = c.messageReader
|
c.reader = c.messageReader
|
||||||
|
@ -987,7 +1023,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
||||||
if c.isServer {
|
if c.isServer {
|
||||||
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
|
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
|
||||||
}
|
}
|
||||||
c.readRemaining -= int64(n)
|
rem := c.readRemaining
|
||||||
|
rem -= int64(n)
|
||||||
|
c.setReadRemaining(rem)
|
||||||
if c.readRemaining > 0 && c.readErr == io.EOF {
|
if c.readRemaining > 0 && c.readErr == io.EOF {
|
||||||
c.readErr = errUnexpectedEOF
|
c.readErr = errUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
|
65
conn_test.go
65
conn_test.go
|
@ -55,7 +55,10 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFraming(t *testing.T) {
|
func TestFraming(t *testing.T) {
|
||||||
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
|
frameSizes := []int{
|
||||||
|
0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
|
||||||
|
// 65536, 65537
|
||||||
|
}
|
||||||
var readChunkers = []struct {
|
var readChunkers = []struct {
|
||||||
name string
|
name string
|
||||||
f func(io.Reader) io.Reader
|
f func(io.Reader) io.Reader
|
||||||
|
@ -120,6 +123,8 @@ func TestFraming(t *testing.T) {
|
||||||
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
|
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Logf("frame size: %d", n)
|
||||||
rbuf, err := ioutil.ReadAll(r)
|
rbuf, err := ioutil.ReadAll(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
|
||||||
|
@ -458,7 +463,7 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestReadLimit(t *testing.T) {
|
func TestReadLimit(t *testing.T) {
|
||||||
|
t.Run("Test ReadLimit is enforced", func(t *testing.T) {
|
||||||
const readLimit = 512
|
const readLimit = 512
|
||||||
message := make([]byte, readLimit+1)
|
message := make([]byte, readLimit+1)
|
||||||
|
|
||||||
|
@ -489,6 +494,62 @@ func TestReadLimit(t *testing.T) {
|
||||||
if err != ErrReadLimit {
|
if err != ErrReadLimit {
|
||||||
t.Fatalf("io.Copy() returned %v", err)
|
t.Fatalf("io.Copy() returned %v", err)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
|
||||||
|
const readLimit = 1
|
||||||
|
|
||||||
|
var b1, b2 bytes.Buffer
|
||||||
|
rc := newTestConn(&b1, &b2, true)
|
||||||
|
rc.SetReadLimit(readLimit)
|
||||||
|
|
||||||
|
// First, send a non-final binary message
|
||||||
|
b1.Write([]byte("\x02\x81"))
|
||||||
|
|
||||||
|
// Mask key
|
||||||
|
b1.Write([]byte("\x00\x00\x00\x00"))
|
||||||
|
|
||||||
|
// First payload
|
||||||
|
b1.Write([]byte("A"))
|
||||||
|
|
||||||
|
// Next, send a negative-length, non-final continuation frame
|
||||||
|
b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
|
||||||
|
|
||||||
|
// Mask key
|
||||||
|
b1.Write([]byte("\x00\x00\x00\x00"))
|
||||||
|
|
||||||
|
// Next, send a too long, final continuation frame
|
||||||
|
b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
|
||||||
|
|
||||||
|
// Mask key
|
||||||
|
b1.Write([]byte("\x00\x00\x00\x00"))
|
||||||
|
|
||||||
|
// Too-long payload
|
||||||
|
b1.Write([]byte("BCDEF"))
|
||||||
|
|
||||||
|
op, r, err := rc.NextReader()
|
||||||
|
if op != BinaryMessage || err != nil {
|
||||||
|
t.Fatalf("1: NextReader() returned %d, %v", op, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf [10]byte
|
||||||
|
var read int
|
||||||
|
n, err := r.Read(buf[:])
|
||||||
|
if err != nil && err != ErrReadLimit {
|
||||||
|
t.Fatalf("unexpected error testing read limit: %v", err)
|
||||||
|
}
|
||||||
|
read += n
|
||||||
|
|
||||||
|
n, err = r.Read(buf[:])
|
||||||
|
if err != nil && err != ErrReadLimit {
|
||||||
|
t.Fatalf("unexpected error testing read limit: %v", err)
|
||||||
|
}
|
||||||
|
read += n
|
||||||
|
|
||||||
|
if err == nil && read > readLimit {
|
||||||
|
t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddrs(t *testing.T) {
|
func TestAddrs(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue