From aacb4ca3b3f1bb8cbc0bfd98703881d5a8723865 Mon Sep 17 00:00:00 2001 From: Saxon Nelson-Milton Date: Fri, 18 Mar 2022 13:55:09 +1030 Subject: [PATCH] protocol/rtmp: fixing handling of large packet sizes --- protocol/rtmp/packet.go | 37 +++++++++++++++++++++++++++++-------- protocol/rtmp/rtmp.go | 25 +++++++++++++++++-------- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/protocol/rtmp/packet.go b/protocol/rtmp/packet.go index 3b5bf6ca..128751a7 100644 --- a/protocol/rtmp/packet.go +++ b/protocol/rtmp/packet.go @@ -91,10 +91,15 @@ type packet struct { timestamp uint32 streamID uint32 bodySize uint32 + bytesRead uint32 buf []byte body []byte } +func (pkt *packet) isReady() bool { + return pkt.bytesRead == pkt.bodySize +} + // readFrom reads a packet from the RTMP connection. func (pkt *packet) readFrom(c *Conn) error { var hbuf [fullHeaderSize]byte @@ -177,6 +182,7 @@ func (pkt *packet) readFrom(c *Conn) error { if size >= 3 { pkt.timestamp = amf.DecodeInt24(header[:3]) + pkt.bytesRead = 0 if size >= 6 { pkt.bodySize = amf.DecodeInt24(header[3:6]) @@ -206,12 +212,24 @@ func (pkt *packet) readFrom(c *Conn) error { c.log(WarnLevel, pkg+"reading large packet", "size", int(pkt.bodySize)) } - _, err = c.read(pkt.body[:pkt.bodySize]) + nToRead := pkt.bodySize - pkt.bytesRead + nChunk := c.inChunkSize + if nToRead < nChunk { + nChunk = nToRead + } + + n, err := c.read(pkt.body[pkt.bytesRead : pkt.bytesRead+nChunk]) if err != nil { c.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error()) return fmt.Errorf("failed to read packet body: %w", err) } + if uint32(n) != nChunk { + return fmt.Errorf("did not read correct number of bytes, read: %d, expected: %d", n, nChunk) + } + + pkt.bytesRead += nChunk + // Keep the packet as a reference for other packets on this channel. if c.channelsIn[pkt.channel] == nil { c.channelsIn[pkt.channel] = &packet{} @@ -222,14 +240,17 @@ func (pkt *packet) readFrom(c *Conn) error { c.channelsIn[pkt.channel].timestamp = 0xffffff } - if !pkt.hasAbsTimestamp { - // Timestamps seem to always be relative. - pkt.timestamp += uint32(c.channelTimestamp[pkt.channel]) - } - c.channelTimestamp[pkt.channel] = int32(pkt.timestamp) + if pkt.isReady() { + if !pkt.hasAbsTimestamp { + // Timestamps seem to always be relative. + pkt.timestamp += uint32(c.channelTimestamp[pkt.channel]) + } + c.channelTimestamp[pkt.channel] = int32(pkt.timestamp) + + c.channelsIn[pkt.channel].body = nil + c.channelsIn[pkt.channel].hasAbsTimestamp = false + } - c.channelsIn[pkt.channel].body = nil - c.channelsIn[pkt.channel].hasAbsTimestamp = false return nil } diff --git a/protocol/rtmp/rtmp.go b/protocol/rtmp/rtmp.go index 411ea7c7..c7e4efc0 100644 --- a/protocol/rtmp/rtmp.go +++ b/protocol/rtmp/rtmp.go @@ -197,22 +197,31 @@ func connect(c *Conn) error { c.log(DebugLevel, pkg+"negotiating") var buf [256]byte + pkt := packet{buf: buf[:]} for !c.isPlaying { - pkt := packet{buf: buf[:]} err = pkt.readFrom(c) if err != nil { return fmt.Errorf("could not read from packet: %w", err) } - switch pkt.packetType { - case packetTypeAudio, packetTypeVideo, packetTypeInfo: - c.log(WarnLevel, pkg+"got packet before play; ignoring", "type", pkt.packetType) - default: - err = handlePacket(c, &pkt) - if err != nil { - return fmt.Errorf("could not handle packet: %w", err) + if pkt.isReady() { + if pkt.bodySize == 0 { + continue } + + switch pkt.packetType { + case packetTypeAudio, packetTypeVideo, packetTypeInfo: + c.log(WarnLevel, pkg+"got packet before play; ignoring", "type", pkt.packetType) + default: + err = handlePacket(c, &pkt) + if err != nil { + return fmt.Errorf("could not handle packet: %w", err) + } + } + pkt = packet{buf: buf[:]} } + + } return nil }