From 0497ee5302392b4f5556cb77a7cd7a487a342ed5 Mon Sep 17 00:00:00 2001
From: Saxon <saxon@ausocean.org>
Date: Mon, 1 Jul 2019 12:36:11 +0930
Subject: [PATCH 1/4] container/mts: GetPTSRange checks for PUSI when looking
 for first PTS

---
 container/mts/mpegts.go | 19 +++++++++++++------
 1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go
index 3f3851a3..25e36ff1 100644
--- a/container/mts/mpegts.go
+++ b/container/mts/mpegts.go
@@ -320,15 +320,22 @@ func DiscontinuityIndicator(f bool) Option {
 
 // GetPTSRange retreives the first and last PTS of an MPEGTS clip.
 func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) {
-	// Find the first packet with PID pidType.
-	pkt, _, err := FindPid(clip, pid)
-	if err != nil {
-		return [2]uint64{}, err
+	var _pkt packet.Packet
+	// Find the first packet with PID pidType and PUSI.
+	var i int
+	for {
+		pkt, _i, err := FindPid(clip[i:], pid)
+		if err != nil {
+			return [2]uint64{}, err
+		}
+		copy(_pkt[:], pkt)
+		if _pkt.PayloadUnitStartIndicator() {
+			break
+		}
+		i = _i + PacketSize
 	}
 
 	// Get the payload of the packet, which will be the start of the PES packet.
-	var _pkt packet.Packet
-	copy(_pkt[:], pkt)
 	payload, err := packet.Payload(&_pkt)
 	if err != nil {
 		return [2]uint64{}, err

From 0d240fa7ff9fa1ca6d23d9f6785325017282b1b5 Mon Sep 17 00:00:00 2001
From: Saxon <saxon@ausocean.org>
Date: Mon, 1 Jul 2019 13:54:18 +0930
Subject: [PATCH 2/4] container/mts: checking index so that we don't go out of
 bounds

---
 container/mts/mpegts.go | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go
index 25e36ff1..27ad6528 100644
--- a/container/mts/mpegts.go
+++ b/container/mts/mpegts.go
@@ -324,6 +324,9 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) {
 	// Find the first packet with PID pidType and PUSI.
 	var i int
 	for {
+		if i >= len(clip) {
+			return [2]uint64{}, errors.New("could not find payload start")
+		}
 		pkt, _i, err := FindPid(clip[i:], pid)
 		if err != nil {
 			return [2]uint64{}, err

From 7bd885bcfb23cdfe66253d5dfd6072444707bd72 Mon Sep 17 00:00:00 2001
From: Saxon <saxon@ausocean.org>
Date: Mon, 1 Jul 2019 14:01:49 +0930
Subject: [PATCH 3/4] container/mts: fixed infinite loop

---
 container/mts/mpegts.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go
index 27ad6528..6ade2f87 100644
--- a/container/mts/mpegts.go
+++ b/container/mts/mpegts.go
@@ -335,7 +335,7 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) {
 		if _pkt.PayloadUnitStartIndicator() {
 			break
 		}
-		i = _i + PacketSize
+		i += _i + PacketSize
 	}
 
 	// Get the payload of the packet, which will be the start of the PES packet.

From b017e92185d8cb64ee301556ca2520b6aef38f02 Mon Sep 17 00:00:00 2001
From: Saxon <saxon@ausocean.org>
Date: Mon, 1 Jul 2019 19:08:20 +0930
Subject: [PATCH 4/4] container/mts: wrote more tests for GetPTSRange

---
 container/mts/mpegts.go      |  8 ++--
 container/mts/mpegts_test.go | 84 +++++++++++++++++++++++++++++++++++-
 2 files changed, 87 insertions(+), 5 deletions(-)

diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go
index 6ade2f87..60b1c484 100644
--- a/container/mts/mpegts.go
+++ b/container/mts/mpegts.go
@@ -30,11 +30,11 @@ LICENSE
 package mts
 
 import (
-	"errors"
 	"fmt"
 
 	"github.com/Comcast/gots/packet"
 	"github.com/Comcast/gots/pes"
+	"github.com/pkg/errors"
 )
 
 const PacketSize = 188
@@ -325,11 +325,11 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) {
 	var i int
 	for {
 		if i >= len(clip) {
-			return [2]uint64{}, errors.New("could not find payload start")
+			return [2]uint64{}, errNoPTS
 		}
 		pkt, _i, err := FindPid(clip[i:], pid)
 		if err != nil {
-			return [2]uint64{}, err
+			return [2]uint64{}, errors.Wrap(err, fmt.Sprintf("could not find packet of PID: %d", pid))
 		}
 		copy(_pkt[:], pkt)
 		if _pkt.PayloadUnitStartIndicator() {
@@ -369,3 +369,5 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) {
 	}
 	return
 }
+
+var errNoPTS = errors.New("could not find PTS")
diff --git a/container/mts/mpegts_test.go b/container/mts/mpegts_test.go
index 4c90cc0e..1a6409c7 100644
--- a/container/mts/mpegts_test.go
+++ b/container/mts/mpegts_test.go
@@ -39,8 +39,8 @@ import (
 )
 
 // TestGetPTSRange checks that GetPTSRange can correctly get the first and last
-// PTS in an MPEGTS clip.
-func TestGetPTSRange(t *testing.T) {
+// PTS in an MPEGTS clip for a general case.
+func TestGetPTSRange1(t *testing.T) {
 	const (
 		numOfFrames  = 20
 		maxFrameSize = 1000
@@ -157,6 +157,86 @@ func writeFrame(b *bytes.Buffer, frame []byte, pts uint64) error {
 	return nil
 }
 
+// TestGetPTSRange2 checks that GetPTSRange behaves correctly with cases where
+// the first instance of a PID is not a payload start, and also where there
+// are no payload starts.
+func TestGetPTSRange2(t *testing.T) {
+	const (
+		nPackets = 8 // The number of MTS packets we will generate.
+		wantPID  = 1 // The PID we want.
+	)
+	tests := []struct {
+		pusi []bool    // The value of PUSI for each packet.
+		pid  []uint16  // The PIDs for each packet.
+		pts  []uint64  // The PTS for each packet.
+		want [2]uint64 // The wanted PTS from GetPTSRange.
+		err  error     // The error we expect from GetPTSRange.
+	}{
+		{
+			[]bool{false, false, false, true, false, false, true, false},
+			[]uint16{0, 0, 1, 1, 1, 1, 1, 1},
+			[]uint64{0, 0, 0, 1, 0, 0, 2, 0},
+			[2]uint64{1, 2},
+			nil,
+		},
+		{
+			[]bool{false, false, false, true, false, false, false, false},
+			[]uint16{0, 0, 1, 1, 1, 1, 1, 1},
+			[]uint64{0, 0, 0, 1, 0, 0, 0, 0},
+			[2]uint64{1, 1},
+			nil,
+		},
+		{
+			[]bool{false, false, false, false, false, false, false, false},
+			[]uint16{0, 0, 1, 1, 1, 1, 1, 1},
+			[]uint64{0, 0, 0, 0, 0, 0, 0, 0},
+			[2]uint64{0, 0},
+			errNoPTS,
+		},
+	}
+
+	var clip bytes.Buffer
+
+	for i, test := range tests {
+		// Generate MTS packets for this test.
+		clip.Reset()
+		for j := 0; j < nPackets; j++ {
+			pesPkt := pes.Packet{
+				StreamID:     H264ID,
+				PDI:          hasPTS,
+				PTS:          test.pts[j],
+				Data:         []byte{},
+				HeaderLength: 5,
+			}
+			buf := pesPkt.Bytes(nil)
+
+			pkt := Packet{
+				PUSI: test.pusi[j],
+				PID:  test.pid[j],
+				RAI:  true,
+				CC:   0,
+				AFC:  hasAdaptationField | hasPayload,
+				PCRF: true,
+			}
+			pkt.FillPayload(buf)
+
+			_, err := clip.Write(pkt.Bytes(nil))
+			if err != nil {
+				t.Fatalf("did not expect clip write error: %v", err)
+			}
+		}
+
+		pts, err := GetPTSRange(clip.Bytes(), wantPID)
+		if err != test.err {
+			t.Errorf("did not get expected error for test: %v\nGot: %v\nWant: %v\n", i, err, test.err)
+		}
+
+		if pts != test.want {
+			t.Errorf("did not get expected result for test: %v\nGot: %v\nWant: %v\n", i, pts, test.want)
+		}
+	}
+}
+
 // TestBytes checks that Packet.Bytes() correctly produces a []byte
 // representation of a Packet.
 func TestBytes(t *testing.T) {