diff --git a/cmd/revid-cli/main.go b/cmd/revid-cli/main.go index b0fc89a1..76acfcff 100644 --- a/cmd/revid-cli/main.go +++ b/cmd/revid-cli/main.go @@ -107,8 +107,9 @@ func handleFlags() revid.Config { var ( cpuprofile = flag.String("cpuprofile", "", "write cpu profile to `file`") - inputPtr = flag.String("Input", "", "The input type: Raspivid, File, v4l, Audio") inputCodecPtr = flag.String("InputCodec", "", "The codec of the input: H264, Mjpeg, PCM, ADPCM") + inputPtr = flag.String("Input", "", "The input type: Raspivid, File, v4l, Audio, RTSP") + rtspURLPtr = flag.String("RTSPURL", "", "The URL for an RTSP server.") rtmpMethodPtr = flag.String("RtmpMethod", "", "The method used to send over rtmp: Ffmpeg, Librtmp") quantizePtr = flag.Bool("Quantize", false, "Quantize input (non-variable bitrate)") verbosityPtr = flag.String("Verbosity", "Info", "Verbosity: Debug, Info, Warning, Error, Fatal") @@ -190,6 +191,8 @@ func handleFlags() revid.Config { cfg.Input = revid.File case "Audio": cfg.Input = revid.Audio + case "RTSP": + cfg.Input = revid.RTSP case "": default: log.Log(logger.Error, pkg+"bad input argument") @@ -214,10 +217,6 @@ func handleFlags() revid.Config { cfg.WriteRate = float64(*frameRatePtr) } - if len(outputs) == 0 { - cfg.Outputs = make([]uint8, 1) - } - for _, o := range outputs { switch o { case "File": @@ -248,6 +247,7 @@ func handleFlags() revid.Config { netsender.ConfigFile = *configFilePtr } + cfg.RTSPURL = *rtspURLPtr cfg.Quantize = *quantizePtr cfg.Rotation = *rotationPtr cfg.FlipHorizontal = *horizontalFlipPtr diff --git a/codec/adpcm/adpcm.go b/codec/adpcm/adpcm.go index f90af10b..ce8ae9f7 100644 --- a/codec/adpcm/adpcm.go +++ b/codec/adpcm/adpcm.go @@ -2,9 +2,6 @@ NAME adpcm.go -DESCRIPTION - adpcm.go contains functions for encoding/compressing pcm into adpcm and decoding/decompressing back to pcm. - AUTHOR Trek Hopton @@ -30,47 +27,25 @@ LICENSE Reference algorithms for ADPCM compression and decompression are in part 6. */ +// Package adpcm provides functions to transcode between PCM and ADPCM. package adpcm import ( - "bytes" "encoding/binary" "fmt" + "io" + "math" ) -// encoder is used to encode to ADPCM from PCM data. -// pred and index hold state that persists between calls to encodeSample and calcHead. -// dest is the output buffer that implements io.writer and io.bytewriter, ie. where the encoded ADPCM data is written to. -type encoder struct { - dest *bytes.Buffer - pred int16 - index int16 -} - -// decoder is used to decode from ADPCM to PCM data. -// pred, index, and step hold state that persists between calls to decodeSample. -// dest is the output buffer that implements io.writer and io.bytewriter, ie. where the decoded PCM data is written to. -type decoder struct { - dest *bytes.Buffer - pred int16 - index int16 - step int16 -} - -// BytesOutput will return the number of adpcm bytes that will be generated for the given pcm data -func BytesOutput(pcm int) int { - // for X pcm bytes, 2 bytes are left uncompressed, the rest is compressed by a factor of 4 - // and a start index and padding byte are added. - return (pcm-2)/4 + 2 + 1 + 1 -} - -// PcmBS is the size of the blocks that an encoder uses. -// 'encodeBlock' will encode PcmBS bytes at a time and the output will be AdpcmBS bytes long. -const PcmBS = 1010 - -// AdpcmBS is the size of the blocks that a decoder uses. -// 'decodeBlock' will decode AdpcmBS bytes at a time and the output will be PcmBS bytes long. -const AdpcmBS = 256 +const ( + byteDepth = 2 // We are working with 16-bit samples. TODO(Trek): make configurable. + initSamps = 2 // Number of samples used to initialise the encoder. + initBytes = initSamps * byteDepth + headBytes = 4 // Number of bytes in the header of ADPCM. + samplesPerEnc = 2 // Number of sample encoded at a time eg. 2 16-bit samples get encoded into 1 byte. + bytesPerEnc = samplesPerEnc * byteDepth + compFact = 4 // In general ADPCM compresses by a factor of 4. +) // Table of index changes (see spec). var indexTable = []int16{ @@ -94,28 +69,35 @@ var stepTable = []int16{ 32767, } -// NewEncoder retuns a new ADPCM encoder. -func NewEncoder(dst *bytes.Buffer) *encoder { - e := encoder{ - dest: dst, - } - return &e +// Encoder is used to encode to ADPCM from PCM data. +type Encoder struct { + // dst is the destination for ADPCM-encoded data. + dst io.Writer + + est int16 // Estimation of sample based on quantised ADPCM nibble. + idx int16 // Index to step used for estimation. } -// NewDecoder retuns a new ADPCM decoder. -func NewDecoder(dst *bytes.Buffer) *decoder { - d := decoder{ - step: stepTable[0], - dest: dst, - } - return &d +// Decoder is used to decode from ADPCM to PCM data. +type Decoder struct { + // dst is the destination for PCM-encoded data. + dst io.Writer + + est int16 // Estimation of sample based on quantised ADPCM nibble. + idx int16 // Index to step used for estimation. + step int16 +} + +// NewEncoder retuns a new ADPCM Encoder. +func NewEncoder(dst io.Writer) *Encoder { + return &Encoder{dst: dst} } // encodeSample takes a single 16 bit PCM sample and // returns a byte of which the last 4 bits are an encoded ADPCM nibble. -func (e *encoder) encodeSample(sample int16) byte { - // Find difference of actual sample from encoder's prediction. - delta := sample - e.pred +func (e *Encoder) encodeSample(sample int16) byte { + // Find difference between the sample and the previous estimation. + delta := capAdd16(sample, -e.est) // Create and set sign bit for nibble and find absolute value of difference. var nib byte @@ -124,217 +106,250 @@ func (e *encoder) encodeSample(sample int16) byte { delta = -delta } - step := stepTable[e.index] + step := stepTable[e.idx] diff := step >> 3 var mask byte = 4 for i := 0; i < 3; i++ { if delta > step { nib |= mask - delta -= step - diff += step + delta = capAdd16(delta, -step) + diff = capAdd16(diff, step) } mask >>= 1 step >>= 1 } - // Adjust predicted sample based on calculated difference. if nib&8 != 0 { - e.pred -= diff - } else { - e.pred += diff + diff = -diff } - e.index += indexTable[nib&7] + // Adjust estimated sample based on calculated difference. + e.est = capAdd16(e.est, diff) + + e.idx += indexTable[nib&7] // Check for underflow and overflow. - if e.index < 0 { - e.index = 0 - } else if e.index > int16(len(stepTable)-1) { - e.index = int16(len(stepTable) - 1) + if e.idx < 0 { + e.idx = 0 + } else if e.idx > int16(len(stepTable)-1) { + e.idx = int16(len(stepTable) - 1) } return nib } +// calcHead sets the state for the Encoder by running the first sample through +// the Encoder, and writing the first sample to the Encoder's io.Writer (dst). +// It returns the number of bytes written to the Encoder's destination and the first error encountered. +func (e *Encoder) calcHead(sample []byte, pad bool) (int, error) { + // Check that we are given 1 sample. + if len(sample) != byteDepth { + return 0, fmt.Errorf("length of given byte array is: %v, expected: %v", len(sample), byteDepth) + } + + n, err := e.dst.Write(sample) + if err != nil { + return n, err + } + + _n, err := e.dst.Write([]byte{byte(int16(e.idx))}) + if err != nil { + return n, err + } + n += _n + + if pad { + _n, err = e.dst.Write([]byte{0x01}) + } else { + _n, err = e.dst.Write([]byte{0x00}) + } + n += _n + if err != nil { + return n, err + } + return n, nil +} + +// init initializes the Encoder's estimation to the first uncompressed sample and the index to +// point to a suitable quantizer step size. +// The suitable step size is the closest step size in the stepTable to half the absolute difference of the first two samples. +func (e *Encoder) init(samples []byte) { + int1 := int16(binary.LittleEndian.Uint16(samples[:byteDepth])) + int2 := int16(binary.LittleEndian.Uint16(samples[byteDepth:initBytes])) + e.est = int1 + + halfDiff := math.Abs(math.Abs(float64(int1)) - math.Abs(float64(int2))/2) + closest := math.Abs(float64(stepTable[0]) - halfDiff) + var cInd int16 + for i, step := range stepTable { + if math.Abs(float64(step)-halfDiff) < closest { + closest = math.Abs(float64(step) - halfDiff) + cInd = int16(i) + } + } + e.idx = cInd +} + +// Write takes a slice of bytes of arbitrary length representing pcm and encodes it into adpcm. +// It writes its output to the Encoder's dst. +// The number of bytes written out is returned along with any error that occured. +func (e *Encoder) Write(b []byte) (int, error) { + // Check that pcm has enough data to initialize Decoder. + pcmLen := len(b) + if pcmLen < initBytes { + return 0, fmt.Errorf("length of given byte array must be >= %v", initBytes) + } + + // Determine if there will be a byte that won't contain two full nibbles and will need padding. + pad := false + if (pcmLen-byteDepth)%bytesPerEnc != 0 { + pad = true + } + + e.init(b[:initBytes]) + n, err := e.calcHead(b[:byteDepth], pad) + if err != nil { + return n, err + } + // Skip the first sample and start at the end of the first two samples, then every two samples encode them into a byte of adpcm. + for i := byteDepth; i+bytesPerEnc-1 < pcmLen; i += bytesPerEnc { + nib1 := e.encodeSample(int16(binary.LittleEndian.Uint16(b[i : i+byteDepth]))) + nib2 := e.encodeSample(int16(binary.LittleEndian.Uint16(b[i+byteDepth : i+bytesPerEnc]))) + _n, err := e.dst.Write([]byte{byte((nib2 << 4) | nib1)}) + n += _n + if err != nil { + return n, err + } + } + // If we've reached the end of the pcm data and there's a sample left over, + // compress it to a nibble and leave the first half of the byte padded with 0s. + if pad { + nib := e.encodeSample(int16(binary.LittleEndian.Uint16(b[pcmLen-byteDepth : pcmLen]))) + _n, err := e.dst.Write([]byte{nib}) + n += _n + if err != nil { + return n, err + } + } + return n, nil +} + +// NewDecoder retuns a new ADPCM Decoder. +func NewDecoder(dst io.Writer) *Decoder { + return &Decoder{dst: dst} +} + // decodeSample takes a byte, the last 4 bits of which contain a single // 4 bit ADPCM nibble, and returns a 16 bit decoded PCM sample. -func (d *decoder) decodeSample(nibble byte) int16 { +func (d *Decoder) decodeSample(nibble byte) int16 { // Calculate difference. var diff int16 if nibble&4 != 0 { - diff += d.step + diff = capAdd16(diff, d.step) } if nibble&2 != 0 { - diff += d.step >> 1 + diff = capAdd16(diff, d.step>>1) } if nibble&1 != 0 { - diff += d.step >> 2 + diff = capAdd16(diff, d.step>>2) } - diff += d.step >> 3 + diff = capAdd16(diff, d.step>>3) // Account for sign bit. if nibble&8 != 0 { diff = -diff } - // Adjust predicted sample based on calculated difference. - d.pred += diff + // Adjust estimated sample based on calculated difference. + d.est = capAdd16(d.est, diff) // Adjust index into step size lookup table using nibble. - d.index += indexTable[nibble] + d.idx += indexTable[nibble] // Check for overflow and underflow. - if d.index < 0 { - d.index = 0 - } else if d.index > int16(len(stepTable)-1) { - d.index = int16(len(stepTable) - 1) + if d.idx < 0 { + d.idx = 0 + } else if d.idx > int16(len(stepTable)-1) { + d.idx = int16(len(stepTable) - 1) } // Find new quantizer step size. - d.step = stepTable[d.index] + d.step = stepTable[d.idx] - return d.pred + return d.est } -// calcHead sets the state for the encoder by running the first sample through -// the encoder, and writing the first sample to the encoder's io.Writer (dest). -// It returns the number of bytes written to the encoder's io.Writer (dest) along with any errors. -func (e *encoder) calcHead(sample []byte) (int, error) { - // Check that we are given 1 16-bit sample (2 bytes). - const sampSize = 2 - if len(sample) != sampSize { - return 0, fmt.Errorf("length of given byte array is: %v, expected: %v", len(sample), sampSize) - } - - intSample := int16(binary.LittleEndian.Uint16(sample)) - e.encodeSample(intSample) - - n, err := e.dest.Write(sample) - if err != nil { - return n, err - } - - err = e.dest.WriteByte(byte(uint16(e.index))) - if err != nil { - return n, err - } - n++ - - err = e.dest.WriteByte(byte(0x00)) - if err != nil { - return n, err - } - n++ - return n, nil -} - -// encodeBlock takes a slice of 1010 bytes (505 16-bit PCM samples). -// It writes encoded (compressed) bytes (each byte containing two ADPCM nibbles) to the encoder's io.Writer (dest). -// The number of bytes written is returned along with any errors. -// Note: nibbles are output in little endian order, eg. n1n0 n3n2 n5n4... -// Note: first 4 bytes are for initializing the decoder before decoding a block. -// - First two bytes contain the first 16-bit sample uncompressed. -// - Third byte is the decoder's starting index for the block, the fourth is padding and ignored. -func (e *encoder) encodeBlock(block []byte) (int, error) { - if len(block) != PcmBS { - return 0, fmt.Errorf("unsupported block size. Given: %v, expected: %v, ie. 505 16-bit PCM samples", len(block), PcmBS) - } - - n, err := e.calcHead(block[0:2]) - if err != nil { - return n, err - } - - for i := 3; i < PcmBS; i += 4 { - nib1 := e.encodeSample(int16(binary.LittleEndian.Uint16(block[i-1 : i+1]))) - nib2 := e.encodeSample(int16(binary.LittleEndian.Uint16(block[i+1 : i+3]))) - err = e.dest.WriteByte(byte((nib2 << 4) | nib1)) - if err != nil { - return n, err - } - n++ - } - - return n, nil -} - -// decodeBlock takes a slice of 256 bytes, each byte after the first 4 should contain two ADPCM encoded nibbles. -// It writes the resulting decoded (decompressed) 16-bit PCM samples to the decoder's io.Writer (dest). -// The number of bytes written is returned along with any errors. -func (d *decoder) decodeBlock(block []byte) (int, error) { - if len(block) != AdpcmBS { - return 0, fmt.Errorf("unsupported block size. Given: %v, expected: %v", len(block), AdpcmBS) - } - - // Initialize decoder with first 4 bytes of the block. - d.pred = int16(binary.LittleEndian.Uint16(block[0:2])) - d.index = int16(block[2]) - d.step = stepTable[d.index] - n, err := d.dest.Write(block[0:2]) +// Write takes a slice of bytes of arbitrary length representing adpcm and decodes it into pcm. +// It writes its output to the Decoder's dst. +// The number of bytes written out is returned along with any error that occured. +func (d *Decoder) Write(b []byte) (int, error) { + // Initialize Decoder with first 4 bytes of b. + d.est = int16(binary.LittleEndian.Uint16(b[:byteDepth])) + d.idx = int16(b[byteDepth]) + d.step = stepTable[d.idx] + n, err := d.dst.Write(b[:byteDepth]) if err != nil { return n, err } // For each byte, seperate it into two nibbles (each nibble is a compressed sample), // then decode each nibble and output the resulting 16-bit samples. - for i := 4; i < AdpcmBS; i++ { - twoNibs := block[i] + // If padding flag is true (Adpcm[3]), only decode up until the last byte, then decode that separately. + for i := headBytes; i < len(b)-int(b[3]); i++ { + twoNibs := b[i] nib2 := byte(twoNibs >> 4) nib1 := byte((nib2 << 4) ^ twoNibs) - firstBytes := make([]byte, 2) + firstBytes := make([]byte, byteDepth) binary.LittleEndian.PutUint16(firstBytes, uint16(d.decodeSample(nib1))) - _n, err := d.dest.Write(firstBytes) + _n, err := d.dst.Write(firstBytes) n += _n if err != nil { return n, err } - secondBytes := make([]byte, 2) + secondBytes := make([]byte, byteDepth) binary.LittleEndian.PutUint16(secondBytes, uint16(d.decodeSample(nib2))) - _n, err = d.dest.Write(secondBytes) + _n, err = d.dst.Write(secondBytes) n += _n if err != nil { return n, err } } - - return n, nil -} - -// Write takes a slice of bytes of arbitrary length representing pcm and encodes in into adpcm. -// It writes its output to the encoder's dest. -// The number of bytes written out is returned along with any error that occured. -func (e *encoder) Write(inPcm []byte) (int, error) { - numBlocks := len(inPcm) / PcmBS - n := 0 - for i := 0; i < numBlocks; i++ { - block := inPcm[PcmBS*i : PcmBS*(i+1)] - _n, err := e.encodeBlock(block) + if b[3] == 0x01 { + padNib := b[len(b)-1] + samp := make([]byte, byteDepth) + binary.LittleEndian.PutUint16(samp, uint16(d.decodeSample(padNib))) + _n, err := d.dst.Write(samp) n += _n if err != nil { return n, err } } - return n, nil } -// Write takes a slice of bytes of arbitrary length representing adpcm and decodes in into pcm. -// It writes its output to the decoder's dest. -// The number of bytes written out is returned along with any error that occured. -func (d *decoder) Write(inAdpcm []byte) (int, error) { - numBlocks := len(inAdpcm) / AdpcmBS - n := 0 - for i := 0; i < numBlocks; i++ { - block := inAdpcm[AdpcmBS*i : AdpcmBS*(i+1)] - _n, err := d.decodeBlock(block) - n += _n - if err != nil { - return n, err - } +// capAdd16 adds two int16s together and caps at max/min int16 instead of overflowing +func capAdd16(a, b int16) int16 { + c := int32(a) + int32(b) + switch { + case c < math.MinInt16: + return math.MinInt16 + case c > math.MaxInt16: + return math.MaxInt16 + default: + return int16(c) } - - return n, nil +} + +// EncBytes will return the number of adpcm bytes that will be generated when encoding the given amount of pcm bytes (n). +func EncBytes(n int) int { + // For 'n' pcm bytes, 1 sample is left uncompressed, the rest is compressed by a factor of 4 + // and a start index and padding-flag byte are added. + // Also if there are an even number of samples, there will be half a byte of padding added to the last byte. + if n%bytesPerEnc == 0 { + return (n-byteDepth)/compFact + headBytes + 1 + } + return (n-byteDepth)/compFact + headBytes } diff --git a/codec/adpcm/adpcm_test.go b/codec/adpcm/adpcm_test.go index 9df24028..8b825696 100644 --- a/codec/adpcm/adpcm_test.go +++ b/codec/adpcm/adpcm_test.go @@ -37,14 +37,13 @@ import ( // then compare the result with expected ADPCM. func TestEncodeBlock(t *testing.T) { // Read input pcm. - pcm, err := ioutil.ReadFile("../../../test/test-data/av/input/raw-voice.pcm") + pcm, err := ioutil.ReadFile("../../../test/test-data/av/input/original_8kHz_adpcm_test.pcm") if err != nil { t.Errorf("Unable to read input PCM file: %v", err) } // Encode adpcm. - numBlocks := len(pcm) / PcmBS - comp := bytes.NewBuffer(make([]byte, 0, AdpcmBS*numBlocks)) + comp := bytes.NewBuffer(make([]byte, 0, EncBytes(len(pcm)))) enc := NewEncoder(comp) _, err = enc.Write(pcm) if err != nil { @@ -52,7 +51,7 @@ func TestEncodeBlock(t *testing.T) { } // Read expected adpcm file. - exp, err := ioutil.ReadFile("../../../test/test-data/av/output/encoded-voice.adpcm") + exp, err := ioutil.ReadFile("../../../test/test-data/av/output/encoded_8kHz_adpcm_test.adpcm") if err != nil { t.Errorf("Unable to read expected ADPCM file: %v", err) } @@ -66,14 +65,13 @@ func TestEncodeBlock(t *testing.T) { // resulting PCM with the expected decoded PCM. func TestDecodeBlock(t *testing.T) { // Read adpcm. - comp, err := ioutil.ReadFile("../../../test/test-data/av/input/encoded-voice.adpcm") + comp, err := ioutil.ReadFile("../../../test/test-data/av/input/encoded_8kHz_adpcm_test.adpcm") if err != nil { t.Errorf("Unable to read input ADPCM file: %v", err) } // Decode adpcm. - numBlocks := len(comp) / AdpcmBS - decoded := bytes.NewBuffer(make([]byte, 0, PcmBS*numBlocks)) + decoded := bytes.NewBuffer(make([]byte, 0, len(comp)*4)) dec := NewDecoder(decoded) _, err = dec.Write(comp) if err != nil { @@ -81,7 +79,7 @@ func TestDecodeBlock(t *testing.T) { } // Read expected pcm file. - exp, err := ioutil.ReadFile("../../../test/test-data/av/output/decoded-voice.pcm") + exp, err := ioutil.ReadFile("../../../test/test-data/av/output/decoded_8kHz_adpcm_test.pcm") if err != nil { t.Errorf("Unable to read expected PCM file: %v", err) } diff --git a/codec/codecutil/bytescanner.go b/codec/codecutil/bytescanner.go new file mode 100644 index 00000000..03f825d6 --- /dev/null +++ b/codec/codecutil/bytescanner.go @@ -0,0 +1,95 @@ +/* +NAME + bytescanner.go + +AUTHOR + Dan Kortschak + +LICENSE + This is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// Package bytescan implements a byte-level scanner. +package codecutil + +import "io" + +// ByteScanner is a byte scanner. +type ByteScanner struct { + buf []byte + off int + + // r is the source of data for the scanner. + r io.Reader +} + +// NewByteScanner returns a scanner initialised with an io.Reader and a read buffer. +func NewByteScanner(r io.Reader, buf []byte) *ByteScanner { + return &ByteScanner{r: r, buf: buf[:0]} +} + +// ScanUntil scans the scanner's underlying io.Reader until a delim byte +// has been read, appending all read bytes to dst. The resulting appended data, +// the last read byte and whether the last read byte was the delimiter. +func (c *ByteScanner) ScanUntil(dst []byte, delim byte) (res []byte, b byte, err error) { +outer: + for { + var i int + for i, b = range c.buf[c.off:] { + if b != delim { + continue + } + dst = append(dst, c.buf[c.off:c.off+i+1]...) + c.off += i + 1 + break outer + } + dst = append(dst, c.buf[c.off:]...) + err = c.reload() + if err != nil { + break + } + } + return dst, b, err +} + +// ReadByte is an unexported ReadByte. +func (c *ByteScanner) ReadByte() (byte, error) { + if c.off >= len(c.buf) { + err := c.reload() + if err != nil { + return 0, err + } + } + b := c.buf[c.off] + c.off++ + return b, nil +} + +// reload re-fills the scanner's buffer. +func (c *ByteScanner) reload() error { + n, err := c.r.Read(c.buf[:cap(c.buf)]) + c.buf = c.buf[:n] + if err != nil { + if err != io.EOF { + return err + } + if n == 0 { + return io.EOF + } + } + c.off = 0 + return nil +} diff --git a/codec/codecutil/bytescanner_test.go b/codec/codecutil/bytescanner_test.go new file mode 100644 index 00000000..68db0006 --- /dev/null +++ b/codec/codecutil/bytescanner_test.go @@ -0,0 +1,82 @@ +/* +NAME + bytescanner_test.go + +DESCRIPTION + See Readme.md + +AUTHOR + Dan Kortschak + +LICENSE + This is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package codecutil + +import ( + "bytes" + "reflect" + "testing" +) + +type chunkEncoder [][]byte + +func (e *chunkEncoder) Encode(b []byte) error { + *e = append(*e, b) + return nil +} + +func (*chunkEncoder) Stream() <-chan []byte { panic("INVALID USE") } + +func TestScannerReadByte(t *testing.T) { + data := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + + for _, size := range []int{1, 2, 8, 1 << 10} { + r := NewByteScanner(bytes.NewReader(data), make([]byte, size)) + var got []byte + for { + b, err := r.ReadByte() + if err != nil { + break + } + got = append(got, b) + } + if !bytes.Equal(got, data) { + t.Errorf("unexpected result for buffer size %d:\ngot :%q\nwant:%q", size, got, data) + } + } +} + +func TestScannerScanUntilZero(t *testing.T) { + data := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit,\x00 sed do eiusmod tempor incididunt ut \x00labore et dolore magna aliqua.") + + for _, size := range []int{1, 2, 8, 1 << 10} { + r := NewByteScanner(bytes.NewReader(data), make([]byte, size)) + var got [][]byte + for { + buf, _, err := r.ScanUntil(nil, 0x0) + got = append(got, buf) + if err != nil { + break + } + } + want := bytes.SplitAfter(data, []byte{0}) + if !reflect.DeepEqual(got, want) { + t.Errorf("unexpected result for buffer zie %d:\ngot :%q\nwant:%q", size, got, want) + } + } +} diff --git a/codec/h264/lex.go b/codec/h264/lex.go new file mode 100644 index 00000000..9a071715 --- /dev/null +++ b/codec/h264/lex.go @@ -0,0 +1,135 @@ +/* +NAME + lex.go + +DESCRIPTION + lex.go provides a lexer to lex h264 bytestream into access units. + +AUTHOR + Dan Kortschak + +LICENSE + lex.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// lex.go provides a lexer to lex h264 bytestream into access units. + +package h264 + +import ( + "io" + "time" + + "bitbucket.org/ausocean/av/codec/codecutil" +) + +var noDelay = make(chan time.Time) + +func init() { + close(noDelay) +} + +var h264Prefix = [...]byte{0x00, 0x00, 0x01, 0x09, 0xf0} + +// Lex lexes H.264 NAL units read from src into separate writes to dst with +// successive writes being performed not earlier than the specified delay. +// NAL units are split after type 1 (Coded slice of a non-IDR picture), 5 +// (Coded slice of a IDR picture) and 8 (Picture parameter set). +func Lex(dst io.Writer, src io.Reader, delay time.Duration) error { + var tick <-chan time.Time + if delay == 0 { + tick = noDelay + } else { + ticker := time.NewTicker(delay) + defer ticker.Stop() + tick = ticker.C + } + + const bufSize = 8 << 10 + + c := codecutil.NewByteScanner(src, make([]byte, 4<<10)) // Standard file buffer size. + + buf := make([]byte, len(h264Prefix), bufSize) + copy(buf, h264Prefix[:]) + writeOut := false +outer: + for { + var b byte + var err error + buf, b, err = c.ScanUntil(buf, 0x00) + if err != nil { + if err != io.EOF { + return err + } + break + } + + for n := 1; b == 0x0 && n < 4; n++ { + b, err = c.ReadByte() + if err != nil { + if err != io.EOF { + return err + } + break outer + } + buf = append(buf, b) + + if b != 0x1 || (n != 2 && n != 3) { + continue + } + + if writeOut { + <-tick + _, err := dst.Write(buf[:len(buf)-(n+1)]) + if err != nil { + return err + } + buf = make([]byte, len(h264Prefix)+n, bufSize) + copy(buf, h264Prefix[:]) + buf = append(buf, 1) + writeOut = false + } + + b, err = c.ReadByte() + if err != nil { + if err != io.EOF { + return err + } + break outer + } + buf = append(buf, b) + + // http://www.itu.int/rec/dologin_pub.asp?lang=e&id=T-REC-H.264-200305-S!!PDF-E&type=items + // Table 7-1 NAL unit type codes + const ( + nonIdrPic = 1 + idrPic = 5 + suppEnhInfo = 6 + paramSet = 8 + ) + switch nalTyp := b & 0x1f; nalTyp { + case nonIdrPic, idrPic, paramSet, suppEnhInfo: + writeOut = true + } + } + } + if len(buf) == len(h264Prefix) { + return nil + } + <-tick + _, err := dst.Write(buf) + return err +} diff --git a/codec/lex/lex_test.go b/codec/h264/lex_test.go similarity index 67% rename from codec/lex/lex_test.go rename to codec/h264/lex_test.go index a107b253..d2eeae2a 100644 --- a/codec/lex/lex_test.go +++ b/codec/h264/lex_test.go @@ -3,7 +3,7 @@ NAME lex_test.go DESCRIPTION - See Readme.md + lex_test.go provides tests for the lexer in lex.go. AUTHOR Dan Kortschak @@ -25,12 +25,11 @@ LICENSE along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. */ -package lex +// lex_test.go provides tests for the lexer in lex.go. + +package h264 import ( - "bytes" - "reflect" - "testing" "time" ) @@ -207,7 +206,7 @@ var h264Tests = []struct { func TestH264(t *testing.T) { for _, test := range h264Tests { var buf chunkEncoder - err := H264(&buf, bytes.NewReader(test.input), test.delay) + err := Lex(&buf, bytes.NewReader(test.input), test.delay) if fmt.Sprint(err) != fmt.Sprint(test.err) { t.Errorf("unexpected error for %q: got:%v want:%v", test.name, err, test.err) } @@ -221,131 +220,3 @@ func TestH264(t *testing.T) { } } */ - -var mjpegTests = []struct { - name string - input []byte - delay time.Duration - want [][]byte - err error -}{ - { - name: "empty", - }, - { - name: "null", - input: []byte{0xff, 0xd8, 0xff, 0xd9}, - delay: 0, - want: [][]byte{{0xff, 0xd8, 0xff, 0xd9}}, - }, - { - name: "null delayed", - input: []byte{0xff, 0xd8, 0xff, 0xd9}, - delay: time.Millisecond, - want: [][]byte{{0xff, 0xd8, 0xff, 0xd9}}, - }, - { - name: "full", - input: []byte{ - 0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9, - 0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9, - 0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9, - 0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9, - 0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9, - }, - delay: 0, - want: [][]byte{ - {0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9}, - {0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9}, - {0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9}, - {0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9}, - {0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9}, - }, - }, - { - name: "full delayed", - input: []byte{ - 0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9, - 0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9, - 0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9, - 0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9, - 0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9, - }, - delay: time.Millisecond, - want: [][]byte{ - {0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9}, - {0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9}, - {0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9}, - {0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9}, - {0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9}, - }, - }, -} - -// FIXME this needs to be adapted -/* -func TestMJEG(t *testing.T) { - for _, test := range mjpegTests { - var buf chunkEncoder - err := MJPEG(&buf, bytes.NewReader(test.input), test.delay) - if fmt.Sprint(err) != fmt.Sprint(test.err) { - t.Errorf("unexpected error for %q: got:%v want:%v", test.name, err, test.err) - } - if err != nil { - continue - } - got := [][]byte(buf) - if !reflect.DeepEqual(got, test.want) { - t.Errorf("unexpected result for %q:\ngot :%#v\nwant:%#v", test.name, got, test.want) - } - } -} -*/ - -type chunkEncoder [][]byte - -func (e *chunkEncoder) Encode(b []byte) error { - *e = append(*e, b) - return nil -} - -func (*chunkEncoder) Stream() <-chan []byte { panic("INVALID USE") } - -func TestScannerReadByte(t *testing.T) { - data := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - - for _, size := range []int{1, 2, 8, 1 << 10} { - r := newScanner(bytes.NewReader(data), make([]byte, size)) - var got []byte - for { - b, err := r.readByte() - if err != nil { - break - } - got = append(got, b) - } - if !bytes.Equal(got, data) { - t.Errorf("unexpected result for buffer size %d:\ngot :%q\nwant:%q", size, got, data) - } - } -} - -func TestScannerScanUntilZero(t *testing.T) { - data := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit,\x00 sed do eiusmod tempor incididunt ut \x00labore et dolore magna aliqua.") - - for _, size := range []int{1, 2, 8, 1 << 10} { - r := newScanner(bytes.NewReader(data), make([]byte, size)) - var got [][]byte - for { - buf, _, err := r.scanUntilZeroInto(nil) - got = append(got, buf) - if err != nil { - break - } - } - want := bytes.SplitAfter(data, []byte{0}) - if !reflect.DeepEqual(got, want) { - t.Errorf("unexpected result for buffer zie %d:\ngot :%q\nwant:%q", size, got, want) - } - } -} diff --git a/codec/h265/lex.go b/codec/h265/lex.go new file mode 100644 index 00000000..ebe34013 --- /dev/null +++ b/codec/h265/lex.go @@ -0,0 +1,203 @@ +/* +NAME + lex.go + +DESCRIPTION + lex.go provides a lexer for taking RTP HEVC (H265) and lexing into access units. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// Package h265 provides an RTP h265 lexer that can extract h265 access units +// from an RTP stream. +package h265 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "time" + + "bitbucket.org/ausocean/av/protocol/rtp" +) + +// NALU types. +const ( + typeAggregation = 48 + typeFragmentation = 49 + typePACI = 50 +) + +// Buffer sizes. +const ( + maxAUSize = 100000 + maxRTPSize = 4096 +) + +// Lexer is an H265 lexer. +type Lexer struct { + donl bool // Indicates whether DONL and DOND will be used for the RTP stream. + buf *bytes.Buffer // Holds the current access unit. + frag bool // Indicates if we're currently dealing with a fragmentation packet. +} + +// NewLexer returns a new Lexer. +func NewLexer(donl bool) *Lexer { + return &Lexer{ + donl: donl, + buf: bytes.NewBuffer(make([]byte, 0, maxAUSize)), + } +} + +// Lex continually reads RTP packets from the io.Reader src and lexes into +// access units which are written to the io.Writer dst. Lex expects that for +// each read from src, a single RTP packet is received. +func (l *Lexer) Lex(dst io.Writer, src io.Reader, delay time.Duration) error { + buf := make([]byte, maxRTPSize) + for { + n, err := src.Read(buf) + switch err { + case nil: // Do nothing. + case io.EOF: + return nil + default: + return fmt.Errorf("source read error: %v\n", err) + } + + // Get payload from RTP packet. + payload, err := rtp.Payload(buf[:n]) + if err != nil { + return fmt.Errorf("could not get rtp payload, failed with err: %v\n", err) + } + nalType := (payload[0] >> 1) & 0x3f + + // If not currently fragmented then we ignore current write. + if l.frag && nalType != typeFragmentation { + l.buf.Reset() + l.frag = false + continue + } + + switch nalType { + case typeAggregation: + l.handleAggregation(payload) + case typeFragmentation: + l.handleFragmentation(payload) + case typePACI: + l.handlePACI(payload) + default: + l.writeWithPrefix(payload) + } + + markerIsSet, err := rtp.Marker(buf[:n]) + if err != nil { + return fmt.Errorf("could not get marker bit, failed with err: %v\n", err) + } + + if markerIsSet { + _, err := l.buf.WriteTo(dst) + if err != nil { + // TODO: work out what to do here. + } + l.buf.Reset() + } + } + return nil +} + +// handleAggregation parses NAL units from an aggregation packet and writes +// them to the Lexers buffer buf. +func (l *Lexer) handleAggregation(d []byte) { + idx := 2 + for idx < len(d) { + if l.donl { + switch idx { + case 2: + idx += 2 + default: + idx++ + } + } + size := int(binary.BigEndian.Uint16(d[idx:])) + idx += 2 + nalu := d[idx : idx+size] + idx += size + l.writeWithPrefix(nalu) + } +} + +// handleFragmentation parses NAL units from fragmentation packets and writes +// them to the Lexer's buf. +func (l *Lexer) handleFragmentation(d []byte) { + // Get start and end indiciators from FU header. + start := d[2]&0x80 != 0 + end := d[2]&0x40 != 0 + + b1 := (d[0] & 0x81) | ((d[2] & 0x3f) << 1) + b2 := d[1] + if start { + d = d[1:] + if l.donl { + d = d[2:] + } + d[0] = b1 + d[1] = b2 + } else { + d = d[3:] + if l.donl { + d = d[2:] + } + } + + switch { + case start && !end: + l.frag = true + l.writeWithPrefix(d) + case !start && end: + l.frag = false + fallthrough + case !start && !end: + l.writeNoPrefix(d) + default: + panic("bad fragmentation packet") + } +} + +// handlePACI will handl PACI packets +// +// TODO: complete this +func (l *Lexer) handlePACI(d []byte) { + panic("unsupported nal type") +} + +// write writes a NAL unit to the Lexer's buf in byte stream format using the +// start code. +func (l *Lexer) writeWithPrefix(d []byte) { + const prefix = "\x00\x00\x00\x01" + l.buf.Write([]byte(prefix)) + l.buf.Write(d) +} + +// writeNoPrefix writes data to the Lexer's buf. This is used for non start +// fragmentations of a NALU. +func (l *Lexer) writeNoPrefix(d []byte) { + l.buf.Write(d) +} diff --git a/codec/h265/lex_test.go b/codec/h265/lex_test.go new file mode 100644 index 00000000..1a409e4c --- /dev/null +++ b/codec/h265/lex_test.go @@ -0,0 +1,262 @@ +/* +NAME + lex_test.go + +DESCRIPTION + lex_test.go provides tests to check validity of the Lexer found in lex.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package h265 + +import ( + "io" + "testing" +) + +// rtpReader provides the RTP stream. +type rtpReader struct { + packets [][]byte + idx int +} + +// Read implements io.Reader. +func (r *rtpReader) Read(p []byte) (int, error) { + if r.idx == len(r.packets) { + return 0, io.EOF + } + b := r.packets[r.idx] + n := copy(p, b) + if n < len(r.packets[r.idx]) { + r.packets[r.idx] = r.packets[r.idx][n:] + } else { + r.idx++ + } + return n, nil +} + +// destination holds the access units extracted during the lexing process. +type destination [][]byte + +// Write implements io.Writer. +func (d *destination) Write(p []byte) (int, error) { + t := make([]byte, len(p)) + copy(t, p) + *d = append([][]byte(*d), t) + return len(p), nil +} + +// TestLex checks that the Lexer can correctly extract H265 access units from +// HEVC RTP stream in RTP payload format. +func TestLex(t *testing.T) { + const rtpVer = 2 + + tests := []struct { + donl bool + packets [][]byte + expect [][]byte + }{ + { + donl: false, + packets: [][]byte{ + { // Single NAL unit. + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // NAL Data. + }, + { // Fragmentation (start packet). + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type49). + 0x80, // FU header. + 0x01, 0x02, 0x03, // FU payload. + }, + { // Fragmentation (middle packet) + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type 49). + 0x00, // FU header. + 0x04, 0x05, 0x06, // FU payload. + }, + { // Fragmentation (end packet) + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type 49). + 0x40, // FU header. + 0x07, 0x08, 0x09, // FU payload + }, + + { // Aggregation. Make last packet of access unit => marker bit true. + 0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x60, 0x00, // NAL header (type 49). + 0x00, 0x04, // NAL 1 size. + 0x01, 0x02, 0x03, 0x04, // NAL 1 data. + 0x00, 0x04, // NAL 2 size. + 0x01, 0x02, 0x03, 0x04, // NAL 2 data. + }, + { // Singla NAL + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // NAL data. + }, + { // Singla NAL. Make last packet of access unit => marker bit true. + 0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // NAL data. + }, + }, + expect: [][]byte{ + // First access unit. + { + // NAL 1 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // NAL data. + // NAL 2 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x00, 0x00, 0x01, 0x02, 0x03, // FU payload. + 0x04, 0x05, 0x06, // FU payload. + 0x07, 0x08, 0x09, // FU payload. + // NAL 3 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x01, 0x02, 0x03, 0x04, // NAL data. + // NAL 4 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x01, 0x02, 0x03, 0x04, // NAL 2 data + }, + // Second access unit. + { + // NAL 1 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // Data. + // NAL 2 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // Data. + }, + }, + }, + { + donl: true, + packets: [][]byte{ + { // Single NAL unit. + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, 0x04, // NAL Data. + }, + { // Fragmentation (start packet). + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type49). + 0x80, // FU header. + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, // FU payload. + }, + { // Fragmentation (middle packet) + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type 49). + 0x00, // FU header. + 0x00, 0x00, // DONL + 0x04, 0x05, 0x06, // FU payload. + }, + { // Fragmentation (end packet) + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type 49). + 0x40, // FU header. + 0x00, 0x00, // DONL + 0x07, 0x08, 0x09, // FU payload + }, + + { // Aggregation. Make last packet of access unit => marker bit true. + 0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x60, 0x00, // NAL header (type 49). + 0x00, 0x00, // DONL + 0x00, 0x04, // NAL 1 size. + 0x01, 0x02, 0x03, 0x04, // NAL 1 data. + 0x00, // DOND + 0x00, 0x04, // NAL 2 size. + 0x01, 0x02, 0x03, 0x04, // NAL 2 data. + }, + { // Singla NAL + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS) + 0x00, 0x00, // DONL. + 0x01, 0x02, 0x03, 0x04, // NAL data. + }, + { // Singla NAL. Make last packet of access unit => marker bit true. + 0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, 0x04, // NAL data. + }, + }, + expect: [][]byte{ + // First access unit. + { + // NAL 1 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, 0x04, // NAL data. + // NAL 2 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x00, 0x00, 0x01, 0x02, 0x03, // FU payload. + 0x04, 0x05, 0x06, // FU payload. + 0x07, 0x08, 0x09, // FU payload. + // NAL 3 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x01, 0x02, 0x03, 0x04, // NAL data. + // NAL 4 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x01, 0x02, 0x03, 0x04, // NAL 2 data + }, + // Second access unit. + { + // NAL 1 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, 0x04, // Data. + // NAL 2 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, 0x04, // Data. + }, + }, + }, + } + + for testNum, test := range tests { + r := &rtpReader{packets: test.packets} + d := &destination{} + err := NewLexer(test.donl).Lex(d, r, 0) + if err != nil { + t.Fatalf("error lexing: %v\n", err) + } + + for i, accessUnit := range test.expect { + for j, part := range accessUnit { + if part != [][]byte(*d)[i][j] { + t.Fatalf("did not get expected data for test: %v.\nGot: %v\nWant: %v\n", testNum, d, test.expect) + } + } + } + } +} diff --git a/codec/lex/lex.go b/codec/lex/lex.go deleted file mode 100644 index 6f6a7ba0..00000000 --- a/codec/lex/lex.go +++ /dev/null @@ -1,280 +0,0 @@ -/* -NAME - lex.go - -AUTHOR - Dan Kortschak - Trek Hopton - -LICENSE - lex.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) - - It is free software: you can redistribute it and/or modify them - under the terms of the GNU General Public License as published by the - Free Software Foundation, either version 3 of the License, or (at your - option) any later version. - - It is distributed in the hope that it will be useful, but WITHOUT - ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - for more details. - - You should have received a copy of the GNU General Public License - along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. -*/ - -// Package lex provides lexers for video encodings. -package lex - -import ( - "bufio" - "bytes" - "fmt" - "io" - "time" -) - -var noDelay = make(chan time.Time) - -func init() { - close(noDelay) -} - -var h264Prefix = [...]byte{0x00, 0x00, 0x01, 0x09, 0xf0} - -// H264 lexes H.264 NAL units read from src into separate writes to dst with -// successive writes being performed not earlier than the specified delay. -// NAL units are split after type 1 (Coded slice of a non-IDR picture), 5 -// (Coded slice of a IDR picture) and 8 (Picture parameter set). -func H264(dst io.Writer, src io.Reader, delay time.Duration, bufSize int) error { - var tick <-chan time.Time - if delay == 0 { - tick = noDelay - } else { - ticker := time.NewTicker(delay) - defer ticker.Stop() - tick = ticker.C - } - - bufSize = 8 << 10 //TODO(Trek): Pass this in rather than set it in here. - - c := newScanner(src, make([]byte, 4<<10)) // Standard file buffer size. - - buf := make([]byte, len(h264Prefix), bufSize) - copy(buf, h264Prefix[:]) - writeOut := false -outer: - for { - var b byte - var err error - buf, b, err = c.scanUntilZeroInto(buf) - if err != nil { - if err != io.EOF { - return err - } - break - } - - for n := 1; b == 0x0 && n < 4; n++ { - b, err = c.readByte() - if err != nil { - if err != io.EOF { - return err - } - break outer - } - buf = append(buf, b) - - if b != 0x1 || (n != 2 && n != 3) { - continue - } - - if writeOut { - <-tick - _, err := dst.Write(buf[:len(buf)-(n+1)]) - if err != nil { - return err - } - buf = make([]byte, len(h264Prefix)+n, bufSize) - copy(buf, h264Prefix[:]) - buf = append(buf, 1) - writeOut = false - } - - b, err = c.readByte() - if err != nil { - if err != io.EOF { - return err - } - break outer - } - buf = append(buf, b) - - // http://www.itu.int/rec/dologin_pub.asp?lang=e&id=T-REC-H.264-200305-S!!PDF-E&type=items - // Table 7-1 NAL unit type codes - const ( - nonIdrPic = 1 - idrPic = 5 - suppEnhInfo = 6 - paramSet = 8 - ) - switch nalTyp := b & 0x1f; nalTyp { - case nonIdrPic, idrPic, paramSet, suppEnhInfo: - writeOut = true - } - } - } - if len(buf) == len(h264Prefix) { - return nil - } - <-tick - _, err := dst.Write(buf) - return err -} - -// scanner is a byte scanner. -type scanner struct { - buf []byte - off int - - // r is the source of data for the scanner. - r io.Reader -} - -// newScanner returns a scanner initialised with an io.Reader and a read buffer. -func newScanner(r io.Reader, buf []byte) *scanner { - return &scanner{r: r, buf: buf[:0]} -} - -// scanUntilZeroInto scans the scanner's underlying io.Reader until a zero byte -// has been read, appending all read bytes to dst. The resulting appended data, -// the last read byte and whether the last read byte was zero are returned. -func (c *scanner) scanUntilZeroInto(dst []byte) (res []byte, b byte, err error) { -outer: - for { - var i int - for i, b = range c.buf[c.off:] { - if b != 0x0 { - continue - } - dst = append(dst, c.buf[c.off:c.off+i+1]...) - c.off += i + 1 - break outer - } - dst = append(dst, c.buf[c.off:]...) - err = c.reload() - if err != nil { - break - } - } - return dst, b, err -} - -// readByte is an unexported ReadByte. -func (c *scanner) readByte() (byte, error) { - if c.off >= len(c.buf) { - err := c.reload() - if err != nil { - return 0, err - } - } - b := c.buf[c.off] - c.off++ - return b, nil -} - -// reload re-fills the scanner's buffer. -func (c *scanner) reload() error { - n, err := c.r.Read(c.buf[:cap(c.buf)]) - c.buf = c.buf[:n] - if err != nil { - if err != io.EOF { - return err - } - if n == 0 { - return io.EOF - } - } - c.off = 0 - return nil -} - -// MJPEG parses MJPEG frames read from src into separate writes to dst with -// successive writes being performed not earlier than the specified delay. -func MJPEG(dst io.Writer, src io.Reader, delay time.Duration, bufSize int) error { - var tick <-chan time.Time - if delay == 0 { - tick = noDelay - } else { - ticker := time.NewTicker(delay) - defer ticker.Stop() - tick = ticker.C - } - - r := bufio.NewReader(src) - for { - buf := make([]byte, 2, 4<<10) - n, err := r.Read(buf) - if n < 2 { - return nil - } - if err != nil { - return err - } - if !bytes.Equal(buf, []byte{0xff, 0xd8}) { - return fmt.Errorf("parser: not MJPEG frame start: %#v", buf) - } - var last byte - for { - b, err := r.ReadByte() - if err != nil { - return err - } - buf = append(buf, b) - if last == 0xff && b == 0xd9 { - break - } - last = b - } - <-tick - _, err = dst.Write(buf) - if err != nil { - return err - } - } -} - -// PCM reads from the given source and breaks the PCM into chunks that -// are an appropriate size for mts and pes packets. -func PCM(dst io.Writer, src io.Reader, delay time.Duration, bufSize int) error { - var tick <-chan time.Time - if delay == 0 { - tick = noDelay - } else { - ticker := time.NewTicker(delay) - defer ticker.Stop() - tick = ticker.C - } - - for { - <-tick - buf := make([]byte, bufSize) - _, err := src.Read(buf) - if err != nil { - return err - } - _, err = dst.Write(buf) - if err != nil { - return err - } - } -} - -// ADPCM reads from the given source and breaks the ADPCM into chunks that -// are an appropriate size for mts and pes packets. -// Since PCM and ADPCM are not any different when it comes to how they are -// transmitted, ADPCM is just a wrapper for PCM. -func ADPCM(dst io.Writer, src io.Reader, delay time.Duration, bufSize int) error { - err := PCM(dst, src, delay, bufSize) - return err -} diff --git a/codec/mjpeg/lex.go b/codec/mjpeg/lex.go new file mode 100644 index 00000000..da2ecae1 --- /dev/null +++ b/codec/mjpeg/lex.go @@ -0,0 +1,89 @@ +/* +NAME + lex.go + +DESCRIPTION + lex.go provides a lexer to extract separate JPEG images from a MJPEG stream. + +AUTHOR + Dan Kortschak + +LICENSE + lex.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// lex.go provides a lexer to extract separate JPEG images from a MJPEG stream. + +package mjpeg + +import ( + "bufio" + "bytes" + "fmt" + "io" + "time" +) + +var noDelay = make(chan time.Time) + +func init() { + close(noDelay) +} + +// Lex parses MJPEG frames read from src into separate writes to dst with +// successive writes being performed not earlier than the specified delay. +func Lex(dst io.Writer, src io.Reader, delay time.Duration) error { + var tick <-chan time.Time + if delay == 0 { + tick = noDelay + } else { + ticker := time.NewTicker(delay) + defer ticker.Stop() + tick = ticker.C + } + + r := bufio.NewReader(src) + for { + buf := make([]byte, 2, 4<<10) + n, err := r.Read(buf) + if n < 2 { + return nil + } + if err != nil { + return err + } + if !bytes.Equal(buf, []byte{0xff, 0xd8}) { + return fmt.Errorf("parser: not MJPEG frame start: %#v", buf) + } + var last byte + for { + b, err := r.ReadByte() + if err != nil { + return err + } + buf = append(buf, b) + if last == 0xff && b == 0xd9 { + break + } + last = b + } + <-tick + _, err = dst.Write(buf) + if err != nil { + return err + } + } +} diff --git a/codec/mjpeg/lex_test.go b/codec/mjpeg/lex_test.go new file mode 100644 index 00000000..fd8b6f86 --- /dev/null +++ b/codec/mjpeg/lex_test.go @@ -0,0 +1,114 @@ +/* +NAME + lex_test.go + +DESCRIPTION + lex_test.go provides testing for the lexer in lex.go. + +AUTHOR + Dan Kortschak + +LICENSE + lex_test.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// lex_test.go provides testing for the lexer in lex.go. + +package mjpeg + +import ( + "time" +) + +var mjpegTests = []struct { + name string + input []byte + delay time.Duration + want [][]byte + err error +}{ + { + name: "empty", + }, + { + name: "null", + input: []byte{0xff, 0xd8, 0xff, 0xd9}, + delay: 0, + want: [][]byte{{0xff, 0xd8, 0xff, 0xd9}}, + }, + { + name: "null delayed", + input: []byte{0xff, 0xd8, 0xff, 0xd9}, + delay: time.Millisecond, + want: [][]byte{{0xff, 0xd8, 0xff, 0xd9}}, + }, + { + name: "full", + input: []byte{ + 0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9, + 0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9, + 0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9, + 0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9, + 0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9, + }, + delay: 0, + want: [][]byte{ + {0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9}, + {0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9}, + {0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9}, + {0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9}, + {0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9}, + }, + }, + { + name: "full delayed", + input: []byte{ + 0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9, + 0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9, + 0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9, + 0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9, + 0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9, + }, + delay: time.Millisecond, + want: [][]byte{ + {0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9}, + {0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9}, + {0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9}, + {0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9}, + {0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9}, + }, + }, +} + +// FIXME this needs to be adapted +/* +func Lex(t *testing.T) { + for _, test := range mjpegTests { + var buf chunkEncoder + err := MJPEG(&buf, bytes.NewReader(test.input), test.delay) + if fmt.Sprint(err) != fmt.Sprint(test.err) { + t.Errorf("unexpected error for %q: got:%v want:%v", test.name, err, test.err) + } + if err != nil { + continue + } + got := [][]byte(buf) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("unexpected result for %q:\ngot :%#v\nwant:%#v", test.name, got, test.want) + } + } +} +*/ diff --git a/codec/pcm/pcm.go b/codec/pcm/pcm.go index 5ead3143..bb200d50 100644 --- a/codec/pcm/pcm.go +++ b/codec/pcm/pcm.go @@ -24,6 +24,8 @@ LICENSE You should have received a copy of the GNU General Public License in gpl.txt. If not, see [GNU licenses](http://www.gnu.org/licenses). */ + +// Package pcm provides functions for processing and converting pcm audio. package pcm import ( diff --git a/container/flv/flv_test.go b/container/flv/flv_test.go new file mode 100644 index 00000000..66f51f4c --- /dev/null +++ b/container/flv/flv_test.go @@ -0,0 +1,115 @@ +/* +NAME + flv_test.go + +DESCRIPTION + flv_test.go provides testing for functionality provided in flv.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + Copyright (C) 2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package flv + +import ( + "bytes" + "testing" +) + +// TestVideoTagBytes checks that we can correctly get a []byte representation +// of a VideoTag using VideoTag.Bytes(). +func TestVideoTagBytes(t *testing.T) { + tests := []struct { + tag VideoTag + expected []byte + }{ + { + tag: VideoTag{ + TagType: VideoTagType, + DataSize: 12, + Timestamp: 1234, + TimestampExtended: 56, + FrameType: KeyFrameType, + Codec: H264, + PacketType: AVCNALU, + CompositionTime: 0, + Data: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, + }, + expected: []byte{ + 0x09, // TagType. + 0x00, 0x00, 0x0c, // DataSize. + 0x00, 0x04, 0xd2, // Timestamp. + 0x38, // TimestampExtended. + 0x00, 0x00, 0x00, // StreamID. (always 0) + 0x17, // FrameType=0001, Codec=0111 + 0x01, // PacketType. + 0x00, 0x00, 0x00, // CompositionTime + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // VideoData. + 0x00, 0x00, 0x00, 0x00, // previousTagSize. + }, + }, + } + + for testNum, test := range tests { + got := test.tag.Bytes() + if !bytes.Equal(got, test.expected) { + t.Errorf("did not get expected result for test: %v.\n Got: %v\n Want: %v\n", testNum, got, test.expected) + } + } +} + +// TestAudioTagBytes checks that we can correctly get a []byte representation of +// an AudioTag using AudioTag.Bytes(). +func TestAudioTagBytes(t *testing.T) { + tests := []struct { + tag AudioTag + expected []byte + }{ + { + tag: AudioTag{ + TagType: AudioTagType, + DataSize: 8, + Timestamp: 1234, + TimestampExtended: 56, + SoundFormat: AACAudioFormat, + SoundRate: 3, + SoundSize: true, + SoundType: true, + Data: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, + }, + expected: []byte{ + 0x08, // TagType. + 0x00, 0x00, 0x08, // DataSize. + 0x00, 0x04, 0xd2, // Timestamp. + 0x38, // TimestampExtended. + 0x00, 0x00, 0x00, // StreamID. (always 0) + 0xaf, // SoundFormat=1010,SoundRate=11,SoundSize=1,SoundType=1 + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, // AudioData. + 0x00, 0x00, 0x00, 0x00, // previousTagSize. + }, + }, + } + + for testNum, test := range tests { + got := test.tag.Bytes() + if !bytes.Equal(got, test.expected) { + t.Errorf("did not get expected result for test: %v.\n Got: %v\n Want: %v\n", testNum, got, test.expected) + } + } +} diff --git a/container/mts/audio_test.go b/container/mts/audio_test.go deleted file mode 100644 index 39a77dff..00000000 --- a/container/mts/audio_test.go +++ /dev/null @@ -1,144 +0,0 @@ -/* -NAME - audio_test.go - -AUTHOR - Trek Hopton - -LICENSE - audio_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) - - It is free software: you can redistribute it and/or modify them - under the terms of the GNU General Public License as published by the - Free Software Foundation, either version 3 of the License, or (at your - option) any later version. - - It is distributed in the hope that it will be useful, but WITHOUT - ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - for more details. - - You should have received a copy of the GNU General Public License in gpl.txt. - If not, see http://www.gnu.org/licenses. -*/ - -package mts - -import ( - "bytes" - "io" - "io/ioutil" - "testing" - - "github.com/Comcast/gots/packet" - "github.com/Comcast/gots/pes" - - "bitbucket.org/ausocean/av/container/mts/meta" -) - -type nopCloser struct{ io.Writer } - -func (nopCloser) Close() error { return nil } - -// TestEncodePcm tests the mpegts encoder's ability to encode pcm audio data. -// It reads and encodes input pcm data into mpegts, then decodes the mpegts and compares the result to the input pcm. -func TestEncodePcm(t *testing.T) { - Meta = meta.New() - - var buf bytes.Buffer - sampleRate := 48000 - sampleSize := 2 - chunkSize := 16000 - writeFreq := float64(sampleRate*sampleSize) / float64(chunkSize) - e := NewEncoder(nopCloser{&buf}, writeFreq, Audio) - - inPath := "../../../test/test-data/av/input/sweep_400Hz_20000Hz_-3dBFS_5s_48khz.pcm" - inPcm, err := ioutil.ReadFile(inPath) - if err != nil { - t.Errorf("unable to read file: %v", err) - } - - // Break pcm into blocks and encode to mts and get the resulting bytes. - for i := 0; i < len(inPcm); i += chunkSize { - if len(inPcm)-i < chunkSize { - block := inPcm[i:] - _, err = e.Write(block) - if err != nil { - t.Errorf("unable to write block: %v", err) - } - } else { - block := inPcm[i : i+chunkSize] - _, err = e.Write(block) - if err != nil { - t.Errorf("unable to write block: %v", err) - } - } - } - clip := buf.Bytes() - - // Get the first MTS packet to check - var pkt packet.Packet - pesPacket := make([]byte, 0, chunkSize) - got := make([]byte, 0, len(inPcm)) - i := 0 - if i+PacketSize <= len(clip) { - copy(pkt[:], clip[i:i+PacketSize]) - } - - // Loop through MTS packets until all the audio data from PES packets has been retrieved - for i+PacketSize <= len(clip) { - - // Check MTS packet - if !(pkt.PID() == AudioPid) { - i += PacketSize - if i+PacketSize <= len(clip) { - copy(pkt[:], clip[i:i+PacketSize]) - } - continue - } - if !pkt.PayloadUnitStartIndicator() { - i += PacketSize - if i+PacketSize <= len(clip) { - copy(pkt[:], clip[i:i+PacketSize]) - } - } else { - // Copy the first MTS payload - payload, err := pkt.Payload() - if err != nil { - t.Errorf("unable to get MTS payload: %v", err) - } - pesPacket = append(pesPacket, payload...) - - i += PacketSize - if i+PacketSize <= len(clip) { - copy(pkt[:], clip[i:i+PacketSize]) - } - - // Copy the rest of the MTS payloads that are part of the same PES packet - for (!pkt.PayloadUnitStartIndicator()) && i+PacketSize <= len(clip) { - payload, err = pkt.Payload() - if err != nil { - t.Errorf("unable to get MTS payload: %v", err) - } - pesPacket = append(pesPacket, payload...) - - i += PacketSize - if i+PacketSize <= len(clip) { - copy(pkt[:], clip[i:i+PacketSize]) - } - } - } - // Get the audio data from the current PES packet - pesHeader, err := pes.NewPESHeader(pesPacket) - if err != nil { - t.Errorf("unable to read PES packet: %v", err) - } - got = append(got, pesHeader.Data()...) - pesPacket = pesPacket[:0] - } - - // Compare data from MTS with original data. - if !bytes.Equal(got, inPcm) { - t.Error("data decoded from mts did not match input data") - } -} diff --git a/container/mts/discontinuity.go b/container/mts/discontinuity.go index adccebad..e127ff94 100644 --- a/container/mts/discontinuity.go +++ b/container/mts/discontinuity.go @@ -4,8 +4,8 @@ NAME DESCRIPTION discontinuity.go provides functionality for detecting discontinuities in - mpegts and accounting for using the discontinuity indicator in the adaptation - field. + MPEG-TS and accounting for using the discontinuity indicator in the adaptation + field. AUTHOR Saxon A. Nelson-Milton @@ -33,7 +33,7 @@ import ( "github.com/Comcast/gots/packet" ) -// discontinuityRepairer provides function to detect discontinuities in mpegts +// discontinuityRepairer provides function to detect discontinuities in MPEG-TS // and set the discontinuity indicator as appropriate. type DiscontinuityRepairer struct { expCC map[int]int @@ -56,7 +56,7 @@ func (dr *DiscontinuityRepairer) Failed() { dr.decExpectedCC(PatPid) } -// Repair takes a clip of mpegts and checks that the first packet, which should +// Repair takes a clip of MPEG-TS and checks that the first packet, which should // be a PAT, contains a cc that is expected, otherwise the discontinuity indicator // is set to true. func (dr *DiscontinuityRepairer) Repair(d []byte) error { diff --git a/container/mts/encoder.go b/container/mts/encoder.go index 61fc1ade..3f7137f7 100644 --- a/container/mts/encoder.go +++ b/container/mts/encoder.go @@ -55,30 +55,6 @@ var ( }, }, } - - // standardPmt is a minimal PMT, without descriptors for time and location. - standardPmt = psi.PSI{ - Pf: 0x00, - Tid: 0x02, - Ssi: true, - Sl: 0x12, - Tss: &psi.TSS{ - Tide: 0x01, - V: 0, - Cni: true, - Sn: 0, - Lsn: 0, - Sd: &psi.PMT{ - Pcrpid: 0x0100, - Pil: 0, - Essd: &psi.ESSD{ - St: 0x1b, - Epid: 0x0100, - Esil: 0x00, - }, - }, - }, - } ) const ( @@ -94,38 +70,41 @@ var Meta *meta.Data var ( patTable = standardPat.Bytes() - pmtTable = standardPmt.Bytes() + pmtTable []byte ) const ( - sdtPid = 17 - patPid = 0 - pmtPid = 4096 - videoPid = 256 - // AudioPid is the Id for packets containing audio data - AudioPid = 210 - videoStreamID = 0xe0 // First video stream ID. + sdtPid = 17 + patPid = 0 + pmtPid = 4096 + videoPid = 256 + audioPid = 210 + H264ID = 27 + H265ID = 36 audioStreamID = 0xc0 // First audio stream ID. ) -// Video and Audio constants are used to communicate which media type will be encoded when creating a -// new encoder with NewEncoder. +// Constants used to communicate which media codec will be packetized. const ( - Video = iota - Audio + EncodeH264 = iota + EncodeH265 + EncodeAudio ) -// Time related constants. +// Time-related constants. const ( // ptsOffset is the offset added to the clock to determine // the current presentation timestamp. ptsOffset = 700 * time.Millisecond - // pcrFreq is the base Program Clock Reference frequency. - pcrFreq = 90000 // Hz + // PCRFrequency is the base Program Clock Reference frequency in Hz. + PCRFrequency = 90000 + + // PTSFrequency is the presentation timestamp frequency in Hz. + PTSFrequency = 90000 ) -// Encoder encapsulates properties of an mpegts generator. +// Encoder encapsulates properties of an MPEG-TS generator. type Encoder struct { dst io.WriteCloser @@ -153,14 +132,41 @@ func NewEncoder(dst io.WriteCloser, rate float64, mediaType int) *Encoder { var mPid int var sid byte switch mediaType { - case Audio: - mPid = AudioPid + case EncodeAudio: + mPid = audioPid sid = audioStreamID - case Video: + case EncodeH265: mPid = videoPid - sid = videoStreamID + sid = H265ID + case EncodeH264: + mPid = videoPid + sid = H264ID } + // standardPmt is a minimal PMT, without descriptors for metadata. + pmtTable = (&psi.PSI{ + Pf: 0x00, + Tid: 0x02, + Ssi: true, + Sl: 0x12, + Tss: &psi.TSS{ + Tide: 0x01, + V: 0, + Cni: true, + Sn: 0, + Lsn: 0, + Sd: &psi.PMT{ + Pcrpid: 0x0100, + Pil: 0, + Essd: &psi.ESSD{ + St: byte(sid), + Epid: 0x0100, + Esil: 0x00, + }, + }, + }, + }).Bytes() + return &Encoder{ dst: dst, @@ -202,7 +208,7 @@ func (e *Encoder) TimeBasedPsi(b bool, sendCount int) { e.pktCount = e.psiSendCount } -// Write implements io.Writer. Write takes raw video or audio data and encodes into mpegts, +// Write implements io.Writer. Write takes raw video or audio data and encodes into MPEG-TS, // then sending it to the encoder's io.Writer destination. func (e *Encoder) Write(data []byte) (int, error) { now := time.Now() @@ -257,7 +263,7 @@ func (e *Encoder) Write(data []byte) (int, error) { return len(data), nil } -// writePSI creates mpegts with pat and pmt tables - with pmt table having updated +// writePSI creates MPEG-TS with pat and pmt tables - with pmt table having updated // location and time data. func (e *Encoder) writePSI() error { // Write PAT. @@ -265,7 +271,7 @@ func (e *Encoder) writePSI() error { PUSI: true, PID: PatPid, CC: e.ccFor(PatPid), - AFC: HasPayload, + AFC: hasPayload, Payload: psi.AddPadding(patTable), } _, err := e.dst.Write(patPkt.Bytes(e.tsSpace[:PacketSize])) @@ -283,7 +289,7 @@ func (e *Encoder) writePSI() error { PUSI: true, PID: PmtPid, CC: e.ccFor(PmtPid), - AFC: HasPayload, + AFC: hasPayload, Payload: psi.AddPadding(pmtTable), } _, err = e.dst.Write(pmtPkt.Bytes(e.tsSpace[:PacketSize])) @@ -301,12 +307,12 @@ func (e *Encoder) tick() { // pts retuns the current presentation timestamp. func (e *Encoder) pts() uint64 { - return uint64((e.clock + e.ptsOffset).Seconds() * pcrFreq) + return uint64((e.clock + e.ptsOffset).Seconds() * PTSFrequency) } // pcr returns the current program clock reference. func (e *Encoder) pcr() uint64 { - return uint64(e.clock.Seconds() * pcrFreq) + return uint64(e.clock.Seconds() * PCRFrequency) } // ccFor returns the next continuity counter for pid. diff --git a/container/mts/encoder_test.go b/container/mts/encoder_test.go new file mode 100644 index 00000000..24fb823d --- /dev/null +++ b/container/mts/encoder_test.go @@ -0,0 +1,252 @@ +/* +NAME + encoder_test.go + +AUTHOR + Trek Hopton + Saxon A. Nelson-Milton + +LICENSE + encoder_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License in gpl.txt. + If not, see http://www.gnu.org/licenses. +*/ + +package mts + +import ( + "bytes" + "io" + "io/ioutil" + "testing" + + "github.com/Comcast/gots/packet" + "github.com/Comcast/gots/pes" + + "bitbucket.org/ausocean/av/container/mts/meta" +) + +type nopCloser struct{ io.Writer } + +func (nopCloser) Close() error { return nil } + +type destination struct { + packets [][]byte +} + +func (d *destination) Write(p []byte) (int, error) { + tmp := make([]byte, PacketSize) + copy(tmp, p) + d.packets = append(d.packets, tmp) + return len(p), nil +} + +// TestEncodeVideo checks that we can correctly encode some dummy data into a +// valid MPEG-TS stream. This checks for correct MPEG-TS headers and also that the +// original data is stored correctly and is retreivable. +func TestEncodeVideo(t *testing.T) { + Meta = meta.New() + + const dataLength = 440 + const numOfPackets = 3 + const stuffingLen = 100 + + // Generate test data. + data := make([]byte, 0, dataLength) + for i := 0; i < dataLength; i++ { + data = append(data, byte(i)) + } + + // Expect headers for PID 256 (video) + // NB: timing fields like PCR are neglected. + expectedHeaders := [][]byte{ + { + 0x47, // Sync byte. + 0x41, // TEI=0, PUSI=1, TP=0, PID=00001 (256). + 0x00, // PID(Cont)=00000000. + 0x30, // TSC=00, AFC=11(adaptation followed by payload), CC=0000(0). + 0x07, // AFL= 7. + 0x50, // DI=0,RAI=1,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + }, + { + 0x47, // Sync byte. + 0x01, // TEI=0, PUSI=0, TP=0, PID=00001 (256). + 0x00, // PID(Cont)=00000000. + 0x31, // TSC=00, AFC=11(adaptation followed by payload), CC=0001(1). + 0x01, // AFL= 1. + 0x00, // DI=0,RAI=0,ESPI=0,PCRF=0,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + }, + { + 0x47, // Sync byte. + 0x01, // TEI=0, PUSI=0, TP=0, PID=00001 (256). + 0x00, // PID(Cont)=00000000. + 0x32, // TSC=00, AFC=11(adaptation followed by payload), CC=0010(2). + 0x57, // AFL= 1+stuffingLen. + 0x00, // DI=0,RAI=0,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + }, + } + + // Create the dst and write the test data to encoder. + dst := &destination{} + _, err := NewEncoder(nopCloser{dst}, 25, EncodeH264).Write(data) + if err != nil { + t.Fatalf("could not write data to encoder, failed with err: %v\n", err) + } + + // Check headers. + var expectedIdx int + for _, p := range dst.packets { + // Get PID. + var _p packet.Packet + copy(_p[:], p) + pid := packet.Pid(&_p) + if pid == VideoPid { + // Get mts header, excluding PCR. + gotHeader := p[0:6] + wantHeader := expectedHeaders[expectedIdx] + if !bytes.Equal(gotHeader, wantHeader) { + t.Errorf("did not get expected header for idx: %v.\n Got: %v\n Want: %v\n", expectedIdx, gotHeader, wantHeader) + } + expectedIdx++ + } + } + + // Gather payload data from packets to form the total PES packet. + var pesData []byte + for _, p := range dst.packets { + var _p packet.Packet + copy(_p[:], p) + pid := packet.Pid(&_p) + if pid == VideoPid { + payload, err := packet.Payload(&_p) + if err != nil { + t.Fatalf("could not get payload from mts packet, failed with err: %v\n", err) + } + pesData = append(pesData, payload...) + } + } + + // Get data from the PES packet and compare with the original data. + pes, err := pes.NewPESHeader(pesData) + if err != nil { + t.Fatalf("got error from pes creation: %v\n", err) + } + _data := pes.Data() + if !bytes.Equal(data, _data) { + t.Errorf("did not get expected result.\n Got: %v\n Want: %v\n", data, _data) + } +} + +// TestEncodePcm tests the MPEG-TS encoder's ability to encode pcm audio data. +// It reads and encodes input pcm data into MPEG-TS, then decodes the MPEG-TS and compares the result to the input pcm. +func TestEncodePcm(t *testing.T) { + Meta = meta.New() + + var buf bytes.Buffer + sampleRate := 48000 + sampleSize := 2 + blockSize := 16000 + writeFreq := float64(sampleRate*sampleSize) / float64(blockSize) + e := NewEncoder(nopCloser{&buf}, writeFreq, EncodeAudio) + + inPath := "../../../test/test-data/av/input/sweep_400Hz_20000Hz_-3dBFS_5s_48khz.pcm" + inPcm, err := ioutil.ReadFile(inPath) + if err != nil { + t.Errorf("unable to read file: %v", err) + } + + // Break pcm into blocks and encode to mts and get the resulting bytes. + for i := 0; i < len(inPcm); i += blockSize { + if len(inPcm)-i < blockSize { + block := inPcm[i:] + _, err = e.Write(block) + if err != nil { + t.Errorf("unable to write block: %v", err) + } + } else { + block := inPcm[i : i+blockSize] + _, err = e.Write(block) + if err != nil { + t.Errorf("unable to write block: %v", err) + } + } + } + clip := buf.Bytes() + + // Get the first MTS packet to check + var pkt packet.Packet + pesPacket := make([]byte, 0, blockSize) + got := make([]byte, 0, len(inPcm)) + i := 0 + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + + // Loop through MTS packets until all the audio data from PES packets has been retrieved + for i+PacketSize <= len(clip) { + + // Check MTS packet + if !(pkt.PID() == audioPid) { + i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + continue + } + if !pkt.PayloadUnitStartIndicator() { + i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + } else { + // Copy the first MTS payload + payload, err := pkt.Payload() + if err != nil { + t.Errorf("unable to get MTS payload: %v", err) + } + pesPacket = append(pesPacket, payload...) + + i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + + // Copy the rest of the MTS payloads that are part of the same PES packet + for (!pkt.PayloadUnitStartIndicator()) && i+PacketSize <= len(clip) { + payload, err = pkt.Payload() + if err != nil { + t.Errorf("unable to get MTS payload: %v", err) + } + pesPacket = append(pesPacket, payload...) + + i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + } + } + // Get the audio data from the current PES packet + pesHeader, err := pes.NewPESHeader(pesPacket) + if err != nil { + t.Errorf("unable to read PES packet: %v", err) + } + got = append(got, pesHeader.Data()...) + pesPacket = pesPacket[:0] + } + + // Compare data from MTS with original data. + if !bytes.Equal(got, inPcm) { + t.Error("data decoded from mts did not match input data") + } +} diff --git a/container/mts/metaEncode_test.go b/container/mts/metaEncode_test.go index 939de5b7..83660777 100644 --- a/container/mts/metaEncode_test.go +++ b/container/mts/metaEncode_test.go @@ -48,7 +48,7 @@ const fps = 25 func TestMetaEncode1(t *testing.T) { Meta = meta.New() var buf bytes.Buffer - e := NewEncoder(nopCloser{&buf}, fps, Video) + e := NewEncoder(nopCloser{&buf}, fps, EncodeH264) Meta.Add("ts", "12345678") if err := e.writePSI(); err != nil { t.Errorf(errUnexpectedErr, err.Error()) @@ -76,7 +76,7 @@ func TestMetaEncode1(t *testing.T) { func TestMetaEncode2(t *testing.T) { Meta = meta.New() var buf bytes.Buffer - e := NewEncoder(nopCloser{&buf}, fps, Video) + e := NewEncoder(nopCloser{&buf}, fps, EncodeH264) Meta.Add("ts", "12345678") Meta.Add("loc", "1234,4321,1234") if err := e.writePSI(); err != nil { diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index ed8cbe02..eb4bee5d 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -1,7 +1,7 @@ /* NAME mpegts.go - provides a data structure intended to encapsulate the properties - of an MpegTs packet and also functions to allow manipulation of these packets. + of an MPEG-TS packet and also functions to allow manipulation of these packets. DESCRIPTION See Readme.md @@ -26,6 +26,7 @@ LICENSE along with revid in gpl.txt. If not, see [GNU licenses](http://www.gnu.org/licenses). */ +// Package mts provides MPEGT-TS (mts) encoding and related functions. package mts import ( @@ -33,13 +34,10 @@ import ( "fmt" "github.com/Comcast/gots/packet" + "github.com/Comcast/gots/pes" ) -// General mpegts packet properties. -const ( - PacketSize = 188 - PayloadSize = 176 -) +const PacketSize = 188 // Program ID for various types of ts packets. const ( @@ -52,7 +50,7 @@ const ( // StreamID is the id of the first stream. const StreamID = 0xe0 -// HeadSize is the size of an mpegts packet header. +// HeadSize is the size of an MPEG-TS packet header. const HeadSize = 4 // Consts relating to adaptation field. @@ -163,28 +161,28 @@ type Packet struct { Payload []byte // Mpeg ts Payload } -// FindPmt will take a clip of mpegts and try to find a PMT table - if one +// FindPmt will take a clip of MPEG-TS and try to find a PMT table - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. func FindPmt(d []byte) ([]byte, int, error) { return FindPid(d, PmtPid) } -// FindPat will take a clip of mpegts and try to find a PAT table - if one +// FindPat will take a clip of MPEG-TS and try to find a PAT table - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. func FindPat(d []byte) ([]byte, int, error) { return FindPid(d, PatPid) } -// FindPid will take a clip of mpegts and try to find a packet with given PID - if one +// FindPid will take a clip of MPEG-TS and try to find a packet with given PID - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) { if len(d) < PacketSize { - return nil, -1, errors.New("Mmpegts data not of valid length") + return nil, -1, errors.New("MPEG-TS data not of valid length") } for i = 0; i < len(d); i += PacketSize { p := (uint16(d[i+1]&0x1f) << 8) | uint16(d[i+2]) if p == pid { - pkt = d[i+4 : i+PacketSize] + pkt = d[i : i+PacketSize] return } } @@ -194,16 +192,69 @@ func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) { // FillPayload takes a channel and fills the packets Payload field until the // channel is empty or we've the packet reaches capacity func (p *Packet) FillPayload(data []byte) int { - currentPktLen := 6 + asInt(p.PCRF)*6 + asInt(p.OPCRF)*6 + - asInt(p.SPF)*1 + asInt(p.TPDF)*1 + len(p.TPD) - if len(data) > PayloadSize-currentPktLen { - p.Payload = make([]byte, PayloadSize-currentPktLen) + currentPktLen := 6 + asInt(p.PCRF)*6 + if len(data) > PacketSize-currentPktLen { + p.Payload = make([]byte, PacketSize-currentPktLen) } else { p.Payload = make([]byte, len(data)) } return copy(p.Payload, data) } +// Bytes interprets the fields of the ts packet instance and outputs a +// corresponding byte slice +func (p *Packet) Bytes(buf []byte) []byte { + if buf == nil || cap(buf) < PacketSize { + buf = make([]byte, PacketSize) + } + + if p.OPCRF { + panic("original program clock reference field unsupported") + } + if p.SPF { + panic("splicing countdown unsupported") + } + if p.TPDF { + panic("transport private data unsupported") + } + if p.AFEF { + panic("adaptation field extension unsupported") + } + + buf = buf[:6] + buf[0] = 0x47 + buf[1] = (asByte(p.TEI)<<7 | asByte(p.PUSI)<<6 | asByte(p.Priority)<<5 | byte((p.PID&0xFF00)>>8)) + buf[2] = byte(p.PID & 0x00FF) + buf[3] = (p.TSC<<6 | p.AFC<<4 | p.CC) + + var maxPayloadSize int + if p.AFC&0x2 != 0 { + maxPayloadSize = PacketSize - 6 - asInt(p.PCRF)*6 + } else { + maxPayloadSize = PacketSize - 4 + } + + stuffingLen := maxPayloadSize - len(p.Payload) + if p.AFC&0x2 != 0 { + buf[4] = byte(1 + stuffingLen + asInt(p.PCRF)*6) + buf[5] = (asByte(p.DI)<<7 | asByte(p.RAI)<<6 | asByte(p.ESPI)<<5 | asByte(p.PCRF)<<4 | asByte(p.OPCRF)<<3 | asByte(p.SPF)<<2 | asByte(p.TPDF)<<1 | asByte(p.AFEF)) + } else { + buf = buf[:4] + } + + for i := 40; p.PCRF && i >= 0; i -= 8 { + buf = append(buf, byte((p.PCR<<15)>>uint(i))) + } + + for i := 0; i < stuffingLen; i++ { + buf = append(buf, 0xff) + } + curLen := len(buf) + buf = buf[:PacketSize] + copy(buf[curLen:], p.Payload) + return buf +} + func asInt(b bool) int { if b { return 1 @@ -218,55 +269,6 @@ func asByte(b bool) byte { return 0 } -// Bytes interprets the fields of the ts packet instance and outputs a -// corresponding byte slice -func (p *Packet) Bytes(buf []byte) []byte { - if buf == nil || cap(buf) != PacketSize { - buf = make([]byte, 0, PacketSize) - } - buf = buf[:0] - stuffingLength := 182 - len(p.Payload) - len(p.TPD) - asInt(p.PCRF)*6 - - asInt(p.OPCRF)*6 - asInt(p.SPF) - var stuffing []byte - if stuffingLength > 0 { - stuffing = make([]byte, stuffingLength) - } - for i := range stuffing { - stuffing[i] = 0xFF - } - afl := 1 + asInt(p.PCRF)*6 + asInt(p.OPCRF)*6 + asInt(p.SPF) + asInt(p.TPDF) + len(p.TPD) + len(stuffing) - buf = append(buf, []byte{ - 0x47, - (asByte(p.TEI)<<7 | asByte(p.PUSI)<<6 | asByte(p.Priority)<<5 | byte((p.PID&0xFF00)>>8)), - byte(p.PID & 0x00FF), - (p.TSC<<6 | p.AFC<<4 | p.CC), - }...) - - if p.AFC == 3 || p.AFC == 2 { - buf = append(buf, []byte{ - byte(afl), (asByte(p.DI)<<7 | asByte(p.RAI)<<6 | asByte(p.ESPI)<<5 | - asByte(p.PCRF)<<4 | asByte(p.OPCRF)<<3 | asByte(p.SPF)<<2 | - asByte(p.TPDF)<<1 | asByte(p.AFEF)), - }...) - for i := 40; p.PCRF && i >= 0; i -= 8 { - buf = append(buf, byte((p.PCR<<15)>>uint(i))) - } - for i := 40; p.OPCRF && i >= 0; i -= 8 { - buf = append(buf, byte(p.OPCR>>uint(i))) - } - if p.SPF { - buf = append(buf, p.SC) - } - if p.TPDF { - buf = append(buf, append([]byte{p.TPDL}, p.TPD...)...) - } - buf = append(buf, p.Ext...) - buf = append(buf, stuffing...) - } - buf = append(buf, p.Payload...) - return buf -} - type Option func(p *packet.Packet) // addAdaptationField adds an adaptation field to p, and applys the passed options to this field. @@ -315,3 +317,47 @@ func DiscontinuityIndicator(f bool) Option { p[DiscontinuityIndicatorIdx] |= DiscontinuityIndicatorMask & set } } + +// GetPTSRange retreives the first and last PTS of an MPEGTS clip. +func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) { + // Find the first packet with PID pidType. + pkt, _, err := FindPid(clip, pid) + if err != nil { + return [2]uint64{}, err + } + + // Get the payload of the packet, which will be the start of the PES packet. + var _pkt packet.Packet + copy(_pkt[:], pkt) + payload, err := packet.Payload(&_pkt) + if err != nil { + fmt.Printf("_pkt: %v\n", _pkt) + return [2]uint64{}, err + } + + // Get the the first PTS from the PES header. + _pes, err := pes.NewPESHeader(payload) + if err != nil { + return [2]uint64{}, err + } + pts[0] = _pes.PTS() + + // Get the final PTS searching from end of clip for access unit start. + for i := len(clip) - PacketSize; i >= 0; i -= PacketSize { + copy(_pkt[:], clip[i:i+PacketSize]) + if packet.PayloadUnitStartIndicator(&_pkt) && uint16(_pkt.PID()) == pid { + payload, err = packet.Payload(&_pkt) + if err != nil { + return [2]uint64{}, err + } + _pes, err = pes.NewPESHeader(payload) + if err != nil { + return [2]uint64{}, err + } + pts[1] = _pes.PTS() + return + } + } + + return [2]uint64{}, errors.New("could only find one access unit in mpegts clip") +} diff --git a/container/mts/mpegts_test.go b/container/mts/mpegts_test.go new file mode 100644 index 00000000..4c90cc0e --- /dev/null +++ b/container/mts/mpegts_test.go @@ -0,0 +1,266 @@ +/* +NAME + mpegts_test.go + +DESCRIPTION + mpegts_test.go contains testing for functionality found in mpegts.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + Copyright (C) 2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package mts + +import ( + "bytes" + "math/rand" + "testing" + "time" + + "bitbucket.org/ausocean/av/container/mts/pes" + "bitbucket.org/ausocean/av/container/mts/psi" + "github.com/Comcast/gots/packet" +) + +// TestGetPTSRange checks that GetPTSRange can correctly get the first and last +// PTS in an MPEGTS clip. +func TestGetPTSRange(t *testing.T) { + const ( + numOfFrames = 20 + maxFrameSize = 1000 + minFrameSize = 100 + rate = 25 // fps + interval = float64(1) / rate // s + ptsFreq = 90000 // Hz + ) + + // Generate randomly sized data for each frame. + rand.Seed(time.Now().UnixNano()) + frames := make([][]byte, numOfFrames) + for i := range frames { + size := rand.Intn(maxFrameSize-minFrameSize) + minFrameSize + frames[i] = make([]byte, size) + } + + var clip bytes.Buffer + + // Write the PSI first. + err := writePSI(&clip) + if err != nil { + t.Fatalf("did not expect error writing psi: %v", err) + } + + // Now write frames. + var curTime float64 + for _, frame := range frames { + nextPTS := curTime * ptsFreq + + err = writeFrame(&clip, frame, uint64(nextPTS)) + if err != nil { + t.Fatalf("did not expect error writing frame: %v", err) + } + + curTime += interval + } + + got, err := GetPTSRange(clip.Bytes(), videoPid) + if err != nil { + t.Fatalf("did not expect error getting PTS range: %v", err) + } + + want := [2]uint64{0, uint64((numOfFrames - 1) * interval * ptsFreq)} + if got != want { + t.Errorf("did not get expected result.\n Got: %v\n Want: %v\n", got, want) + } +} + +// writePSI is a helper function write the PSI found at the start of a clip. +func writePSI(b *bytes.Buffer) error { + // Write PAT. + pat := Packet{ + PUSI: true, + PID: PatPid, + CC: 0, + AFC: HasPayload, + Payload: psi.AddPadding(patTable), + } + _, err := b.Write(pat.Bytes(nil)) + if err != nil { + return err + } + + // Write PMT. + pmt := Packet{ + PUSI: true, + PID: PmtPid, + CC: 0, + AFC: HasPayload, + Payload: psi.AddPadding(pmtTable), + } + _, err = b.Write(pmt.Bytes(nil)) + if err != nil { + return err + } + return nil +} + +// writeFrame is a helper function used to form a PES packet from a frame, and +// then fragment this across MPEGTS packets where they are then written to the +// given buffer. +func writeFrame(b *bytes.Buffer, frame []byte, pts uint64) error { + // Prepare PES data. + pesPkt := pes.Packet{ + StreamID: H264ID, + PDI: hasPTS, + PTS: pts, + Data: frame, + HeaderLength: 5, + } + buf := pesPkt.Bytes(nil) + + // Write PES data acroos MPEGTS packets. + pusi := true + for len(buf) != 0 { + pkt := Packet{ + PUSI: pusi, + PID: videoPid, + RAI: pusi, + CC: 0, + AFC: hasAdaptationField | hasPayload, + PCRF: pusi, + } + n := pkt.FillPayload(buf) + buf = buf[n:] + + pusi = false + _, err := b.Write(pkt.Bytes(nil)) + if err != nil { + return err + } + } + return nil +} + +// TestBytes checks that Packet.Bytes() correctly produces a []byte +// representation of a Packet. +func TestBytes(t *testing.T) { + const payloadLen, payloadChar, stuffingChar = 120, 0x11, 0xff + const stuffingLen = PacketSize - payloadLen - 12 + + tests := []struct { + packet Packet + expectedHeader []byte + }{ + { + packet: Packet{ + PUSI: true, + PID: 1, + RAI: true, + CC: 4, + AFC: HasPayload | HasAdaptationField, + PCRF: true, + PCR: 1, + }, + expectedHeader: []byte{ + 0x47, // Sync byte. + 0x40, // TEI=0, PUSI=1, TP=0, PID=00000. + 0x01, // PID(Cont)=00000001. + 0x34, // TSC=00, AFC=11(adaptation followed by payload), CC=0100(4). + byte(7 + stuffingLen), // AFL=. + 0x50, // DI=0,RAI=1,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, // PCR. + }, + }, + } + + for testNum, test := range tests { + // Construct payload. + payload := make([]byte, 0, payloadLen) + for i := 0; i < payloadLen; i++ { + payload = append(payload, payloadChar) + } + + // Fill the packet payload. + test.packet.FillPayload(payload) + + // Create expected packet data and copy in expected header. + expected := make([]byte, len(test.expectedHeader), PacketSize) + copy(expected, test.expectedHeader) + + // Append stuffing. + for i := 0; i < stuffingLen; i++ { + expected = append(expected, stuffingChar) + } + + // Append payload to expected bytes. + expected = append(expected, payload...) + + // Compare got with expected. + got := test.packet.Bytes(nil) + if !bytes.Equal(got, expected) { + t.Errorf("did not get expected result for test: %v.\n Got: %v\n Want: %v\n", testNum, got, expected) + } + } +} + +// TestFindPid checks that FindPid can correctly extract the first instance +// of a PID from an MPEG-TS stream. +func TestFindPid(t *testing.T) { + const targetPacketNum, numOfPackets, targetPid, stdPid = 6, 15, 1, 0 + + // Prepare the stream of packets. + var stream []byte + for i := 0; i < numOfPackets; i++ { + pid := uint16(stdPid) + if i == targetPacketNum { + pid = targetPid + } + + p := Packet{ + PID: pid, + AFC: hasPayload | hasAdaptationField, + } + p.FillPayload([]byte{byte(i)}) + stream = append(stream, p.Bytes(nil)...) + } + + // Try to find the targetPid in the stream. + p, i, err := FindPid(stream, targetPid) + if err != nil { + t.Fatalf("unexpected error finding PID: %v\n", err) + } + + // Check the payload. + var _p packet.Packet + copy(_p[:], p) + payload, err := packet.Payload(&_p) + if err != nil { + t.Fatalf("unexpected error getting packet payload: %v\n", err) + } + got := payload[0] + if got != targetPacketNum { + t.Errorf("payload of found packet is not correct.\nGot: %v, Want: %v\n", got, targetPacketNum) + } + + // Check the index. + _got := i / PacketSize + if _got != targetPacketNum { + t.Errorf("index of found packet is not correct.\nGot: %v, want: %v\n", _got, targetPacketNum) + } +} diff --git a/container/mts/pes/pes.go b/container/mts/pes/pes.go index 58a143cc..1dc2dd3e 100644 --- a/container/mts/pes/pes.go +++ b/container/mts/pes/pes.go @@ -26,7 +26,13 @@ LICENSE package pes +<<<<<<< HEAD const MaxPesSize = 64 * 1 << 10 // 65536 +======= +import "github.com/Comcast/gots" + +const MaxPesSize = 64 * 1 << 10 +>>>>>>> master /* The below data struct encapsulates the fields of an PES packet. Below is @@ -108,16 +114,11 @@ func (p *Packet) Bytes(buf []byte) []byte { boolByte(p.ACIF)<<2 | boolByte(p.CRCF)<<1 | boolByte(p.EF)), p.HeaderLength, }...) + if p.PDI == byte(2) { - pts := 0x2100010001 | (p.PTS&0x1C0000000)<<3 | (p.PTS&0x3FFF8000)<<2 | - (p.PTS&0x7FFF)<<1 - buf = append(buf, []byte{ - byte((pts & 0xFF00000000) >> 32), - byte((pts & 0x00FF000000) >> 24), - byte((pts & 0x0000FF0000) >> 16), - byte((pts & 0x000000FF00) >> 8), - byte(pts & 0x00000000FF), - }...) + ptsIdx := len(buf) + buf = buf[:ptsIdx+5] + gots.InsertPTS(buf[ptsIdx:], p.PTS) } buf = append(buf, append(p.Stuff, p.Data...)...) return buf diff --git a/container/mts/psi/helpers.go b/container/mts/psi/helpers.go index b8bab6b5..621460f5 100644 --- a/container/mts/psi/helpers.go +++ b/container/mts/psi/helpers.go @@ -125,7 +125,7 @@ func trimTo(d []byte, t byte) []byte { } // addPadding adds an appropriate amount of padding to a pat or pmt table for -// addition to an mpegts packet +// addition to an MPEG-TS packet func AddPadding(d []byte) []byte { t := make([]byte, PacketSize) copy(t, d) diff --git a/container/mts/psi/psi.go b/container/mts/psi/psi.go index c93d3011..3703faf4 100644 --- a/container/mts/psi/psi.go +++ b/container/mts/psi/psi.go @@ -32,7 +32,7 @@ import ( "github.com/Comcast/gots/psi" ) -// PacketSize of psi (without mpegts header) +// PacketSize of psi (without MPEG-TS header) const PacketSize = 184 // Lengths of section definitions. diff --git a/exp/adpcm/decode-pcm/decode-pcm.go b/exp/adpcm/decode-pcm/decode-pcm.go index 8d2bd7f6..2d471324 100644 --- a/exp/adpcm/decode-pcm/decode-pcm.go +++ b/exp/adpcm/decode-pcm/decode-pcm.go @@ -2,9 +2,6 @@ NAME decode-pcm.go -DESCRIPTION - decode-pcm.go is a program for decoding/decompressing an adpcm file to a pcm file. - AUTHOR Trek Hopton @@ -25,6 +22,7 @@ LICENSE If not, see [GNU licenses](http://www.gnu.org/licenses). */ +// decode-pcm is a command-line program for decoding/decompressing an adpcm file to a pcm file. package main import ( @@ -54,8 +52,7 @@ func main() { fmt.Println("Read", len(comp), "bytes from file", inPath) // Decode adpcm. - numBlocks := len(comp) / adpcm.AdpcmBS - decoded := bytes.NewBuffer(make([]byte, 0, adpcm.PcmBS*numBlocks)) + decoded := bytes.NewBuffer(make([]byte, 0, len(comp)*4)) dec := adpcm.NewDecoder(decoded) _, err = dec.Write(comp) if err != nil { diff --git a/exp/adpcm/encode-pcm/encode-pcm.go b/exp/adpcm/encode-pcm/encode-pcm.go index d283c822..ded88017 100644 --- a/exp/adpcm/encode-pcm/encode-pcm.go +++ b/exp/adpcm/encode-pcm/encode-pcm.go @@ -2,9 +2,6 @@ NAME encode-pcm.go -DESCRIPTION - encode-pcm.go is a program for encoding/compressing a pcm file to an adpcm file. - AUTHOR Trek Hopton @@ -25,6 +22,7 @@ LICENSE If not, see [GNU licenses](http://www.gnu.org/licenses). */ +// encode-pcm is a command-line program for encoding/compressing a pcm file to an adpcm file. package main import ( @@ -54,8 +52,7 @@ func main() { fmt.Println("Read", len(pcm), "bytes from file", inPath) // Encode adpcm. - numBlocks := len(pcm) / adpcm.PcmBS - comp := bytes.NewBuffer(make([]byte, 0, adpcm.AdpcmBS*numBlocks)) + comp := bytes.NewBuffer(make([]byte, 0, adpcm.EncBytes(len(pcm)))) enc := adpcm.NewEncoder(comp) _, err = enc.Write(pcm) if err != nil { diff --git a/exp/pcm/resample/resample.go b/exp/pcm/resample/resample.go index eab7a342..3d595bb8 100644 --- a/exp/pcm/resample/resample.go +++ b/exp/pcm/resample/resample.go @@ -2,9 +2,6 @@ NAME resample.go -DESCRIPTION - resample.go is a program for resampling a pcm file. - AUTHOR Trek Hopton @@ -24,6 +21,8 @@ LICENSE You should have received a copy of the GNU General Public License in gpl.txt. If not, see [GNU licenses](http://www.gnu.org/licenses). */ + +// resample is a command-line program for resampling a pcm file. package main import ( diff --git a/exp/pcm/stereo-to-mono/stereo-to-mono.go b/exp/pcm/stereo-to-mono/stereo-to-mono.go index ccbf87bf..7dbfd9a5 100644 --- a/exp/pcm/stereo-to-mono/stereo-to-mono.go +++ b/exp/pcm/stereo-to-mono/stereo-to-mono.go @@ -2,9 +2,6 @@ NAME stereo-to-mono.go -DESCRIPTION - stereo-to-mono.go is a program for converting a mono pcm file to a stereo pcm file. - AUTHOR Trek Hopton @@ -24,6 +21,8 @@ LICENSE You should have received a copy of the GNU General Public License in gpl.txt. If not, see [GNU licenses](http://www.gnu.org/licenses). */ + +// stereo-to-mono is a command-line program for converting a mono pcm file to a stereo pcm file. package main import ( diff --git a/protocol/rtcp/client.go b/protocol/rtcp/client.go new file mode 100644 index 00000000..4ac5b694 --- /dev/null +++ b/protocol/rtcp/client.go @@ -0,0 +1,288 @@ +/* +NAME + client.go + +DESCRIPTION + Client.go provides an implemntation of a basic RTCP Client that will send + receiver reports, and receive sender reports to parse relevant statistics. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtcp + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "net" + "sync" + "time" + + "bitbucket.org/ausocean/av/protocol/rtp" + "bitbucket.org/ausocean/utils/logger" +) + +const ( + clientSSRC = 1 // Any non-zero value will do. + defaultClientName = "Client" + defaultSendInterval = 2 * time.Second + delayUnit = 1.0 / 65536.0 + pkg = "rtcp: " + rtcpVer = 2 + receiverBufSize = 200 +) + +// Log describes a function signature required by the RTCP for the purpose of +// logging. +type Log func(lvl int8, msg string, args ...interface{}) + +// Client is an RTCP Client that will handle receiving SenderReports from a server +// and sending out ReceiverReports. +type Client struct { + cAddr *net.UDPAddr // Address of client. + sAddr *net.UDPAddr // Address of RTSP server. + name string // Name of the client for source description purposes. + sourceSSRC uint32 // Source identifier of this client. + mu sync.Mutex // Will be used to change parameters during operation safely. + seq uint32 // Last RTP sequence number. + senderTs [8]byte // The timestamp of the last sender report. + interval time.Duration // Interval between sender report and receiver report. + receiveTime time.Time // Time last sender report was received. + buf [receiverBufSize]byte // Buf used to store the receiver report and source descriptions. + conn *net.UDPConn // The UDP connection used for receiving and sending RTSP packets. + wg sync.WaitGroup // This is used to wait for send and recv routines to stop when Client is stopped. + quit chan struct{} // Channel used to communicate quit signal to send and recv routines. + log Log // Used to log any messages. + rtpClt *rtp.Client + err chan error // Client will send any errors through this chan. Can be accessed by Err(). +} + +// NewClient returns a pointer to a new Client. +func NewClient(clientAddress, serverAddress string, rtpClt *rtp.Client, l Log) (*Client, error) { + c := &Client{ + name: defaultClientName, + quit: make(chan struct{}), + interval: defaultSendInterval, + rtpClt: rtpClt, + log: l, + } + + var err error + c.cAddr, err = net.ResolveUDPAddr("udp", clientAddress) + if err != nil { + return nil, errors.New(fmt.Sprintf("can't resolve Client address, failed with error: %v\n", err)) + } + + c.sAddr, err = net.ResolveUDPAddr("udp", serverAddress) + if err != nil { + return nil, errors.New(fmt.Sprintf("can't resolve server address, failed with error: %v\n", err)) + } + + c.conn, err = net.DialUDP("udp", c.cAddr, c.sAddr) + if err != nil { + return nil, errors.New(fmt.Sprintf("can't dial, failed with error: %v\n", err)) + } + return c, nil +} + +// SetSendInterval sets a custom receiver report send interval (default is 5 seconds.) +func (c *Client) SetSendInterval(d time.Duration) { + c.interval = d +} + +// SetName sets a custom client name for use in receiver report source description. +// Default is "Client". +func (c *Client) SetName(name string) { + c.name = name +} + +// Start starts the listen and send routines. This will start the process of +// receiving and parsing sender reports, and the process of sending receiver +// reports to the server. +func (c *Client) Start() { + c.log(logger.Debug, pkg+"Client is starting") + c.err = make(chan error) + c.wg.Add(2) + go c.recv() + go c.send() +} + +// Stop sends a quit signal to the send and receive routines and closes the +// UDP connection. It will wait until both routines have returned. +func (c *Client) Stop() { + c.log(logger.Debug, pkg+"Client is stopping") + close(c.quit) + c.conn.Close() + c.wg.Wait() + close(c.err) +} + +// Err provides read access to the Client err channel. This must be checked +// otherwise the client will block if an error encountered. +func (c *Client) Err() <-chan error { + return c.err +} + +// recv reads from the UDP connection and parses SenderReports. +func (c *Client) recv() { + defer c.wg.Done() + c.log(logger.Debug, pkg+"Client is receiving") + buf := make([]byte, 4096) + for { + select { + case <-c.quit: + return + default: + n, _, err := c.conn.ReadFromUDP(buf) + if err != nil { + c.err <- err + continue + } + c.log(logger.Debug, pkg+"sender report received", "report", buf[:n]) + c.parse(buf[:n]) + } + } +} + +// send writes receiver reports to the server. +func (c *Client) send() { + defer c.wg.Done() + c.log(logger.Debug, pkg+"Client is sending") + for { + select { + case <-c.quit: + return + default: + time.Sleep(c.interval) + + report := ReceiverReport{ + Header: Header{ + Version: rtcpVer, + Padding: false, + ReportCount: 1, + Type: typeReceiverReport, + }, + SenderSSRC: clientSSRC, + Blocks: []ReportBlock{ + ReportBlock{ + SourceIdentifier: c.rtpClt.SSRC(), + FractionLost: 0, + PacketsLost: math.MaxUint32, + HighestSequence: uint32((c.rtpClt.Cycles() << 16) | c.rtpClt.Sequence()), + Jitter: c.jitter(), + SenderReportTs: c.lastSenderTs(), + SenderReportDelay: c.delay(), + }, + }, + Extensions: nil, + } + + description := Description{ + Header: Header{ + Version: rtcpVer, + Padding: false, + ReportCount: 1, + Type: typeDescription, + }, + Chunks: []Chunk{ + Chunk{ + SSRC: clientSSRC, + Items: []SDESItem{ + SDESItem{ + Type: typeCName, + Text: []byte(c.name), + }, + }, + }, + }, + } + + c.log(logger.Debug, pkg+"sending receiver report") + _, err := c.conn.Write(c.formPayload(&report, &description)) + if err != nil { + c.err <- err + } + } + } +} + +// formPayload takes a pointer to a ReceiverReport and a pointer to a +// Source Description and calls Bytes on both, writing to the underlying Client +// buf. A slice to the combined writtem memory is returned. +func (c *Client) formPayload(r *ReceiverReport, d *Description) []byte { + rl := len(r.Bytes(c.buf[:])) + dl := len(d.Bytes(c.buf[rl:])) + t := rl + dl + if t > cap(c.buf) { + panic("Client buf not big enough") + } + return c.buf[:t] +} + +// parse will read important statistics from sender reports. +func (c *Client) parse(buf []byte) { + c.markReceivedTime() + t, err := ParseTimestamp(buf) + if err != nil { + c.err <- fmt.Errorf("could not get timestamp from sender report, failed with error: %v", err) + } + c.setSenderTs(t) +} + +// jitter returns the interarrival jitter as described by RTCP specifications: +// https://tools.ietf.org/html/rfc3550 +// TODO(saxon): complete this. +func (c *Client) jitter() uint32 { + return 0 +} + +// setSenderTs allows us to safely set the current sender report timestamp. +func (c *Client) setSenderTs(t Timestamp) { + c.mu.Lock() + binary.BigEndian.PutUint32(c.senderTs[:], t.Seconds) + binary.BigEndian.PutUint32(c.senderTs[4:], t.Fraction) + c.mu.Unlock() +} + +// lastSenderTs returns the timestamp of the most recent sender report. +func (c *Client) lastSenderTs() uint32 { + c.mu.Lock() + t := binary.BigEndian.Uint32(c.senderTs[2:]) + c.mu.Unlock() + return t +} + +// delay returns the duration between the receive time of the last sender report +// and now. This is called when forming a receiver report. +func (c *Client) delay() uint32 { + c.mu.Lock() + t := c.receiveTime + c.mu.Unlock() + return uint32(time.Now().Sub(t).Seconds() / delayUnit) +} + +// markReceivedTime is called when a sender report is received to mark the receive time. +func (c *Client) markReceivedTime() { + c.mu.Lock() + c.receiveTime = time.Now() + c.mu.Unlock() +} diff --git a/protocol/rtcp/client_test.go b/protocol/rtcp/client_test.go new file mode 100644 index 00000000..6c95c75d --- /dev/null +++ b/protocol/rtcp/client_test.go @@ -0,0 +1,232 @@ +/* +NAME + client_test.go + +DESCRIPTION + client_test.go contains testing utilities for functionality provided in client.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtcp + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "net" + "strings" + "testing" + "time" + + "bitbucket.org/ausocean/av/protocol/rtp" + "bitbucket.org/ausocean/utils/logger" +) + +// TestFromPayload checks that formPayload is working as expected. +func TestFormPayload(t *testing.T) { + // Expected data from a valid RTCP packet. + expect := []byte{ + 0x81, 0xc9, 0x00, 0x07, + 0xd6, 0xe0, 0x98, 0xda, + 0x6f, 0xad, 0x40, 0xc6, + 0x00, 0xff, 0xff, 0xff, + 0x00, 0x01, 0x83, 0x08, + 0x00, 0x00, 0x00, 0x20, + 0xb9, 0xe1, 0x25, 0x2a, + 0x00, 0x00, 0x2b, 0xf9, + 0x81, 0xca, 0x00, 0x04, + 0xd6, 0xe0, 0x98, 0xda, + 0x01, 0x08, 0x73, 0x61, + 0x78, 0x6f, 0x6e, 0x2d, + 0x70, 0x63, 0x00, 0x00, + } + + report := ReceiverReport{ + Header: Header{ + Version: 2, + Padding: false, + ReportCount: 1, + Type: typeReceiverReport, + }, + SenderSSRC: 3605043418, + Blocks: []ReportBlock{ + ReportBlock{ + SourceIdentifier: 1873625286, + FractionLost: 0, + PacketsLost: math.MaxUint32, + HighestSequence: 99080, + Jitter: 32, + SenderReportTs: 3118540074, + SenderReportDelay: 11257, + }, + }, + Extensions: nil, + } + + description := Description{ + Header: Header{ + Version: 2, + Padding: false, + ReportCount: 1, + Type: typeDescription, + }, + Chunks: []Chunk{ + Chunk{ + SSRC: 3605043418, + Items: []SDESItem{ + SDESItem{ + Type: typeCName, + Text: []byte("saxon-pc"), + }, + }, + }, + }, + } + + c := &Client{} + p := c.formPayload(&report, &description) + + if !bytes.Equal(p, expect) { + t.Fatalf("unexpected result.\nGot: %v\n Want: %v\n", p, expect) + } + + bufAddr := fmt.Sprintf("%p", c.buf[:]) + pAddr := fmt.Sprintf("%p", p) + if bufAddr != pAddr { + t.Errorf("unexpected result.\nGot: %v\n want: %v\n", pAddr, bufAddr) + } +} + +// dummyLogger will allow logging to be done by the testing pkg. +type dummyLogger testing.T + +func (dl *dummyLogger) log(lvl int8, msg string, args ...interface{}) { + var l string + switch lvl { + case logger.Warning: + l = "warning" + case logger.Debug: + l = "debug" + case logger.Info: + l = "info" + case logger.Error: + l = "error" + case logger.Fatal: + l = "fatal" + } + msg = l + ": " + msg + for i := 0; i < len(args); i++ { + msg += " %v" + } + if len(args) == 0 { + dl.Log(msg + "\n") + return + } + dl.Logf(msg+"\n", args) +} + +// TestReceiveAndSend tests basic RTCP client behaviour with a basic RTCP server. +// The RTCP client will send through receiver reports, and the RTCP server will +// respond with sender reports. +func TestReceiveAndSend(t *testing.T) { + const clientAddr, serverAddr = "localhost:8000", "localhost:8001" + rtpClt, err := rtp.NewClient("localhost:8002") + if err != nil { + t.Fatalf("unexpected error when creating RTP client: %v", err) + } + + c, err := NewClient( + clientAddr, + serverAddr, + rtpClt, + (*dummyLogger)(t).log, + ) + if err != nil { + t.Fatalf("unexpected error when creating client: %v\n", err) + } + + go func() { + for { + err, ok := <-c.Err() + if ok { + const errConnClosed = "use of closed network connection" + if !strings.Contains(err.Error(), errConnClosed) { + t.Fatalf("error received from client error chan: %v\n", err) + } + } else { + return + } + } + }() + + c.Start() + + sAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + t.Fatalf("could not resolve test server address, failed with error: %v", err) + } + + cAddr, err := net.ResolveUDPAddr("udp", clientAddr) + if err != nil { + t.Fatalf("could not resolve client address, failed with error: %v", err) + } + + conn, err := net.DialUDP("udp", sAddr, cAddr) + if err != nil { + t.Fatalf("could not dial, failed with error: %v\n", err) + } + + buf := make([]byte, 4096) + for i := 0; i < 5; i++ { + t.Log("SERVER: waiting for receiver report\n") + n, _, _ := conn.ReadFromUDP(buf) + t.Logf("SERVER: receiver report received: \n%v\n", buf[:n]) + + now := time.Now().Second() + var time [8]byte + binary.BigEndian.PutUint64(time[:], uint64(now)) + msw := binary.BigEndian.Uint32(time[:4]) + lsw := binary.BigEndian.Uint32(time[4:]) + + report := SenderReport{ + Header: Header{ + Version: rtcpVer, + Padding: false, + ReportCount: 0, + Type: typeSenderReport, + }, + SSRC: 1234567, + TimestampMSW: msw, + TimestampLSW: lsw, + RTPTimestamp: 0, + PacketCount: 0, + OctetCount: 0, + } + r := report.Bytes() + t.Logf("SERVER: sending sender report: \n%v\n", r) + _, err := conn.Write(r) + if err != nil { + t.Errorf("did not expect error: %v\n", err) + } + } + c.Stop() +} diff --git a/protocol/rtcp/parse.go b/protocol/rtcp/parse.go new file mode 100644 index 00000000..2007a26d --- /dev/null +++ b/protocol/rtcp/parse.go @@ -0,0 +1,60 @@ +/* +NAME + parse.go + +DESCRIPTION + parse.go contains functionality for parsing RTCP packets. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtcp + +import ( + "encoding/binary" + "errors" +) + +// Timestamp describes an NTP timestamp, see https://tools.ietf.org/html/rfc1305 +type Timestamp struct { + Seconds uint32 + Fraction uint32 +} + +// ParseTimestamp gets the timestamp from a receiver report and returns it as +// a Timestamp as defined above. If the given bytes do not represent a valid +// receiver report, an error is returned. +func ParseTimestamp(buf []byte) (Timestamp, error) { + if len(buf) < 4 { + return Timestamp{}, errors.New("bad RTCP packet, not of sufficient length") + } + if (buf[0]&0xc0)>>6 != rtcpVer { + return Timestamp{}, errors.New("incompatible RTCP version") + } + + if buf[1] != typeSenderReport { + return Timestamp{}, errors.New("RTCP packet is not of sender report type") + } + + return Timestamp{ + Seconds: binary.BigEndian.Uint32(buf[8:]), + Fraction: binary.BigEndian.Uint32(buf[12:]), + }, nil +} diff --git a/protocol/rtcp/parse_test.go b/protocol/rtcp/parse_test.go new file mode 100644 index 00000000..ec63aac2 --- /dev/null +++ b/protocol/rtcp/parse_test.go @@ -0,0 +1,61 @@ +/* +NAME + parse_test.go + +DESCRIPTION + parse_test.go provides testing utilities for functionality found in parse.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtcp + +import ( + "testing" +) + +// TestTimestamp checks that Timestamp correctly returns the most signicicant +// word, and least signiciant word, of a receiver report timestamp. +func TestTimestamp(t *testing.T) { + const expectedMSW = 2209003992 + const expectedLSW = 1956821460 + report := []byte{ + 0x80, 0xc8, 0x00, 0x06, + 0x6f, 0xad, 0x40, 0xc6, + 0x83, 0xaa, 0xb9, 0xd8, // Most significant word of timestamp (2209003992) + 0x74, 0xa2, 0xb9, 0xd4, // Least significant word of timestamp (1956821460) + 0x4b, 0x1c, 0x5a, 0xa5, + 0x00, 0x00, 0x00, 0x66, + 0x00, 0x01, 0xc2, 0xc5, + } + + ts, err := ParseTimestamp(report) + if err != nil { + t.Fatalf("did not expect error: %v", err) + } + + if ts.Seconds != expectedMSW { + t.Errorf("most significant word of timestamp is not what's expected. \nGot: %v\n Want: %v\n", ts.Seconds, expectedMSW) + } + + if ts.Fraction != expectedLSW { + t.Errorf("least significant word of timestamp is not what's expected. \nGot: %v\n Want: %v\n", ts.Fraction, expectedLSW) + } +} diff --git a/protocol/rtcp/rtcp.go b/protocol/rtcp/rtcp.go new file mode 100644 index 00000000..7debdf79 --- /dev/null +++ b/protocol/rtcp/rtcp.go @@ -0,0 +1,222 @@ +/* +NAME + rtcp.go + +DESCRIPTION + rtcp.go contains structs to describe RTCP packets, and functionality to form + []bytes of these structs. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// Package RTCP provides RTCP data structures and a client for communicating +// with an RTCP service. +package rtcp + +import ( + "encoding/binary" +) + +// RTCP packet types. +const ( + typeSenderReport = 200 + typeReceiverReport = 201 + typeDescription = 202 +) + +// Source Description Item types. +const ( + typeCName = 1 +) + +const ( + reportBlockSize = 6 + senderReportSize = 28 +) + +// ReceiverReport describes an RTCP receiver report packet. +type ReceiverReport struct { + Header // Standard RTCP packet header. + SenderSSRC uint32 // SSRC of the sender of this report. + Blocks []ReportBlock // Report blocks. + Extensions [][4]byte // Contains any extensions to the packet. +} + +// Bytes returns a []byte of the ReceiverReport r. +func (r *ReceiverReport) Bytes(buf []byte) []byte { + l := 8 + 4*reportBlockSize*len(r.Blocks) + 4*len(r.Extensions) + if buf == nil || cap(buf) < l { + buf = make([]byte, l) + } + buf = buf[:l] + l = 1 + reportBlockSize*len(r.Blocks) + len(r.Extensions) + r.writeHeader(buf, l) + binary.BigEndian.PutUint32(buf[4:], r.SenderSSRC) + + idx := 8 + for _, b := range r.Blocks { + binary.BigEndian.PutUint32(buf[idx:], b.SourceIdentifier) + binary.BigEndian.PutUint32(buf[idx+4:], b.PacketsLost) + buf[idx+4] = b.FractionLost + binary.BigEndian.PutUint32(buf[idx+8:], b.HighestSequence) + binary.BigEndian.PutUint32(buf[idx+12:], b.Jitter) + binary.BigEndian.PutUint32(buf[idx+16:], b.SenderReportTs) + binary.BigEndian.PutUint32(buf[idx+20:], b.SenderReportDelay) + idx += 24 + } + + for _, e := range r.Extensions { + copy(buf[idx:], e[:]) + idx += 4 + } + + return buf +} + +// ReportBlock describes an RTCP report block used in Sender/Receiver Reports. +type ReportBlock struct { + SourceIdentifier uint32 // Source identifier. + FractionLost uint8 // Fraction of packets lost. + PacketsLost uint32 // Cumulative number of packets lost. + HighestSequence uint32 // Extended highest sequence number received. + Jitter uint32 // Interarrival jitter. + SenderReportTs uint32 // Last sender report timestamp. + SenderReportDelay uint32 // Delay since last sender report. +} + +// Description describes a source description RTCP packet. +type Description struct { + Header // Standard RTCP packet header. + Chunks []Chunk // Chunks to describe items of each SSRC. +} + +// Bytes returns an []byte of the Description d. +func (d *Description) Bytes(buf []byte) []byte { + bodyLen := d.bodyLen() + rem := bodyLen % 4 + if rem != 0 { + bodyLen += 4 - rem + } + + l := 4 + bodyLen + if buf == nil || cap(buf) < l { + buf = make([]byte, l) + } + buf = buf[:l] + + d.writeHeader(buf, bodyLen/4) + idx := 4 + for _, c := range d.Chunks { + binary.BigEndian.PutUint32(buf[idx:], c.SSRC) + idx += 4 + for _, i := range c.Items { + buf[idx] = i.Type + buf[idx+1] = byte(len(i.Text)) + idx += 2 + copy(buf[idx:], i.Text) + idx += len(i.Text) + } + } + return buf +} + +// bodyLen calculates the body length of a source description packet in bytes. +func (d *Description) bodyLen() int { + var l int + for _, c := range d.Chunks { + l += c.len() + } + return l +} + +// SenderReport describes an RTCP sender report. +type SenderReport struct { + Header // Standard RTCP header. + SSRC uint32 // SSRC of sender. + TimestampMSW uint32 // Most significant word of timestamp. + TimestampLSW uint32 // Least significant word of timestamp. + RTPTimestamp uint32 // Current RTP timestamp. + PacketCount uint32 // Senders packet count. + OctetCount uint32 // Senders octet count. + + // Report blocks (unimplemented) + // ... +} + +// Bytes returns a []byte of the SenderReport. +func (r *SenderReport) Bytes() []byte { + buf := make([]byte, senderReportSize) + r.writeHeader(buf, senderReportSize-1) + for i, w := range []uint32{ + r.SSRC, + r.TimestampMSW, + r.TimestampLSW, + r.RTPTimestamp, + r.PacketCount, + r.OctetCount, + } { + binary.BigEndian.PutUint32(buf[i+4:], w) + } + return buf +} + +// Header describes a standard RTCP packet header. +type Header struct { + Version uint8 // RTCP version. + Padding bool // Padding indicator. + ReportCount uint8 // Number of reports contained. + Type uint8 // Type of RTCP packet. +} + +// SDESItem describes a source description item. +type SDESItem struct { + Type uint8 // Type of item. + Text []byte // Item text. +} + +// Chunk describes a source description chunk for a given SSRC. +type Chunk struct { + SSRC uint32 // SSRC of the source being described by the below items. + Items []SDESItem // Items describing the source. +} + +// len returns the len of a chunk in bytes. +func (c *Chunk) len() int { + tot := 4 + for _, i := range c.Items { + tot += 2 + len(i.Text) + } + return tot +} + +// writeHeader writes the standard RTCP header given a buffer to write to and l +// the RTCP body length that needs to be encoded into the header. +func (h Header) writeHeader(buf []byte, l int) { + buf[0] = h.Version<<6 | asByte(h.Padding)<<5 | 0x1f&h.ReportCount + buf[1] = h.Type + binary.BigEndian.PutUint16(buf[2:], uint16(l)) +} + +func asByte(b bool) byte { + if b { + return 0x01 + } + return 0x00 +} diff --git a/protocol/rtcp/rtcp_test.go b/protocol/rtcp/rtcp_test.go new file mode 100644 index 00000000..9d109ebe --- /dev/null +++ b/protocol/rtcp/rtcp_test.go @@ -0,0 +1,112 @@ +/* +NAME + rtcp_test.go + +DESCRIPTION + rtcp_test.go contains testing utilities for functionality provided in rtcp_test.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtcp + +import ( + "bytes" + "math" + "testing" +) + +// TestReceiverReportBytes checks that we can correctly obtain a []byte of an +// RTCP receiver report from the struct representation. +func TestReceiverReportBytes(t *testing.T) { + expect := []byte{ + 0x81, 0xc9, 0x00, 0x07, + 0xd6, 0xe0, 0x98, 0xda, + 0x6f, 0xad, 0x40, 0xc6, + 0x00, 0xff, 0xff, 0xff, + 0x00, 0x01, 0x83, 0x08, + 0x00, 0x00, 0x00, 0x20, + 0xb9, 0xe1, 0x25, 0x2a, + 0x00, 0x00, 0x2b, 0xf9, + } + + report := ReceiverReport{ + Header: Header{ + Version: 2, + Padding: false, + ReportCount: 1, + Type: typeReceiverReport, + }, + SenderSSRC: 3605043418, + Blocks: []ReportBlock{ + ReportBlock{ + SourceIdentifier: 1873625286, + FractionLost: 0, + PacketsLost: math.MaxUint32, + HighestSequence: 99080, + Jitter: 32, + SenderReportTs: 3118540074, + SenderReportDelay: 11257, + }, + }, + Extensions: nil, + } + + got := report.Bytes(nil) + if !bytes.Equal(got, expect) { + t.Errorf("did not get expected result. \nGot: %v\nWant: %v\n", got, expect) + } +} + +// TestSourceDescriptionBytes checks that we can correctly obtain a []byte of an +// RTCP source description from the struct representation. +func TestSourceDescriptionBytes(t *testing.T) { + expect := []byte{ + 0x81, 0xca, 0x00, 0x04, + 0xd6, 0xe0, 0x98, 0xda, + 0x01, 0x08, 0x73, 0x61, + 0x78, 0x6f, 0x6e, 0x2d, + 0x70, 0x63, 0x00, 0x00, + } + + description := Description{ + Header: Header{ + Version: 2, + Padding: false, + ReportCount: 1, + Type: typeDescription, + }, + Chunks: []Chunk{ + Chunk{ + SSRC: 3605043418, + Items: []SDESItem{ + SDESItem{ + Type: typeCName, + Text: []byte("saxon-pc"), + }, + }, + }, + }, + } + got := description.Bytes(nil) + if !bytes.Equal(got, expect) { + t.Errorf("Did not get expected result.\nGot: %v\n Want: %v\n", got, expect) + } +} diff --git a/protocol/rtmp/rtmp_test.go b/protocol/rtmp/rtmp_test.go index 1cf056cb..e1e79796 100644 --- a/protocol/rtmp/rtmp_test.go +++ b/protocol/rtmp/rtmp_test.go @@ -38,7 +38,7 @@ import ( "testing" "time" - "bitbucket.org/ausocean/av/codec/lex" + "bitbucket.org/ausocean/av/codec/h264" "bitbucket.org/ausocean/av/container/flv" ) @@ -199,7 +199,7 @@ func TestFromFrame(t *testing.T) { if err != nil { t.Errorf("Failed to create flv encoder with error: %v", err) } - err = lex.H264(flvEncoder, bytes.NewReader(videoData), time.Second/time.Duration(frameRate)) + err = h264.Lex(flvEncoder, bytes.NewReader(videoData), time.Second/time.Duration(frameRate)) if err != nil { t.Errorf("Lexing failed with error: %v", err) } @@ -251,7 +251,7 @@ func TestFromFile(t *testing.T) { if err != nil { t.Fatalf("failed to create encoder: %v", err) } - err = lex.H264(flvEncoder, f, time.Second/time.Duration(25)) + err = h264.Lex(flvEncoder, f, time.Second/time.Duration(25)) if err != nil { t.Errorf("Lexing and encoding failed with error: %v", err) } diff --git a/protocol/rtp/client.go b/protocol/rtp/client.go new file mode 100644 index 00000000..e8418b0d --- /dev/null +++ b/protocol/rtp/client.go @@ -0,0 +1,118 @@ +/* +NAME + client.go + +DESCRIPTION + client.go provides an RTP client. + +AUTHOR + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtp + +import ( + "net" + "sync" +) + +// Client describes an RTP client that can receive an RTP stream and implements +// io.Reader. +type Client struct { + r *PacketReader + ssrc uint32 + mu sync.Mutex + sequence uint16 + cycles uint16 +} + +// NewClient returns a pointer to a new Client. +// +// addr is the address of form : that we expect to receive +// RTP at. +func NewClient(addr string) (*Client, error) { + c := &Client{r: &PacketReader{}} + + a, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + c.r.PacketConn, err = net.ListenUDP("udp", a) + if err != nil { + return nil, err + } + return c, nil +} + +// SSRC returns the identified for the source from which the RTP packets being +// received are coming from. +func (c *Client) SSRC() uint32 { + return c.ssrc +} + +// Read implements io.Reader. +func (c *Client) Read(p []byte) (int, error) { + n, err := c.r.Read(p) + if err != nil { + return n, err + } + if c.ssrc == 0 { + c.ssrc, _ = SSRC(p[:n]) + } + s, _ := Sequence(p[:n]) + c.setSequence(s) + return n, err +} + +// setSequence sets the most recently received sequence number, and updates the +// cycles count if the sequence number has rolled over. +func (c *Client) setSequence(s uint16) { + c.mu.Lock() + if s < c.sequence { + c.cycles++ + } + c.sequence = s + c.mu.Unlock() +} + +// Sequence returns the most recent RTP packet sequence number received. +func (c *Client) Sequence() uint16 { + c.mu.Lock() + defer c.mu.Unlock() + return c.sequence +} + +// Cycles returns the number of RTP sequence number cycles that have been received. +func (c *Client) Cycles() uint16 { + c.mu.Lock() + defer c.mu.Unlock() + return c.cycles +} + +// PacketReader provides an io.Reader interface to an underlying UDP PacketConn. +type PacketReader struct { + net.PacketConn +} + +// Read implements io.Reader. +func (r PacketReader) Read(b []byte) (int, error) { + n, _, err := r.PacketConn.ReadFrom(b) + return n, err +} diff --git a/protocol/rtp/client_test.go b/protocol/rtp/client_test.go new file mode 100644 index 00000000..39810fef --- /dev/null +++ b/protocol/rtp/client_test.go @@ -0,0 +1,125 @@ +/* +NAME + client_test.go + +DESCRIPTION + client_test.go provides testing utilities to check RTP client functionality + provided in client.go. + +AUTHOR + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtp + +import ( + "bytes" + "fmt" + "io" + "net" + "testing" +) + +// TestReceive checks that the Client can correctly receive RTP packets and +// perform a specificed operation on the packets before storing in the ringBuffer. +func TestReceive(t *testing.T) { + const ( + clientAddr = "localhost:8000" + packetsToSend = 20 + ) + + testErr := make(chan error) + serverErr := make(chan error) + done := make(chan struct{}) + clientReady := make(chan struct{}) + var c *Client + + // Start routine to read from client. + go func() { + // Create and start the client. + var err error + c, err = NewClient(clientAddr) + if err != nil { + testErr <- fmt.Errorf("could not create client, failed with error: %v\n", err) + } + close(clientReady) + + // Read packets using the client and check them with expected. + var packetsReceived int + buf := make([]byte, 4096) + for packetsReceived != packetsToSend { + n, err := c.Read(buf) + switch err { + case nil: + case io.EOF: + continue + default: + testErr <- fmt.Errorf("unexpected error from c.Read: %v\n", err) + } + + // Create expected data and apply operation if there is one. + expect := (&Packet{V: rtpVer, Payload: []byte{byte(packetsReceived)}}).Bytes(nil) + + // Compare. + got := buf[:n] + if !bytes.Equal(got, expect) { + testErr <- fmt.Errorf("did not get expected result. \nGot: %v\n Want: %v\n", got, expect) + } + packetsReceived++ + } + close(done) + }() + + // Start the RTP server. + go func() { + <-clientReady + cAddr, err := net.ResolveUDPAddr("udp", clientAddr) + if err != nil { + serverErr <- fmt.Errorf("could not resolve server address, failed with err: %v\n", err) + } + + conn, err := net.DialUDP("udp", nil, cAddr) + if err != nil { + serverErr <- fmt.Errorf("could not dial udp, failed with err: %v\n", err) + } + + // Send packets to the client. + for i := 0; i < packetsToSend; i++ { + p := (&Packet{V: rtpVer, Payload: []byte{byte(i)}}).Bytes(nil) + _, err := conn.Write(p) + if err != nil { + serverErr <- fmt.Errorf("could not write packet to conn, failed with err: %v\n", err) + } + } + }() + + <-clientReady +loop: + for { + select { + case err := <-testErr: + t.Fatal(err) + case err := <-serverErr: + t.Fatal(err) + case <-done: + break loop + default: + } + } +} diff --git a/protocol/rtp/encoder.go b/protocol/rtp/encoder.go index 587f64c1..d74ea97c 100644 --- a/protocol/rtp/encoder.go +++ b/protocol/rtp/encoder.go @@ -97,7 +97,7 @@ func min(a, b int) int { // Encode takes a nalu unit and encodes it into an rtp packet and // writes to the io.Writer given in NewEncoder func (e *Encoder) Encode(payload []byte) error { - pkt := Pkt{ + pkt := Packet{ V: rtpVer, // version X: false, // header extension CC: 0, // CSRC count diff --git a/protocol/rtp/parse.go b/protocol/rtp/parse.go index d658aa20..16e64c5d 100644 --- a/protocol/rtp/parse.go +++ b/protocol/rtp/parse.go @@ -34,14 +34,25 @@ import ( const badVer = "incompatible RTP version" -// Payload returns the payload from an RTP packet provided the version is -// compatible, otherwise an error is returned. -func Payload(d []byte) ([]byte, error) { +// Marker returns the state of the RTP marker bit, and an error if parsing fails. +func Marker(d []byte) (bool, error) { if len(d) < defaultHeadSize { panic("invalid RTP packet length") } + if version(d) != rtpVer { - return nil, errors.New(badVer) + return false, errors.New(badVer) + } + + return d[1]&0x80 != 0, nil +} + +// Payload returns the payload from an RTP packet provided the version is +// compatible, otherwise an error is returned. +func Payload(d []byte) ([]byte, error) { + err := checkPacket(d) + if err != nil { + return nil, err } extLen := 0 if hasExt(d) { @@ -51,6 +62,38 @@ func Payload(d []byte) ([]byte, error) { return d[payloadIdx:], nil } +// SSRC returns the source identifier from an RTP packet. An error is return if +// the packet is not valid. +func SSRC(d []byte) (uint32, error) { + err := checkPacket(d) + if err != nil { + return 0, err + } + return binary.BigEndian.Uint32(d[8:]), nil +} + +// Sequence returns the sequence number of an RTP packet. An error is returned +// if the packet is not valid. +func Sequence(d []byte) (uint16, error) { + err := checkPacket(d) + if err != nil { + return 0, err + } + return binary.BigEndian.Uint16(d[2:]), nil +} + +// checkPacket checks the validity of the packet, firstly by checking size and +// then also checking that version is compatible with these utilities. +func checkPacket(d []byte) error { + if len(d) < defaultHeadSize { + return errors.New("invalid RTP packet length") + } + if version(d) != rtpVer { + return errors.New(badVer) + } + return nil +} + // hasExt returns true if an extension is present in the RTP packet. func hasExt(d []byte) bool { return (d[0] & 0x10 >> 4) == 1 diff --git a/protocol/rtp/parse_test.go b/protocol/rtp/parse_test.go index f3468c57..1f046f68 100644 --- a/protocol/rtp/parse_test.go +++ b/protocol/rtp/parse_test.go @@ -35,7 +35,7 @@ import ( // TestVersion checks that we can correctly get the version from an RTP packet. func TestVersion(t *testing.T) { const expect = 1 - got := version((&Pkt{V: expect}).Bytes(nil)) + got := version((&Packet{V: expect}).Bytes(nil)) if got != expect { t.Errorf("unexpected version for RTP packet. Got: %v\n Want: %v\n", got, expect) } @@ -46,7 +46,7 @@ func TestVersion(t *testing.T) { func TestCsrcCount(t *testing.T) { const ver, expect = 2, 2 - pkt := (&Pkt{ + pkt := (&Packet{ V: ver, CC: expect, CSRC: make([][4]byte, expect), @@ -64,7 +64,7 @@ func TestHasExt(t *testing.T) { const ver = 2 // First check for when there is an extension field. - pkt := &Pkt{ + pkt := &Packet{ V: ver, X: true, Extension: ExtensionHeader{ @@ -93,19 +93,19 @@ func TestPayload(t *testing.T) { expect := []byte{0x01, 0x02, 0x03, 0x04, 0x05} testPkts := [][]byte{ - (&Pkt{ + (&Packet{ V: ver, Payload: expect, }).Bytes(nil), - (&Pkt{ + (&Packet{ V: ver, CC: 3, CSRC: make([][4]byte, 3), Payload: expect, }).Bytes(nil), - (&Pkt{ + (&Packet{ V: ver, X: true, Extension: ExtensionHeader{ @@ -115,7 +115,7 @@ func TestPayload(t *testing.T) { Payload: expect, }).Bytes(nil), - (&Pkt{ + (&Packet{ V: ver, CC: 3, CSRC: make([][4]byte, 3), diff --git a/protocol/rtp/rtp.go b/protocol/rtp/rtp.go index 73f6f15b..ba9ab8f5 100644 --- a/protocol/rtp/rtp.go +++ b/protocol/rtp/rtp.go @@ -46,7 +46,7 @@ const ( // Pkt provides fields consistent with RFC3550 definition of an rtp packet // The padding indicator does not need to be set manually, only the padding length -type Pkt struct { +type Packet struct { V uint8 // Version (currently 2). p bool // Padding indicator (0 => padding, 1 => padding). X bool // Extension header indicator. @@ -69,7 +69,7 @@ type ExtensionHeader struct { } // Bytes provides a byte slice of the packet -func (p *Pkt) Bytes(buf []byte) []byte { +func (p *Packet) Bytes(buf []byte) []byte { // Calculate the required length for the RTP packet. headerExtensionLen := 0 if p.X { diff --git a/protocol/rtp/rtp_test.go b/protocol/rtp/rtp_test.go index 2622fb81..438f6035 100644 --- a/protocol/rtp/rtp_test.go +++ b/protocol/rtp/rtp_test.go @@ -35,13 +35,13 @@ import ( // TODO (saxon): add more tests var rtpTests = []struct { num int - pkt Pkt + pkt Packet want []byte }{ // No padding, no CSRC and no extension. { num: 1, - pkt: Pkt{ + pkt: Packet{ V: 2, p: false, X: false, @@ -67,7 +67,7 @@ var rtpTests = []struct { // With padding. { num: 2, - pkt: Pkt{ + pkt: Packet{ V: 2, p: true, X: false, @@ -101,7 +101,7 @@ var rtpTests = []struct { // With padding and CSRC. { num: 3, - pkt: Pkt{ + pkt: Packet{ V: 2, p: true, X: false, @@ -141,7 +141,7 @@ var rtpTests = []struct { // With padding, CSRC and extension. { num: 4, - pkt: Pkt{ + pkt: Packet{ V: 2, p: true, X: true, diff --git a/protocol/rtsp/client.go b/protocol/rtsp/client.go new file mode 100644 index 00000000..f6c9d0eb --- /dev/null +++ b/protocol/rtsp/client.go @@ -0,0 +1,141 @@ +/* +NAME + client.go + +DESCRIPTION + client.go provides a Client type providing functionality to send RTSP requests + of methods DESCRIBE, OPTIONS, SETUP and PLAY to an RTSP server. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtsp + +import ( + "net" + "net/url" + "strconv" +) + +// Client describes an RTSP Client. +type Client struct { + cSeq int + addr string + url *url.URL + conn net.Conn + sessionID string +} + +// NewClient returns a pointer to a new Client and the local address of the +// RTSP connection. The address addr will be parsed and a connection to the +// RTSP server will be made. +func NewClient(addr string) (c *Client, local, remote *net.TCPAddr, err error) { + c = &Client{addr: addr} + c.url, err = url.Parse(addr) + if err != nil { + return nil, nil,nil, err + } + c.conn, err = net.Dial("tcp", c.url.Host) + if err != nil { + return nil, nil, nil, err + } + local = c.conn.LocalAddr().(*net.TCPAddr) + remote = c.conn.RemoteAddr().(*net.TCPAddr) + return +} + +// Close closes the RTSP connection. +func (c *Client) Close() error { + return c.conn.Close() +} + +// Describe forms and sends an RTSP request of method DESCRIBE to the RTSP server. +func (c *Client) Describe() (*Response, error) { + req, err := NewRequest("DESCRIBE", c.nextCSeq(), c.url, nil) + if err != nil { + return nil, err + } + req.Header.Add("Accept", "application/sdp") + return c.Do(req) +} + +// Options forms and sends an RTSP request of method OPTIONS to the RTSP server. +func (c *Client) Options() (*Response, error) { + req, err := NewRequest("OPTIONS", c.nextCSeq(), c.url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +// Setup forms and sends an RTSP request of method SETUP to the RTSP server. +func (c *Client) Setup(track, transport string) (*Response, error) { + u, err := url.Parse(c.addr + "/" + track) + if err != nil { + return nil, err + } + + req, err := NewRequest("SETUP", c.nextCSeq(), u, nil) + if err != nil { + return nil, err + } + req.Header.Add("Transport", transport) + + resp, err := c.Do(req) + if err != nil { + return nil, err + } + c.sessionID = resp.Header.Get("Session") + + return resp, err +} + +// Play forms and sends an RTSP request of method PLAY to the RTSP server +func (c *Client) Play() (*Response, error) { + req, err := NewRequest("PLAY", c.nextCSeq(), c.url, nil) + if err != nil { + return nil, err + } + req.Header.Add("Session", c.sessionID) + + return c.Do(req) +} + +// Do sends the given RTSP request req, reads any responses and returns the response +// and any errors. +func (c *Client) Do(req *Request) (*Response, error) { + err := req.Write(c.conn) + if err != nil { + return nil, err + } + + resp, err := ReadResponse(c.conn) + if err != nil { + return nil, err + } + + return resp, nil +} + +// nextCSeq provides the next CSeq number for the next RTSP request. +func (c *Client) nextCSeq() string { + c.cSeq++ + return strconv.Itoa(c.cSeq) +} diff --git a/protocol/rtsp/rtsp.go b/protocol/rtsp/rtsp.go new file mode 100644 index 00000000..9e0995c8 --- /dev/null +++ b/protocol/rtsp/rtsp.go @@ -0,0 +1,182 @@ +/* +NAME + rtsp.go + +DESCRIPTION + rtsp.go provides functionality for forming and sending RTSP requests for + methods, DESCRIBE, OPTIONS, SETUP and PLAY, as described by + the RTSP standards, see https://tools.ietf.org/html/rfc7826 + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// Package rtsp provides an RTSP client implementation and methods for +// communication with an RTSP server to request video. +package rtsp + +import ( + "bufio" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "strconv" + "strings" +) + +// Minimum response size to be considered valid in bytes. +const minResponse = 12 + +var errInvalidResponse = errors.New("invalid response") + +// Request describes an RTSP request. +type Request struct { + Method string + URL *url.URL + Proto string + ProtoMajor int + ProtoMinor int + Header http.Header + ContentLength int + Body io.ReadCloser +} + +// NewRequest returns a pointer to a new Request. +func NewRequest(method, cSeq string, u *url.URL, body io.ReadCloser) (*Request, error) { + req := &Request{ + Method: method, + URL: u, + Proto: "RTSP", + ProtoMajor: 1, + ProtoMinor: 0, + Header: map[string][]string{"CSeq": []string{cSeq}}, + Body: body, + } + return req, nil +} + +// Write writes the request r to the given io.Writer w. +func (r *Request) Write(w io.Writer) error { + _, err := w.Write([]byte(r.String())) + return err +} + +// String returns a formatted string of the Request. +func (r Request) String() string { + var b strings.Builder + fmt.Fprintf(&b, "%s %s %s/%d.%d\r\n", r.Method, r.URL.String(), r.Proto, r.ProtoMajor, r.ProtoMinor) + for k, v := range r.Header { + for _, v := range v { + fmt.Fprintf(&b, "%s: %s\r\n", k, v) + } + } + b.WriteString("\r\n") + if r.Body != nil { + s, _ := ioutil.ReadAll(r.Body) + b.WriteString(string(s)) + } + return b.String() +} + +// Response describes an RTSP response. +type Response struct { + Proto string + ProtoMajor int + ProtoMinor int + StatusCode int + ContentLength int + Header http.Header + Body io.ReadCloser +} + +// String returns a formatted string of the Response. +func (r Response) String() string { + var b strings.Builder + fmt.Fprintf(&b, "%s/%d.%d %d\n", r.Proto, r.ProtoMajor, r.ProtoMinor, r.StatusCode) + for k, v := range r.Header { + for _, v := range v { + fmt.Fprintf(&b, "%s: %s", k, v) + } + } + return b.String() +} + +// ReadResponse will read the response of the RTSP request from the connection, +// and return a pointer to a new Response. +func ReadResponse(r io.Reader) (*Response, error) { + resp := &Response{Header: make(map[string][]string)} + + scanner := bufio.NewScanner(r) + + // Read the first line. + scanner.Scan() + err := scanner.Err() + if err != nil { + return nil, err + } + s := scanner.Text() + + if len(s) < minResponse || !strings.HasPrefix(s, "RTSP/") { + return nil, errInvalidResponse + } + resp.Proto = "RTSP" + + n, err := fmt.Sscanf(s[5:], "%d.%d %d", &resp.ProtoMajor, &resp.ProtoMinor, &resp.StatusCode) + if err != nil || n != 3 { + return nil, fmt.Errorf("could not Sscanf response, error: %v", err) + } + + // Read headers. + for scanner.Scan() { + err = scanner.Err() + if err != nil { + return nil, err + } + parts := strings.SplitN(scanner.Text(), ":", 2) + if len(parts) < 2 { + break + } + resp.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])) + } + // Get the content length from the header. + resp.ContentLength, _ = strconv.Atoi(resp.Header.Get("Content-Length")) + + resp.Body = closer{r} + return resp, nil +} + +type closer struct { + io.Reader +} + +func (c closer) Close() error { + if c.Reader == nil { + return nil + } + defer func() { + c.Reader = nil + }() + if r, ok := c.Reader.(io.ReadCloser); ok { + return r.Close() + } + return nil +} diff --git a/protocol/rtsp/rtsp_test.go b/protocol/rtsp/rtsp_test.go new file mode 100644 index 00000000..4157cfba --- /dev/null +++ b/protocol/rtsp/rtsp_test.go @@ -0,0 +1,344 @@ +/* +NAME + 0x r,tsp_test.go + +DESCRIPTION + rtsp_test.go provides a test to check functionality provided in rtsp.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtsp + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "net/url" + "strings" + "testing" + "time" + "unicode" +) + +// The max request size we should get in bytes. +const maxRequest = 1024 + +// TestMethods checks that we can correctly form requests for each of the RTSP +// methods supported in the rtsp pkg. This test also checks that communication +// over a TCP connection is performed correctly. +func TestMethods(t *testing.T) { + const dummyURL = "rtsp://admin:admin@192.168.0.50:8554/CH001.sdp" + url, err := url.Parse(dummyURL) + if err != nil { + t.Fatalf("could not parse dummy address, failed with err: %v", err) + } + + // tests holds tests which consist of a function used to create and write a + // request of a particular method, and also the expected request bytes + // to be received on the server side. The bytes in these tests have been + // obtained from a valid RTSP communication cltion.. + tests := []struct { + method func(c *Client) (*Response, error) + expected []byte + }{ + { + method: func(c *Client) (*Response, error) { + req, err := NewRequest("DESCRIBE", c.nextCSeq(), url, nil) + if err != nil { + return nil, err + } + req.Header.Add("Accept", "application/sdp") + return c.Do(req) + }, + expected: []byte{ + 0x44, 0x45, 0x53, 0x43, 0x52, 0x49, 0x42, 0x45, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, 0x2f, // |DESCRIBE rtsp://| + 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, // |admin:admin@192.| + 0x31, 0x36, 0x38, 0x2e, 0x30, 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, // |168.0.50:8554/CH| + 0x30, 0x30, 0x31, 0x2e, 0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, // |001.sdp RTSP/1.0| + 0x0d, 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x32, 0x0d, 0x0a, 0x41, 0x63, 0x63, 0x65, 0x70, // |..CSeq: 2..Accep| + 0x74, 0x3a, 0x20, 0x61, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x73, // |t: application/s| + 0x64, 0x70, 0x0d, 0x0a, 0x0d, 0x0a, /* */ // |dp....| + }, + }, + { + method: func(c *Client) (*Response, error) { + req, err := NewRequest("OPTIONS", c.nextCSeq(), url, nil) + if err != nil { + return nil, err + } + return c.Do(req) + }, + expected: []byte{ + 0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x53, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, 0x2f, 0x61, // |OPTIONS rtsp://a| + 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, // |dmin:admin@192.1| + 0x36, 0x38, 0x2e, 0x30, 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, // |68.0.50:8554/CH0| + 0x30, 0x31, 0x2e, 0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, 0x0d, // |01.sdp RTSP/1.0.| + 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x31, 0x0d, 0x0a, 0x0d, 0x0a, /* */ // |.CSeq: 1....| + }, + }, + { + method: func(c *Client) (*Response, error) { + u, err := url.Parse(dummyURL + "/track1") + if err != nil { + return nil, err + } + + req, err := NewRequest("SETUP", c.nextCSeq(), u, nil) + if err != nil { + return nil, err + } + req.Header.Add("Transport", fmt.Sprintf("RTP/AVP;unicast;client_port=%d-%d", 6870, 6871)) + + return c.Do(req) + }, + expected: []byte{ + 0x53, 0x45, 0x54, 0x55, 0x50, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, 0x2f, 0x61, 0x64, 0x6d, // |SETUP rtsp://adm| + 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, // |in:admin@192.168| + 0x2e, 0x30, 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, 0x30, 0x31, // |.0.50:8554/CH001| + 0x2e, 0x73, 0x64, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x6b, 0x31, 0x20, 0x52, 0x54, 0x53, 0x50, // |.sdp/track1 RTSP| + 0x2f, 0x31, 0x2e, 0x30, 0x0d, 0x0a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x3a, // |/1.0..Transport:| + 0x20, 0x52, 0x54, 0x50, 0x2f, 0x41, 0x56, 0x50, 0x3b, 0x75, 0x6e, 0x69, 0x63, 0x61, 0x73, 0x74, // | RTP/AVP;unicast| + 0x3b, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x3d, 0x36, 0x38, 0x37, // |;client_port=687| + 0x30, 0x2d, 0x36, 0x38, 0x37, 0x31, 0x0d, 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x33, 0x0d, // |0-6871..CSeq: 3.| + 0x0a, 0x0d, 0x0a, /* */ // |...| + }, + }, + { + method: func(c *Client) (*Response, error) { + req, err := NewRequest("PLAY", c.nextCSeq(), url, nil) + if err != nil { + return nil, err + } + req.Header.Add("Session", "00000021") + + return c.Do(req) + }, + expected: []byte{ + 0x50, 0x4c, 0x41, 0x59, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, 0x2f, 0x61, 0x64, 0x6d, 0x69, // |PLAY rtsp://admi| + 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, // |n:admin@192.168.| + 0x30, 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, 0x30, 0x31, 0x2e, // |0.50:8554/CH001.| + 0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, 0x0d, 0x0a, 0x43, 0x53, // |sdp RTSP/1.0..CS| + 0x65, 0x71, 0x3a, 0x20, 0x34, 0x0d, 0x0a, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x3a, 0x20, // |eq: 4..Session: | + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x32, 0x31, 0x0d, 0x0a, 0x0d, 0x0a, /* */ // |00000021....| + }, + }, + } + + const serverAddr = "rtsp://localhost:8005" + const retries = 10 + + clientErr := make(chan error) + serverErr := make(chan error) + done := make(chan struct{}) + + // This routine acts as the server. + go func() { + l, err := net.Listen("tcp", strings.TrimLeft(serverAddr, "rtsp://")) + if err != nil { + serverErr <- errors.New(fmt.Sprintf("server could not listen, error: %v", err)) + } + + conn, err := l.Accept() + if err != nil { + serverErr <- errors.New(fmt.Sprintf("server could not accept connection, error: %v", err)) + } + + buf := make([]byte, maxRequest) + var n int + for i, test := range tests { + loop: + for { + n, err = conn.Read(buf) + err, ok := err.(net.Error) + + switch { + case err == nil: + break loop + case err == io.EOF: + case ok && err.Timeout(): + default: + serverErr <- errors.New(fmt.Sprintf("server could not read conn, error: %v", err)) + return + } + } + + // Write a dummy response, client won't care. + conn.Write([]byte{'\n'}) + + want := test.expected + got := buf[:n] + if !equal(got, want) { + serverErr <- errors.New(fmt.Sprintf("unexpected result for test: %v. \nGot: %v\n Want: %v\n", i, got, want)) + } + } + close(done) + }() + + // This routine acts as the client. + go func() { + var clt *Client + var err error + + // Keep trying to connect to server. + // TODO: use generalised retry utility when available. + for retry := 0; ; retry++ { + clt, _, _, err = NewClient(serverAddr) + if err == nil { + break + } + + if retry > retries { + clientErr <- errors.New(fmt.Sprintf("client could not connect to server, error: %v", err)) + } + time.Sleep(10 * time.Millisecond) + } + + for i, test := range tests { + _, err = test.method(clt) + if err != nil && err != io.EOF && err != errInvalidResponse { + clientErr <- errors.New(fmt.Sprintf("error request for: %v err: %v", i, err)) + } + } + }() + + // We check for errors or a done signal from the server and client routines. + for { + select { + case err := <-clientErr: + t.Fatalf("client error: %v", err) + case err := <-serverErr: + t.Fatalf("server error: %v", err) + case <-done: + return + default: + } + } +} + +// equal checks that the got slice is considered equivalent to the want slice, +// neglecting unimportant differences such as order of items in header and the +// CSeq number. +func equal(got, want []byte) bool { + const eol = "\r\n" + gotParts := strings.Split(strings.TrimRight(string(got), eol), eol) + wantParts := strings.Split(strings.TrimRight(string(want), eol), eol) + gotParts, ok := rmSeqNum(gotParts) + if !ok { + return false + } + wantParts, ok = rmSeqNum(wantParts) + if !ok { + return false + } + for _, gotStr := range gotParts { + for i, wantStr := range wantParts { + if gotStr == wantStr { + wantParts = append(wantParts[:i], wantParts[i+1:]...) + } + } + } + return len(wantParts) == 0 +} + +// rmSeqNum removes the CSeq number from a string in []string that contains it. +// If a CSeq field is not found nil and false is returned. +func rmSeqNum(s []string) ([]string, bool) { + for i, _s := range s { + if strings.Contains(_s, "CSeq") { + s[i] = strings.TrimFunc(s[i], func(r rune) bool { return unicode.IsNumber(r) }) + return s, true + } + } + return nil, false +} + +// TestReadResponse checks that ReadResponse behaves as expected. +func TestReadResponse(t *testing.T) { + // input has been obtained from a valid RTSP response. + input := []byte{ + 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, 0x20, 0x32, 0x30, 0x30, 0x20, 0x4f, 0x4b, 0x0d, // |RTSP/1.0 200 OK.| + 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x32, 0x0d, 0x0a, 0x44, 0x61, 0x74, 0x65, 0x3a, 0x20, // |.CSeq: 2..Date: | + 0x57, 0x65, 0x64, 0x2c, 0x20, 0x4a, 0x61, 0x6e, 0x20, 0x32, 0x31, 0x20, 0x31, 0x39, 0x37, 0x30, // |Wed, Jan 21 1970| + 0x20, 0x30, 0x32, 0x3a, 0x33, 0x37, 0x3a, 0x31, 0x34, 0x20, 0x47, 0x4d, 0x54, 0x0d, 0x0a, 0x50, // | 02:37:14 GMT..P| + 0x75, 0x62, 0x6c, 0x69, 0x63, 0x3a, 0x20, 0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x53, 0x2c, 0x20, // |ublic: OPTIONS, | + 0x44, 0x45, 0x53, 0x43, 0x52, 0x49, 0x42, 0x45, 0x2c, 0x20, 0x53, 0x45, 0x54, 0x55, 0x50, 0x2c, // |DESCRIBE, SETUP,| + 0x20, 0x54, 0x45, 0x41, 0x52, 0x44, 0x4f, 0x57, 0x4e, 0x2c, 0x20, 0x50, 0x4c, 0x41, 0x59, 0x2c, // | TEARDOWN, PLAY,| + 0x20, 0x47, 0x45, 0x54, 0x5f, 0x50, 0x41, 0x52, 0x41, 0x4d, 0x45, 0x54, 0x45, 0x52, 0x2c, 0x20, // | GET_PARAMETER, | + 0x53, 0x45, 0x54, 0x5f, 0x50, 0x41, 0x52, 0x41, 0x4d, 0x45, 0x54, 0x45, 0x52, 0x0d, 0x0a, 0x0d, // |SET_PARAMETER...| + 0x0a, + } + + expect := Response{ + Proto: "RTSP", + ProtoMajor: 1, + ProtoMinor: 0, + StatusCode: 200, + ContentLength: 0, + Header: map[string][]string{ + "Cseq": []string{"2"}, + "Date": []string{"Wed, Jan 21 1970 02:37:14 GMT"}, + "Public": []string{"OPTIONS, DESCRIBE, SETUP, TEARDOWN, PLAY, GET_PARAMETER, SET_PARAMETER"}, + }, + } + + got, err := ReadResponse(bytes.NewReader(input)) + if err != nil { + t.Fatalf("should not have got error: %v", err) + } + + if !respEqual(*got, expect) { + t.Errorf("did not get expected result.\nGot: %+v\n Want: %+v\n", got, expect) + } +} + +// respEqual checks the equality of two Responses. +func respEqual(got, want Response) bool { + for _, f := range [][2]interface{}{ + {got.Proto, want.Proto}, + {got.ProtoMajor, want.ProtoMajor}, + {got.ProtoMinor, want.ProtoMinor}, + {got.StatusCode, want.StatusCode}, + {got.ContentLength, want.ContentLength}, + } { + if f[0] != f[1] { + return false + } + } + + if len(got.Header) != len(want.Header) { + return false + } + + for k, v := range got.Header { + if len(v) != len(want.Header[k]) { + return false + } + + for i, _v := range v { + if _v != want.Header[k][i] { + return false + } + } + } + return true +} diff --git a/revid/config.go b/revid/config.go index e4ff1a46..7e5c2d20 100644 --- a/revid/config.go +++ b/revid/config.go @@ -45,6 +45,7 @@ type Config struct { Packetization uint8 Quantize bool // Determines whether input to revid will have constant or variable bitrate. RtmpUrl string + RTSPURL string Bitrate uint OutputPath string InputPath string @@ -137,6 +138,7 @@ const ( Udp MpegtsRtp Rtp + RTSP ) // Default config settings @@ -186,7 +188,7 @@ func (c *Config) Validate(r *Revid) error { } switch c.Input { - case Raspivid, V4L, File, Audio: + case Raspivid, V4L, File, Audio, RTSP: case NothingDefined: c.Logger.Log(logger.Info, pkg+"no input type defined, defaulting", "input", defaultInput) c.Input = defaultInput diff --git a/revid/revid.go b/revid/revid.go index dc3a0951..ec34b358 100644 --- a/revid/revid.go +++ b/revid/revid.go @@ -32,6 +32,7 @@ import ( "errors" "fmt" "io" + "net" "os" "os/exec" "strconv" @@ -39,9 +40,14 @@ import ( "sync" "time" - "bitbucket.org/ausocean/av/codec/lex" + "bitbucket.org/ausocean/av/codec/codecutil" + "bitbucket.org/ausocean/av/codec/h264" + "bitbucket.org/ausocean/av/codec/h265" "bitbucket.org/ausocean/av/container/flv" "bitbucket.org/ausocean/av/container/mts" + "bitbucket.org/ausocean/av/protocol/rtcp" + "bitbucket.org/ausocean/av/protocol/rtp" + "bitbucket.org/ausocean/av/protocol/rtsp" "bitbucket.org/ausocean/iot/pi/netsender" "bitbucket.org/ausocean/utils/ioext" "bitbucket.org/ausocean/utils/logger" @@ -59,6 +65,12 @@ const ( rtmpConnectionTimeout = 10 ) +const ( + rtpPort = 60000 + rtcpPort = 60001 + defaultServerRTCPPort = 17301 +) + const pkg = "revid:" type Logger interface { @@ -161,8 +173,17 @@ func (r *Revid) reset(config Config) error { r.config.Logger.SetLevel(config.LogLevel) err = r.setupPipeline( - func(dst io.WriteCloser, fps float64, medType int) (io.WriteCloser, error) { - e := mts.NewEncoder(dst, fps, medType) + func(dst io.WriteCloser, fps int, medType int) (io.WriteCloser, error) { + var st int + switch r.config.Input { + case Raspivid, File, V4L: + st = mts.EncodeH264 + case RTSP: + st = mts.EncodeH265 + case Audio: + st = mts.EncodeAudio + } + e := mts.NewEncoder(dst, float64(fps), st) return e, nil }, func(dst io.WriteCloser, fps int) (io.WriteCloser, error) { @@ -267,28 +288,21 @@ func (r *Revid) setupPipeline(mtsEnc func(dst io.WriteCloser, rate float64, medi switch r.config.Input { case Raspivid: r.setupInput = r.startRaspivid + r.lexTo = h264.Lex case V4L: r.setupInput = r.startV4L + r.lexTo = h264.Lex case File: r.setupInput = r.setupInputForFile + r.lexTo = h264.Lex + case RTSP: + r.setupInput = r.startRTSPCamera + r.lexTo = h265.NewLexer(false).Lex case Audio: r.setupInput = r.startAudioDevice + r.lexTo = codecutil.LexBytes } - switch r.config.InputCodec { - case H264: - r.config.Logger.Log(logger.Info, pkg+"using H264 lexer") - r.lexTo = lex.H264 - case Mjpeg: - r.config.Logger.Log(logger.Info, pkg+"using MJPEG lexer") - r.lexTo = lex.MJPEG - case PCM: - r.config.Logger.Log(logger.Info, pkg+"using PCM lexer") - r.lexTo = lex.PCM - case ADPCM: - r.config.Logger.Log(logger.Info, pkg+"using ADPCM lexer") - r.lexTo = lex.ADPCM - } return nil } @@ -647,6 +661,108 @@ func (r *Revid) startAudioDevice() (func() error, error) { }, nil } +// startRTSPCamera uses RTSP to request an RTP stream from an IP camera. An RTP +// client is created from which RTP packets containing either h264/h265 can read +// by the selected lexer. +func (r *Revid) startRTSPCamera() (func() error, error) { + rtspClt, local, remote, err := rtsp.NewClient(r.config.RTSPURL) + if err != nil { + return nil, err + } + + resp, err := rtspClt.Options() + if err != nil { + return nil, err + } + r.config.Logger.Log(logger.Info, pkg+"RTSP OPTIONS response", "response", resp.String()) + + resp, err = rtspClt.Describe() + if err != nil { + return nil, err + } + r.config.Logger.Log(logger.Info, pkg+"RTSP DESCRIBE response", "response", resp.String()) + + resp, err = rtspClt.Setup("track1", fmt.Sprintf("RTP/AVP;unicast;client_port=%d-%d", rtpPort, rtcpPort)) + if err != nil { + return nil, err + } + r.config.Logger.Log(logger.Info, pkg+"RTSP SETUP response", "response", resp.String()) + rtpCltAddr, rtcpCltAddr, rtcpSvrAddr, err := formAddrs(local, remote, *resp) + if err != nil { + return nil, err + } + + resp, err = rtspClt.Play() + if err != nil { + return nil, err + } + r.config.Logger.Log(logger.Info, pkg+"RTSP server PLAY response", "response", resp.String()) + + rtpClt, err := rtp.NewClient(rtpCltAddr) + if err != nil { + return nil, err + } + + rtcpClt, err := rtcp.NewClient(rtcpCltAddr, rtcpSvrAddr, rtpClt, r.config.Logger.Log) + if err != nil { + return nil, err + } + + // Check errors from RTCP client until it has stopped running. + go func() { + for { + err, ok := <-rtcpClt.Err() + if ok { + r.config.Logger.Log(logger.Warning, pkg+"RTCP error", "error", err.Error()) + } else { + return + } + } + }() + + // Start the RTCP client. + rtcpClt.Start() + + // Start reading data from the RTP client. + r.wg.Add(1) + go r.processFrom(rtpClt, time.Second/time.Duration(r.config.FrameRate)) + + return func() error { + rtspClt.Close() + rtcpClt.Stop() + return nil + }, nil +} + +// formAddrs is a helper function to form the addresses for the RTP client, +// RTCP client, and the RTSP server's RTCP addr using the local, remote addresses +// of the RTSP conn, and the SETUP method response. +func formAddrs(local, remote *net.TCPAddr, setupResp rtsp.Response) (rtpCltAddr, rtcpCltAddr, rtcpSvrAddr string, err error) { + svrRTCPPort, err := parseSvrRTCPPort(setupResp) + if err != nil { + return "", "", "", err + } + rtpCltAddr = strings.Split(local.String(), ":")[0] + ":" + strconv.Itoa(rtpPort) + rtcpCltAddr = strings.Split(local.String(), ":")[0] + ":" + strconv.Itoa(rtcpPort) + rtcpSvrAddr = strings.Split(remote.String(), ":")[0] + ":" + strconv.Itoa(svrRTCPPort) + return +} + +// parseServerRTCPPort is a helper function to get the RTSP server's RTCP port. +func parseSvrRTCPPort(resp rtsp.Response) (int, error) { + transport := resp.Header.Get("Transport") + for _, p := range strings.Split(transport, ";") { + if strings.Contains(p, "server_port") { + port, err := strconv.Atoi(strings.Split(p, "-")[1]) + if err != nil { + return 0, err + } + return port, nil + } + } + return 0, errors.New("SETUP response did not provide RTCP port") +} + func (r *Revid) processFrom(read io.Reader, delay time.Duration, bufSize int) { r.config.Logger.Log(logger.Info, pkg+"reading input data") r.err <- r.lexTo(r.encoders, read, delay, bufSize) diff --git a/revid/senders_test.go b/revid/senders_test.go index 80293759..b7f67959 100644 --- a/revid/senders_test.go +++ b/revid/senders_test.go @@ -134,7 +134,7 @@ func TestMtsSenderSegment(t *testing.T) { const numberOfClips = 11 dst := &destination{t: t, done: make(chan struct{}), doneAt: numberOfClips} sender := newMtsSender(dst, (*dummyLogger)(t).log, rbSize, rbElementSize, 0) - encoder := mts.NewEncoder(sender, 25, mts.Video) + encoder := mts.NewEncoder(sender, 25, mts.EncodeH264) // Turn time based PSI writing off for encoder. const psiSendCount = 10 @@ -212,7 +212,7 @@ func TestMtsSenderFailedSend(t *testing.T) { const clipToFailAt = 3 dst := &destination{t: t, testFails: true, failAt: clipToFailAt, done: make(chan struct{})} sender := newMtsSender(dst, (*dummyLogger)(t).log, rbSize, rbElementSize, 0) - encoder := mts.NewEncoder(sender, 25, mts.Video) + encoder := mts.NewEncoder(sender, 25, mts.EncodeH264) // Turn time based PSI writing off for encoder and send PSI every 10 packets. const psiSendCount = 10 @@ -292,7 +292,7 @@ func TestMtsSenderDiscontinuity(t *testing.T) { const clipToDelay = 3 dst := &destination{t: t, sendDelay: 10 * time.Millisecond, delayAt: clipToDelay, done: make(chan struct{})} sender := newMtsSender(dst, (*dummyLogger)(t).log, 1, rbElementSize, 0) - encoder := mts.NewEncoder(sender, 25, mts.Video) + encoder := mts.NewEncoder(sender, 25, mts.EncodeH264) // Turn time based PSI writing off for encoder. const psiSendCount = 10