av/container/mts/mpegts_test.go

764 lines
18 KiB
Go
Raw Permalink Normal View History

/*
NAME
mpegts_test.go
DESCRIPTION
mpegts_test.go contains testing for functionality found in mpegts.go.
AUTHORS
Saxon A. Nelson-Milton <saxon@ausocean.org>
LICENSE
Copyright (C) 2019 the Australian Ocean Lab (AusOcean)
It is free software: you can redistribute it and/or modify them
under the terms of the GNU General Public License as published by the
Free Software Foundation, either version 3 of the License, or (at your
option) any later version.
It is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
for more details.
You should have received a copy of the GNU General Public License
in gpl.txt. If not, see http://www.gnu.org/licenses.
*/
package mts
import (
"bytes"
"math/rand"
"reflect"
"strconv"
"testing"
"time"
"github.com/Comcast/gots/packet"
gotspsi "github.com/Comcast/gots/psi"
"github.com/pkg/errors"
"bitbucket.org/ausocean/av/container/mts/meta"
"bitbucket.org/ausocean/av/container/mts/pes"
"bitbucket.org/ausocean/av/container/mts/psi"
)
// TestGetPTSRange checks that GetPTSRange can correctly get the first and last
// PTS in an MPEGTS clip for a general case.
func TestGetPTSRange1(t *testing.T) {
const (
numOfFrames = 20
maxFrameSize = 1000
minFrameSize = 100
rate = 25 // fps
interval = float64(1) / rate // s
ptsFreq = 90000 // Hz
)
// Generate randomly sized data for each frame.
rand.Seed(time.Now().UnixNano())
frames := make([][]byte, numOfFrames)
for i := range frames {
size := rand.Intn(maxFrameSize-minFrameSize) + minFrameSize
frames[i] = make([]byte, size)
}
var clip bytes.Buffer
// Write the PSI first.
err := writePSI(&clip)
if err != nil {
t.Fatalf("did not expect error writing psi: %v", err)
}
// Now write frames.
var curTime float64
for _, frame := range frames {
nextPTS := curTime * ptsFreq
err = writeFrame(&clip, frame, uint64(nextPTS))
if err != nil {
t.Fatalf("did not expect error writing frame: %v", err)
}
curTime += interval
}
got, err := GetPTSRange(clip.Bytes(), PIDVideo)
if err != nil {
t.Fatalf("did not expect error getting PTS range: %v", err)
}
want := [2]uint64{0, uint64((numOfFrames - 1) * interval * ptsFreq)}
if got != want {
t.Errorf("did not get expected result.\n Got: %v\n Want: %v\n", got, want)
}
}
// writePSI is a helper function write the PSI found at the start of a clip.
func writePSI(b *bytes.Buffer) error {
patBytes := psi.NewPATPSI().Bytes()
pmtBytes := psi.NewPMTPSI().Bytes()
// Write PAT.
pat := Packet{
PUSI: true,
PID: PatPid,
CC: 0,
AFC: HasPayload,
Payload: psi.AddPadding(patBytes),
}
_, err := b.Write(pat.Bytes(nil))
if err != nil {
return err
}
// Write PMT.
pmt := Packet{
PUSI: true,
PID: PmtPid,
CC: 0,
AFC: HasPayload,
Payload: psi.AddPadding(pmtBytes),
}
_, err = b.Write(pmt.Bytes(nil))
if err != nil {
return err
}
return nil
}
// writeFrame is a helper function used to form a PES packet from a frame, and
// then fragment this across MPEGTS packets where they are then written to the
// given buffer.
func writeFrame(b *bytes.Buffer, frame []byte, pts uint64) error {
// Prepare PES data.
pesPkt := pes.Packet{
StreamID: pes.H264SID,
PDI: hasPTS,
PTS: pts,
Data: frame,
HeaderLength: 5,
}
buf := pesPkt.Bytes(nil)
// Write PES data acroos MPEGTS packets.
pusi := true
for len(buf) != 0 {
pkt := Packet{
PUSI: pusi,
PID: PIDVideo,
RAI: pusi,
CC: 0,
AFC: hasAdaptationField | hasPayload,
PCRF: pusi,
}
n := pkt.FillPayload(buf)
buf = buf[n:]
pusi = false
_, err := b.Write(pkt.Bytes(nil))
if err != nil {
return err
}
}
return nil
}
// TestGetPTSRange2 checks that GetPTSRange behaves correctly with cases where
// the first instance of a PID is not a payload start, and also where there
// are no payload starts.
func TestGetPTSRange2(t *testing.T) {
const (
nPackets = 8 // The number of MTS packets we will generate.
wantPID = 1 // The PID we want.
)
tests := []struct {
pusi []bool // The value of PUSI for each packet.
pid []uint16 // The PIDs for each packet.
pts []uint64 // The PTS for each packet.
want [2]uint64 // The wanted PTS from GetPTSRange.
err error // The error we expect from GetPTSRange.
}{
{
[]bool{false, false, false, true, false, false, true, false},
[]uint16{0, 0, 1, 1, 1, 1, 1, 1},
[]uint64{0, 0, 0, 1, 0, 0, 2, 0},
[2]uint64{1, 2},
nil,
},
{
[]bool{false, false, false, true, false, false, false, false},
[]uint16{0, 0, 1, 1, 1, 1, 1, 1},
[]uint64{0, 0, 0, 1, 0, 0, 0, 0},
[2]uint64{1, 1},
nil,
},
{
[]bool{false, false, false, false, false, false, false, false},
[]uint16{0, 0, 1, 1, 1, 1, 1, 1},
[]uint64{0, 0, 0, 0, 0, 0, 0, 0},
[2]uint64{0, 0},
errNoPTS,
},
}
var clip bytes.Buffer
for i, test := range tests {
// Generate MTS packets for this test.
clip.Reset()
for j := 0; j < nPackets; j++ {
pesPkt := pes.Packet{
StreamID: pes.H264SID,
PDI: hasPTS,
PTS: test.pts[j],
Data: []byte{},
HeaderLength: 5,
}
buf := pesPkt.Bytes(nil)
pkt := Packet{
PUSI: test.pusi[j],
PID: test.pid[j],
RAI: true,
CC: 0,
AFC: hasAdaptationField | hasPayload,
PCRF: true,
}
pkt.FillPayload(buf)
_, err := clip.Write(pkt.Bytes(nil))
if err != nil {
t.Fatalf("did not expect clip write error: %v", err)
}
}
pts, err := GetPTSRange(clip.Bytes(), wantPID)
if err != test.err {
t.Errorf("did not get expected error for test: %v\nGot: %v\nWant: %v\n", i, err, test.err)
}
if pts != test.want {
t.Errorf("did not get expected result for test: %v\nGot: %v\nWant: %v\n", i, pts, test.want)
}
}
}
2019-05-07 06:47:33 +03:00
// TestBytes checks that Packet.Bytes() correctly produces a []byte
// representation of a Packet.
func TestBytes(t *testing.T) {
2019-05-07 06:47:33 +03:00
const payloadLen, payloadChar, stuffingChar = 120, 0x11, 0xff
const stuffingLen = PacketSize - payloadLen - 12
tests := []struct {
packet Packet
expectedHeader []byte
}{
{
packet: Packet{
PUSI: true,
PID: 1,
RAI: true,
CC: 4,
AFC: HasPayload | HasAdaptationField,
PCRF: true,
PCR: 1,
},
expectedHeader: []byte{
0x47, // Sync byte.
0x40, // TEI=0, PUSI=1, TP=0, PID=00000.
0x01, // PID(Cont)=00000001.
0x34, // TSC=00, AFC=11(adaptation followed by payload), CC=0100(4).
2019-05-07 06:47:33 +03:00
byte(7 + stuffingLen), // AFL=.
0x50, // DI=0,RAI=1,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0.
0x00, 0x00, 0x00, 0x00, 0x80, 0x00, // PCR.
},
},
}
for testNum, test := range tests {
// Construct payload.
2019-05-07 06:47:33 +03:00
payload := make([]byte, 0, payloadLen)
for i := 0; i < payloadLen; i++ {
payload = append(payload, payloadChar)
}
// Fill the packet payload.
test.packet.FillPayload(payload)
2019-05-07 06:47:33 +03:00
// Create expected packet data and copy in expected header.
expected := make([]byte, len(test.expectedHeader), PacketSize)
copy(expected, test.expectedHeader)
// Append stuffing.
2019-05-07 06:47:33 +03:00
for i := 0; i < stuffingLen; i++ {
expected = append(expected, stuffingChar)
}
2019-05-07 06:47:33 +03:00
// Append payload to expected bytes.
expected = append(expected, payload...)
// Compare got with expected.
got := test.packet.Bytes(nil)
if !bytes.Equal(got, expected) {
t.Errorf("did not get expected result for test: %v.\n Got: %v\n Want: %v\n", testNum, got, expected)
}
}
}
2019-05-07 06:47:33 +03:00
// TestFindPid checks that FindPid can correctly extract the first instance
// of a PID from an MPEG-TS stream.
func TestFindPid(t *testing.T) {
const targetPacketNum, numOfPackets, targetPid, stdPid = 6, 15, 1, 0
2019-05-07 06:47:33 +03:00
// Prepare the stream of packets.
var stream []byte
for i := 0; i < numOfPackets; i++ {
pid := uint16(stdPid)
if i == targetPacketNum {
pid = targetPid
}
2019-05-07 06:47:33 +03:00
p := Packet{
PID: pid,
AFC: hasPayload | hasAdaptationField,
}
p.FillPayload([]byte{byte(i)})
stream = append(stream, p.Bytes(nil)...)
}
2019-05-07 06:47:33 +03:00
// Try to find the targetPid in the stream.
p, i, err := FindPid(stream, targetPid)
if err != nil {
t.Fatalf("unexpected error finding PID: %v\n", err)
}
2019-05-07 06:47:33 +03:00
// Check the payload.
var _p packet.Packet
copy(_p[:], p)
payload, err := packet.Payload(&_p)
if err != nil {
t.Fatalf("unexpected error getting packet payload: %v\n", err)
}
got := payload[0]
if got != targetPacketNum {
t.Errorf("payload of found packet is not correct.\nGot: %v, Want: %v\n", got, targetPacketNum)
}
2019-05-07 06:47:33 +03:00
// Check the index.
_got := i / PacketSize
if _got != targetPacketNum {
t.Errorf("index of found packet is not correct.\nGot: %v, want: %v\n", _got, targetPacketNum)
}
}
// TestTrimToMetaRange checks that TrimToMetaRange can correctly return a segment
// of MPEG-TS corresponding to a meta interval in a slice of MPEG-TS.
func TestTrimToMetaRange(t *testing.T) {
Meta = meta.New()
const (
nPSI = 10
key = "n"
)
var clip bytes.Buffer
for i := 0; i < nPSI; i++ {
Meta.Add(key, strconv.Itoa((i*2)+1))
err := writePSIWithMeta(&clip, t)
if err != nil {
t.Fatalf("did not expect to get error writing PSI, error: %v", err)
}
}
tests := []struct {
from string
to string
expect []byte
err error
}{
{
from: "3",
to: "9",
expect: clip.Bytes()[3*PacketSize : 10*PacketSize],
err: nil,
},
{
from: "30",
to: "8",
expect: nil,
err: errMetaLowerBound,
},
{
from: "3",
to: "30",
expect: nil,
err: errMetaUpperBound,
},
}
// Run tests.
for i, test := range tests {
got, err := TrimToMetaRange(clip.Bytes(), key, test.from, test.to)
// First check the error.
if err != nil && err != test.err {
t.Errorf("unexpected error: %v for test: %v", err, i)
continue
} else if err != test.err {
t.Errorf("expected to get error: %v for test: %v", test.err, i)
continue
}
// Now check data.
if test.err == nil && !bytes.Equal(test.expect, got) {
t.Errorf("did not get expected data for test: %v\n Got: %v\n, Want: %v\n", i, got, test.expect)
}
}
}
// TestSegmentForMeta checks that SegmentForMeta can correctly segment some MTS
// data based on a given meta key and value.
func TestSegmentForMeta(t *testing.T) {
Meta = meta.New()
const (
nPSI = 10 // The number of PSI pairs to write.
key = "n" // The meta key we will work with.
val = "*" // This is the meta value we will look for.
)
tests := []struct {
metaVals [nPSI]string // This represents the meta value for meta pairs (PAT and PMT)
expectIdxs []rng // This gives the expected index ranges for the segments.
}{
{
metaVals: [nPSI]string{"1", "2", val, val, val, "3", val, val, "4", "4"},
expectIdxs: []rng{
scale(2, 5),
scale(6, 8),
},
},
{
metaVals: [nPSI]string{"1", "2", val, val, val, "", "3", val, val, "4"},
expectIdxs: []rng{
scale(2, 5),
scale(7, 9),
},
},
{
metaVals: [nPSI]string{"1", "2", val, val, val, "", "3", val, val, val},
expectIdxs: []rng{
scale(2, 5),
{((7 * 2) + 1) * PacketSize, (nPSI * 2) * PacketSize},
},
},
{
metaVals: [nPSI]string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"},
expectIdxs: nil,
},
}
var clip bytes.Buffer
for testn, test := range tests {
// We want a clean buffer for each new test, so reset.
clip.Reset()
// Add meta and write PSI to clip.
for i := 0; i < nPSI; i++ {
if test.metaVals[i] != "" {
Meta.Add(key, test.metaVals[i])
} else {
Meta.Delete(key)
}
err := writePSIWithMeta(&clip, t)
if err != nil {
t.Fatalf("did not expect to get error writing PSI, error: %v", err)
}
}
// Now we get the expected segments using the index ranges from the test.
var want [][]byte
for _, idxs := range test.expectIdxs {
want = append(want, clip.Bytes()[idxs.start:idxs.end])
}
// Now use the function we're testing to get the segments.
got, err := SegmentForMeta(clip.Bytes(), key, val)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check that segments are OK.
if !reflect.DeepEqual(want, got) {
t.Errorf("did not get expected result for test %v\nGot: %v\nWant: %v\n", testn, got, want)
}
2019-07-11 12:59:46 +03:00
// Now test IndexPid.
i, _, m, err := FindPSI(clip.Bytes())
2019-07-11 12:59:46 +03:00
if err != nil {
t.Fatalf("IndexPid failed with error: %v", err)
}
if i != 0 {
t.Fatalf("IndexPid unexpected index; got %d, expected 0", i)
}
if m["n"] != "1" {
t.Fatalf("IndexPid unexpected metadata; got %s, expected 1", m["n"])
}
}
// Finally, test IndexPid error handling.
for _, d := range [][]byte{[]byte{}, make([]byte, PacketSize/2), make([]byte, PacketSize)} {
_, _, _, err := FindPSI(d)
2019-07-11 12:59:46 +03:00
if err == nil {
t.Fatalf("IndexPid expected error")
}
}
}
// rng describes an index range and is used by TestSegmentForMeta.
type rng struct {
start int
end int
}
// scale takes a PSI index (i.e. first PSI is 0, next is 1) and modifies to be
// the index of the first byte of the PSI pair (PAT and PMT) in the byte stream.
// This assumes there are only PSI written consequitively, and is used by
// TestSegmentForMeta.
func scale(x, y int) rng {
return rng{
((x * 2) + 1) * PacketSize,
((y * 2) + 1) * PacketSize,
}
}
func TestFindPSI(t *testing.T) {
const (
pat = iota
pmt
media
)
const (
metaKey = "key"
mediaType = gotspsi.PmtStreamTypeMpeg4Video
pmtPID = 3
streamPID = 4
)
type want struct {
idx int
streamType uint8
streamPID uint16
meta map[string]string
err error
}
tests := []struct {
pkts []int
meta string
want want
}{
{
pkts: []int{pat, pmt, media, media},
meta: "1",
want: want{
idx: 0,
streamType: gotspsi.PmtStreamTypeMpeg4Video,
streamPID: 4,
meta: map[string]string{
"key": "1",
},
err: nil,
},
},
{
pkts: []int{media, pat, pmt, media, media},
meta: "1",
want: want{
idx: 188,
streamType: gotspsi.PmtStreamTypeMpeg4Video,
streamPID: 4,
meta: map[string]string{
"key": "1",
},
err: nil,
},
},
{
pkts: []int{pat, media, pmt, media, media},
meta: "1",
want: want{
idx: 0,
streamType: gotspsi.PmtStreamTypeMpeg4Video,
streamPID: 4,
meta: map[string]string{
"key": "1",
},
err: ErrNotConsecutive,
},
},
}
var clip bytes.Buffer
var err error
Meta = meta.New()
for i, test := range tests {
// Generate MTS packets for this test.
clip.Reset()
for _, pkt := range test.pkts {
switch pkt {
case pat:
patTable := (&psi.PSI{
PointerField: 0x00,
TableID: 0x00,
SyntaxIndicator: true,
PrivateBit: false,
SectionLen: 0x0d,
SyntaxSection: &psi.SyntaxSection{
TableIDExt: 0x01,
Version: 0,
CurrentNext: true,
Section: 0,
LastSection: 0,
SpecificData: &psi.PAT{
Program: 0x01,
ProgramMapPID: pmtPID,
},
},
}).Bytes()
pat := Packet{
PUSI: true,
PID: PatPid,
CC: 0,
AFC: HasPayload,
Payload: psi.AddPadding(patTable),
}
_, err := clip.Write(pat.Bytes(nil))
if err != nil {
t.Fatalf("could not write PAT to clip for test %d", i)
}
case pmt:
pmtTable := (&psi.PSI{
PointerField: 0x00,
TableID: 0x02,
SyntaxIndicator: true,
SectionLen: 0x12,
SyntaxSection: &psi.SyntaxSection{
TableIDExt: 0x01,
Version: 0,
CurrentNext: true,
Section: 0,
LastSection: 0,
SpecificData: &psi.PMT{
ProgramClockPID: 0x0100,
ProgramInfoLen: 0,
StreamSpecificData: &psi.StreamSpecificData{
StreamType: mediaType,
PID: streamPID,
StreamInfoLen: 0x00,
},
},
},
}).Bytes()
Meta.Add(metaKey, test.meta)
pmtTable, err = updateMeta(pmtTable, (*testLogger)(t))
if err != nil {
t.Fatalf("could not update meta for test %d", i)
}
pmt := Packet{
PUSI: true,
PID: pmtPID,
CC: 0,
AFC: HasPayload,
Payload: psi.AddPadding(pmtTable),
}
_, err = clip.Write(pmt.Bytes(nil))
if err != nil {
t.Fatalf("could not write PMT to clip for test %d", i)
}
case media:
pesPkt := pes.Packet{
StreamID: mediaType,
PDI: hasPTS,
Data: []byte{},
HeaderLength: 5,
}
buf := pesPkt.Bytes(nil)
pkt := Packet{
PUSI: true,
PID: uint16(streamPID),
RAI: true,
CC: 0,
AFC: hasAdaptationField | hasPayload,
PCRF: true,
}
pkt.FillPayload(buf)
_, err := clip.Write(pkt.Bytes(nil))
if err != nil {
t.Fatalf("did not expect clip write error: %v", err)
}
default:
t.Fatalf("undefined pkt type %d in test %d", pkt, i)
}
}
gotIdx, gotStreams, gotMeta, gotErr := FindPSI(clip.Bytes())
// Check error.
if errors.Cause(gotErr) != test.want.err {
t.Errorf("did not get expected error for test %d\nGot: %v\nWant: %v\n", i, gotErr, test.want.err)
}
if gotErr == nil {
// Check idx
if gotIdx != test.want.idx {
t.Errorf("did not get expected idx for test %d\nGot: %v\nWant: %v\n", i, gotIdx, test.want.idx)
}
// Check stream type and PID
if gotStreams == nil {
t.Fatalf("gotStreams should not be nil")
}
if len(gotStreams) == 0 {
t.Fatalf("gotStreams should not be 0 length")
}
var (
gotStreamPID uint16
gotStreamType uint8
)
for k, v := range gotStreams {
gotStreamPID = k
gotStreamType = v
}
if gotStreamType != test.want.streamType {
t.Errorf("did not get expected stream type for test %d\nGot: %v\nWant: %v\n", i, gotStreamType, test.want.streamType)
}
if gotStreamPID != test.want.streamPID {
t.Errorf("did not get expected stream PID for test %d\nGot: %v\nWant: %v\n", i, gotStreamPID, test.want.streamPID)
}
// Check meta
if !reflect.DeepEqual(gotMeta, test.want.meta) {
t.Errorf("did not get expected meta for test %d\nGot: %v\nWant: %v\n", i, gotMeta, test.want.meta)
}
}
}
}