container/mts: wrote test for BytesForPTSInterval and corrected bugs

This commit is contained in:
Saxon 2019-06-14 14:21:45 +09:30
parent 70eb8193cb
commit 22079fcb48
2 changed files with 125 additions and 10 deletions

View File

@ -119,6 +119,13 @@ func (c *Clip) Bytes() []byte {
return c.backing return c.backing
} }
// Errors used in BytesForPTSInterval.
var (
errBadLowerBound = errors.New("'from' cannot be found")
errBadUpperBound = errors.New("'to' cannot be found")
errInvalidPTSRange = errors.New("PTS interval invalid")
)
// BytesForPTSInterval returns the media data between PTS' from and to. If from // BytesForPTSInterval returns the media data between PTS' from and to. If from
// sits between two PTS, the Frame posessing lower PTS will be considered the start. // sits between two PTS, the Frame posessing lower PTS will be considered the start.
// The Frame before the Frame corresponding to to will be considered the final // The Frame before the Frame corresponding to to will be considered the final
@ -126,22 +133,22 @@ func (c *Clip) Bytes() []byte {
func (c *Clip) BytesForPTSInterval(from, to uint64) ([]byte, error) { func (c *Clip) BytesForPTSInterval(from, to uint64) ([]byte, error) {
// First check that the interval makes sense. // First check that the interval makes sense.
if from >= to { if from >= to {
return nil, errors.New("PTS interval is invalid") return nil, errInvalidPTSRange
} }
// Use binary search to find 'from'. // Use binary search to find 'from'.
n := len(c.frames) n := len(c.frames) - 1
idx := sort.Search( idx := sort.Search(
n, n,
func(i int) bool { func(i int) bool {
if from >= c.frames[i].PTS && from < c.frames[i].PTS { if from < c.frames[i+1].PTS {
return true return true
} }
return false return false
}, },
) )
if idx == n { if idx == n {
return nil, errors.New("'from' cannot be found") return nil, errBadLowerBound
} }
// Now get the start index for the backing slice from this Frame. // Now get the start index for the backing slice from this Frame.
@ -153,18 +160,18 @@ func (c *Clip) BytesForPTSInterval(from, to uint64) ([]byte, error) {
idx = sort.Search( idx = sort.Search(
n, n,
func(i int) bool { func(i int) bool {
if to >= c.frames[i+off].PTS && from < c.frames[i+off].PTS { if to <= c.frames[i+off].PTS {
return true return true
} }
return false return false
}, },
) )
if idx == n { if idx == n {
return nil, errors.New("'to' cannot be found") return nil, errBadUpperBound
} }
// Now get the end index for the backing slice from this Frame, and return // Now get the end index for the backing slice from this Frame, and return
// segment from backing slice corresponding to start and end. // segment from backing slice corresponding to start and end.
end := c.frames[idx+off].idx end := c.frames[idx+off-1].idx
return c.backing[start:end], nil return c.backing[start : end+len(c.frames[idx+off].Media)], nil
} }

View File

@ -85,8 +85,25 @@ func TestExtract(t *testing.T) {
// Check frames individually. // Check frames individually.
for i, frame := range want.frames { for i, frame := range want.frames {
if !reflect.DeepEqual(frame, got.frames[i]) { // Check media data.
t.Fatalf("did not get expected result.\nGot: %v\n, Want: %v\n", got.frames[i], frame) wantMedia := frame.Media
gotMedia := got.frames[i].Media
if !bytes.Equal(wantMedia, gotMedia) {
t.Fatalf("did not get expected data for frame: %v\nGot: %v\nWant: %v\n", i, gotMedia, wantMedia)
}
// Check stream ID.
wantID := frame.ID
gotID := got.frames[i].ID
if wantID != gotID {
t.Fatalf("did not get expected ID for frame: %v\nGot: %v\nWant: %v\n", i, gotID, wantID)
}
// Check meta.
wantMeta := frame.Meta
gotMeta := got.frames[i].Meta
if !reflect.DeepEqual(wantMeta, gotMeta) {
t.Fatalf("did not get expected meta for frame: %v\nGot: %v\nwant: %v\n", i, gotMeta, wantMeta)
} }
} }
} }
@ -207,3 +224,94 @@ func genFrames(n, min, max int) [][]byte {
} }
return frames return frames
} }
// TestBytesForPTSInterval checks that BytesForPTSInterval can correctly return
// a slice of media data corresponding to a given PTS interval.
func TestBytesForPTSInterval(t *testing.T) {
const (
numOfTestFrames = 10
ptsInterval = 4
frameSize = 3
)
clip := &Clip{}
// Generate test frames.
for i := 0; i < numOfTestFrames; i++ {
clip.backing = append(clip.backing, []byte{byte(i), byte(i), byte(i)}...)
clip.frames = append(
clip.frames,
Frame{
Media: clip.backing[i*frameSize : (i+1)*frameSize],
PTS: uint64(i * ptsInterval),
idx: i * frameSize,
},
)
}
// We test each of these scenarios.
tests := []struct {
from uint64
to uint64
expect []byte
err error
}{
{
from: 6,
to: 15,
expect: []byte{
0x01, 0x01, 0x01,
0x02, 0x02, 0x02,
0x03, 0x03, 0x03,
},
err: nil,
},
{
from: 4,
to: 16,
expect: []byte{
0x01, 0x01, 0x01,
0x02, 0x02, 0x02,
0x03, 0x03, 0x03,
},
err: nil,
},
{
from: 10,
to: 5,
expect: nil,
err: errInvalidPTSRange,
},
{
from: 50,
to: 70,
expect: nil,
err: errBadLowerBound,
},
{
from: 5,
to: 70,
expect: nil,
err: errBadUpperBound,
},
}
// Run tests.
for i, test := range tests {
got, err := clip.BytesForPTSInterval(test.from, test.to)
// First check the error.
if err != nil && err != test.err {
t.Errorf("unexpected error: %v for test: %v from BytesForPTSInterval", err, i)
continue
} else if err != test.err {
t.Errorf("expected to get error: %v for test: %v from BytesForPTSInterval", test.err, i)
continue
}
// Now check data.
if test.err == nil && !bytes.Equal(test.expect, got) {
t.Errorf("did not get expected data for test: %v from BytesForPTSInterval.\n Got: %v\n, Want: %v\n", i, got, test.expect)
}
}
}