From 0497ee5302392b4f5556cb77a7cd7a487a342ed5 Mon Sep 17 00:00:00 2001 From: Saxon <saxon@ausocean.org> Date: Mon, 1 Jul 2019 12:36:11 +0930 Subject: [PATCH 1/4] container/mts: GetPTSRange checks for PUSI when looking for first PTS --- container/mts/mpegts.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index 3f3851a3..25e36ff1 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -320,15 +320,22 @@ func DiscontinuityIndicator(f bool) Option { // GetPTSRange retreives the first and last PTS of an MPEGTS clip. func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) { - // Find the first packet with PID pidType. - pkt, _, err := FindPid(clip, pid) - if err != nil { - return [2]uint64{}, err + var _pkt packet.Packet + // Find the first packet with PID pidType and PUSI. + var i int + for { + pkt, _i, err := FindPid(clip[i:], pid) + if err != nil { + return [2]uint64{}, err + } + copy(_pkt[:], pkt) + if _pkt.PayloadUnitStartIndicator() { + break + } + i = _i + PacketSize } // Get the payload of the packet, which will be the start of the PES packet. - var _pkt packet.Packet - copy(_pkt[:], pkt) payload, err := packet.Payload(&_pkt) if err != nil { return [2]uint64{}, err From 0d240fa7ff9fa1ca6d23d9f6785325017282b1b5 Mon Sep 17 00:00:00 2001 From: Saxon <saxon@ausocean.org> Date: Mon, 1 Jul 2019 13:54:18 +0930 Subject: [PATCH 2/4] container/mts: checking index so that we don't go out of bounds --- container/mts/mpegts.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index 25e36ff1..27ad6528 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -324,6 +324,9 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) { // Find the first packet with PID pidType and PUSI. var i int for { + if i >= len(clip) { + return [2]uint64{}, errors.New("could not find payload start") + } pkt, _i, err := FindPid(clip[i:], pid) if err != nil { return [2]uint64{}, err From 7bd885bcfb23cdfe66253d5dfd6072444707bd72 Mon Sep 17 00:00:00 2001 From: Saxon <saxon@ausocean.org> Date: Mon, 1 Jul 2019 14:01:49 +0930 Subject: [PATCH 3/4] container/mts: fixed infinite loop --- container/mts/mpegts.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index 27ad6528..6ade2f87 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -335,7 +335,7 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) { if _pkt.PayloadUnitStartIndicator() { break } - i = _i + PacketSize + i += _i + PacketSize } // Get the payload of the packet, which will be the start of the PES packet. From b017e92185d8cb64ee301556ca2520b6aef38f02 Mon Sep 17 00:00:00 2001 From: Saxon <saxon@ausocean.org> Date: Mon, 1 Jul 2019 19:08:20 +0930 Subject: [PATCH 4/4] container/mts: wrote more tests for GetPTSRange --- container/mts/mpegts.go | 8 ++-- container/mts/mpegts_test.go | 84 +++++++++++++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index 6ade2f87..60b1c484 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -30,11 +30,11 @@ LICENSE package mts import ( - "errors" "fmt" "github.com/Comcast/gots/packet" "github.com/Comcast/gots/pes" + "github.com/pkg/errors" ) const PacketSize = 188 @@ -325,11 +325,11 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) { var i int for { if i >= len(clip) { - return [2]uint64{}, errors.New("could not find payload start") + return [2]uint64{}, errNoPTS } pkt, _i, err := FindPid(clip[i:], pid) if err != nil { - return [2]uint64{}, err + return [2]uint64{}, errors.Wrap(err, fmt.Sprintf("could not find packet of PID: %d", pid)) } copy(_pkt[:], pkt) if _pkt.PayloadUnitStartIndicator() { @@ -369,3 +369,5 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) { } return } + +var errNoPTS = errors.New("could not find PTS") diff --git a/container/mts/mpegts_test.go b/container/mts/mpegts_test.go index 4c90cc0e..1a6409c7 100644 --- a/container/mts/mpegts_test.go +++ b/container/mts/mpegts_test.go @@ -39,8 +39,8 @@ import ( ) // TestGetPTSRange checks that GetPTSRange can correctly get the first and last -// PTS in an MPEGTS clip. -func TestGetPTSRange(t *testing.T) { +// PTS in an MPEGTS clip for a general case. +func TestGetPTSRange1(t *testing.T) { const ( numOfFrames = 20 maxFrameSize = 1000 @@ -157,6 +157,86 @@ func writeFrame(b *bytes.Buffer, frame []byte, pts uint64) error { return nil } +// TestGetPTSRange2 checks that GetPTSRange behaves correctly with cases where +// the first instance of a PID is not a payload start, and also where there +// are no payload starts. +func TestGetPTSRange2(t *testing.T) { + const ( + nPackets = 8 // The number of MTS packets we will generate. + wantPID = 1 // The PID we want. + ) + tests := []struct { + pusi []bool // The value of PUSI for each packet. + pid []uint16 // The PIDs for each packet. + pts []uint64 // The PTS for each packet. + want [2]uint64 // The wanted PTS from GetPTSRange. + err error // The error we expect from GetPTSRange. + }{ + { + []bool{false, false, false, true, false, false, true, false}, + []uint16{0, 0, 1, 1, 1, 1, 1, 1}, + []uint64{0, 0, 0, 1, 0, 0, 2, 0}, + [2]uint64{1, 2}, + nil, + }, + { + []bool{false, false, false, true, false, false, false, false}, + []uint16{0, 0, 1, 1, 1, 1, 1, 1}, + []uint64{0, 0, 0, 1, 0, 0, 0, 0}, + [2]uint64{1, 1}, + nil, + }, + { + []bool{false, false, false, false, false, false, false, false}, + []uint16{0, 0, 1, 1, 1, 1, 1, 1}, + []uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [2]uint64{0, 0}, + errNoPTS, + }, + } + + var clip bytes.Buffer + + for i, test := range tests { + // Generate MTS packets for this test. + clip.Reset() + for j := 0; j < nPackets; j++ { + pesPkt := pes.Packet{ + StreamID: H264ID, + PDI: hasPTS, + PTS: test.pts[j], + Data: []byte{}, + HeaderLength: 5, + } + buf := pesPkt.Bytes(nil) + + pkt := Packet{ + PUSI: test.pusi[j], + PID: test.pid[j], + RAI: true, + CC: 0, + AFC: hasAdaptationField | hasPayload, + PCRF: true, + } + pkt.FillPayload(buf) + + _, err := clip.Write(pkt.Bytes(nil)) + if err != nil { + t.Fatalf("did not expect clip write error: %v", err) + } + } + + pts, err := GetPTSRange(clip.Bytes(), wantPID) + if err != test.err { + t.Errorf("did not get expected error for test: %v\nGot: %v\nWant: %v\n", i, err, test.err) + } + + if pts != test.want { + t.Errorf("did not get expected result for test: %v\nGot: %v\nWant: %v\n", i, pts, test.want) + } + } +} + // TestBytes checks that Packet.Bytes() correctly produces a []byte // representation of a Packet. func TestBytes(t *testing.T) {