diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index c43df44c..28ad802b 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -178,8 +178,9 @@ func FindPat(d []byte) ([]byte, int, error) { // Errors used by FindPid. var ( - errInvalidLen = errors.New("MPEG-TS data not of valid length") - errCouldNotFind = errors.New("could not find packet with given PID") + errInvalidLen = errors.New("MPEG-TS data not of valid length") + errCouldNotFind = errors.New("could not find packet with given PID") + errNotConsecutive = errors.New("could not find consecutive PIDs") ) // FindPid will take a clip of MPEG-TS and try to find a packet with given PID - if one @@ -221,26 +222,24 @@ func LastPid(d []byte, pid uint16) (pkt []byte, i int, err error) { // along with optional metadata if present. Commonly used to find a // PAT immediately followed by a PMT. func IndexPid(d []byte, pids ...uint16) (idx int, m map[string]string, err error) { - prev := 0 + idx = -1 for _, pid := range pids { - pkt, i, _err := FindPid(d, pid) + if len(d) < PacketSize { + return idx, m, errInvalidLen + } + pkt, i, err := FindPid(d, pid) if err != nil { - err = errors.Wrap(_err, "could not find PID") - return + return idx, m, errCouldNotFind } if pid == PmtPid { m, _ = metaFromPMT(pkt) } - if prev == 0 { + if idx == -1 { idx = i - prev = i - continue + } else if i != 0 { + return idx, m, errNotConsecutive } - if i != prev+PacketSize { - err = errors.New("PIDs not consecutive") - return - } - prev = i + d = d[i+PacketSize:] } return } diff --git a/container/mts/mpegts_test.go b/container/mts/mpegts_test.go index 3a603a20..9ce93b5b 100644 --- a/container/mts/mpegts_test.go +++ b/container/mts/mpegts_test.go @@ -491,6 +491,26 @@ func TestSegmentForMeta(t *testing.T) { if !reflect.DeepEqual(want, got) { t.Errorf("did not get expected result for test %v\nGot: %v\nWant: %v\n", testn, got, want) } + + // Now test IndexPid. + i, m, err := IndexPid(clip.Bytes(), PatPid, PmtPid) + if err != nil { + t.Fatalf("IndexPid failed with error: %v", err) + } + if i != 0 { + t.Fatalf("IndexPid unexpected index; got %d, expected 0", i) + } + if m["n"] != "1" { + t.Fatalf("IndexPid unexpected metadata; got %s, expected 1", m["n"]) + } + } + + // Finally, test IndexPid error handling. + for _, d := range [][]byte{[]byte{}, make([]byte, PacketSize/2), make([]byte, PacketSize)} { + _, _, err := IndexPid(d, PatPid, PmtPid) + if err == nil { + t.Fatalf("IndexPid expected error") + } } }