av/container/mts/mpegts_test.go

267 lines
6.6 KiB
Go
Raw Normal View History

/*
NAME
mpegts_test.go
DESCRIPTION
mpegts_test.go contains testing for functionality found in mpegts.go.
AUTHORS
Saxon A. Nelson-Milton <saxon@ausocean.org>
LICENSE
Copyright (C) 2019 the Australian Ocean Lab (AusOcean)
It is free software: you can redistribute it and/or modify them
under the terms of the GNU General Public License as published by the
Free Software Foundation, either version 3 of the License, or (at your
option) any later version.
It is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
for more details.
You should have received a copy of the GNU General Public License
in gpl.txt. If not, see http://www.gnu.org/licenses.
*/
package mts
import (
"bytes"
"math/rand"
"testing"
"time"
"bitbucket.org/ausocean/av/container/mts/pes"
"bitbucket.org/ausocean/av/container/mts/psi"
"github.com/Comcast/gots/packet"
)
// TestGetPTSRange checks that GetPTSRange can correctly get the first and last
// PTS in an MPEGTS clip.
func TestGetPTSRange(t *testing.T) {
const (
numOfFrames = 20
maxFrameSize = 1000
minFrameSize = 100
rate = 25 // fps
interval = float64(1) / rate // s
ptsFreq = 90000 // Hz
)
// Generate randomly sized data for each frame.
rand.Seed(time.Now().UnixNano())
frames := make([][]byte, numOfFrames)
for i := range frames {
size := rand.Intn(maxFrameSize-minFrameSize) + minFrameSize
frames[i] = make([]byte, size)
}
var clip bytes.Buffer
// Write the PSI first.
err := writePSI(&clip)
if err != nil {
t.Fatalf("did not expect error writing psi: %v", err)
}
// Now write frames.
var curTime float64
for _, frame := range frames {
nextPTS := curTime * ptsFreq
err = writeFrame(&clip, frame, uint64(nextPTS))
if err != nil {
t.Fatalf("did not expect error writing frame: %v", err)
}
curTime += interval
}
2019-06-05 20:24:00 +03:00
got, err := GetPTSRange(clip.Bytes(), VideoPid)
if err != nil {
t.Fatalf("did not expect error getting PTS range: %v", err)
}
want := [2]uint64{0, uint64((numOfFrames - 1) * interval * ptsFreq)}
if got != want {
t.Errorf("did not get expected result.\n Got: %v\n Want: %v\n", got, want)
}
}
// writePSI is a helper function write the PSI found at the start of a clip.
func writePSI(b *bytes.Buffer) error {
// Write PAT.
pat := Packet{
PUSI: true,
PID: PatPid,
CC: 0,
AFC: HasPayload,
Payload: psi.AddPadding(patTable),
}
_, err := b.Write(pat.Bytes(nil))
if err != nil {
return err
}
// Write PMT.
pmt := Packet{
PUSI: true,
PID: PmtPid,
CC: 0,
AFC: HasPayload,
Payload: psi.AddPadding(pmtTable),
}
_, err = b.Write(pmt.Bytes(nil))
if err != nil {
return err
}
return nil
}
// writeFrame is a helper function used to form a PES packet from a frame, and
// then fragment this across MPEGTS packets where they are then written to the
// given buffer.
func writeFrame(b *bytes.Buffer, frame []byte, pts uint64) error {
// Prepare PES data.
pesPkt := pes.Packet{
StreamID: H264ID,
PDI: hasPTS,
PTS: pts,
Data: frame,
HeaderLength: 5,
}
buf := pesPkt.Bytes(nil)
// Write PES data acroos MPEGTS packets.
pusi := true
for len(buf) != 0 {
pkt := Packet{
PUSI: pusi,
2019-06-05 20:24:00 +03:00
PID: VideoPid,
RAI: pusi,
CC: 0,
AFC: hasAdaptationField | hasPayload,
PCRF: pusi,
}
n := pkt.FillPayload(buf)
buf = buf[n:]
pusi = false
_, err := b.Write(pkt.Bytes(nil))
if err != nil {
return err
}
}
return nil
}
2019-05-07 06:47:33 +03:00
// TestBytes checks that Packet.Bytes() correctly produces a []byte
// representation of a Packet.
func TestBytes(t *testing.T) {
2019-05-07 06:47:33 +03:00
const payloadLen, payloadChar, stuffingChar = 120, 0x11, 0xff
const stuffingLen = PacketSize - payloadLen - 12
tests := []struct {
packet Packet
expectedHeader []byte
}{
{
packet: Packet{
PUSI: true,
PID: 1,
RAI: true,
CC: 4,
AFC: HasPayload | HasAdaptationField,
PCRF: true,
PCR: 1,
},
expectedHeader: []byte{
0x47, // Sync byte.
0x40, // TEI=0, PUSI=1, TP=0, PID=00000.
0x01, // PID(Cont)=00000001.
0x34, // TSC=00, AFC=11(adaptation followed by payload), CC=0100(4).
2019-05-07 06:47:33 +03:00
byte(7 + stuffingLen), // AFL=.
0x50, // DI=0,RAI=1,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0.
0x00, 0x00, 0x00, 0x00, 0x80, 0x00, // PCR.
},
},
}
for testNum, test := range tests {
// Construct payload.
2019-05-07 06:47:33 +03:00
payload := make([]byte, 0, payloadLen)
for i := 0; i < payloadLen; i++ {
payload = append(payload, payloadChar)
}
// Fill the packet payload.
test.packet.FillPayload(payload)
2019-05-07 06:47:33 +03:00
// Create expected packet data and copy in expected header.
expected := make([]byte, len(test.expectedHeader), PacketSize)
copy(expected, test.expectedHeader)
// Append stuffing.
2019-05-07 06:47:33 +03:00
for i := 0; i < stuffingLen; i++ {
expected = append(expected, stuffingChar)
}
2019-05-07 06:47:33 +03:00
// Append payload to expected bytes.
expected = append(expected, payload...)
// Compare got with expected.
got := test.packet.Bytes(nil)
if !bytes.Equal(got, expected) {
t.Errorf("did not get expected result for test: %v.\n Got: %v\n Want: %v\n", testNum, got, expected)
}
}
}
2019-05-07 06:47:33 +03:00
// TestFindPid checks that FindPid can correctly extract the first instance
// of a PID from an MPEG-TS stream.
func TestFindPid(t *testing.T) {
const targetPacketNum, numOfPackets, targetPid, stdPid = 6, 15, 1, 0
2019-05-07 06:47:33 +03:00
// Prepare the stream of packets.
var stream []byte
for i := 0; i < numOfPackets; i++ {
pid := uint16(stdPid)
if i == targetPacketNum {
pid = targetPid
}
2019-05-07 06:47:33 +03:00
p := Packet{
PID: pid,
AFC: hasPayload | hasAdaptationField,
}
p.FillPayload([]byte{byte(i)})
stream = append(stream, p.Bytes(nil)...)
}
2019-05-07 06:47:33 +03:00
// Try to find the targetPid in the stream.
p, i, err := FindPid(stream, targetPid)
if err != nil {
t.Fatalf("unexpected error finding PID: %v\n", err)
}
2019-05-07 06:47:33 +03:00
// Check the payload.
var _p packet.Packet
copy(_p[:], p)
payload, err := packet.Payload(&_p)
if err != nil {
t.Fatalf("unexpected error getting packet payload: %v\n", err)
}
got := payload[0]
if got != targetPacketNum {
t.Errorf("payload of found packet is not correct.\nGot: %v, Want: %v\n", got, targetPacketNum)
}
2019-05-07 06:47:33 +03:00
// Check the index.
_got := i / PacketSize
if _got != targetPacketNum {
t.Errorf("index of found packet is not correct.\nGot: %v, want: %v\n", _got, targetPacketNum)
}
}