From 22079fcb48790aa954dec3eee0c9c029beb22d8d Mon Sep 17 00:00:00 2001 From: Saxon Date: Fri, 14 Jun 2019 14:21:45 +0930 Subject: [PATCH] container/mts: wrote test for BytesForPTSInterval and corrected bugs --- container/mts/payload.go | 23 ++++--- container/mts/payload_test.go | 112 +++++++++++++++++++++++++++++++++- 2 files changed, 125 insertions(+), 10 deletions(-) diff --git a/container/mts/payload.go b/container/mts/payload.go index a824cb64..f8491ade 100644 --- a/container/mts/payload.go +++ b/container/mts/payload.go @@ -119,6 +119,13 @@ func (c *Clip) Bytes() []byte { return c.backing } +// Errors used in BytesForPTSInterval. +var ( + errBadLowerBound = errors.New("'from' cannot be found") + errBadUpperBound = errors.New("'to' cannot be found") + errInvalidPTSRange = errors.New("PTS interval invalid") +) + // BytesForPTSInterval returns the media data between PTS' from and to. If from // sits between two PTS, the Frame posessing lower PTS will be considered the start. // The Frame before the Frame corresponding to to will be considered the final @@ -126,22 +133,22 @@ func (c *Clip) Bytes() []byte { func (c *Clip) BytesForPTSInterval(from, to uint64) ([]byte, error) { // First check that the interval makes sense. if from >= to { - return nil, errors.New("PTS interval is invalid") + return nil, errInvalidPTSRange } // Use binary search to find 'from'. - n := len(c.frames) + n := len(c.frames) - 1 idx := sort.Search( n, func(i int) bool { - if from >= c.frames[i].PTS && from < c.frames[i].PTS { + if from < c.frames[i+1].PTS { return true } return false }, ) if idx == n { - return nil, errors.New("'from' cannot be found") + return nil, errBadLowerBound } // Now get the start index for the backing slice from this Frame. @@ -153,18 +160,18 @@ func (c *Clip) BytesForPTSInterval(from, to uint64) ([]byte, error) { idx = sort.Search( n, func(i int) bool { - if to >= c.frames[i+off].PTS && from < c.frames[i+off].PTS { + if to <= c.frames[i+off].PTS { return true } return false }, ) if idx == n { - return nil, errors.New("'to' cannot be found") + return nil, errBadUpperBound } // Now get the end index for the backing slice from this Frame, and return // segment from backing slice corresponding to start and end. - end := c.frames[idx+off].idx - return c.backing[start:end], nil + end := c.frames[idx+off-1].idx + return c.backing[start : end+len(c.frames[idx+off].Media)], nil } diff --git a/container/mts/payload_test.go b/container/mts/payload_test.go index 568a842e..7f7dd968 100644 --- a/container/mts/payload_test.go +++ b/container/mts/payload_test.go @@ -85,8 +85,25 @@ func TestExtract(t *testing.T) { // Check frames individually. for i, frame := range want.frames { - if !reflect.DeepEqual(frame, got.frames[i]) { - t.Fatalf("did not get expected result.\nGot: %v\n, Want: %v\n", got.frames[i], frame) + // Check media data. + wantMedia := frame.Media + gotMedia := got.frames[i].Media + if !bytes.Equal(wantMedia, gotMedia) { + t.Fatalf("did not get expected data for frame: %v\nGot: %v\nWant: %v\n", i, gotMedia, wantMedia) + } + + // Check stream ID. + wantID := frame.ID + gotID := got.frames[i].ID + if wantID != gotID { + t.Fatalf("did not get expected ID for frame: %v\nGot: %v\nWant: %v\n", i, gotID, wantID) + } + + // Check meta. + wantMeta := frame.Meta + gotMeta := got.frames[i].Meta + if !reflect.DeepEqual(wantMeta, gotMeta) { + t.Fatalf("did not get expected meta for frame: %v\nGot: %v\nwant: %v\n", i, gotMeta, wantMeta) } } } @@ -207,3 +224,94 @@ func genFrames(n, min, max int) [][]byte { } return frames } + +// TestBytesForPTSInterval checks that BytesForPTSInterval can correctly return +// a slice of media data corresponding to a given PTS interval. +func TestBytesForPTSInterval(t *testing.T) { + const ( + numOfTestFrames = 10 + ptsInterval = 4 + frameSize = 3 + ) + + clip := &Clip{} + + // Generate test frames. + for i := 0; i < numOfTestFrames; i++ { + clip.backing = append(clip.backing, []byte{byte(i), byte(i), byte(i)}...) + clip.frames = append( + clip.frames, + Frame{ + Media: clip.backing[i*frameSize : (i+1)*frameSize], + PTS: uint64(i * ptsInterval), + idx: i * frameSize, + }, + ) + } + + // We test each of these scenarios. + tests := []struct { + from uint64 + to uint64 + expect []byte + err error + }{ + { + from: 6, + to: 15, + expect: []byte{ + 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, + 0x03, 0x03, 0x03, + }, + err: nil, + }, + { + from: 4, + to: 16, + expect: []byte{ + 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, + 0x03, 0x03, 0x03, + }, + err: nil, + }, + { + from: 10, + to: 5, + expect: nil, + err: errInvalidPTSRange, + }, + { + from: 50, + to: 70, + expect: nil, + err: errBadLowerBound, + }, + { + from: 5, + to: 70, + expect: nil, + err: errBadUpperBound, + }, + } + + // Run tests. + for i, test := range tests { + got, err := clip.BytesForPTSInterval(test.from, test.to) + + // First check the error. + if err != nil && err != test.err { + t.Errorf("unexpected error: %v for test: %v from BytesForPTSInterval", err, i) + continue + } else if err != test.err { + t.Errorf("expected to get error: %v for test: %v from BytesForPTSInterval", test.err, i) + continue + } + + // Now check data. + if test.err == nil && !bytes.Equal(test.expect, got) { + t.Errorf("did not get expected data for test: %v from BytesForPTSInterval.\n Got: %v\n, Want: %v\n", i, got, test.expect) + } + } +}