diff --git a/container/mts/payload.go b/container/mts/payload.go index d05e994e..71c7c022 100644 --- a/container/mts/payload.go +++ b/container/mts/payload.go @@ -8,8 +8,11 @@ import ( "github.com/Comcast/gots/pes" ) +// TODO: write function for getting subslice of mpegts for meta interval + // Extract extracts the media, PTS, stream ID and meta for an MPEG-TS clip given // by p, and returns as a Clip. The MPEG-TS must contain only complete packets. +// The resultant data is a copy of the original. func Extract(p []byte) (*Clip, error) { l := len(p) // Check that clip is divisible by 188, i.e. contains a series of full MPEG-TS clips. @@ -127,11 +130,12 @@ var ( errPTSRange = errors.New("PTS interval invalid") ) -// BytesForPTSInterval returns the media data between PTS' from and to. If from -// sits between two PTS, the Frame posessing lower PTS will be considered the start. -// The Frame before the Frame corresponding to to will be considered the final -// Frame. -func (c *Clip) BytesForPTSInterval(from, to uint64) ([]byte, error) { +// TrimToPTSRange returns the sub Clip in PTS range defined by from and to. +// The first Frame in the new Clip will be the Frame for which from corresponds +// exactly with Frame.PTS, or the Frame in which from lies within. The final +// Frame in the Clip will be the previous of that for which to coincides with, +// or the Frame that to lies within. +func (c *Clip) TrimToPTSRange(from, to uint64) (*Clip, error) { // First check that the interval makes sense. if from >= to { return nil, errPTSRange @@ -139,7 +143,7 @@ func (c *Clip) BytesForPTSInterval(from, to uint64) ([]byte, error) { // Use binary search to find 'from'. n := len(c.frames) - 1 - idx := sort.Search( + startFrameIdx := sort.Search( n, func(i int) bool { if from < c.frames[i+1].PTS { @@ -148,17 +152,17 @@ func (c *Clip) BytesForPTSInterval(from, to uint64) ([]byte, error) { return false }, ) - if idx == n { + if startFrameIdx == n { return nil, errPTSLowerBound } // Now get the start index for the backing slice from this Frame. - start := c.frames[idx].idx + startBackingIdx := c.frames[startFrameIdx].idx // Now use binary search again to find 'to'. - off := idx + 1 + off := startFrameIdx + 1 n = n - (off) - idx = sort.Search( + endFrameIdx := sort.Search( n, func(i int) bool { if to <= c.frames[i+off].PTS { @@ -167,14 +171,18 @@ func (c *Clip) BytesForPTSInterval(from, to uint64) ([]byte, error) { return false }, ) - if idx == n { + if endFrameIdx == n { return nil, errPTSUpperBound } - // Now get the end index for the backing slice from this Frame, and return - // segment from backing slice corresponding to start and end. - end := c.frames[idx+off-1].idx - return c.backing[start : end+len(c.frames[idx+off].Media)], nil + // Now get the end index for the backing slice from this Frame. + endBackingIdx := c.frames[endFrameIdx+off-1].idx + + // Now return a new clip. NB: data is not copied. + return &Clip{ + frames: c.frames[startFrameIdx : endFrameIdx+1], + backing: c.backing[startBackingIdx : endBackingIdx+len(c.frames[endFrameIdx+off].Media)], + }, nil } // Errors that maybe returned from BytesForMetaInterval. diff --git a/container/mts/payload_test.go b/container/mts/payload_test.go index d56ca516..c98876fe 100644 --- a/container/mts/payload_test.go +++ b/container/mts/payload_test.go @@ -225,9 +225,9 @@ func genFrames(n, min, max int) [][]byte { return frames } -// TestBytesForPTSInterval checks that BytesForPTSInterval can correctly return -// a slice of media data corresponding to a given PTS interval. -func TestBytesForPTSInterval(t *testing.T) { +// TestTrimToPTSRange checks that Clip.TrimToPTSRange will correctly return a +// sub Clip of the given PTS range. +func TestTrimToPTSRange(t *testing.T) { const ( numOfTestFrames = 10 ptsInterval = 4 @@ -298,7 +298,7 @@ func TestBytesForPTSInterval(t *testing.T) { // Run tests. for i, test := range tests { - got, err := clip.BytesForPTSInterval(test.from, test.to) + got, err := clip.TrimToPTSRange(test.from, test.to) // First check the error. if err != nil && err != test.err { @@ -310,7 +310,7 @@ func TestBytesForPTSInterval(t *testing.T) { } // Now check data. - if test.err == nil && !bytes.Equal(test.expect, got) { + if test.err == nil && !bytes.Equal(test.expect, got.Bytes()) { t.Errorf("did not get expected data for test: %v\n Got: %v\n, Want: %v\n", i, got, test.expect) } }