diff --git a/cmd/revid-cli/main.go b/cmd/revid-cli/main.go index 39872022..fc90dbcc 100644 --- a/cmd/revid-cli/main.go +++ b/cmd/revid-cli/main.go @@ -193,10 +193,6 @@ func handleFlags() revid.Config { log.Log(logger.Error, pkg+"bad input codec argument") } - if len(outputs) == 0 { - cfg.Outputs = make([]uint8, 1) - } - for _, o := range outputs { switch o { case "File": diff --git a/codec/codecutil/bytescanner.go b/codec/codecutil/bytescanner.go new file mode 100644 index 00000000..981bd6e0 --- /dev/null +++ b/codec/codecutil/bytescanner.go @@ -0,0 +1,95 @@ +/* +NAME + bytescanner.go + +AUTHOR + Dan Kortschak + +LICENSE + This is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// Package bytescan implements a byte-level scanner. +package bytescan + +import "io" + +// ByteScanner is a byte scanner. +type ByteScanner struct { + buf []byte + off int + + // r is the source of data for the scanner. + r io.Reader +} + +// NewByteScanner returns a scanner initialised with an io.Reader and a read buffer. +func NewByteScanner(r io.Reader, buf []byte) *ByteScanner { + return &ByteScanner{r: r, buf: buf[:0]} +} + +// ScanUntil scans the scanner's underlying io.Reader until a delim byte +// has been read, appending all read bytes to dst. The resulting appended data, +// the last read byte and whether the last read byte was the delimiter. +func (c *ByteScanner) ScanUntil(dst []byte, delim byte) (res []byte, b byte, err error) { +outer: + for { + var i int + for i, b = range c.buf[c.off:] { + if b != delim { + continue + } + dst = append(dst, c.buf[c.off:c.off+i+1]...) + c.off += i + 1 + break outer + } + dst = append(dst, c.buf[c.off:]...) + err = c.reload() + if err != nil { + break + } + } + return dst, b, err +} + +// ReadByte is an unexported ReadByte. +func (c *ByteScanner) ReadByte() (byte, error) { + if c.off >= len(c.buf) { + err := c.reload() + if err != nil { + return 0, err + } + } + b := c.buf[c.off] + c.off++ + return b, nil +} + +// reload re-fills the scanner's buffer. +func (c *ByteScanner) reload() error { + n, err := c.r.Read(c.buf[:cap(c.buf)]) + c.buf = c.buf[:n] + if err != nil { + if err != io.EOF { + return err + } + if n == 0 { + return io.EOF + } + } + c.off = 0 + return nil +} diff --git a/codec/codecutil/bytescanner_test.go b/codec/codecutil/bytescanner_test.go new file mode 100644 index 00000000..64c17c31 --- /dev/null +++ b/codec/codecutil/bytescanner_test.go @@ -0,0 +1,82 @@ +/* +NAME + bytescanner_test.go + +DESCRIPTION + See Readme.md + +AUTHOR + Dan Kortschak + +LICENSE + This is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package bytescan + +import ( + "bytes" + "reflect" + "testing" +) + +type chunkEncoder [][]byte + +func (e *chunkEncoder) Encode(b []byte) error { + *e = append(*e, b) + return nil +} + +func (*chunkEncoder) Stream() <-chan []byte { panic("INVALID USE") } + +func TestScannerReadByte(t *testing.T) { + data := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + + for _, size := range []int{1, 2, 8, 1 << 10} { + r := NewByteScanner(bytes.NewReader(data), make([]byte, size)) + var got []byte + for { + b, err := r.ReadByte() + if err != nil { + break + } + got = append(got, b) + } + if !bytes.Equal(got, data) { + t.Errorf("unexpected result for buffer size %d:\ngot :%q\nwant:%q", size, got, data) + } + } +} + +func TestScannerScanUntilZero(t *testing.T) { + data := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit,\x00 sed do eiusmod tempor incididunt ut \x00labore et dolore magna aliqua.") + + for _, size := range []int{1, 2, 8, 1 << 10} { + r := NewByteScanner(bytes.NewReader(data), make([]byte, size)) + var got [][]byte + for { + buf, _, err := r.ScanUntil(nil, 0x0) + got = append(got, buf) + if err != nil { + break + } + } + want := bytes.SplitAfter(data, []byte{0}) + if !reflect.DeepEqual(got, want) { + t.Errorf("unexpected result for buffer zie %d:\ngot :%q\nwant:%q", size, got, want) + } + } +} diff --git a/container/mts/discontinuity.go b/container/mts/discontinuity.go index adccebad..e127ff94 100644 --- a/container/mts/discontinuity.go +++ b/container/mts/discontinuity.go @@ -4,8 +4,8 @@ NAME DESCRIPTION discontinuity.go provides functionality for detecting discontinuities in - mpegts and accounting for using the discontinuity indicator in the adaptation - field. + MPEG-TS and accounting for using the discontinuity indicator in the adaptation + field. AUTHOR Saxon A. Nelson-Milton @@ -33,7 +33,7 @@ import ( "github.com/Comcast/gots/packet" ) -// discontinuityRepairer provides function to detect discontinuities in mpegts +// discontinuityRepairer provides function to detect discontinuities in MPEG-TS // and set the discontinuity indicator as appropriate. type DiscontinuityRepairer struct { expCC map[int]int @@ -56,7 +56,7 @@ func (dr *DiscontinuityRepairer) Failed() { dr.decExpectedCC(PatPid) } -// Repair takes a clip of mpegts and checks that the first packet, which should +// Repair takes a clip of MPEG-TS and checks that the first packet, which should // be a PAT, contains a cc that is expected, otherwise the discontinuity indicator // is set to true. func (dr *DiscontinuityRepairer) Repair(d []byte) error { diff --git a/container/mts/encoder.go b/container/mts/encoder.go index 155ea03b..3208276b 100644 --- a/container/mts/encoder.go +++ b/container/mts/encoder.go @@ -127,7 +127,7 @@ const ( PTSFrequency = 90000 ) -// Encoder encapsulates properties of an mpegts generator. +// Encoder encapsulates properties of an MPEG-TS generator. type Encoder struct { dst io.WriteCloser @@ -204,7 +204,7 @@ func (e *Encoder) TimeBasedPsi(b bool, sendCount int) { e.pktCount = e.psiSendCount } -// Write implements io.Writer. Write takes raw h264 and encodes into mpegts, +// Write implements io.Writer. Write takes raw h264 and encodes into MPEG-TS, // then sending it to the encoder's io.Writer destination. func (e *Encoder) Write(data []byte) (int, error) { now := time.Now() @@ -259,7 +259,7 @@ func (e *Encoder) Write(data []byte) (int, error) { return len(data), nil } -// writePSI creates mpegts with pat and pmt tables - with pmt table having updated +// writePSI creates MPEG-TS with pat and pmt tables - with pmt table having updated // location and time data. func (e *Encoder) writePSI() error { // Write PAT. @@ -267,7 +267,7 @@ func (e *Encoder) writePSI() error { PUSI: true, PID: PatPid, CC: e.ccFor(PatPid), - AFC: HasPayload, + AFC: hasPayload, Payload: psi.AddPadding(patTable), } _, err := e.dst.Write(patPkt.Bytes(e.tsSpace[:PacketSize])) @@ -285,7 +285,7 @@ func (e *Encoder) writePSI() error { PUSI: true, PID: PmtPid, CC: e.ccFor(PmtPid), - AFC: HasPayload, + AFC: hasPayload, Payload: psi.AddPadding(pmtTable), } _, err = e.dst.Write(pmtPkt.Bytes(e.tsSpace[:PacketSize])) diff --git a/container/mts/audio_test.go b/container/mts/encoder_test.go similarity index 50% rename from container/mts/audio_test.go rename to container/mts/encoder_test.go index f785930d..8436d241 100644 --- a/container/mts/audio_test.go +++ b/container/mts/encoder_test.go @@ -1,12 +1,13 @@ /* NAME - audio_test.go + encoder_test.go AUTHOR Trek Hopton + Saxon A. Nelson-Milton LICENSE - audio_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) + encoder_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) It is free software: you can redistribute it and/or modify them under the terms of the GNU General Public License as published by the @@ -16,7 +17,7 @@ LICENSE It is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - for more details. + for more details. You should have received a copy of the GNU General Public License in gpl.txt. If not, see http://www.gnu.org/licenses. @@ -40,8 +41,115 @@ type nopCloser struct{ io.Writer } func (nopCloser) Close() error { return nil } -// TestEncodePcm tests the mpegts encoder's ability to encode pcm audio data. -// It reads and encodes input pcm data into mpegts, then decodes the mpegts and compares the result to the input pcm. +type destination struct { + packets [][]byte +} + +func (d *destination) Write(p []byte) (int, error) { + tmp := make([]byte, PacketSize) + copy(tmp, p) + d.packets = append(d.packets, tmp) + return len(p), nil +} + +// TestEncodeVideo checks that we can correctly encode some dummy data into a +// valid MPEG-TS stream. This checks for correct MPEG-TS headers and also that the +// original data is stored correctly and is retreivable. +func TestEncodeVideo(t *testing.T) { + Meta = meta.New() + + const dataLength = 440 + const numOfPackets = 3 + const stuffingLen = 100 + + // Generate test data. + data := make([]byte, 0, dataLength) + for i := 0; i < dataLength; i++ { + data = append(data, byte(i)) + } + + // Expect headers for PID 256 (video) + // NB: timing fields like PCR are neglected. + expectedHeaders := [][]byte{ + { + 0x47, // Sync byte. + 0x41, // TEI=0, PUSI=1, TP=0, PID=00001 (256). + 0x00, // PID(Cont)=00000000. + 0x30, // TSC=00, AFC=11(adaptation followed by payload), CC=0000(0). + 0x07, // AFL= 7. + 0x50, // DI=0,RAI=1,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + }, + { + 0x47, // Sync byte. + 0x01, // TEI=0, PUSI=0, TP=0, PID=00001 (256). + 0x00, // PID(Cont)=00000000. + 0x31, // TSC=00, AFC=11(adaptation followed by payload), CC=0001(1). + 0x01, // AFL= 1. + 0x00, // DI=0,RAI=0,ESPI=0,PCRF=0,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + }, + { + 0x47, // Sync byte. + 0x01, // TEI=0, PUSI=0, TP=0, PID=00001 (256). + 0x00, // PID(Cont)=00000000. + 0x32, // TSC=00, AFC=11(adaptation followed by payload), CC=0010(2). + 0x57, // AFL= 1+stuffingLen. + 0x00, // DI=0,RAI=0,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + }, + } + + // Create the dst and write the test data to encoder. + dst := &destination{} + _, err := NewEncoder(nopCloser{dst}, 25, Video).Write(data) + if err != nil { + t.Fatalf("could not write data to encoder, failed with err: %v\n", err) + } + + // Check headers. + var expectedIdx int + for _, p := range dst.packets { + // Get PID. + var _p packet.Packet + copy(_p[:], p) + pid := packet.Pid(&_p) + if pid == VideoPid { + // Get mts header, excluding PCR. + gotHeader := p[0:6] + wantHeader := expectedHeaders[expectedIdx] + if !bytes.Equal(gotHeader, wantHeader) { + t.Errorf("did not get expected header for idx: %v.\n Got: %v\n Want: %v\n", expectedIdx, gotHeader, wantHeader) + } + expectedIdx++ + } + } + + // Gather payload data from packets to form the total PES packet. + var pesData []byte + for _, p := range dst.packets { + var _p packet.Packet + copy(_p[:], p) + pid := packet.Pid(&_p) + if pid == VideoPid { + payload, err := packet.Payload(&_p) + if err != nil { + t.Fatalf("could not get payload from mts packet, failed with err: %v\n", err) + } + pesData = append(pesData, payload...) + } + } + + // Get data from the PES packet and compare with the original data. + pes, err := pes.NewPESHeader(pesData) + if err != nil { + t.Fatalf("got error from pes creation: %v\n", err) + } + _data := pes.Data() + if !bytes.Equal(data, _data) { + t.Errorf("did not get expected result.\n Got: %v\n Want: %v\n", data, _data) + } +} + +// TestEncodePcm tests the MPEG-TS encoder's ability to encode pcm audio data. +// It reads and encodes input pcm data into MPEG-TS, then decodes the MPEG-TS and compares the result to the input pcm. func TestEncodePcm(t *testing.T) { Meta = meta.New() diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index b590b8d3..eb4bee5d 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -1,7 +1,7 @@ /* NAME mpegts.go - provides a data structure intended to encapsulate the properties - of an MpegTs packet and also functions to allow manipulation of these packets. + of an MPEG-TS packet and also functions to allow manipulation of these packets. DESCRIPTION See Readme.md @@ -37,11 +37,7 @@ import ( "github.com/Comcast/gots/pes" ) -// General mpegts packet properties. -const ( - PacketSize = 188 - PayloadSize = 176 -) +const PacketSize = 188 // Program ID for various types of ts packets. const ( @@ -54,7 +50,7 @@ const ( // StreamID is the id of the first stream. const StreamID = 0xe0 -// HeadSize is the size of an mpegts packet header. +// HeadSize is the size of an MPEG-TS packet header. const HeadSize = 4 // Consts relating to adaptation field. @@ -165,23 +161,23 @@ type Packet struct { Payload []byte // Mpeg ts Payload } -// FindPmt will take a clip of mpegts and try to find a PMT table - if one +// FindPmt will take a clip of MPEG-TS and try to find a PMT table - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. func FindPmt(d []byte) ([]byte, int, error) { return FindPid(d, PmtPid) } -// FindPat will take a clip of mpegts and try to find a PAT table - if one +// FindPat will take a clip of MPEG-TS and try to find a PAT table - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. func FindPat(d []byte) ([]byte, int, error) { return FindPid(d, PatPid) } -// FindPid will take a clip of mpegts and try to find a packet with given PID - if one +// FindPid will take a clip of MPEG-TS and try to find a packet with given PID - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) { if len(d) < PacketSize { - return nil, -1, errors.New("Mmpegts data not of valid length") + return nil, -1, errors.New("MPEG-TS data not of valid length") } for i = 0; i < len(d); i += PacketSize { p := (uint16(d[i+1]&0x1f) << 8) | uint16(d[i+2]) @@ -196,16 +192,69 @@ func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) { // FillPayload takes a channel and fills the packets Payload field until the // channel is empty or we've the packet reaches capacity func (p *Packet) FillPayload(data []byte) int { - currentPktLen := 6 + asInt(p.PCRF)*6 + asInt(p.OPCRF)*6 + - asInt(p.SPF)*1 + asInt(p.TPDF)*1 + len(p.TPD) - if len(data) > PayloadSize-currentPktLen { - p.Payload = make([]byte, PayloadSize-currentPktLen) + currentPktLen := 6 + asInt(p.PCRF)*6 + if len(data) > PacketSize-currentPktLen { + p.Payload = make([]byte, PacketSize-currentPktLen) } else { p.Payload = make([]byte, len(data)) } return copy(p.Payload, data) } +// Bytes interprets the fields of the ts packet instance and outputs a +// corresponding byte slice +func (p *Packet) Bytes(buf []byte) []byte { + if buf == nil || cap(buf) < PacketSize { + buf = make([]byte, PacketSize) + } + + if p.OPCRF { + panic("original program clock reference field unsupported") + } + if p.SPF { + panic("splicing countdown unsupported") + } + if p.TPDF { + panic("transport private data unsupported") + } + if p.AFEF { + panic("adaptation field extension unsupported") + } + + buf = buf[:6] + buf[0] = 0x47 + buf[1] = (asByte(p.TEI)<<7 | asByte(p.PUSI)<<6 | asByte(p.Priority)<<5 | byte((p.PID&0xFF00)>>8)) + buf[2] = byte(p.PID & 0x00FF) + buf[3] = (p.TSC<<6 | p.AFC<<4 | p.CC) + + var maxPayloadSize int + if p.AFC&0x2 != 0 { + maxPayloadSize = PacketSize - 6 - asInt(p.PCRF)*6 + } else { + maxPayloadSize = PacketSize - 4 + } + + stuffingLen := maxPayloadSize - len(p.Payload) + if p.AFC&0x2 != 0 { + buf[4] = byte(1 + stuffingLen + asInt(p.PCRF)*6) + buf[5] = (asByte(p.DI)<<7 | asByte(p.RAI)<<6 | asByte(p.ESPI)<<5 | asByte(p.PCRF)<<4 | asByte(p.OPCRF)<<3 | asByte(p.SPF)<<2 | asByte(p.TPDF)<<1 | asByte(p.AFEF)) + } else { + buf = buf[:4] + } + + for i := 40; p.PCRF && i >= 0; i -= 8 { + buf = append(buf, byte((p.PCR<<15)>>uint(i))) + } + + for i := 0; i < stuffingLen; i++ { + buf = append(buf, 0xff) + } + curLen := len(buf) + buf = buf[:PacketSize] + copy(buf[curLen:], p.Payload) + return buf +} + func asInt(b bool) int { if b { return 1 @@ -220,55 +269,6 @@ func asByte(b bool) byte { return 0 } -// Bytes interprets the fields of the ts packet instance and outputs a -// corresponding byte slice -func (p *Packet) Bytes(buf []byte) []byte { - if buf == nil || cap(buf) != PacketSize { - buf = make([]byte, 0, PacketSize) - } - buf = buf[:0] - stuffingLength := 182 - len(p.Payload) - len(p.TPD) - asInt(p.PCRF)*6 - - asInt(p.OPCRF)*6 - asInt(p.SPF) - var stuffing []byte - if stuffingLength > 0 { - stuffing = make([]byte, stuffingLength) - } - for i := range stuffing { - stuffing[i] = 0xFF - } - afl := 1 + asInt(p.PCRF)*6 + asInt(p.OPCRF)*6 + asInt(p.SPF) + asInt(p.TPDF) + len(p.TPD) + len(stuffing) - buf = append(buf, []byte{ - 0x47, - (asByte(p.TEI)<<7 | asByte(p.PUSI)<<6 | asByte(p.Priority)<<5 | byte((p.PID&0xFF00)>>8)), - byte(p.PID & 0x00FF), - (p.TSC<<6 | p.AFC<<4 | p.CC), - }...) - - if p.AFC == 3 || p.AFC == 2 { - buf = append(buf, []byte{ - byte(afl), (asByte(p.DI)<<7 | asByte(p.RAI)<<6 | asByte(p.ESPI)<<5 | - asByte(p.PCRF)<<4 | asByte(p.OPCRF)<<3 | asByte(p.SPF)<<2 | - asByte(p.TPDF)<<1 | asByte(p.AFEF)), - }...) - for i := 40; p.PCRF && i >= 0; i -= 8 { - buf = append(buf, byte((p.PCR<<15)>>uint(i))) - } - for i := 40; p.OPCRF && i >= 0; i -= 8 { - buf = append(buf, byte(p.OPCR>>uint(i))) - } - if p.SPF { - buf = append(buf, p.SC) - } - if p.TPDF { - buf = append(buf, append([]byte{p.TPDL}, p.TPD...)...) - } - buf = append(buf, p.Ext...) - buf = append(buf, stuffing...) - } - buf = append(buf, p.Payload...) - return buf -} - type Option func(p *packet.Packet) // addAdaptationField adds an adaptation field to p, and applys the passed options to this field. diff --git a/container/mts/mpegts_test.go b/container/mts/mpegts_test.go index 579acd65..650fecab 100644 --- a/container/mts/mpegts_test.go +++ b/container/mts/mpegts_test.go @@ -1,3 +1,30 @@ +/* +NAME + mpegts_test.go + +DESCRIPTION + mpegts_test.go contains testing for functionality found in mpegts.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + Copyright (C) 2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + package mts import ( @@ -8,6 +35,7 @@ import ( "bitbucket.org/ausocean/av/container/mts/pes" "bitbucket.org/ausocean/av/container/mts/psi" + "github.com/Comcast/gots/packet" ) // TestGetPTSRange checks that GetPTSRange can correctly get the first and last @@ -128,3 +156,111 @@ func writeFrame(b *bytes.Buffer, frame []byte, pts uint64) error { } return nil } + +// TestBytes checks that Packet.Bytes() correctly produces a []byte +// representation of a Packet. +func TestBytes(t *testing.T) { + const payloadLen, payloadChar, stuffingChar = 120, 0x11, 0xff + const stuffingLen = PacketSize - payloadLen - 12 + + tests := []struct { + packet Packet + expectedHeader []byte + }{ + { + packet: Packet{ + PUSI: true, + PID: 1, + RAI: true, + CC: 4, + AFC: HasPayload | HasAdaptationField, + PCRF: true, + PCR: 1, + }, + expectedHeader: []byte{ + 0x47, // Sync byte. + 0x40, // TEI=0, PUSI=1, TP=0, PID=00000. + 0x01, // PID(Cont)=00000001. + 0x34, // TSC=00, AFC=11(adaptation followed by payload), CC=0100(4). + byte(7 + stuffingLen), // AFL=. + 0x50, // DI=0,RAI=1,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, // PCR. + }, + }, + } + + for testNum, test := range tests { + // Construct payload. + payload := make([]byte, 0, payloadLen) + for i := 0; i < payloadLen; i++ { + payload = append(payload, payloadChar) + } + + // Fill the packet payload. + test.packet.FillPayload(payload) + + // Create expected packet data and copy in expected header. + expected := make([]byte, len(test.expectedHeader), PacketSize) + copy(expected, test.expectedHeader) + + // Append stuffing. + for i := 0; i < stuffingLen; i++ { + expected = append(expected, stuffingChar) + } + + // Append payload to expected bytes. + expected = append(expected, payload...) + + // Compare got with expected. + got := test.packet.Bytes(nil) + if !bytes.Equal(got, expected) { + t.Errorf("did not get expected result for test: %v.\n Got: %v\n Want: %v\n", testNum, got, expected) + } + } +} + +// TestFindPid checks that FindPid can correctly extract the first instance +// of a PID from an MPEG-TS stream. +func TestFindPid(t *testing.T) { + const targetPacketNum, numOfPackets, targetPid, stdPid = 6, 15, 1, 0 + + // Prepare the stream of packets. + var stream []byte + for i := 0; i < numOfPackets; i++ { + pid := uint16(stdPid) + if i == targetPacketNum { + pid = targetPid + } + + p := Packet{ + PID: pid, + AFC: hasPayload | hasAdaptationField, + } + p.FillPayload([]byte{byte(i)}) + stream = append(stream, p.Bytes(nil)...) + } + + // Try to find the targetPid in the stream. + p, i, err := FindPid(stream, targetPid) + if err != nil { + t.Fatalf("unexpected error finding PID: %v\n", err) + } + + // Check the payload. + var _p packet.Packet + copy(_p[:], p) + payload, err := packet.Payload(&_p) + if err != nil { + t.Fatalf("unexpected error getting packet payload: %v\n", err) + } + got := payload[0] + if got != targetPacketNum { + t.Errorf("payload of found packet is not correct.\nGot: %v, Want: %v\n", got, targetPacketNum) + } + + // Check the index. + _got := i / PacketSize + if _got != targetPacketNum { + t.Errorf("index of found packet is not correct.\nGot: %v, want: %v\n", _got, targetPacketNum) + } +} diff --git a/container/mts/psi/helpers.go b/container/mts/psi/helpers.go index b8bab6b5..621460f5 100644 --- a/container/mts/psi/helpers.go +++ b/container/mts/psi/helpers.go @@ -125,7 +125,7 @@ func trimTo(d []byte, t byte) []byte { } // addPadding adds an appropriate amount of padding to a pat or pmt table for -// addition to an mpegts packet +// addition to an MPEG-TS packet func AddPadding(d []byte) []byte { t := make([]byte, PacketSize) copy(t, d) diff --git a/container/mts/psi/psi.go b/container/mts/psi/psi.go index c93d3011..3703faf4 100644 --- a/container/mts/psi/psi.go +++ b/container/mts/psi/psi.go @@ -32,7 +32,7 @@ import ( "github.com/Comcast/gots/psi" ) -// PacketSize of psi (without mpegts header) +// PacketSize of psi (without MPEG-TS header) const PacketSize = 184 // Lengths of section definitions. diff --git a/protocol/rtp/client.go b/protocol/rtp/client.go index f47f9baf..3ab856b9 100644 --- a/protocol/rtp/client.go +++ b/protocol/rtp/client.go @@ -34,7 +34,7 @@ import ( // Client describes an RTP client that can receive an RTP stream and implements // io.Reader. type Client struct { - conn *net.UDPConn + r *PacketReader } // NewClient returns a pointer to a new Client. @@ -42,14 +42,14 @@ type Client struct { // addr is the address of form : that we expect to receive // RTP at. func NewClient(addr string) (*Client, error) { - c := &Client{} + c := &Client{r: &PacketReader{}} a, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } - c.conn, err = net.ListenUDP("udp", a) + c.r.PacketConn, err = net.ListenUDP("udp", a) if err != nil { return nil, err } @@ -57,7 +57,18 @@ func NewClient(addr string) (*Client, error) { return c, nil } -// Read implements io.Reader. This wraps the Read for the connection. +// Read implements io.Reader. func (c *Client) Read(p []byte) (int, error) { - return c.conn.Read(p) + return c.r.Read(p) +} + +// PacketReader provides an io.Reader interface to an underlying UDP PacketConn. +type PacketReader struct { + net.PacketConn +} + +// Read implements io.Reader. +func (r PacketReader) Read(b []byte) (int, error) { + n, _, err := r.PacketConn.ReadFrom(b) + return n, err }