diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index 00315d5c..064c0f9b 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -225,9 +225,8 @@ var ( ) // FindPSI finds the index of a PAT in an a slice of MPEG-TS and returns, along -// with a map of meta from the PMT and the PmtElementaryStreams, which contain -// the media PIDs and their types. -func FindPSI(d []byte) (int, []gotspsi.PmtElementaryStream, map[string]string, error) { +// with a map of meta from the PMT and the stream PIDs and their types. +func FindPSI(d []byte) (int, map[uint16]uint8, map[string]string, error) { if len(d) < PacketSize { return -1, nil, nil, ErrInvalidLen } @@ -278,7 +277,12 @@ func FindPSI(d []byte) (int, []gotspsi.PmtElementaryStream, map[string]string, e return i, nil, meta, errors.Wrap(err, "could not get streams from PMT") } - return i, streams, meta, nil + var streamMap map[uint16]uint8 + for _, s := range streams { + streamMap[s.ElementaryPid()] = s.StreamType() + } + + return i, streamMap, meta, nil } // FillPayload takes a channel and fills the packets Payload field until the diff --git a/container/mts/mpegts_test.go b/container/mts/mpegts_test.go index 791d9de5..66a7f93c 100644 --- a/container/mts/mpegts_test.go +++ b/container/mts/mpegts_test.go @@ -734,9 +734,15 @@ func TestFindPSI(t *testing.T) { t.Fatalf("gotStreams should not be 0 length") } - s := gotStreams[0] - gotStreamType := s.StreamType() - gotStreamPID := s.ElementaryPid() + var ( + gotStreamPID uint16 + gotStreamType uint8 + ) + + for k, v := range gotStreams { + gotStreamPID = k + gotStreamType = v + } if gotStreamType != test.want.streamType { t.Errorf("did not get expected stream type for test %d\nGot: %v\nWant: %v\n", i, gotStreamType, test.want.streamType)