diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index f091b10e..2dc608ed 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -33,6 +33,7 @@ import ( "fmt" "github.com/Comcast/gots/packet" + gotspsi "github.com/Comcast/gots/psi" "github.com/pkg/errors" "bitbucket.org/ausocean/av/container/mts/meta" @@ -178,16 +179,15 @@ func FindPat(d []byte) ([]byte, int, error) { // Errors used by FindPid. var ( - errInvalidLen = errors.New("MPEG-TS data not of valid length") - errCouldNotFind = errors.New("could not find packet with given PID") - errNotConsecutive = errors.New("could not find consecutive PIDs") + ErrInvalidLen = errors.New("MPEG-TS data not of valid length") + errCouldNotFind = errors.New("could not find packet with given PID") ) // FindPid will take a clip of MPEG-TS and try to find a packet with given PID - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) { if len(d) < PacketSize { - return nil, -1, errInvalidLen + return nil, -1, ErrInvalidLen } for i = 0; i < len(d); i += PacketSize { p := (uint16(d[i+1]&0x1f) << 8) | uint16(d[i+2]) @@ -205,7 +205,7 @@ func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) { // nil, -1 and an error is returned. func LastPid(d []byte, pid uint16) (pkt []byte, i int, err error) { if len(d) < PacketSize { - return nil, -1, errInvalidLen + return nil, -1, ErrInvalidLen } for i = len(d) - PacketSize; i >= 0; i -= PacketSize { @@ -218,30 +218,72 @@ func LastPid(d []byte, pid uint16) (pkt []byte, i int, err error) { return nil, -1, errCouldNotFind } -// IndexPid returns the position of one or more consecutive pids, -// along with optional metadata if present. Commonly used to find a -// PAT immediately followed by a PMT. -func IndexPid(d []byte, pids ...uint16) (idx int, m map[string]string, err error) { - idx = -1 - for _, pid := range pids { - if len(d) < PacketSize { - return idx, m, errInvalidLen - } - pkt, i, err := FindPid(d, pid) - if err != nil { - return idx, m, errCouldNotFind - } - if pid == PmtPid { - m, _ = metaFromPMT(pkt) - } - if idx == -1 { - idx = i - } else if i != 0 { - return idx, m, errNotConsecutive - } - d = d[i+PacketSize:] +// Errors used by FindPSI. +var ( + ErrMultiplePrograms = errors.New("more than one program not supported") + ErrNoPrograms = errors.New("no programs in PAT") + ErrNotConsecutive = errors.New("could not find consecutive PIDs") +) + +// FindPSI finds the index of a PAT in an a slice of MPEG-TS and returns, along +// with a map of meta from the PMT and the stream PIDs and their types. +func FindPSI(d []byte) (int, map[uint16]uint8, map[string]string, error) { + if len(d) < PacketSize { + return -1, nil, nil, ErrInvalidLen } - return + + // Find the PAT if it exists. + pkt, i, err := FindPid(d, PatPid) + if err != nil { + return -1, nil, nil, errors.Wrap(err, "error finding PAT") + } + + // Let's take this opportunity to check what programs are in this MPEG-TS + // stream, and therefore the PID of the PMT, from which we can get metadata. + // NB: currently we only support one program. + progs, err := Programs(pkt) + if err != nil { + return i, nil, nil, errors.Wrap(err, "cannot get programs from PAT") + } + + if len(progs) == 0 { + return i, nil, nil, ErrNoPrograms + } + + if len(progs) > 1 { + return i, nil, nil, ErrMultiplePrograms + } + + pmtPID := pmtPIDs(progs)[0] + + // Now we can look for the PMT. We want to adjust d so that we're not looking + // at the same data twice. + d = d[i+PacketSize:] + pkt, pmtIdx, err := FindPid(d, pmtPID) + if err != nil { + return i, nil, nil, errors.Wrap(err, "error finding PMT") + } + + // Check that the PMT comes straight after the PAT. + if pmtIdx != 0 { + return i, nil, nil, ErrNotConsecutive + } + + // Now we can try to get meta from the PMT. + meta, _ := metaFromPMT(pkt) + + // Now to get the elementary streams defined for this program. + streams, err := Streams(pkt) + if err != nil { + return i, nil, meta, errors.Wrap(err, "could not get streams from PMT") + } + + streamMap := make(map[uint16]uint8) + for _, s := range streams { + streamMap[s.ElementaryPid()] = s.StreamType() + } + + return i, streamMap, meta, nil } // FillPayload takes a channel and fills the packets Payload field until the @@ -421,9 +463,9 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) { } var ( - errNoPesPayload = errors.New("no PES payload") - errNoPesPTS = errors.New("no PES PTS") - errInvalidPesHeader = errors.New("invalid PES header") + errNoPesPayload = errors.New("no PES payload") + errNoPesPTS = errors.New("no PES PTS") + errInvalidPesHeader = errors.New("invalid PES header") errInvalidPesPayload = errors.New("invalid PES payload") ) @@ -593,3 +635,98 @@ func SegmentForMeta(d []byte, key, val string) ([][]byte, error) { return res, nil } + +// pid returns the packet identifier for the given packet. +func pid(p []byte) uint16 { + return uint16(p[1]&0x1f)<<8 | uint16(p[2]) +} + +// Programs returns a map of program numbers and corresponding PMT PIDs for a +// given MPEG-TS PAT packet. +func Programs(p []byte) (map[uint16]uint16, error) { + pat, err := gotspsi.NewPAT(p) + if err != nil { + return nil, err + } + return pat.ProgramMap(), nil +} + +// Streams returns elementary streams defined in a given MPEG-TS PMT packet. +func Streams(p []byte) ([]gotspsi.PmtElementaryStream, error) { + payload, err := Payload(p) + if err != nil { + return nil, errors.Wrap(err, "cannot get packet payload") + } + pmt, err := gotspsi.NewPMT(payload) + if err != nil { + return nil, err + } + return pmt.ElementaryStreams(), nil +} + +// MediaStreams retrieves the PmtElementaryStreams from the given PSI. This +// function currently assumes that PSI contain a PAT followed by a PMT directly +// after. We also assume that this MPEG-TS stream contains just one program, +// but this program may contain different streams, i.e. a video stream + audio +// stream. +func MediaStreams(p []byte) ([]gotspsi.PmtElementaryStream, error) { + pat := p[:PacketSize] + pmt := p[PacketSize : 2*PacketSize] + + if pid(pat) != PatPid { + return nil, errors.New("first packet is not a PAT") + } + + m, err := Programs(pat) + if err != nil { + return nil, errors.Wrap(err, "could not get programs from PAT") + } + + if len(m) == 0 { + return nil, ErrNoPrograms + } + + if len(m) > 1 { + return nil, ErrMultiplePrograms + } + + if pid(pmt) != pmtPIDs(m)[0] { + return nil, errors.New("second packet is not desired PMT") + } + + s, err := Streams(pmt) + if err != nil { + return nil, errors.Wrap(err, "could not get streams from PMT") + } + return s, nil +} + +// pmtPIDs returns PMT PIDS from a map containing program number as keys and +// corresponding PMT PIDs as values. +func pmtPIDs(m map[uint16]uint16) []uint16 { + r := make([]uint16, 0, len(m)) + for _, v := range m { + r = append(r, v) + } + return r +} + +// Errors used by Payload. +var ErrNoPayload = errors.New("no payload") + +// Payload returns the payload of an MPEG-TS packet p. +// NB: this is not a copy of the payload in the interest of performance. +// TODO: offer function that will do copy if we have interests in safety. +func Payload(p []byte) ([]byte, error) { + c := byte((p[3] & 0x30) >> 4) + if c == 2 { + return nil, ErrNoPayload + } + + // Check if there is an adaptation field. + off := 4 + if p[3]&0x20 == 1 { + off = int(5 + p[4]) + } + return p[off:], nil +} diff --git a/container/mts/mpegts_test.go b/container/mts/mpegts_test.go index 1cd1f643..99741c92 100644 --- a/container/mts/mpegts_test.go +++ b/container/mts/mpegts_test.go @@ -35,10 +35,13 @@ import ( "testing" "time" + "github.com/Comcast/gots/packet" + gotspsi "github.com/Comcast/gots/psi" + "github.com/pkg/errors" + "bitbucket.org/ausocean/av/container/mts/meta" "bitbucket.org/ausocean/av/container/mts/pes" "bitbucket.org/ausocean/av/container/mts/psi" - "github.com/Comcast/gots/packet" ) // TestGetPTSRange checks that GetPTSRange can correctly get the first and last @@ -493,7 +496,7 @@ func TestSegmentForMeta(t *testing.T) { } // Now test IndexPid. - i, m, err := IndexPid(clip.Bytes(), PatPid, PmtPid) + i, _, m, err := FindPSI(clip.Bytes()) if err != nil { t.Fatalf("IndexPid failed with error: %v", err) } @@ -507,7 +510,7 @@ func TestSegmentForMeta(t *testing.T) { // Finally, test IndexPid error handling. for _, d := range [][]byte{[]byte{}, make([]byte, PacketSize/2), make([]byte, PacketSize)} { - _, _, err := IndexPid(d, PatPid, PmtPid) + _, _, _, err := FindPSI(d) if err == nil { t.Fatalf("IndexPid expected error") } @@ -530,3 +533,229 @@ func scale(x, y int) rng { ((y * 2) + 1) * PacketSize, } } + +func TestFindPSI(t *testing.T) { + const ( + pat = iota + pmt + media + ) + + const ( + metaKey = "key" + mediaType = gotspsi.PmtStreamTypeMpeg4Video + pmtPID = 3 + streamPID = 4 + ) + + type want struct { + idx int + streamType uint8 + streamPID uint16 + meta map[string]string + err error + } + + tests := []struct { + pkts []int + meta string + want want + }{ + { + pkts: []int{pat, pmt, media, media}, + meta: "1", + want: want{ + idx: 0, + streamType: gotspsi.PmtStreamTypeMpeg4Video, + streamPID: 4, + meta: map[string]string{ + "key": "1", + }, + err: nil, + }, + }, + { + pkts: []int{media, pat, pmt, media, media}, + meta: "1", + want: want{ + idx: 188, + streamType: gotspsi.PmtStreamTypeMpeg4Video, + streamPID: 4, + meta: map[string]string{ + "key": "1", + }, + err: nil, + }, + }, + { + pkts: []int{pat, media, pmt, media, media}, + meta: "1", + want: want{ + idx: 0, + streamType: gotspsi.PmtStreamTypeMpeg4Video, + streamPID: 4, + meta: map[string]string{ + "key": "1", + }, + err: ErrNotConsecutive, + }, + }, + } + + var clip bytes.Buffer + var err error + Meta = meta.New() + + for i, test := range tests { + // Generate MTS packets for this test. + clip.Reset() + + for _, pkt := range test.pkts { + switch pkt { + case pat: + patTable := (&psi.PSI{ + Pf: 0x00, + Tid: 0x00, + Ssi: true, + Pb: false, + Sl: 0x0d, + Tss: &psi.TSS{ + Tide: 0x01, + V: 0, + Cni: true, + Sn: 0, + Lsn: 0, + Sd: &psi.PAT{ + Pn: 0x01, + Pmpid: pmtPID, + }, + }, + }).Bytes() + + pat := Packet{ + PUSI: true, + PID: PatPid, + CC: 0, + AFC: HasPayload, + Payload: psi.AddPadding(patTable), + } + _, err := clip.Write(pat.Bytes(nil)) + if err != nil { + t.Fatalf("could not write PAT to clip for test %d", i) + } + case pmt: + pmtTable := (&psi.PSI{ + Pf: 0x00, + Tid: 0x02, + Ssi: true, + Sl: 0x12, + Tss: &psi.TSS{ + Tide: 0x01, + V: 0, + Cni: true, + Sn: 0, + Lsn: 0, + Sd: &psi.PMT{ + Pcrpid: 0x0100, + Pil: 0, + Essd: &psi.ESSD{ + St: mediaType, + Epid: streamPID, + Esil: 0x00, + }, + }, + }, + }).Bytes() + + Meta.Add(metaKey, test.meta) + pmtTable, err = updateMeta(pmtTable) + if err != nil { + t.Fatalf("could not update meta for test %d", i) + } + + pmt := Packet{ + PUSI: true, + PID: pmtPID, + CC: 0, + AFC: HasPayload, + Payload: psi.AddPadding(pmtTable), + } + _, err = clip.Write(pmt.Bytes(nil)) + if err != nil { + t.Fatalf("could not write PMT to clip for test %d", i) + } + case media: + pesPkt := pes.Packet{ + StreamID: mediaType, + PDI: hasPTS, + Data: []byte{}, + HeaderLength: 5, + } + buf := pesPkt.Bytes(nil) + + pkt := Packet{ + PUSI: true, + PID: uint16(streamPID), + RAI: true, + CC: 0, + AFC: hasAdaptationField | hasPayload, + PCRF: true, + } + pkt.FillPayload(buf) + + _, err := clip.Write(pkt.Bytes(nil)) + if err != nil { + t.Fatalf("did not expect clip write error: %v", err) + } + default: + t.Fatalf("undefined pkt type %d in test %d", pkt, i) + } + } + + gotIdx, gotStreams, gotMeta, gotErr := FindPSI(clip.Bytes()) + + // Check error + if errors.Cause(gotErr) != test.want.err { + t.Errorf("did not get expected error for test %d\nGot: %v\nWant: %v\n", i, gotErr, test.want.err) + } + + if gotErr == nil { + // Check idx + if gotIdx != test.want.idx { + t.Errorf("did not get expected idx for test %d\nGot: %v\nWant: %v\n", i, gotIdx, test.want.idx) + } + + // Check stream type and PID + if gotStreams == nil { + t.Fatalf("gotStreams should not be nil") + } + + if len(gotStreams) == 0 { + t.Fatalf("gotStreams should not be 0 length") + } + + var ( + gotStreamPID uint16 + gotStreamType uint8 + ) + + for k, v := range gotStreams { + gotStreamPID = k + gotStreamType = v + } + + if gotStreamType != test.want.streamType { + t.Errorf("did not get expected stream type for test %d\nGot: %v\nWant: %v\n", i, gotStreamType, test.want.streamType) + } + + if gotStreamPID != test.want.streamPID { + t.Errorf("did not get expected stream PID for test %d\nGot: %v\nWant: %v\n", i, gotStreamPID, test.want.streamPID) + } + + // Check meta + if !reflect.DeepEqual(gotMeta, test.want.meta) { + t.Errorf("did not get expected meta for test %d\nGot: %v\nWant: %v\n", i, gotMeta, test.want.meta) + } + } + } +}