Read Limit Fix (#537)

This fix addresses a potential denial-of-service (DoS) vector that can cause an integer overflow in the presence of malicious WebSocket frames.

The fix adds additional checks against the remaining bytes on a connection, as well as a test to prevent regression.

Credit to Max Justicz (https://justi.cz/) for discovering and reporting this, as well as providing a robust PoC and review.

* build: go.mod to go1.12
* bugfix: fix DoS vector caused by readLimit bypass
* test: update TestReadLimit sub-test
* bugfix: payload length 127 should read bytes as uint64
* bugfix: defend against readLength overflows
This commit is contained in:
Matt Silverlock 2019-08-24 18:17:28 -07:00 committed by GitHub
parent 7e9819d926
commit 5b740c2926
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 138 additions and 37 deletions

58
conn.go
View File

@ -260,10 +260,12 @@ type Conn struct {
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
// Read fields // Read fields
reader io.ReadCloser // the current reader returned to the application reader io.ReadCloser // the current reader returned to the application
readErr error readErr error
br *bufio.Reader br *bufio.Reader
readRemaining int64 // bytes remaining in current frame. // bytes remaining in current frame.
// set setReadRemaining to safely update this value and prevent overflow
readRemaining int64
readFinal bool // true the current message has more frames. readFinal bool // true the current message has more frames.
readLength int64 // Message size. readLength int64 // Message size.
readLimit int64 // Maximum message size. readLimit int64 // Maximum message size.
@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
return c return c
} }
// setReadRemaining tracks the number of bytes remaining on the connection. If n
// overflows, an ErrReadLimit is returned.
func (c *Conn) setReadRemaining(n int64) error {
if n < 0 {
return ErrReadLimit
}
c.readRemaining = n
return nil
}
// Subprotocol returns the negotiated protocol for the connection. // Subprotocol returns the negotiated protocol for the connection.
func (c *Conn) Subprotocol() string { func (c *Conn) Subprotocol() string {
return c.subprotocol return c.subprotocol
@ -790,7 +803,7 @@ func (c *Conn) advanceFrame() (int, error) {
final := p[0]&finalBit != 0 final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf) frameType := int(p[0] & 0xf)
mask := p[1]&maskBit != 0 mask := p[1]&maskBit != 0
c.readRemaining = int64(p[1] & 0x7f) c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 { if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
@ -824,7 +837,17 @@ func (c *Conn) advanceFrame() (int, error) {
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType)) return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
} }
// 3. Read and parse frame length. // 3. Read and parse frame length as per
// https://tools.ietf.org/html/rfc6455#section-5.2
//
// The length of the "Payload data", in bytes: if 0-125, that is the payload
// length.
// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
// integer are the payload length.
// - If 127, the following 8 bytes interpreted as
// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
// payload length. Multibyte length quantities are expressed in network byte
// order.
switch c.readRemaining { switch c.readRemaining {
case 126: case 126:
@ -832,13 +855,19 @@ func (c *Conn) advanceFrame() (int, error) {
if err != nil { if err != nil {
return noFrame, err return noFrame, err
} }
c.readRemaining = int64(binary.BigEndian.Uint16(p))
if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
return noFrame, err
}
case 127: case 127:
p, err := c.read(8) p, err := c.read(8)
if err != nil { if err != nil {
return noFrame, err return noFrame, err
} }
c.readRemaining = int64(binary.BigEndian.Uint64(p))
if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
return noFrame, err
}
} }
// 4. Handle frame masking. // 4. Handle frame masking.
@ -861,6 +890,12 @@ func (c *Conn) advanceFrame() (int, error) {
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
c.readLength += c.readRemaining c.readLength += c.readRemaining
// Don't allow readLength to overflow in the presence of a large readRemaining
// counter.
if c.readLength < 0 {
return noFrame, ErrReadLimit
}
if c.readLimit > 0 && c.readLength > c.readLimit { if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit return noFrame, ErrReadLimit
@ -874,7 +909,7 @@ func (c *Conn) advanceFrame() (int, error) {
var payload []byte var payload []byte
if c.readRemaining > 0 { if c.readRemaining > 0 {
payload, err = c.read(int(c.readRemaining)) payload, err = c.read(int(c.readRemaining))
c.readRemaining = 0 c.setReadRemaining(0)
if err != nil { if err != nil {
return noFrame, err return noFrame, err
} }
@ -947,6 +982,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
c.readErr = hideTempErr(err) c.readErr = hideTempErr(err)
break break
} }
if frameType == TextMessage || frameType == BinaryMessage { if frameType == TextMessage || frameType == BinaryMessage {
c.messageReader = &messageReader{c} c.messageReader = &messageReader{c}
c.reader = c.messageReader c.reader = c.messageReader
@ -987,7 +1023,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
if c.isServer { if c.isServer {
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
} }
c.readRemaining -= int64(n) rem := c.readRemaining
rem -= int64(n)
c.setReadRemaining(rem)
if c.readRemaining > 0 && c.readErr == io.EOF { if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF c.readErr = errUnexpectedEOF
} }

View File

@ -55,7 +55,10 @@ func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
} }
func TestFraming(t *testing.T) { func TestFraming(t *testing.T) {
frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537} frameSizes := []int{
0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535,
// 65536, 65537
}
var readChunkers = []struct { var readChunkers = []struct {
name string name string
f func(io.Reader) io.Reader f func(io.Reader) io.Reader
@ -120,6 +123,8 @@ func TestFraming(t *testing.T) {
t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
continue continue
} }
t.Logf("frame size: %d", n)
rbuf, err := ioutil.ReadAll(r) rbuf, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
@ -458,37 +463,93 @@ func TestWriteAfterMessageWriterClose(t *testing.T) {
} }
func TestReadLimit(t *testing.T) { func TestReadLimit(t *testing.T) {
t.Run("Test ReadLimit is enforced", func(t *testing.T) {
const readLimit = 512
message := make([]byte, readLimit+1)
const readLimit = 512 var b1, b2 bytes.Buffer
message := make([]byte, readLimit+1) wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
rc := newTestConn(&b1, &b2, true)
rc.SetReadLimit(readLimit)
var b1, b2 bytes.Buffer // Send message at the limit with interleaved pong.
wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) w, _ := wc.NextWriter(BinaryMessage)
rc := newTestConn(&b1, &b2, true) w.Write(message[:readLimit-1])
rc.SetReadLimit(readLimit) wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
w.Write(message[:1])
w.Close()
// Send message at the limit with interleaved pong. // Send message larger than the limit.
w, _ := wc.NextWriter(BinaryMessage) wc.WriteMessage(BinaryMessage, message[:readLimit+1])
w.Write(message[:readLimit-1])
wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
w.Write(message[:1])
w.Close()
// Send message larger than the limit. op, _, err := rc.NextReader()
wc.WriteMessage(BinaryMessage, message[:readLimit+1]) if op != BinaryMessage || err != nil {
t.Fatalf("1: NextReader() returned %d, %v", op, err)
}
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("2: NextReader() returned %d, %v", op, err)
}
_, err = io.Copy(ioutil.Discard, r)
if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err)
}
})
op, _, err := rc.NextReader() t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) {
if op != BinaryMessage || err != nil { const readLimit = 1
t.Fatalf("1: NextReader() returned %d, %v", op, err)
} var b1, b2 bytes.Buffer
op, r, err := rc.NextReader() rc := newTestConn(&b1, &b2, true)
if op != BinaryMessage || err != nil { rc.SetReadLimit(readLimit)
t.Fatalf("2: NextReader() returned %d, %v", op, err)
} // First, send a non-final binary message
_, err = io.Copy(ioutil.Discard, r) b1.Write([]byte("\x02\x81"))
if err != ErrReadLimit {
t.Fatalf("io.Copy() returned %v", err) // Mask key
} b1.Write([]byte("\x00\x00\x00\x00"))
// First payload
b1.Write([]byte("A"))
// Next, send a negative-length, non-final continuation frame
b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00"))
// Mask key
b1.Write([]byte("\x00\x00\x00\x00"))
// Next, send a too long, final continuation frame
b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05"))
// Mask key
b1.Write([]byte("\x00\x00\x00\x00"))
// Too-long payload
b1.Write([]byte("BCDEF"))
op, r, err := rc.NextReader()
if op != BinaryMessage || err != nil {
t.Fatalf("1: NextReader() returned %d, %v", op, err)
}
var buf [10]byte
var read int
n, err := r.Read(buf[:])
if err != nil && err != ErrReadLimit {
t.Fatalf("unexpected error testing read limit: %v", err)
}
read += n
n, err = r.Read(buf[:])
if err != nil && err != ErrReadLimit {
t.Fatalf("unexpected error testing read limit: %v", err)
}
read += n
if err == nil && read > readLimit {
t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read)
}
})
} }
func TestAddrs(t *testing.T) { func TestAddrs(t *testing.T) {

2
go.mod
View File

@ -1 +1,3 @@
module github.com/gorilla/websocket module github.com/gorilla/websocket
go 1.12