Merged in media-pids (pull request #220)

container/mts/mpegts.go: added Programs, Streams and MediaStreams functions

Approved-by: Alan Noble <anoble@gmail.com>
This commit is contained in:
Saxon Milton 2019-08-06 12:01:11 +00:00
commit f6c296de01
2 changed files with 400 additions and 34 deletions

View File

@ -33,6 +33,7 @@ import (
"fmt" "fmt"
"github.com/Comcast/gots/packet" "github.com/Comcast/gots/packet"
gotspsi "github.com/Comcast/gots/psi"
"github.com/pkg/errors" "github.com/pkg/errors"
"bitbucket.org/ausocean/av/container/mts/meta" "bitbucket.org/ausocean/av/container/mts/meta"
@ -178,16 +179,15 @@ func FindPat(d []byte) ([]byte, int, error) {
// Errors used by FindPid. // Errors used by FindPid.
var ( var (
errInvalidLen = errors.New("MPEG-TS data not of valid length") ErrInvalidLen = errors.New("MPEG-TS data not of valid length")
errCouldNotFind = errors.New("could not find packet with given PID") errCouldNotFind = errors.New("could not find packet with given PID")
errNotConsecutive = errors.New("could not find consecutive PIDs")
) )
// FindPid will take a clip of MPEG-TS and try to find a packet with given PID - if one // FindPid will take a clip of MPEG-TS and try to find a packet with given PID - if one
// is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned.
func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) { func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) {
if len(d) < PacketSize { if len(d) < PacketSize {
return nil, -1, errInvalidLen return nil, -1, ErrInvalidLen
} }
for i = 0; i < len(d); i += PacketSize { for i = 0; i < len(d); i += PacketSize {
p := (uint16(d[i+1]&0x1f) << 8) | uint16(d[i+2]) p := (uint16(d[i+1]&0x1f) << 8) | uint16(d[i+2])
@ -205,7 +205,7 @@ func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) {
// nil, -1 and an error is returned. // nil, -1 and an error is returned.
func LastPid(d []byte, pid uint16) (pkt []byte, i int, err error) { func LastPid(d []byte, pid uint16) (pkt []byte, i int, err error) {
if len(d) < PacketSize { if len(d) < PacketSize {
return nil, -1, errInvalidLen return nil, -1, ErrInvalidLen
} }
for i = len(d) - PacketSize; i >= 0; i -= PacketSize { for i = len(d) - PacketSize; i >= 0; i -= PacketSize {
@ -218,30 +218,72 @@ func LastPid(d []byte, pid uint16) (pkt []byte, i int, err error) {
return nil, -1, errCouldNotFind return nil, -1, errCouldNotFind
} }
// IndexPid returns the position of one or more consecutive pids, // Errors used by FindPSI.
// along with optional metadata if present. Commonly used to find a var (
// PAT immediately followed by a PMT. ErrMultiplePrograms = errors.New("more than one program not supported")
func IndexPid(d []byte, pids ...uint16) (idx int, m map[string]string, err error) { ErrNoPrograms = errors.New("no programs in PAT")
idx = -1 ErrNotConsecutive = errors.New("could not find consecutive PIDs")
for _, pid := range pids { )
if len(d) < PacketSize {
return idx, m, errInvalidLen // FindPSI finds the index of a PAT in an a slice of MPEG-TS and returns, along
} // with a map of meta from the PMT and the stream PIDs and their types.
pkt, i, err := FindPid(d, pid) func FindPSI(d []byte) (int, map[uint16]uint8, map[string]string, error) {
if err != nil { if len(d) < PacketSize {
return idx, m, errCouldNotFind return -1, nil, nil, ErrInvalidLen
}
if pid == PmtPid {
m, _ = metaFromPMT(pkt)
}
if idx == -1 {
idx = i
} else if i != 0 {
return idx, m, errNotConsecutive
}
d = d[i+PacketSize:]
} }
return
// Find the PAT if it exists.
pkt, i, err := FindPid(d, PatPid)
if err != nil {
return -1, nil, nil, errors.Wrap(err, "error finding PAT")
}
// Let's take this opportunity to check what programs are in this MPEG-TS
// stream, and therefore the PID of the PMT, from which we can get metadata.
// NB: currently we only support one program.
progs, err := Programs(pkt)
if err != nil {
return i, nil, nil, errors.Wrap(err, "cannot get programs from PAT")
}
if len(progs) == 0 {
return i, nil, nil, ErrNoPrograms
}
if len(progs) > 1 {
return i, nil, nil, ErrMultiplePrograms
}
pmtPID := pmtPIDs(progs)[0]
// Now we can look for the PMT. We want to adjust d so that we're not looking
// at the same data twice.
d = d[i+PacketSize:]
pkt, pmtIdx, err := FindPid(d, pmtPID)
if err != nil {
return i, nil, nil, errors.Wrap(err, "error finding PMT")
}
// Check that the PMT comes straight after the PAT.
if pmtIdx != 0 {
return i, nil, nil, ErrNotConsecutive
}
// Now we can try to get meta from the PMT.
meta, _ := metaFromPMT(pkt)
// Now to get the elementary streams defined for this program.
streams, err := Streams(pkt)
if err != nil {
return i, nil, meta, errors.Wrap(err, "could not get streams from PMT")
}
streamMap := make(map[uint16]uint8)
for _, s := range streams {
streamMap[s.ElementaryPid()] = s.StreamType()
}
return i, streamMap, meta, nil
} }
// FillPayload takes a channel and fills the packets Payload field until the // FillPayload takes a channel and fills the packets Payload field until the
@ -421,9 +463,9 @@ func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) {
} }
var ( var (
errNoPesPayload = errors.New("no PES payload") errNoPesPayload = errors.New("no PES payload")
errNoPesPTS = errors.New("no PES PTS") errNoPesPTS = errors.New("no PES PTS")
errInvalidPesHeader = errors.New("invalid PES header") errInvalidPesHeader = errors.New("invalid PES header")
errInvalidPesPayload = errors.New("invalid PES payload") errInvalidPesPayload = errors.New("invalid PES payload")
) )
@ -593,3 +635,98 @@ func SegmentForMeta(d []byte, key, val string) ([][]byte, error) {
return res, nil return res, nil
} }
// pid returns the packet identifier for the given packet.
func pid(p []byte) uint16 {
return uint16(p[1]&0x1f)<<8 | uint16(p[2])
}
// Programs returns a map of program numbers and corresponding PMT PIDs for a
// given MPEG-TS PAT packet.
func Programs(p []byte) (map[uint16]uint16, error) {
pat, err := gotspsi.NewPAT(p)
if err != nil {
return nil, err
}
return pat.ProgramMap(), nil
}
// Streams returns elementary streams defined in a given MPEG-TS PMT packet.
func Streams(p []byte) ([]gotspsi.PmtElementaryStream, error) {
payload, err := Payload(p)
if err != nil {
return nil, errors.Wrap(err, "cannot get packet payload")
}
pmt, err := gotspsi.NewPMT(payload)
if err != nil {
return nil, err
}
return pmt.ElementaryStreams(), nil
}
// MediaStreams retrieves the PmtElementaryStreams from the given PSI. This
// function currently assumes that PSI contain a PAT followed by a PMT directly
// after. We also assume that this MPEG-TS stream contains just one program,
// but this program may contain different streams, i.e. a video stream + audio
// stream.
func MediaStreams(p []byte) ([]gotspsi.PmtElementaryStream, error) {
pat := p[:PacketSize]
pmt := p[PacketSize : 2*PacketSize]
if pid(pat) != PatPid {
return nil, errors.New("first packet is not a PAT")
}
m, err := Programs(pat)
if err != nil {
return nil, errors.Wrap(err, "could not get programs from PAT")
}
if len(m) == 0 {
return nil, ErrNoPrograms
}
if len(m) > 1 {
return nil, ErrMultiplePrograms
}
if pid(pmt) != pmtPIDs(m)[0] {
return nil, errors.New("second packet is not desired PMT")
}
s, err := Streams(pmt)
if err != nil {
return nil, errors.Wrap(err, "could not get streams from PMT")
}
return s, nil
}
// pmtPIDs returns PMT PIDS from a map containing program number as keys and
// corresponding PMT PIDs as values.
func pmtPIDs(m map[uint16]uint16) []uint16 {
r := make([]uint16, 0, len(m))
for _, v := range m {
r = append(r, v)
}
return r
}
// Errors used by Payload.
var ErrNoPayload = errors.New("no payload")
// Payload returns the payload of an MPEG-TS packet p.
// NB: this is not a copy of the payload in the interest of performance.
// TODO: offer function that will do copy if we have interests in safety.
func Payload(p []byte) ([]byte, error) {
c := byte((p[3] & 0x30) >> 4)
if c == 2 {
return nil, ErrNoPayload
}
// Check if there is an adaptation field.
off := 4
if p[3]&0x20 == 1 {
off = int(5 + p[4])
}
return p[off:], nil
}

View File

@ -35,10 +35,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/Comcast/gots/packet"
gotspsi "github.com/Comcast/gots/psi"
"github.com/pkg/errors"
"bitbucket.org/ausocean/av/container/mts/meta" "bitbucket.org/ausocean/av/container/mts/meta"
"bitbucket.org/ausocean/av/container/mts/pes" "bitbucket.org/ausocean/av/container/mts/pes"
"bitbucket.org/ausocean/av/container/mts/psi" "bitbucket.org/ausocean/av/container/mts/psi"
"github.com/Comcast/gots/packet"
) )
// TestGetPTSRange checks that GetPTSRange can correctly get the first and last // TestGetPTSRange checks that GetPTSRange can correctly get the first and last
@ -493,7 +496,7 @@ func TestSegmentForMeta(t *testing.T) {
} }
// Now test IndexPid. // Now test IndexPid.
i, m, err := IndexPid(clip.Bytes(), PatPid, PmtPid) i, _, m, err := FindPSI(clip.Bytes())
if err != nil { if err != nil {
t.Fatalf("IndexPid failed with error: %v", err) t.Fatalf("IndexPid failed with error: %v", err)
} }
@ -507,7 +510,7 @@ func TestSegmentForMeta(t *testing.T) {
// Finally, test IndexPid error handling. // Finally, test IndexPid error handling.
for _, d := range [][]byte{[]byte{}, make([]byte, PacketSize/2), make([]byte, PacketSize)} { for _, d := range [][]byte{[]byte{}, make([]byte, PacketSize/2), make([]byte, PacketSize)} {
_, _, err := IndexPid(d, PatPid, PmtPid) _, _, _, err := FindPSI(d)
if err == nil { if err == nil {
t.Fatalf("IndexPid expected error") t.Fatalf("IndexPid expected error")
} }
@ -530,3 +533,229 @@ func scale(x, y int) rng {
((y * 2) + 1) * PacketSize, ((y * 2) + 1) * PacketSize,
} }
} }
func TestFindPSI(t *testing.T) {
const (
pat = iota
pmt
media
)
const (
metaKey = "key"
mediaType = gotspsi.PmtStreamTypeMpeg4Video
pmtPID = 3
streamPID = 4
)
type want struct {
idx int
streamType uint8
streamPID uint16
meta map[string]string
err error
}
tests := []struct {
pkts []int
meta string
want want
}{
{
pkts: []int{pat, pmt, media, media},
meta: "1",
want: want{
idx: 0,
streamType: gotspsi.PmtStreamTypeMpeg4Video,
streamPID: 4,
meta: map[string]string{
"key": "1",
},
err: nil,
},
},
{
pkts: []int{media, pat, pmt, media, media},
meta: "1",
want: want{
idx: 188,
streamType: gotspsi.PmtStreamTypeMpeg4Video,
streamPID: 4,
meta: map[string]string{
"key": "1",
},
err: nil,
},
},
{
pkts: []int{pat, media, pmt, media, media},
meta: "1",
want: want{
idx: 0,
streamType: gotspsi.PmtStreamTypeMpeg4Video,
streamPID: 4,
meta: map[string]string{
"key": "1",
},
err: ErrNotConsecutive,
},
},
}
var clip bytes.Buffer
var err error
Meta = meta.New()
for i, test := range tests {
// Generate MTS packets for this test.
clip.Reset()
for _, pkt := range test.pkts {
switch pkt {
case pat:
patTable := (&psi.PSI{
Pf: 0x00,
Tid: 0x00,
Ssi: true,
Pb: false,
Sl: 0x0d,
Tss: &psi.TSS{
Tide: 0x01,
V: 0,
Cni: true,
Sn: 0,
Lsn: 0,
Sd: &psi.PAT{
Pn: 0x01,
Pmpid: pmtPID,
},
},
}).Bytes()
pat := Packet{
PUSI: true,
PID: PatPid,
CC: 0,
AFC: HasPayload,
Payload: psi.AddPadding(patTable),
}
_, err := clip.Write(pat.Bytes(nil))
if err != nil {
t.Fatalf("could not write PAT to clip for test %d", i)
}
case pmt:
pmtTable := (&psi.PSI{
Pf: 0x00,
Tid: 0x02,
Ssi: true,
Sl: 0x12,
Tss: &psi.TSS{
Tide: 0x01,
V: 0,
Cni: true,
Sn: 0,
Lsn: 0,
Sd: &psi.PMT{
Pcrpid: 0x0100,
Pil: 0,
Essd: &psi.ESSD{
St: mediaType,
Epid: streamPID,
Esil: 0x00,
},
},
},
}).Bytes()
Meta.Add(metaKey, test.meta)
pmtTable, err = updateMeta(pmtTable)
if err != nil {
t.Fatalf("could not update meta for test %d", i)
}
pmt := Packet{
PUSI: true,
PID: pmtPID,
CC: 0,
AFC: HasPayload,
Payload: psi.AddPadding(pmtTable),
}
_, err = clip.Write(pmt.Bytes(nil))
if err != nil {
t.Fatalf("could not write PMT to clip for test %d", i)
}
case media:
pesPkt := pes.Packet{
StreamID: mediaType,
PDI: hasPTS,
Data: []byte{},
HeaderLength: 5,
}
buf := pesPkt.Bytes(nil)
pkt := Packet{
PUSI: true,
PID: uint16(streamPID),
RAI: true,
CC: 0,
AFC: hasAdaptationField | hasPayload,
PCRF: true,
}
pkt.FillPayload(buf)
_, err := clip.Write(pkt.Bytes(nil))
if err != nil {
t.Fatalf("did not expect clip write error: %v", err)
}
default:
t.Fatalf("undefined pkt type %d in test %d", pkt, i)
}
}
gotIdx, gotStreams, gotMeta, gotErr := FindPSI(clip.Bytes())
// Check error
if errors.Cause(gotErr) != test.want.err {
t.Errorf("did not get expected error for test %d\nGot: %v\nWant: %v\n", i, gotErr, test.want.err)
}
if gotErr == nil {
// Check idx
if gotIdx != test.want.idx {
t.Errorf("did not get expected idx for test %d\nGot: %v\nWant: %v\n", i, gotIdx, test.want.idx)
}
// Check stream type and PID
if gotStreams == nil {
t.Fatalf("gotStreams should not be nil")
}
if len(gotStreams) == 0 {
t.Fatalf("gotStreams should not be 0 length")
}
var (
gotStreamPID uint16
gotStreamType uint8
)
for k, v := range gotStreams {
gotStreamPID = k
gotStreamType = v
}
if gotStreamType != test.want.streamType {
t.Errorf("did not get expected stream type for test %d\nGot: %v\nWant: %v\n", i, gotStreamType, test.want.streamType)
}
if gotStreamPID != test.want.streamPID {
t.Errorf("did not get expected stream PID for test %d\nGot: %v\nWant: %v\n", i, gotStreamPID, test.want.streamPID)
}
// Check meta
if !reflect.DeepEqual(gotMeta, test.want.meta) {
t.Errorf("did not get expected meta for test %d\nGot: %v\nWant: %v\n", i, gotMeta, test.want.meta)
}
}
}
}