diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index 5f326942..f5fae7dc 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -318,9 +318,9 @@ func DiscontinuityIndicator(f bool) Option { } // GetPTSRange retreives the first and last PTS of an MPEGTS clip. -func GetPTSRange(clip []byte) (pts [2]uint64, err error) { - // Find the first video packet. - pkt, _, err := FindPid(clip, videoPid) +func GetPTSRange(clip []byte, pidType int) (pts [2]uint64, err error) { + // Find the first packet with PID pidType. + pkt, _, err := FindPid(clip, uint16(pidType)) if err != nil { return [2]uint64{}, err } @@ -344,7 +344,7 @@ func GetPTSRange(clip []byte) (pts [2]uint64, err error) { // Get the final PTS searching from end of clip for access unit start. for i := len(clip) - PacketSize; i >= 0; i -= PacketSize { copy(_pkt[:], clip[i:i+PacketSize]) - if packet.PayloadUnitStartIndicator(&_pkt) { + if packet.PayloadUnitStartIndicator(&_pkt) && _pkt.PID() == pidType { payload, err = packet.Payload(&_pkt) if err != nil { return [2]uint64{}, err diff --git a/container/mts/mpegts_test.go b/container/mts/mpegts_test.go index eed1162a..579acd65 100644 --- a/container/mts/mpegts_test.go +++ b/container/mts/mpegts_test.go @@ -51,7 +51,7 @@ func TestGetPTSRange(t *testing.T) { curTime += interval } - got, err := GetPTSRange(clip.Bytes()) + got, err := GetPTSRange(clip.Bytes(), videoPid) if err != nil { t.Fatalf("did not expect error getting PTS range: %v", err) }