container/mts: wrote more tests for GetPTSRange

This commit is contained in:
Saxon 2019-07-01 19:08:20 +09:30
parent 7bd885bcfb
commit b017e92185
2 changed files with 87 additions and 5 deletions

View File

@ -30,11 +30,11 @@ LICENSE
package mts package mts
import ( import (
"errors"
"fmt" "fmt"
"github.com/Comcast/gots/packet" "github.com/Comcast/gots/packet"
"github.com/Comcast/gots/pes" "github.com/Comcast/gots/pes"
"github.com/pkg/errors"
) )
const PacketSize = 188 const PacketSize = 188
@ -325,11 +325,11 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) {
var i int var i int
for { for {
if i >= len(clip) { if i >= len(clip) {
return [2]uint64{}, errors.New("could not find payload start") return [2]uint64{}, errNoPTS
} }
pkt, _i, err := FindPid(clip[i:], pid) pkt, _i, err := FindPid(clip[i:], pid)
if err != nil { if err != nil {
return [2]uint64{}, err return [2]uint64{}, errors.Wrap(err, fmt.Sprintf("could not find packet of PID: %d", pid))
} }
copy(_pkt[:], pkt) copy(_pkt[:], pkt)
if _pkt.PayloadUnitStartIndicator() { if _pkt.PayloadUnitStartIndicator() {
@ -369,3 +369,5 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) {
} }
return return
} }
var errNoPTS = errors.New("could not find PTS")

View File

@ -39,8 +39,8 @@ import (
) )
// TestGetPTSRange checks that GetPTSRange can correctly get the first and last // TestGetPTSRange checks that GetPTSRange can correctly get the first and last
// PTS in an MPEGTS clip. // PTS in an MPEGTS clip for a general case.
func TestGetPTSRange(t *testing.T) { func TestGetPTSRange1(t *testing.T) {
const ( const (
numOfFrames = 20 numOfFrames = 20
maxFrameSize = 1000 maxFrameSize = 1000
@ -157,6 +157,86 @@ func writeFrame(b *bytes.Buffer, frame []byte, pts uint64) error {
return nil return nil
} }
// TestGetPTSRange2 checks that GetPTSRange behaves correctly with cases where
// the first instance of a PID is not a payload start, and also where there
// are no payload starts.
func TestGetPTSRange2(t *testing.T) {
const (
nPackets = 8 // The number of MTS packets we will generate.
wantPID = 1 // The PID we want.
)
tests := []struct {
pusi []bool // The value of PUSI for each packet.
pid []uint16 // The PIDs for each packet.
pts []uint64 // The PTS for each packet.
want [2]uint64 // The wanted PTS from GetPTSRange.
err error // The error we expect from GetPTSRange.
}{
{
[]bool{false, false, false, true, false, false, true, false},
[]uint16{0, 0, 1, 1, 1, 1, 1, 1},
[]uint64{0, 0, 0, 1, 0, 0, 2, 0},
[2]uint64{1, 2},
nil,
},
{
[]bool{false, false, false, true, false, false, false, false},
[]uint16{0, 0, 1, 1, 1, 1, 1, 1},
[]uint64{0, 0, 0, 1, 0, 0, 0, 0},
[2]uint64{1, 1},
nil,
},
{
[]bool{false, false, false, false, false, false, false, false},
[]uint16{0, 0, 1, 1, 1, 1, 1, 1},
[]uint64{0, 0, 0, 0, 0, 0, 0, 0},
[2]uint64{0, 0},
errNoPTS,
},
}
var clip bytes.Buffer
for i, test := range tests {
// Generate MTS packets for this test.
clip.Reset()
for j := 0; j < nPackets; j++ {
pesPkt := pes.Packet{
StreamID: H264ID,
PDI: hasPTS,
PTS: test.pts[j],
Data: []byte{},
HeaderLength: 5,
}
buf := pesPkt.Bytes(nil)
pkt := Packet{
PUSI: test.pusi[j],
PID: test.pid[j],
RAI: true,
CC: 0,
AFC: hasAdaptationField | hasPayload,
PCRF: true,
}
pkt.FillPayload(buf)
_, err := clip.Write(pkt.Bytes(nil))
if err != nil {
t.Fatalf("did not expect clip write error: %v", err)
}
}
pts, err := GetPTSRange(clip.Bytes(), wantPID)
if err != test.err {
t.Errorf("did not get expected error for test: %v\nGot: %v\nWant: %v\n", i, err, test.err)
}
if pts != test.want {
t.Errorf("did not get expected result for test: %v\nGot: %v\nWant: %v\n", i, pts, test.want)
}
}
}
// TestBytes checks that Packet.Bytes() correctly produces a []byte // TestBytes checks that Packet.Bytes() correctly produces a []byte
// representation of a Packet. // representation of a Packet.
func TestBytes(t *testing.T) { func TestBytes(t *testing.T) {