From c55cd9a64ea5985469dd1aeb59cf531de76bdbcd Mon Sep 17 00:00:00 2001 From: Saxon Date: Fri, 26 Jul 2019 14:16:05 +0930 Subject: [PATCH] container/mts: wrote test for FindPSI and corrected bugs Wrote the FindPSI test which revealed a bug regarding creation of a PMT using comcast gots. This was fixed by writing Payload function and extracting payload of PMT packet before giving to psi.NewPMT. --- container/mts/mpegts.go | 26 ++++- container/mts/mpegts_test.go | 197 ++++++++++++++++++++++++++++++++++- 2 files changed, 221 insertions(+), 2 deletions(-) diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index 009f2ec0..00315d5c 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -648,7 +648,11 @@ func Programs(p []byte) (map[uint16]uint16, error) { // Streams returns elementary streams defined in a given MPEG-TS PMT packet. func Streams(p []byte) ([]gotspsi.PmtElementaryStream, error) { - pmt, err := gotspsi.NewPMT(p) + payload, err := Payload(p) + if err != nil { + return nil, errors.Wrap(err, "cannot get packet payload") + } + pmt, err := gotspsi.NewPMT(payload) if err != nil { return nil, err } @@ -701,3 +705,23 @@ func pmtPIDs(m map[uint16]uint16) []uint16 { } return r } + +// Errors used by Payload. +var ErrNoPayload = errors.New("no payload") + +// Payload returns the payload of an MPEG-TS packet p. +// NB: this is not a copy of the payload in the interest of performance. +// TODO: offer function that will do copy if we have interests in safety. +func Payload(p []byte) ([]byte, error) { + c := byte((p[3] & 0x30) >> 4) + if c == 2 { + return nil, ErrNoPayload + } + + // Check if there is an adaptation field. + off := 4 + if p[3]&0x20 == 1 { + off = int(5 + p[4]) + } + return p[off:], nil +} diff --git a/container/mts/mpegts_test.go b/container/mts/mpegts_test.go index 3de27cf1..e801d6f1 100644 --- a/container/mts/mpegts_test.go +++ b/container/mts/mpegts_test.go @@ -35,10 +35,13 @@ import ( "testing" "time" + "github.com/Comcast/gots/packet" + gotspsi "github.com/Comcast/gots/psi" + "github.com/pkg/errors" + "bitbucket.org/ausocean/av/container/mts/meta" "bitbucket.org/ausocean/av/container/mts/pes" "bitbucket.org/ausocean/av/container/mts/psi" - "github.com/Comcast/gots/packet" ) // TestGetPTSRange checks that GetPTSRange can correctly get the first and last @@ -530,3 +533,195 @@ func scale(x, y int) rng { ((y * 2) + 1) * PacketSize, } } + +func TestFindPSI(t *testing.T) { + const ( + pat = iota + pmt + media + ) + + const ( + metaKey = "key" + mediaType = gotspsi.PmtStreamTypeMpeg4Video + pmtPID = 3 + streamPID = 4 + ) + + type want struct { + idx int + streamType uint8 + streamPID uint16 + meta map[string]string + err error + } + + tests := []struct { + pkts []int + meta string + want want + }{ + { + pkts: []int{pat, pmt, media, media}, + meta: "1", + want: want{ + idx: 0, + streamType: gotspsi.PmtStreamTypeMpeg4Video, + streamPID: 4, + meta: map[string]string{ + "key": "1", + }, + err: nil, + }, + }, + } + + var clip bytes.Buffer + var err error + Meta = meta.New() + + for i, test := range tests { + // Generate MTS packets for this test. + clip.Reset() + + for _, pkt := range test.pkts { + switch pkt { + case pat: + patTable := (&psi.PSI{ + Pf: 0x00, + Tid: 0x00, + Ssi: true, + Pb: false, + Sl: 0x0d, + Tss: &psi.TSS{ + Tide: 0x01, + V: 0, + Cni: true, + Sn: 0, + Lsn: 0, + Sd: &psi.PAT{ + Pn: 0x01, + Pmpid: pmtPID, + }, + }, + }).Bytes() + + pat := Packet{ + PUSI: true, + PID: PatPid, + CC: 0, + AFC: HasPayload, + Payload: psi.AddPadding(patTable), + } + _, err := clip.Write(pat.Bytes(nil)) + if err != nil { + t.Fatalf("could not write PAT to clip for test %d", i) + } + case pmt: + pmtTable := (&psi.PSI{ + Pf: 0x00, + Tid: 0x02, + Ssi: true, + Sl: 0x12, + Tss: &psi.TSS{ + Tide: 0x01, + V: 0, + Cni: true, + Sn: 0, + Lsn: 0, + Sd: &psi.PMT{ + Pcrpid: 0x0100, + Pil: 0, + Essd: &psi.ESSD{ + St: mediaType, + Epid: streamPID, + Esil: 0x00, + }, + }, + }, + }).Bytes() + + Meta.Add(metaKey, test.meta) + pmtTable, err = updateMeta(pmtTable) + if err != nil { + t.Fatalf("could not update meta for test %d", i) + } + + pmt := Packet{ + PUSI: true, + PID: pmtPID, + CC: 0, + AFC: HasPayload, + Payload: psi.AddPadding(pmtTable), + } + _, err = clip.Write(pmt.Bytes(nil)) + if err != nil { + t.Fatalf("could not write PMT to clip for test %d", i) + } + case media: + pesPkt := pes.Packet{ + StreamID: mediaType, + PDI: hasPTS, + Data: []byte{}, + HeaderLength: 5, + } + buf := pesPkt.Bytes(nil) + + pkt := Packet{ + PUSI: true, + PID: uint16(streamPID), + RAI: true, + CC: 0, + AFC: hasAdaptationField | hasPayload, + PCRF: true, + } + pkt.FillPayload(buf) + + _, err := clip.Write(pkt.Bytes(nil)) + if err != nil { + t.Fatalf("did not expect clip write error: %v", err) + } + default: + t.Fatalf("undefined pkt type %d in test %d", pkt, i) + } + } + + gotIdx, gotStreams, gotMeta, gotErr := FindPSI(clip.Bytes()) + + // Check error + if errors.Cause(gotErr) != test.want.err { + t.Errorf("did not get expected error for test %d\nGot: %v\nWant: %v\n", i, gotErr, test.want.err) + } + + // Check idx + if gotIdx != test.want.idx { + t.Errorf("did not get expected idx for test %d\nGot: %v\nWant: %v\n", i, gotIdx, test.want.idx) + } + + // Check stream type and PID + if gotStreams == nil { + t.Fatalf("gotStreams should not be nil") + } + + if len(gotStreams) == 0 { + t.Fatalf("gotStreams should not be 0 length") + } + + s := gotStreams[0] + gotStreamType := s.StreamType() + gotStreamPID := s.ElementaryPid() + + if gotStreamType != test.want.streamType { + t.Errorf("did not get expected stream type for test %d\nGot: %v\nWant: %v\n", i, gotStreamType, test.want.streamType) + } + + if gotStreamPID != test.want.streamPID { + t.Errorf("did not get expected stream PID for test %d\nGot: %v\nWant: %v\n", i, gotStreamPID, test.want.streamPID) + } + + // Check meta + if !reflect.DeepEqual(gotMeta, test.want.meta) { + t.Errorf("did not get expected meta for test %d\nGot: %v\nWant: %v\n", i, gotMeta, test.want.meta) + } + } +}