diff --git a/cmd/revid-cli/main.go b/cmd/revid-cli/main.go index 98804a5b..46f39905 100644 --- a/cmd/revid-cli/main.go +++ b/cmd/revid-cli/main.go @@ -105,7 +105,8 @@ func handleFlags() revid.Config { var ( cpuprofile = flag.String("cpuprofile", "", "write cpu profile to `file`") - inputPtr = flag.String("Input", "", "The input type: Raspivid, File, Webcam") + inputPtr = flag.String("Input", "", "The input type: Raspivid, File, Webcam, RTSP") + rtspURLPtr = flag.String("RTSPURL", "", "The URL for an RTSP server.") inputCodecPtr = flag.String("InputCodec", "", "The codec of the input: H264, Mjpeg") quantizePtr = flag.Bool("Quantize", false, "Quantize input (non-variable bitrate)") verbosityPtr = flag.String("Verbosity", "Info", "Verbosity: Debug, Info, Warning, Error, Fatal") @@ -178,6 +179,8 @@ func handleFlags() revid.Config { cfg.Input = revid.V4L case "File": cfg.Input = revid.File + case "RTSP": + cfg.Input = revid.RTSP case "": default: log.Log(logger.Error, pkg+"bad input argument") @@ -211,6 +214,7 @@ func handleFlags() revid.Config { netsender.ConfigFile = *configFilePtr } + cfg.RTSPURL = *rtspURLPtr cfg.Quantize = *quantizePtr cfg.Rotation = *rotationPtr cfg.FlipHorizontal = *horizontalFlipPtr diff --git a/codec/codecutil/bytescanner.go b/codec/codecutil/bytescanner.go new file mode 100644 index 00000000..03f825d6 --- /dev/null +++ b/codec/codecutil/bytescanner.go @@ -0,0 +1,95 @@ +/* +NAME + bytescanner.go + +AUTHOR + Dan Kortschak + +LICENSE + This is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// Package bytescan implements a byte-level scanner. +package codecutil + +import "io" + +// ByteScanner is a byte scanner. +type ByteScanner struct { + buf []byte + off int + + // r is the source of data for the scanner. + r io.Reader +} + +// NewByteScanner returns a scanner initialised with an io.Reader and a read buffer. +func NewByteScanner(r io.Reader, buf []byte) *ByteScanner { + return &ByteScanner{r: r, buf: buf[:0]} +} + +// ScanUntil scans the scanner's underlying io.Reader until a delim byte +// has been read, appending all read bytes to dst. The resulting appended data, +// the last read byte and whether the last read byte was the delimiter. +func (c *ByteScanner) ScanUntil(dst []byte, delim byte) (res []byte, b byte, err error) { +outer: + for { + var i int + for i, b = range c.buf[c.off:] { + if b != delim { + continue + } + dst = append(dst, c.buf[c.off:c.off+i+1]...) + c.off += i + 1 + break outer + } + dst = append(dst, c.buf[c.off:]...) + err = c.reload() + if err != nil { + break + } + } + return dst, b, err +} + +// ReadByte is an unexported ReadByte. +func (c *ByteScanner) ReadByte() (byte, error) { + if c.off >= len(c.buf) { + err := c.reload() + if err != nil { + return 0, err + } + } + b := c.buf[c.off] + c.off++ + return b, nil +} + +// reload re-fills the scanner's buffer. +func (c *ByteScanner) reload() error { + n, err := c.r.Read(c.buf[:cap(c.buf)]) + c.buf = c.buf[:n] + if err != nil { + if err != io.EOF { + return err + } + if n == 0 { + return io.EOF + } + } + c.off = 0 + return nil +} diff --git a/codec/codecutil/bytescanner_test.go b/codec/codecutil/bytescanner_test.go new file mode 100644 index 00000000..68db0006 --- /dev/null +++ b/codec/codecutil/bytescanner_test.go @@ -0,0 +1,82 @@ +/* +NAME + bytescanner_test.go + +DESCRIPTION + See Readme.md + +AUTHOR + Dan Kortschak + +LICENSE + This is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package codecutil + +import ( + "bytes" + "reflect" + "testing" +) + +type chunkEncoder [][]byte + +func (e *chunkEncoder) Encode(b []byte) error { + *e = append(*e, b) + return nil +} + +func (*chunkEncoder) Stream() <-chan []byte { panic("INVALID USE") } + +func TestScannerReadByte(t *testing.T) { + data := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") + + for _, size := range []int{1, 2, 8, 1 << 10} { + r := NewByteScanner(bytes.NewReader(data), make([]byte, size)) + var got []byte + for { + b, err := r.ReadByte() + if err != nil { + break + } + got = append(got, b) + } + if !bytes.Equal(got, data) { + t.Errorf("unexpected result for buffer size %d:\ngot :%q\nwant:%q", size, got, data) + } + } +} + +func TestScannerScanUntilZero(t *testing.T) { + data := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit,\x00 sed do eiusmod tempor incididunt ut \x00labore et dolore magna aliqua.") + + for _, size := range []int{1, 2, 8, 1 << 10} { + r := NewByteScanner(bytes.NewReader(data), make([]byte, size)) + var got [][]byte + for { + buf, _, err := r.ScanUntil(nil, 0x0) + got = append(got, buf) + if err != nil { + break + } + } + want := bytes.SplitAfter(data, []byte{0}) + if !reflect.DeepEqual(got, want) { + t.Errorf("unexpected result for buffer zie %d:\ngot :%q\nwant:%q", size, got, want) + } + } +} diff --git a/codec/h264/lex.go b/codec/h264/lex.go new file mode 100644 index 00000000..9a071715 --- /dev/null +++ b/codec/h264/lex.go @@ -0,0 +1,135 @@ +/* +NAME + lex.go + +DESCRIPTION + lex.go provides a lexer to lex h264 bytestream into access units. + +AUTHOR + Dan Kortschak + +LICENSE + lex.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// lex.go provides a lexer to lex h264 bytestream into access units. + +package h264 + +import ( + "io" + "time" + + "bitbucket.org/ausocean/av/codec/codecutil" +) + +var noDelay = make(chan time.Time) + +func init() { + close(noDelay) +} + +var h264Prefix = [...]byte{0x00, 0x00, 0x01, 0x09, 0xf0} + +// Lex lexes H.264 NAL units read from src into separate writes to dst with +// successive writes being performed not earlier than the specified delay. +// NAL units are split after type 1 (Coded slice of a non-IDR picture), 5 +// (Coded slice of a IDR picture) and 8 (Picture parameter set). +func Lex(dst io.Writer, src io.Reader, delay time.Duration) error { + var tick <-chan time.Time + if delay == 0 { + tick = noDelay + } else { + ticker := time.NewTicker(delay) + defer ticker.Stop() + tick = ticker.C + } + + const bufSize = 8 << 10 + + c := codecutil.NewByteScanner(src, make([]byte, 4<<10)) // Standard file buffer size. + + buf := make([]byte, len(h264Prefix), bufSize) + copy(buf, h264Prefix[:]) + writeOut := false +outer: + for { + var b byte + var err error + buf, b, err = c.ScanUntil(buf, 0x00) + if err != nil { + if err != io.EOF { + return err + } + break + } + + for n := 1; b == 0x0 && n < 4; n++ { + b, err = c.ReadByte() + if err != nil { + if err != io.EOF { + return err + } + break outer + } + buf = append(buf, b) + + if b != 0x1 || (n != 2 && n != 3) { + continue + } + + if writeOut { + <-tick + _, err := dst.Write(buf[:len(buf)-(n+1)]) + if err != nil { + return err + } + buf = make([]byte, len(h264Prefix)+n, bufSize) + copy(buf, h264Prefix[:]) + buf = append(buf, 1) + writeOut = false + } + + b, err = c.ReadByte() + if err != nil { + if err != io.EOF { + return err + } + break outer + } + buf = append(buf, b) + + // http://www.itu.int/rec/dologin_pub.asp?lang=e&id=T-REC-H.264-200305-S!!PDF-E&type=items + // Table 7-1 NAL unit type codes + const ( + nonIdrPic = 1 + idrPic = 5 + suppEnhInfo = 6 + paramSet = 8 + ) + switch nalTyp := b & 0x1f; nalTyp { + case nonIdrPic, idrPic, paramSet, suppEnhInfo: + writeOut = true + } + } + } + if len(buf) == len(h264Prefix) { + return nil + } + <-tick + _, err := dst.Write(buf) + return err +} diff --git a/codec/lex/lex_test.go b/codec/h264/lex_test.go similarity index 67% rename from codec/lex/lex_test.go rename to codec/h264/lex_test.go index a107b253..d2eeae2a 100644 --- a/codec/lex/lex_test.go +++ b/codec/h264/lex_test.go @@ -3,7 +3,7 @@ NAME lex_test.go DESCRIPTION - See Readme.md + lex_test.go provides tests for the lexer in lex.go. AUTHOR Dan Kortschak @@ -25,12 +25,11 @@ LICENSE along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. */ -package lex +// lex_test.go provides tests for the lexer in lex.go. + +package h264 import ( - "bytes" - "reflect" - "testing" "time" ) @@ -207,7 +206,7 @@ var h264Tests = []struct { func TestH264(t *testing.T) { for _, test := range h264Tests { var buf chunkEncoder - err := H264(&buf, bytes.NewReader(test.input), test.delay) + err := Lex(&buf, bytes.NewReader(test.input), test.delay) if fmt.Sprint(err) != fmt.Sprint(test.err) { t.Errorf("unexpected error for %q: got:%v want:%v", test.name, err, test.err) } @@ -221,131 +220,3 @@ func TestH264(t *testing.T) { } } */ - -var mjpegTests = []struct { - name string - input []byte - delay time.Duration - want [][]byte - err error -}{ - { - name: "empty", - }, - { - name: "null", - input: []byte{0xff, 0xd8, 0xff, 0xd9}, - delay: 0, - want: [][]byte{{0xff, 0xd8, 0xff, 0xd9}}, - }, - { - name: "null delayed", - input: []byte{0xff, 0xd8, 0xff, 0xd9}, - delay: time.Millisecond, - want: [][]byte{{0xff, 0xd8, 0xff, 0xd9}}, - }, - { - name: "full", - input: []byte{ - 0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9, - 0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9, - 0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9, - 0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9, - 0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9, - }, - delay: 0, - want: [][]byte{ - {0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9}, - {0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9}, - {0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9}, - {0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9}, - {0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9}, - }, - }, - { - name: "full delayed", - input: []byte{ - 0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9, - 0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9, - 0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9, - 0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9, - 0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9, - }, - delay: time.Millisecond, - want: [][]byte{ - {0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9}, - {0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9}, - {0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9}, - {0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9}, - {0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9}, - }, - }, -} - -// FIXME this needs to be adapted -/* -func TestMJEG(t *testing.T) { - for _, test := range mjpegTests { - var buf chunkEncoder - err := MJPEG(&buf, bytes.NewReader(test.input), test.delay) - if fmt.Sprint(err) != fmt.Sprint(test.err) { - t.Errorf("unexpected error for %q: got:%v want:%v", test.name, err, test.err) - } - if err != nil { - continue - } - got := [][]byte(buf) - if !reflect.DeepEqual(got, test.want) { - t.Errorf("unexpected result for %q:\ngot :%#v\nwant:%#v", test.name, got, test.want) - } - } -} -*/ - -type chunkEncoder [][]byte - -func (e *chunkEncoder) Encode(b []byte) error { - *e = append(*e, b) - return nil -} - -func (*chunkEncoder) Stream() <-chan []byte { panic("INVALID USE") } - -func TestScannerReadByte(t *testing.T) { - data := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") - - for _, size := range []int{1, 2, 8, 1 << 10} { - r := newScanner(bytes.NewReader(data), make([]byte, size)) - var got []byte - for { - b, err := r.readByte() - if err != nil { - break - } - got = append(got, b) - } - if !bytes.Equal(got, data) { - t.Errorf("unexpected result for buffer size %d:\ngot :%q\nwant:%q", size, got, data) - } - } -} - -func TestScannerScanUntilZero(t *testing.T) { - data := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit,\x00 sed do eiusmod tempor incididunt ut \x00labore et dolore magna aliqua.") - - for _, size := range []int{1, 2, 8, 1 << 10} { - r := newScanner(bytes.NewReader(data), make([]byte, size)) - var got [][]byte - for { - buf, _, err := r.scanUntilZeroInto(nil) - got = append(got, buf) - if err != nil { - break - } - } - want := bytes.SplitAfter(data, []byte{0}) - if !reflect.DeepEqual(got, want) { - t.Errorf("unexpected result for buffer zie %d:\ngot :%q\nwant:%q", size, got, want) - } - } -} diff --git a/codec/h265/lex.go b/codec/h265/lex.go new file mode 100644 index 00000000..ebe34013 --- /dev/null +++ b/codec/h265/lex.go @@ -0,0 +1,203 @@ +/* +NAME + lex.go + +DESCRIPTION + lex.go provides a lexer for taking RTP HEVC (H265) and lexing into access units. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// Package h265 provides an RTP h265 lexer that can extract h265 access units +// from an RTP stream. +package h265 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "time" + + "bitbucket.org/ausocean/av/protocol/rtp" +) + +// NALU types. +const ( + typeAggregation = 48 + typeFragmentation = 49 + typePACI = 50 +) + +// Buffer sizes. +const ( + maxAUSize = 100000 + maxRTPSize = 4096 +) + +// Lexer is an H265 lexer. +type Lexer struct { + donl bool // Indicates whether DONL and DOND will be used for the RTP stream. + buf *bytes.Buffer // Holds the current access unit. + frag bool // Indicates if we're currently dealing with a fragmentation packet. +} + +// NewLexer returns a new Lexer. +func NewLexer(donl bool) *Lexer { + return &Lexer{ + donl: donl, + buf: bytes.NewBuffer(make([]byte, 0, maxAUSize)), + } +} + +// Lex continually reads RTP packets from the io.Reader src and lexes into +// access units which are written to the io.Writer dst. Lex expects that for +// each read from src, a single RTP packet is received. +func (l *Lexer) Lex(dst io.Writer, src io.Reader, delay time.Duration) error { + buf := make([]byte, maxRTPSize) + for { + n, err := src.Read(buf) + switch err { + case nil: // Do nothing. + case io.EOF: + return nil + default: + return fmt.Errorf("source read error: %v\n", err) + } + + // Get payload from RTP packet. + payload, err := rtp.Payload(buf[:n]) + if err != nil { + return fmt.Errorf("could not get rtp payload, failed with err: %v\n", err) + } + nalType := (payload[0] >> 1) & 0x3f + + // If not currently fragmented then we ignore current write. + if l.frag && nalType != typeFragmentation { + l.buf.Reset() + l.frag = false + continue + } + + switch nalType { + case typeAggregation: + l.handleAggregation(payload) + case typeFragmentation: + l.handleFragmentation(payload) + case typePACI: + l.handlePACI(payload) + default: + l.writeWithPrefix(payload) + } + + markerIsSet, err := rtp.Marker(buf[:n]) + if err != nil { + return fmt.Errorf("could not get marker bit, failed with err: %v\n", err) + } + + if markerIsSet { + _, err := l.buf.WriteTo(dst) + if err != nil { + // TODO: work out what to do here. + } + l.buf.Reset() + } + } + return nil +} + +// handleAggregation parses NAL units from an aggregation packet and writes +// them to the Lexers buffer buf. +func (l *Lexer) handleAggregation(d []byte) { + idx := 2 + for idx < len(d) { + if l.donl { + switch idx { + case 2: + idx += 2 + default: + idx++ + } + } + size := int(binary.BigEndian.Uint16(d[idx:])) + idx += 2 + nalu := d[idx : idx+size] + idx += size + l.writeWithPrefix(nalu) + } +} + +// handleFragmentation parses NAL units from fragmentation packets and writes +// them to the Lexer's buf. +func (l *Lexer) handleFragmentation(d []byte) { + // Get start and end indiciators from FU header. + start := d[2]&0x80 != 0 + end := d[2]&0x40 != 0 + + b1 := (d[0] & 0x81) | ((d[2] & 0x3f) << 1) + b2 := d[1] + if start { + d = d[1:] + if l.donl { + d = d[2:] + } + d[0] = b1 + d[1] = b2 + } else { + d = d[3:] + if l.donl { + d = d[2:] + } + } + + switch { + case start && !end: + l.frag = true + l.writeWithPrefix(d) + case !start && end: + l.frag = false + fallthrough + case !start && !end: + l.writeNoPrefix(d) + default: + panic("bad fragmentation packet") + } +} + +// handlePACI will handl PACI packets +// +// TODO: complete this +func (l *Lexer) handlePACI(d []byte) { + panic("unsupported nal type") +} + +// write writes a NAL unit to the Lexer's buf in byte stream format using the +// start code. +func (l *Lexer) writeWithPrefix(d []byte) { + const prefix = "\x00\x00\x00\x01" + l.buf.Write([]byte(prefix)) + l.buf.Write(d) +} + +// writeNoPrefix writes data to the Lexer's buf. This is used for non start +// fragmentations of a NALU. +func (l *Lexer) writeNoPrefix(d []byte) { + l.buf.Write(d) +} diff --git a/codec/h265/lex_test.go b/codec/h265/lex_test.go new file mode 100644 index 00000000..1a409e4c --- /dev/null +++ b/codec/h265/lex_test.go @@ -0,0 +1,262 @@ +/* +NAME + lex_test.go + +DESCRIPTION + lex_test.go provides tests to check validity of the Lexer found in lex.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package h265 + +import ( + "io" + "testing" +) + +// rtpReader provides the RTP stream. +type rtpReader struct { + packets [][]byte + idx int +} + +// Read implements io.Reader. +func (r *rtpReader) Read(p []byte) (int, error) { + if r.idx == len(r.packets) { + return 0, io.EOF + } + b := r.packets[r.idx] + n := copy(p, b) + if n < len(r.packets[r.idx]) { + r.packets[r.idx] = r.packets[r.idx][n:] + } else { + r.idx++ + } + return n, nil +} + +// destination holds the access units extracted during the lexing process. +type destination [][]byte + +// Write implements io.Writer. +func (d *destination) Write(p []byte) (int, error) { + t := make([]byte, len(p)) + copy(t, p) + *d = append([][]byte(*d), t) + return len(p), nil +} + +// TestLex checks that the Lexer can correctly extract H265 access units from +// HEVC RTP stream in RTP payload format. +func TestLex(t *testing.T) { + const rtpVer = 2 + + tests := []struct { + donl bool + packets [][]byte + expect [][]byte + }{ + { + donl: false, + packets: [][]byte{ + { // Single NAL unit. + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // NAL Data. + }, + { // Fragmentation (start packet). + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type49). + 0x80, // FU header. + 0x01, 0x02, 0x03, // FU payload. + }, + { // Fragmentation (middle packet) + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type 49). + 0x00, // FU header. + 0x04, 0x05, 0x06, // FU payload. + }, + { // Fragmentation (end packet) + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type 49). + 0x40, // FU header. + 0x07, 0x08, 0x09, // FU payload + }, + + { // Aggregation. Make last packet of access unit => marker bit true. + 0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x60, 0x00, // NAL header (type 49). + 0x00, 0x04, // NAL 1 size. + 0x01, 0x02, 0x03, 0x04, // NAL 1 data. + 0x00, 0x04, // NAL 2 size. + 0x01, 0x02, 0x03, 0x04, // NAL 2 data. + }, + { // Singla NAL + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // NAL data. + }, + { // Singla NAL. Make last packet of access unit => marker bit true. + 0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // NAL data. + }, + }, + expect: [][]byte{ + // First access unit. + { + // NAL 1 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // NAL data. + // NAL 2 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x00, 0x00, 0x01, 0x02, 0x03, // FU payload. + 0x04, 0x05, 0x06, // FU payload. + 0x07, 0x08, 0x09, // FU payload. + // NAL 3 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x01, 0x02, 0x03, 0x04, // NAL data. + // NAL 4 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x01, 0x02, 0x03, 0x04, // NAL 2 data + }, + // Second access unit. + { + // NAL 1 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // Data. + // NAL 2 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x01, 0x02, 0x03, 0x04, // Data. + }, + }, + }, + { + donl: true, + packets: [][]byte{ + { // Single NAL unit. + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, 0x04, // NAL Data. + }, + { // Fragmentation (start packet). + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type49). + 0x80, // FU header. + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, // FU payload. + }, + { // Fragmentation (middle packet) + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type 49). + 0x00, // FU header. + 0x00, 0x00, // DONL + 0x04, 0x05, 0x06, // FU payload. + }, + { // Fragmentation (end packet) + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x62, 0x00, // NAL header (type 49). + 0x40, // FU header. + 0x00, 0x00, // DONL + 0x07, 0x08, 0x09, // FU payload + }, + + { // Aggregation. Make last packet of access unit => marker bit true. + 0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x60, 0x00, // NAL header (type 49). + 0x00, 0x00, // DONL + 0x00, 0x04, // NAL 1 size. + 0x01, 0x02, 0x03, 0x04, // NAL 1 data. + 0x00, // DOND + 0x00, 0x04, // NAL 2 size. + 0x01, 0x02, 0x03, 0x04, // NAL 2 data. + }, + { // Singla NAL + 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS) + 0x00, 0x00, // DONL. + 0x01, 0x02, 0x03, 0x04, // NAL data. + }, + { // Singla NAL. Make last packet of access unit => marker bit true. + 0x80, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // RTP header. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, 0x04, // NAL data. + }, + }, + expect: [][]byte{ + // First access unit. + { + // NAL 1 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, 0x04, // NAL data. + // NAL 2 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x00, 0x00, 0x01, 0x02, 0x03, // FU payload. + 0x04, 0x05, 0x06, // FU payload. + 0x07, 0x08, 0x09, // FU payload. + // NAL 3 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x01, 0x02, 0x03, 0x04, // NAL data. + // NAL 4 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x01, 0x02, 0x03, 0x04, // NAL 2 data + }, + // Second access unit. + { + // NAL 1 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, 0x04, // Data. + // NAL 2 + 0x00, 0x00, 0x00, 0x01, // Start code. + 0x40, 0x00, // NAL header (type=32 VPS). + 0x00, 0x00, // DONL + 0x01, 0x02, 0x03, 0x04, // Data. + }, + }, + }, + } + + for testNum, test := range tests { + r := &rtpReader{packets: test.packets} + d := &destination{} + err := NewLexer(test.donl).Lex(d, r, 0) + if err != nil { + t.Fatalf("error lexing: %v\n", err) + } + + for i, accessUnit := range test.expect { + for j, part := range accessUnit { + if part != [][]byte(*d)[i][j] { + t.Fatalf("did not get expected data for test: %v.\nGot: %v\nWant: %v\n", testNum, d, test.expect) + } + } + } + } +} diff --git a/codec/lex/lex.go b/codec/lex/lex.go deleted file mode 100644 index da0dd1b6..00000000 --- a/codec/lex/lex.go +++ /dev/null @@ -1,247 +0,0 @@ -/* -NAME - lex.go - -DESCRIPTION - See Readme.md - -AUTHOR - Dan Kortschak - -LICENSE - lex.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) - - It is free software: you can redistribute it and/or modify them - under the terms of the GNU General Public License as published by the - Free Software Foundation, either version 3 of the License, or (at your - option) any later version. - - It is distributed in the hope that it will be useful, but WITHOUT - ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - for more details. - - You should have received a copy of the GNU General Public License - along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. -*/ - -// Package lex provides lexers for video encodings. -package lex - -import ( - "bufio" - "bytes" - "fmt" - "io" - "time" -) - -var noDelay = make(chan time.Time) - -func init() { - close(noDelay) -} - -var h264Prefix = [...]byte{0x00, 0x00, 0x01, 0x09, 0xf0} - -// H264 lexes H.264 NAL units read from src into separate writes to dst with -// successive writes being performed not earlier than the specified delay. -// NAL units are split after type 1 (Coded slice of a non-IDR picture), 5 -// (Coded slice of a IDR picture) and 8 (Picture parameter set). -func H264(dst io.Writer, src io.Reader, delay time.Duration) error { - var tick <-chan time.Time - if delay == 0 { - tick = noDelay - } else { - ticker := time.NewTicker(delay) - defer ticker.Stop() - tick = ticker.C - } - - const bufSize = 8 << 10 - - c := newScanner(src, make([]byte, 4<<10)) // Standard file buffer size. - - buf := make([]byte, len(h264Prefix), bufSize) - copy(buf, h264Prefix[:]) - writeOut := false -outer: - for { - var b byte - var err error - buf, b, err = c.scanUntilZeroInto(buf) - if err != nil { - if err != io.EOF { - return err - } - break - } - - for n := 1; b == 0x0 && n < 4; n++ { - b, err = c.readByte() - if err != nil { - if err != io.EOF { - return err - } - break outer - } - buf = append(buf, b) - - if b != 0x1 || (n != 2 && n != 3) { - continue - } - - if writeOut { - <-tick - _, err := dst.Write(buf[:len(buf)-(n+1)]) - if err != nil { - return err - } - buf = make([]byte, len(h264Prefix)+n, bufSize) - copy(buf, h264Prefix[:]) - buf = append(buf, 1) - writeOut = false - } - - b, err = c.readByte() - if err != nil { - if err != io.EOF { - return err - } - break outer - } - buf = append(buf, b) - - // http://www.itu.int/rec/dologin_pub.asp?lang=e&id=T-REC-H.264-200305-S!!PDF-E&type=items - // Table 7-1 NAL unit type codes - const ( - nonIdrPic = 1 - idrPic = 5 - suppEnhInfo = 6 - paramSet = 8 - ) - switch nalTyp := b & 0x1f; nalTyp { - case nonIdrPic, idrPic, paramSet, suppEnhInfo: - writeOut = true - } - } - } - if len(buf) == len(h264Prefix) { - return nil - } - <-tick - _, err := dst.Write(buf) - return err -} - -// scanner is a byte scanner. -type scanner struct { - buf []byte - off int - - // r is the source of data for the scanner. - r io.Reader -} - -// newScanner returns a scanner initialised with an io.Reader and a read buffer. -func newScanner(r io.Reader, buf []byte) *scanner { - return &scanner{r: r, buf: buf[:0]} -} - -// scanUntilZeroInto scans the scanner's underlying io.Reader until a zero byte -// has been read, appending all read bytes to dst. The resulting appended data, -// the last read byte and whether the last read byte was zero are returned. -func (c *scanner) scanUntilZeroInto(dst []byte) (res []byte, b byte, err error) { -outer: - for { - var i int - for i, b = range c.buf[c.off:] { - if b != 0x0 { - continue - } - dst = append(dst, c.buf[c.off:c.off+i+1]...) - c.off += i + 1 - break outer - } - dst = append(dst, c.buf[c.off:]...) - err = c.reload() - if err != nil { - break - } - } - return dst, b, err -} - -// readByte is an unexported ReadByte. -func (c *scanner) readByte() (byte, error) { - if c.off >= len(c.buf) { - err := c.reload() - if err != nil { - return 0, err - } - } - b := c.buf[c.off] - c.off++ - return b, nil -} - -// reload re-fills the scanner's buffer. -func (c *scanner) reload() error { - n, err := c.r.Read(c.buf[:cap(c.buf)]) - c.buf = c.buf[:n] - if err != nil { - if err != io.EOF { - return err - } - if n == 0 { - return io.EOF - } - } - c.off = 0 - return nil -} - -// MJPEG parses MJPEG frames read from src into separate writes to dst with -// successive writes being performed not earlier than the specified delay. -func MJPEG(dst io.Writer, src io.Reader, delay time.Duration) error { - var tick <-chan time.Time - if delay == 0 { - tick = noDelay - } else { - ticker := time.NewTicker(delay) - defer ticker.Stop() - tick = ticker.C - } - - r := bufio.NewReader(src) - for { - buf := make([]byte, 2, 4<<10) - n, err := r.Read(buf) - if n < 2 { - return nil - } - if err != nil { - return err - } - if !bytes.Equal(buf, []byte{0xff, 0xd8}) { - return fmt.Errorf("parser: not MJPEG frame start: %#v", buf) - } - var last byte - for { - b, err := r.ReadByte() - if err != nil { - return err - } - buf = append(buf, b) - if last == 0xff && b == 0xd9 { - break - } - last = b - } - <-tick - _, err = dst.Write(buf) - if err != nil { - return err - } - } -} diff --git a/codec/mjpeg/lex.go b/codec/mjpeg/lex.go new file mode 100644 index 00000000..da2ecae1 --- /dev/null +++ b/codec/mjpeg/lex.go @@ -0,0 +1,89 @@ +/* +NAME + lex.go + +DESCRIPTION + lex.go provides a lexer to extract separate JPEG images from a MJPEG stream. + +AUTHOR + Dan Kortschak + +LICENSE + lex.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// lex.go provides a lexer to extract separate JPEG images from a MJPEG stream. + +package mjpeg + +import ( + "bufio" + "bytes" + "fmt" + "io" + "time" +) + +var noDelay = make(chan time.Time) + +func init() { + close(noDelay) +} + +// Lex parses MJPEG frames read from src into separate writes to dst with +// successive writes being performed not earlier than the specified delay. +func Lex(dst io.Writer, src io.Reader, delay time.Duration) error { + var tick <-chan time.Time + if delay == 0 { + tick = noDelay + } else { + ticker := time.NewTicker(delay) + defer ticker.Stop() + tick = ticker.C + } + + r := bufio.NewReader(src) + for { + buf := make([]byte, 2, 4<<10) + n, err := r.Read(buf) + if n < 2 { + return nil + } + if err != nil { + return err + } + if !bytes.Equal(buf, []byte{0xff, 0xd8}) { + return fmt.Errorf("parser: not MJPEG frame start: %#v", buf) + } + var last byte + for { + b, err := r.ReadByte() + if err != nil { + return err + } + buf = append(buf, b) + if last == 0xff && b == 0xd9 { + break + } + last = b + } + <-tick + _, err = dst.Write(buf) + if err != nil { + return err + } + } +} diff --git a/codec/mjpeg/lex_test.go b/codec/mjpeg/lex_test.go new file mode 100644 index 00000000..fd8b6f86 --- /dev/null +++ b/codec/mjpeg/lex_test.go @@ -0,0 +1,114 @@ +/* +NAME + lex_test.go + +DESCRIPTION + lex_test.go provides testing for the lexer in lex.go. + +AUTHOR + Dan Kortschak + +LICENSE + lex_test.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// lex_test.go provides testing for the lexer in lex.go. + +package mjpeg + +import ( + "time" +) + +var mjpegTests = []struct { + name string + input []byte + delay time.Duration + want [][]byte + err error +}{ + { + name: "empty", + }, + { + name: "null", + input: []byte{0xff, 0xd8, 0xff, 0xd9}, + delay: 0, + want: [][]byte{{0xff, 0xd8, 0xff, 0xd9}}, + }, + { + name: "null delayed", + input: []byte{0xff, 0xd8, 0xff, 0xd9}, + delay: time.Millisecond, + want: [][]byte{{0xff, 0xd8, 0xff, 0xd9}}, + }, + { + name: "full", + input: []byte{ + 0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9, + 0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9, + 0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9, + 0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9, + 0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9, + }, + delay: 0, + want: [][]byte{ + {0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9}, + {0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9}, + {0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9}, + {0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9}, + {0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9}, + }, + }, + { + name: "full delayed", + input: []byte{ + 0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9, + 0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9, + 0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9, + 0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9, + 0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9, + }, + delay: time.Millisecond, + want: [][]byte{ + {0xff, 0xd8, 'f', 'u', 'l', 'l', 0xff, 0xd9}, + {0xff, 0xd8, 'f', 'r', 'a', 'm', 'e', 0xff, 0xd9}, + {0xff, 0xd8, 'w', 'i', 't', 'h', 0xff, 0xd9}, + {0xff, 0xd8, 'l', 'e', 'n', 'g', 't', 'h', 0xff, 0xd9}, + {0xff, 0xd8, 's', 'p', 'r', 'e', 'a', 'd', 0xff, 0xd9}, + }, + }, +} + +// FIXME this needs to be adapted +/* +func Lex(t *testing.T) { + for _, test := range mjpegTests { + var buf chunkEncoder + err := MJPEG(&buf, bytes.NewReader(test.input), test.delay) + if fmt.Sprint(err) != fmt.Sprint(test.err) { + t.Errorf("unexpected error for %q: got:%v want:%v", test.name, err, test.err) + } + if err != nil { + continue + } + got := [][]byte(buf) + if !reflect.DeepEqual(got, test.want) { + t.Errorf("unexpected result for %q:\ngot :%#v\nwant:%#v", test.name, got, test.want) + } + } +} +*/ diff --git a/container/mts/audio_test.go b/container/mts/audio_test.go deleted file mode 100644 index f785930d..00000000 --- a/container/mts/audio_test.go +++ /dev/null @@ -1,144 +0,0 @@ -/* -NAME - audio_test.go - -AUTHOR - Trek Hopton - -LICENSE - audio_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) - - It is free software: you can redistribute it and/or modify them - under the terms of the GNU General Public License as published by the - Free Software Foundation, either version 3 of the License, or (at your - option) any later version. - - It is distributed in the hope that it will be useful, but WITHOUT - ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - for more details. - - You should have received a copy of the GNU General Public License in gpl.txt. - If not, see http://www.gnu.org/licenses. -*/ - -package mts - -import ( - "bytes" - "io" - "io/ioutil" - "testing" - - "github.com/Comcast/gots/packet" - "github.com/Comcast/gots/pes" - - "bitbucket.org/ausocean/av/container/mts/meta" -) - -type nopCloser struct{ io.Writer } - -func (nopCloser) Close() error { return nil } - -// TestEncodePcm tests the mpegts encoder's ability to encode pcm audio data. -// It reads and encodes input pcm data into mpegts, then decodes the mpegts and compares the result to the input pcm. -func TestEncodePcm(t *testing.T) { - Meta = meta.New() - - var buf bytes.Buffer - sampleRate := 48000 - sampleSize := 2 - blockSize := 16000 - writeFreq := float64(sampleRate*sampleSize) / float64(blockSize) - e := NewEncoder(nopCloser{&buf}, writeFreq, Audio) - - inPath := "../../../test/test-data/av/input/sweep_400Hz_20000Hz_-3dBFS_5s_48khz.pcm" - inPcm, err := ioutil.ReadFile(inPath) - if err != nil { - t.Errorf("unable to read file: %v", err) - } - - // Break pcm into blocks and encode to mts and get the resulting bytes. - for i := 0; i < len(inPcm); i += blockSize { - if len(inPcm)-i < blockSize { - block := inPcm[i:] - _, err = e.Write(block) - if err != nil { - t.Errorf("unable to write block: %v", err) - } - } else { - block := inPcm[i : i+blockSize] - _, err = e.Write(block) - if err != nil { - t.Errorf("unable to write block: %v", err) - } - } - } - clip := buf.Bytes() - - // Get the first MTS packet to check - var pkt packet.Packet - pesPacket := make([]byte, 0, blockSize) - got := make([]byte, 0, len(inPcm)) - i := 0 - if i+PacketSize <= len(clip) { - copy(pkt[:], clip[i:i+PacketSize]) - } - - // Loop through MTS packets until all the audio data from PES packets has been retrieved - for i+PacketSize <= len(clip) { - - // Check MTS packet - if !(pkt.PID() == audioPid) { - i += PacketSize - if i+PacketSize <= len(clip) { - copy(pkt[:], clip[i:i+PacketSize]) - } - continue - } - if !pkt.PayloadUnitStartIndicator() { - i += PacketSize - if i+PacketSize <= len(clip) { - copy(pkt[:], clip[i:i+PacketSize]) - } - } else { - // Copy the first MTS payload - payload, err := pkt.Payload() - if err != nil { - t.Errorf("unable to get MTS payload: %v", err) - } - pesPacket = append(pesPacket, payload...) - - i += PacketSize - if i+PacketSize <= len(clip) { - copy(pkt[:], clip[i:i+PacketSize]) - } - - // Copy the rest of the MTS payloads that are part of the same PES packet - for (!pkt.PayloadUnitStartIndicator()) && i+PacketSize <= len(clip) { - payload, err = pkt.Payload() - if err != nil { - t.Errorf("unable to get MTS payload: %v", err) - } - pesPacket = append(pesPacket, payload...) - - i += PacketSize - if i+PacketSize <= len(clip) { - copy(pkt[:], clip[i:i+PacketSize]) - } - } - } - // Get the audio data from the current PES packet - pesHeader, err := pes.NewPESHeader(pesPacket) - if err != nil { - t.Errorf("unable to read PES packet: %v", err) - } - got = append(got, pesHeader.Data()...) - pesPacket = pesPacket[:0] - } - - // Compare data from MTS with original data. - if !bytes.Equal(got, inPcm) { - t.Error("data decoded from mts did not match input data") - } -} diff --git a/container/mts/discontinuity.go b/container/mts/discontinuity.go index adccebad..e127ff94 100644 --- a/container/mts/discontinuity.go +++ b/container/mts/discontinuity.go @@ -4,8 +4,8 @@ NAME DESCRIPTION discontinuity.go provides functionality for detecting discontinuities in - mpegts and accounting for using the discontinuity indicator in the adaptation - field. + MPEG-TS and accounting for using the discontinuity indicator in the adaptation + field. AUTHOR Saxon A. Nelson-Milton @@ -33,7 +33,7 @@ import ( "github.com/Comcast/gots/packet" ) -// discontinuityRepairer provides function to detect discontinuities in mpegts +// discontinuityRepairer provides function to detect discontinuities in MPEG-TS // and set the discontinuity indicator as appropriate. type DiscontinuityRepairer struct { expCC map[int]int @@ -56,7 +56,7 @@ func (dr *DiscontinuityRepairer) Failed() { dr.decExpectedCC(PatPid) } -// Repair takes a clip of mpegts and checks that the first packet, which should +// Repair takes a clip of MPEG-TS and checks that the first packet, which should // be a PAT, contains a cc that is expected, otherwise the discontinuity indicator // is set to true. func (dr *DiscontinuityRepairer) Repair(d []byte) error { diff --git a/container/mts/encoder.go b/container/mts/encoder.go index e9efbd97..a1598dd8 100644 --- a/container/mts/encoder.go +++ b/container/mts/encoder.go @@ -55,30 +55,6 @@ var ( }, }, } - - // standardPmt is a minimal PMT, without descriptors for time and location. - standardPmt = psi.PSI{ - Pf: 0x00, - Tid: 0x02, - Ssi: true, - Sl: 0x12, - Tss: &psi.TSS{ - Tide: 0x01, - V: 0, - Cni: true, - Sn: 0, - Lsn: 0, - Sd: &psi.PMT{ - Pcrpid: 0x0100, - Pil: 0, - Essd: &psi.ESSD{ - St: 0x1b, - Epid: 0x0100, - Esil: 0x00, - }, - }, - }, - } ) const ( @@ -94,7 +70,7 @@ var Meta *meta.Data var ( patTable = standardPat.Bytes() - pmtTable = standardPmt.Bytes() + pmtTable []byte ) const ( @@ -103,28 +79,32 @@ const ( pmtPid = 4096 videoPid = 256 audioPid = 210 - videoStreamID = 0xe0 // First video stream ID. + H264ID = 27 + H265ID = 36 audioStreamID = 0xc0 // First audio stream ID. ) -// Video and Audio constants are used to communicate which media type will be encoded when creating a -// new encoder with NewEncoder. +// Constants used to communicate which media codec will be packetized. const ( - Video = iota - Audio + EncodeH264 = iota + EncodeH265 + EncodeAudio ) -// Time related constants. +// Time-related constants. const ( // ptsOffset is the offset added to the clock to determine // the current presentation timestamp. ptsOffset = 700 * time.Millisecond - // pcrFreq is the base Program Clock Reference frequency. - pcrFreq = 90000 // Hz + // PCRFrequency is the base Program Clock Reference frequency in Hz. + PCRFrequency = 90000 + + // PTSFrequency is the presentation timestamp frequency in Hz. + PTSFrequency = 90000 ) -// Encoder encapsulates properties of an mpegts generator. +// Encoder encapsulates properties of an MPEG-TS generator. type Encoder struct { dst io.WriteCloser @@ -152,14 +132,41 @@ func NewEncoder(dst io.WriteCloser, rate float64, mediaType int) *Encoder { var mPid int var sid byte switch mediaType { - case Audio: + case EncodeAudio: mPid = audioPid sid = audioStreamID - case Video: + case EncodeH265: mPid = videoPid - sid = videoStreamID + sid = H265ID + case EncodeH264: + mPid = videoPid + sid = H264ID } + // standardPmt is a minimal PMT, without descriptors for metadata. + pmtTable = (&psi.PSI{ + Pf: 0x00, + Tid: 0x02, + Ssi: true, + Sl: 0x12, + Tss: &psi.TSS{ + Tide: 0x01, + V: 0, + Cni: true, + Sn: 0, + Lsn: 0, + Sd: &psi.PMT{ + Pcrpid: 0x0100, + Pil: 0, + Essd: &psi.ESSD{ + St: byte(sid), + Epid: 0x0100, + Esil: 0x00, + }, + }, + }, + }).Bytes() + return &Encoder{ dst: dst, @@ -201,7 +208,7 @@ func (e *Encoder) TimeBasedPsi(b bool, sendCount int) { e.pktCount = e.psiSendCount } -// Write implements io.Writer. Write takes raw h264 and encodes into mpegts, +// Write implements io.Writer. Write takes raw h264 and encodes into MPEG-TS, // then sending it to the encoder's io.Writer destination. func (e *Encoder) Write(data []byte) (int, error) { now := time.Now() @@ -256,7 +263,7 @@ func (e *Encoder) Write(data []byte) (int, error) { return len(data), nil } -// writePSI creates mpegts with pat and pmt tables - with pmt table having updated +// writePSI creates MPEG-TS with pat and pmt tables - with pmt table having updated // location and time data. func (e *Encoder) writePSI() error { // Write PAT. @@ -264,7 +271,7 @@ func (e *Encoder) writePSI() error { PUSI: true, PID: PatPid, CC: e.ccFor(PatPid), - AFC: HasPayload, + AFC: hasPayload, Payload: psi.AddPadding(patTable), } _, err := e.dst.Write(patPkt.Bytes(e.tsSpace[:PacketSize])) @@ -282,7 +289,7 @@ func (e *Encoder) writePSI() error { PUSI: true, PID: PmtPid, CC: e.ccFor(PmtPid), - AFC: HasPayload, + AFC: hasPayload, Payload: psi.AddPadding(pmtTable), } _, err = e.dst.Write(pmtPkt.Bytes(e.tsSpace[:PacketSize])) @@ -300,12 +307,12 @@ func (e *Encoder) tick() { // pts retuns the current presentation timestamp. func (e *Encoder) pts() uint64 { - return uint64((e.clock + e.ptsOffset).Seconds() * pcrFreq) + return uint64((e.clock + e.ptsOffset).Seconds() * PTSFrequency) } // pcr returns the current program clock reference. func (e *Encoder) pcr() uint64 { - return uint64(e.clock.Seconds() * pcrFreq) + return uint64(e.clock.Seconds() * PCRFrequency) } // ccFor returns the next continuity counter for pid. diff --git a/container/mts/encoder_test.go b/container/mts/encoder_test.go new file mode 100644 index 00000000..24fb823d --- /dev/null +++ b/container/mts/encoder_test.go @@ -0,0 +1,252 @@ +/* +NAME + encoder_test.go + +AUTHOR + Trek Hopton + Saxon A. Nelson-Milton + +LICENSE + encoder_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License in gpl.txt. + If not, see http://www.gnu.org/licenses. +*/ + +package mts + +import ( + "bytes" + "io" + "io/ioutil" + "testing" + + "github.com/Comcast/gots/packet" + "github.com/Comcast/gots/pes" + + "bitbucket.org/ausocean/av/container/mts/meta" +) + +type nopCloser struct{ io.Writer } + +func (nopCloser) Close() error { return nil } + +type destination struct { + packets [][]byte +} + +func (d *destination) Write(p []byte) (int, error) { + tmp := make([]byte, PacketSize) + copy(tmp, p) + d.packets = append(d.packets, tmp) + return len(p), nil +} + +// TestEncodeVideo checks that we can correctly encode some dummy data into a +// valid MPEG-TS stream. This checks for correct MPEG-TS headers and also that the +// original data is stored correctly and is retreivable. +func TestEncodeVideo(t *testing.T) { + Meta = meta.New() + + const dataLength = 440 + const numOfPackets = 3 + const stuffingLen = 100 + + // Generate test data. + data := make([]byte, 0, dataLength) + for i := 0; i < dataLength; i++ { + data = append(data, byte(i)) + } + + // Expect headers for PID 256 (video) + // NB: timing fields like PCR are neglected. + expectedHeaders := [][]byte{ + { + 0x47, // Sync byte. + 0x41, // TEI=0, PUSI=1, TP=0, PID=00001 (256). + 0x00, // PID(Cont)=00000000. + 0x30, // TSC=00, AFC=11(adaptation followed by payload), CC=0000(0). + 0x07, // AFL= 7. + 0x50, // DI=0,RAI=1,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + }, + { + 0x47, // Sync byte. + 0x01, // TEI=0, PUSI=0, TP=0, PID=00001 (256). + 0x00, // PID(Cont)=00000000. + 0x31, // TSC=00, AFC=11(adaptation followed by payload), CC=0001(1). + 0x01, // AFL= 1. + 0x00, // DI=0,RAI=0,ESPI=0,PCRF=0,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + }, + { + 0x47, // Sync byte. + 0x01, // TEI=0, PUSI=0, TP=0, PID=00001 (256). + 0x00, // PID(Cont)=00000000. + 0x32, // TSC=00, AFC=11(adaptation followed by payload), CC=0010(2). + 0x57, // AFL= 1+stuffingLen. + 0x00, // DI=0,RAI=0,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + }, + } + + // Create the dst and write the test data to encoder. + dst := &destination{} + _, err := NewEncoder(nopCloser{dst}, 25, EncodeH264).Write(data) + if err != nil { + t.Fatalf("could not write data to encoder, failed with err: %v\n", err) + } + + // Check headers. + var expectedIdx int + for _, p := range dst.packets { + // Get PID. + var _p packet.Packet + copy(_p[:], p) + pid := packet.Pid(&_p) + if pid == VideoPid { + // Get mts header, excluding PCR. + gotHeader := p[0:6] + wantHeader := expectedHeaders[expectedIdx] + if !bytes.Equal(gotHeader, wantHeader) { + t.Errorf("did not get expected header for idx: %v.\n Got: %v\n Want: %v\n", expectedIdx, gotHeader, wantHeader) + } + expectedIdx++ + } + } + + // Gather payload data from packets to form the total PES packet. + var pesData []byte + for _, p := range dst.packets { + var _p packet.Packet + copy(_p[:], p) + pid := packet.Pid(&_p) + if pid == VideoPid { + payload, err := packet.Payload(&_p) + if err != nil { + t.Fatalf("could not get payload from mts packet, failed with err: %v\n", err) + } + pesData = append(pesData, payload...) + } + } + + // Get data from the PES packet and compare with the original data. + pes, err := pes.NewPESHeader(pesData) + if err != nil { + t.Fatalf("got error from pes creation: %v\n", err) + } + _data := pes.Data() + if !bytes.Equal(data, _data) { + t.Errorf("did not get expected result.\n Got: %v\n Want: %v\n", data, _data) + } +} + +// TestEncodePcm tests the MPEG-TS encoder's ability to encode pcm audio data. +// It reads and encodes input pcm data into MPEG-TS, then decodes the MPEG-TS and compares the result to the input pcm. +func TestEncodePcm(t *testing.T) { + Meta = meta.New() + + var buf bytes.Buffer + sampleRate := 48000 + sampleSize := 2 + blockSize := 16000 + writeFreq := float64(sampleRate*sampleSize) / float64(blockSize) + e := NewEncoder(nopCloser{&buf}, writeFreq, EncodeAudio) + + inPath := "../../../test/test-data/av/input/sweep_400Hz_20000Hz_-3dBFS_5s_48khz.pcm" + inPcm, err := ioutil.ReadFile(inPath) + if err != nil { + t.Errorf("unable to read file: %v", err) + } + + // Break pcm into blocks and encode to mts and get the resulting bytes. + for i := 0; i < len(inPcm); i += blockSize { + if len(inPcm)-i < blockSize { + block := inPcm[i:] + _, err = e.Write(block) + if err != nil { + t.Errorf("unable to write block: %v", err) + } + } else { + block := inPcm[i : i+blockSize] + _, err = e.Write(block) + if err != nil { + t.Errorf("unable to write block: %v", err) + } + } + } + clip := buf.Bytes() + + // Get the first MTS packet to check + var pkt packet.Packet + pesPacket := make([]byte, 0, blockSize) + got := make([]byte, 0, len(inPcm)) + i := 0 + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + + // Loop through MTS packets until all the audio data from PES packets has been retrieved + for i+PacketSize <= len(clip) { + + // Check MTS packet + if !(pkt.PID() == audioPid) { + i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + continue + } + if !pkt.PayloadUnitStartIndicator() { + i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + } else { + // Copy the first MTS payload + payload, err := pkt.Payload() + if err != nil { + t.Errorf("unable to get MTS payload: %v", err) + } + pesPacket = append(pesPacket, payload...) + + i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + + // Copy the rest of the MTS payloads that are part of the same PES packet + for (!pkt.PayloadUnitStartIndicator()) && i+PacketSize <= len(clip) { + payload, err = pkt.Payload() + if err != nil { + t.Errorf("unable to get MTS payload: %v", err) + } + pesPacket = append(pesPacket, payload...) + + i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + } + } + // Get the audio data from the current PES packet + pesHeader, err := pes.NewPESHeader(pesPacket) + if err != nil { + t.Errorf("unable to read PES packet: %v", err) + } + got = append(got, pesHeader.Data()...) + pesPacket = pesPacket[:0] + } + + // Compare data from MTS with original data. + if !bytes.Equal(got, inPcm) { + t.Error("data decoded from mts did not match input data") + } +} diff --git a/container/mts/metaEncode_test.go b/container/mts/metaEncode_test.go index 939de5b7..83660777 100644 --- a/container/mts/metaEncode_test.go +++ b/container/mts/metaEncode_test.go @@ -48,7 +48,7 @@ const fps = 25 func TestMetaEncode1(t *testing.T) { Meta = meta.New() var buf bytes.Buffer - e := NewEncoder(nopCloser{&buf}, fps, Video) + e := NewEncoder(nopCloser{&buf}, fps, EncodeH264) Meta.Add("ts", "12345678") if err := e.writePSI(); err != nil { t.Errorf(errUnexpectedErr, err.Error()) @@ -76,7 +76,7 @@ func TestMetaEncode1(t *testing.T) { func TestMetaEncode2(t *testing.T) { Meta = meta.New() var buf bytes.Buffer - e := NewEncoder(nopCloser{&buf}, fps, Video) + e := NewEncoder(nopCloser{&buf}, fps, EncodeH264) Meta.Add("ts", "12345678") Meta.Add("loc", "1234,4321,1234") if err := e.writePSI(); err != nil { diff --git a/container/mts/mpegts.go b/container/mts/mpegts.go index ed8cbe02..eb4bee5d 100644 --- a/container/mts/mpegts.go +++ b/container/mts/mpegts.go @@ -1,7 +1,7 @@ /* NAME mpegts.go - provides a data structure intended to encapsulate the properties - of an MpegTs packet and also functions to allow manipulation of these packets. + of an MPEG-TS packet and also functions to allow manipulation of these packets. DESCRIPTION See Readme.md @@ -26,6 +26,7 @@ LICENSE along with revid in gpl.txt. If not, see [GNU licenses](http://www.gnu.org/licenses). */ +// Package mts provides MPEGT-TS (mts) encoding and related functions. package mts import ( @@ -33,13 +34,10 @@ import ( "fmt" "github.com/Comcast/gots/packet" + "github.com/Comcast/gots/pes" ) -// General mpegts packet properties. -const ( - PacketSize = 188 - PayloadSize = 176 -) +const PacketSize = 188 // Program ID for various types of ts packets. const ( @@ -52,7 +50,7 @@ const ( // StreamID is the id of the first stream. const StreamID = 0xe0 -// HeadSize is the size of an mpegts packet header. +// HeadSize is the size of an MPEG-TS packet header. const HeadSize = 4 // Consts relating to adaptation field. @@ -163,28 +161,28 @@ type Packet struct { Payload []byte // Mpeg ts Payload } -// FindPmt will take a clip of mpegts and try to find a PMT table - if one +// FindPmt will take a clip of MPEG-TS and try to find a PMT table - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. func FindPmt(d []byte) ([]byte, int, error) { return FindPid(d, PmtPid) } -// FindPat will take a clip of mpegts and try to find a PAT table - if one +// FindPat will take a clip of MPEG-TS and try to find a PAT table - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. func FindPat(d []byte) ([]byte, int, error) { return FindPid(d, PatPid) } -// FindPid will take a clip of mpegts and try to find a packet with given PID - if one +// FindPid will take a clip of MPEG-TS and try to find a packet with given PID - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) { if len(d) < PacketSize { - return nil, -1, errors.New("Mmpegts data not of valid length") + return nil, -1, errors.New("MPEG-TS data not of valid length") } for i = 0; i < len(d); i += PacketSize { p := (uint16(d[i+1]&0x1f) << 8) | uint16(d[i+2]) if p == pid { - pkt = d[i+4 : i+PacketSize] + pkt = d[i : i+PacketSize] return } } @@ -194,16 +192,69 @@ func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) { // FillPayload takes a channel and fills the packets Payload field until the // channel is empty or we've the packet reaches capacity func (p *Packet) FillPayload(data []byte) int { - currentPktLen := 6 + asInt(p.PCRF)*6 + asInt(p.OPCRF)*6 + - asInt(p.SPF)*1 + asInt(p.TPDF)*1 + len(p.TPD) - if len(data) > PayloadSize-currentPktLen { - p.Payload = make([]byte, PayloadSize-currentPktLen) + currentPktLen := 6 + asInt(p.PCRF)*6 + if len(data) > PacketSize-currentPktLen { + p.Payload = make([]byte, PacketSize-currentPktLen) } else { p.Payload = make([]byte, len(data)) } return copy(p.Payload, data) } +// Bytes interprets the fields of the ts packet instance and outputs a +// corresponding byte slice +func (p *Packet) Bytes(buf []byte) []byte { + if buf == nil || cap(buf) < PacketSize { + buf = make([]byte, PacketSize) + } + + if p.OPCRF { + panic("original program clock reference field unsupported") + } + if p.SPF { + panic("splicing countdown unsupported") + } + if p.TPDF { + panic("transport private data unsupported") + } + if p.AFEF { + panic("adaptation field extension unsupported") + } + + buf = buf[:6] + buf[0] = 0x47 + buf[1] = (asByte(p.TEI)<<7 | asByte(p.PUSI)<<6 | asByte(p.Priority)<<5 | byte((p.PID&0xFF00)>>8)) + buf[2] = byte(p.PID & 0x00FF) + buf[3] = (p.TSC<<6 | p.AFC<<4 | p.CC) + + var maxPayloadSize int + if p.AFC&0x2 != 0 { + maxPayloadSize = PacketSize - 6 - asInt(p.PCRF)*6 + } else { + maxPayloadSize = PacketSize - 4 + } + + stuffingLen := maxPayloadSize - len(p.Payload) + if p.AFC&0x2 != 0 { + buf[4] = byte(1 + stuffingLen + asInt(p.PCRF)*6) + buf[5] = (asByte(p.DI)<<7 | asByte(p.RAI)<<6 | asByte(p.ESPI)<<5 | asByte(p.PCRF)<<4 | asByte(p.OPCRF)<<3 | asByte(p.SPF)<<2 | asByte(p.TPDF)<<1 | asByte(p.AFEF)) + } else { + buf = buf[:4] + } + + for i := 40; p.PCRF && i >= 0; i -= 8 { + buf = append(buf, byte((p.PCR<<15)>>uint(i))) + } + + for i := 0; i < stuffingLen; i++ { + buf = append(buf, 0xff) + } + curLen := len(buf) + buf = buf[:PacketSize] + copy(buf[curLen:], p.Payload) + return buf +} + func asInt(b bool) int { if b { return 1 @@ -218,55 +269,6 @@ func asByte(b bool) byte { return 0 } -// Bytes interprets the fields of the ts packet instance and outputs a -// corresponding byte slice -func (p *Packet) Bytes(buf []byte) []byte { - if buf == nil || cap(buf) != PacketSize { - buf = make([]byte, 0, PacketSize) - } - buf = buf[:0] - stuffingLength := 182 - len(p.Payload) - len(p.TPD) - asInt(p.PCRF)*6 - - asInt(p.OPCRF)*6 - asInt(p.SPF) - var stuffing []byte - if stuffingLength > 0 { - stuffing = make([]byte, stuffingLength) - } - for i := range stuffing { - stuffing[i] = 0xFF - } - afl := 1 + asInt(p.PCRF)*6 + asInt(p.OPCRF)*6 + asInt(p.SPF) + asInt(p.TPDF) + len(p.TPD) + len(stuffing) - buf = append(buf, []byte{ - 0x47, - (asByte(p.TEI)<<7 | asByte(p.PUSI)<<6 | asByte(p.Priority)<<5 | byte((p.PID&0xFF00)>>8)), - byte(p.PID & 0x00FF), - (p.TSC<<6 | p.AFC<<4 | p.CC), - }...) - - if p.AFC == 3 || p.AFC == 2 { - buf = append(buf, []byte{ - byte(afl), (asByte(p.DI)<<7 | asByte(p.RAI)<<6 | asByte(p.ESPI)<<5 | - asByte(p.PCRF)<<4 | asByte(p.OPCRF)<<3 | asByte(p.SPF)<<2 | - asByte(p.TPDF)<<1 | asByte(p.AFEF)), - }...) - for i := 40; p.PCRF && i >= 0; i -= 8 { - buf = append(buf, byte((p.PCR<<15)>>uint(i))) - } - for i := 40; p.OPCRF && i >= 0; i -= 8 { - buf = append(buf, byte(p.OPCR>>uint(i))) - } - if p.SPF { - buf = append(buf, p.SC) - } - if p.TPDF { - buf = append(buf, append([]byte{p.TPDL}, p.TPD...)...) - } - buf = append(buf, p.Ext...) - buf = append(buf, stuffing...) - } - buf = append(buf, p.Payload...) - return buf -} - type Option func(p *packet.Packet) // addAdaptationField adds an adaptation field to p, and applys the passed options to this field. @@ -315,3 +317,47 @@ func DiscontinuityIndicator(f bool) Option { p[DiscontinuityIndicatorIdx] |= DiscontinuityIndicatorMask & set } } + +// GetPTSRange retreives the first and last PTS of an MPEGTS clip. +func GetPTSRange(clip []byte, pid uint16) (pts [2]uint64, err error) { + // Find the first packet with PID pidType. + pkt, _, err := FindPid(clip, pid) + if err != nil { + return [2]uint64{}, err + } + + // Get the payload of the packet, which will be the start of the PES packet. + var _pkt packet.Packet + copy(_pkt[:], pkt) + payload, err := packet.Payload(&_pkt) + if err != nil { + fmt.Printf("_pkt: %v\n", _pkt) + return [2]uint64{}, err + } + + // Get the the first PTS from the PES header. + _pes, err := pes.NewPESHeader(payload) + if err != nil { + return [2]uint64{}, err + } + pts[0] = _pes.PTS() + + // Get the final PTS searching from end of clip for access unit start. + for i := len(clip) - PacketSize; i >= 0; i -= PacketSize { + copy(_pkt[:], clip[i:i+PacketSize]) + if packet.PayloadUnitStartIndicator(&_pkt) && uint16(_pkt.PID()) == pid { + payload, err = packet.Payload(&_pkt) + if err != nil { + return [2]uint64{}, err + } + _pes, err = pes.NewPESHeader(payload) + if err != nil { + return [2]uint64{}, err + } + pts[1] = _pes.PTS() + return + } + } + + return [2]uint64{}, errors.New("could only find one access unit in mpegts clip") +} diff --git a/container/mts/mpegts_test.go b/container/mts/mpegts_test.go new file mode 100644 index 00000000..4c90cc0e --- /dev/null +++ b/container/mts/mpegts_test.go @@ -0,0 +1,266 @@ +/* +NAME + mpegts_test.go + +DESCRIPTION + mpegts_test.go contains testing for functionality found in mpegts.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + Copyright (C) 2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package mts + +import ( + "bytes" + "math/rand" + "testing" + "time" + + "bitbucket.org/ausocean/av/container/mts/pes" + "bitbucket.org/ausocean/av/container/mts/psi" + "github.com/Comcast/gots/packet" +) + +// TestGetPTSRange checks that GetPTSRange can correctly get the first and last +// PTS in an MPEGTS clip. +func TestGetPTSRange(t *testing.T) { + const ( + numOfFrames = 20 + maxFrameSize = 1000 + minFrameSize = 100 + rate = 25 // fps + interval = float64(1) / rate // s + ptsFreq = 90000 // Hz + ) + + // Generate randomly sized data for each frame. + rand.Seed(time.Now().UnixNano()) + frames := make([][]byte, numOfFrames) + for i := range frames { + size := rand.Intn(maxFrameSize-minFrameSize) + minFrameSize + frames[i] = make([]byte, size) + } + + var clip bytes.Buffer + + // Write the PSI first. + err := writePSI(&clip) + if err != nil { + t.Fatalf("did not expect error writing psi: %v", err) + } + + // Now write frames. + var curTime float64 + for _, frame := range frames { + nextPTS := curTime * ptsFreq + + err = writeFrame(&clip, frame, uint64(nextPTS)) + if err != nil { + t.Fatalf("did not expect error writing frame: %v", err) + } + + curTime += interval + } + + got, err := GetPTSRange(clip.Bytes(), videoPid) + if err != nil { + t.Fatalf("did not expect error getting PTS range: %v", err) + } + + want := [2]uint64{0, uint64((numOfFrames - 1) * interval * ptsFreq)} + if got != want { + t.Errorf("did not get expected result.\n Got: %v\n Want: %v\n", got, want) + } +} + +// writePSI is a helper function write the PSI found at the start of a clip. +func writePSI(b *bytes.Buffer) error { + // Write PAT. + pat := Packet{ + PUSI: true, + PID: PatPid, + CC: 0, + AFC: HasPayload, + Payload: psi.AddPadding(patTable), + } + _, err := b.Write(pat.Bytes(nil)) + if err != nil { + return err + } + + // Write PMT. + pmt := Packet{ + PUSI: true, + PID: PmtPid, + CC: 0, + AFC: HasPayload, + Payload: psi.AddPadding(pmtTable), + } + _, err = b.Write(pmt.Bytes(nil)) + if err != nil { + return err + } + return nil +} + +// writeFrame is a helper function used to form a PES packet from a frame, and +// then fragment this across MPEGTS packets where they are then written to the +// given buffer. +func writeFrame(b *bytes.Buffer, frame []byte, pts uint64) error { + // Prepare PES data. + pesPkt := pes.Packet{ + StreamID: H264ID, + PDI: hasPTS, + PTS: pts, + Data: frame, + HeaderLength: 5, + } + buf := pesPkt.Bytes(nil) + + // Write PES data acroos MPEGTS packets. + pusi := true + for len(buf) != 0 { + pkt := Packet{ + PUSI: pusi, + PID: videoPid, + RAI: pusi, + CC: 0, + AFC: hasAdaptationField | hasPayload, + PCRF: pusi, + } + n := pkt.FillPayload(buf) + buf = buf[n:] + + pusi = false + _, err := b.Write(pkt.Bytes(nil)) + if err != nil { + return err + } + } + return nil +} + +// TestBytes checks that Packet.Bytes() correctly produces a []byte +// representation of a Packet. +func TestBytes(t *testing.T) { + const payloadLen, payloadChar, stuffingChar = 120, 0x11, 0xff + const stuffingLen = PacketSize - payloadLen - 12 + + tests := []struct { + packet Packet + expectedHeader []byte + }{ + { + packet: Packet{ + PUSI: true, + PID: 1, + RAI: true, + CC: 4, + AFC: HasPayload | HasAdaptationField, + PCRF: true, + PCR: 1, + }, + expectedHeader: []byte{ + 0x47, // Sync byte. + 0x40, // TEI=0, PUSI=1, TP=0, PID=00000. + 0x01, // PID(Cont)=00000001. + 0x34, // TSC=00, AFC=11(adaptation followed by payload), CC=0100(4). + byte(7 + stuffingLen), // AFL=. + 0x50, // DI=0,RAI=1,ESPI=0,PCRF=1,OPCRF=0,SPF=0,TPDF=0, AFEF=0. + 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, // PCR. + }, + }, + } + + for testNum, test := range tests { + // Construct payload. + payload := make([]byte, 0, payloadLen) + for i := 0; i < payloadLen; i++ { + payload = append(payload, payloadChar) + } + + // Fill the packet payload. + test.packet.FillPayload(payload) + + // Create expected packet data and copy in expected header. + expected := make([]byte, len(test.expectedHeader), PacketSize) + copy(expected, test.expectedHeader) + + // Append stuffing. + for i := 0; i < stuffingLen; i++ { + expected = append(expected, stuffingChar) + } + + // Append payload to expected bytes. + expected = append(expected, payload...) + + // Compare got with expected. + got := test.packet.Bytes(nil) + if !bytes.Equal(got, expected) { + t.Errorf("did not get expected result for test: %v.\n Got: %v\n Want: %v\n", testNum, got, expected) + } + } +} + +// TestFindPid checks that FindPid can correctly extract the first instance +// of a PID from an MPEG-TS stream. +func TestFindPid(t *testing.T) { + const targetPacketNum, numOfPackets, targetPid, stdPid = 6, 15, 1, 0 + + // Prepare the stream of packets. + var stream []byte + for i := 0; i < numOfPackets; i++ { + pid := uint16(stdPid) + if i == targetPacketNum { + pid = targetPid + } + + p := Packet{ + PID: pid, + AFC: hasPayload | hasAdaptationField, + } + p.FillPayload([]byte{byte(i)}) + stream = append(stream, p.Bytes(nil)...) + } + + // Try to find the targetPid in the stream. + p, i, err := FindPid(stream, targetPid) + if err != nil { + t.Fatalf("unexpected error finding PID: %v\n", err) + } + + // Check the payload. + var _p packet.Packet + copy(_p[:], p) + payload, err := packet.Payload(&_p) + if err != nil { + t.Fatalf("unexpected error getting packet payload: %v\n", err) + } + got := payload[0] + if got != targetPacketNum { + t.Errorf("payload of found packet is not correct.\nGot: %v, Want: %v\n", got, targetPacketNum) + } + + // Check the index. + _got := i / PacketSize + if _got != targetPacketNum { + t.Errorf("index of found packet is not correct.\nGot: %v, want: %v\n", _got, targetPacketNum) + } +} diff --git a/container/mts/pes/pes.go b/container/mts/pes/pes.go index b0e40f86..5b5cb612 100644 --- a/container/mts/pes/pes.go +++ b/container/mts/pes/pes.go @@ -26,6 +26,8 @@ LICENSE package pes +import "github.com/Comcast/gots" + const MaxPesSize = 64 * 1 << 10 /* @@ -108,16 +110,11 @@ func (p *Packet) Bytes(buf []byte) []byte { boolByte(p.ACIF)<<2 | boolByte(p.CRCF)<<1 | boolByte(p.EF)), p.HeaderLength, }...) + if p.PDI == byte(2) { - pts := 0x2100010001 | (p.PTS&0x1C0000000)<<3 | (p.PTS&0x3FFF8000)<<2 | - (p.PTS&0x7FFF)<<1 - buf = append(buf, []byte{ - byte((pts & 0xFF00000000) >> 32), - byte((pts & 0x00FF000000) >> 24), - byte((pts & 0x0000FF0000) >> 16), - byte((pts & 0x000000FF00) >> 8), - byte(pts & 0x00000000FF), - }...) + ptsIdx := len(buf) + buf = buf[:ptsIdx+5] + gots.InsertPTS(buf[ptsIdx:], p.PTS) } buf = append(buf, append(p.Stuff, p.Data...)...) return buf diff --git a/container/mts/psi/helpers.go b/container/mts/psi/helpers.go index b8bab6b5..621460f5 100644 --- a/container/mts/psi/helpers.go +++ b/container/mts/psi/helpers.go @@ -125,7 +125,7 @@ func trimTo(d []byte, t byte) []byte { } // addPadding adds an appropriate amount of padding to a pat or pmt table for -// addition to an mpegts packet +// addition to an MPEG-TS packet func AddPadding(d []byte) []byte { t := make([]byte, PacketSize) copy(t, d) diff --git a/container/mts/psi/psi.go b/container/mts/psi/psi.go index c93d3011..3703faf4 100644 --- a/container/mts/psi/psi.go +++ b/container/mts/psi/psi.go @@ -32,7 +32,7 @@ import ( "github.com/Comcast/gots/psi" ) -// PacketSize of psi (without mpegts header) +// PacketSize of psi (without MPEG-TS header) const PacketSize = 184 // Lengths of section definitions. diff --git a/protocol/rtcp/client.go b/protocol/rtcp/client.go index 7d6c995c..4ac5b694 100644 --- a/protocol/rtcp/client.go +++ b/protocol/rtcp/client.go @@ -37,16 +37,18 @@ import ( "sync" "time" + "bitbucket.org/ausocean/av/protocol/rtp" "bitbucket.org/ausocean/utils/logger" ) const ( - senderSSRC = 1 // Any non-zero value will do. - defaultClientName = "Client" - delayUnit = 1.0 / 65536.0 - pkg = "rtcp: " - rtcpVer = 2 - receiverBufSize = 200 + clientSSRC = 1 // Any non-zero value will do. + defaultClientName = "Client" + defaultSendInterval = 2 * time.Second + delayUnit = 1.0 / 65536.0 + pkg = "rtcp: " + rtcpVer = 2 + receiverBufSize = 200 ) // Log describes a function signature required by the RTCP for the purpose of @@ -70,23 +72,18 @@ type Client struct { wg sync.WaitGroup // This is used to wait for send and recv routines to stop when Client is stopped. quit chan struct{} // Channel used to communicate quit signal to send and recv routines. log Log // Used to log any messages. - - err chan error // Client will send any errors through this chan. Can be accessed by Err(). + rtpClt *rtp.Client + err chan error // Client will send any errors through this chan. Can be accessed by Err(). } // NewClient returns a pointer to a new Client. -func NewClient(clientAddress, serverAddress, name string, sendInterval time.Duration, rtpSSRC uint32, l Log) (*Client, error) { - if name == "" { - name = defaultClientName - } - +func NewClient(clientAddress, serverAddress string, rtpClt *rtp.Client, l Log) (*Client, error) { c := &Client{ - name: name, - err: make(chan error), - quit: make(chan struct{}), - interval: sendInterval, - sourceSSRC: rtpSSRC, - log: l, + name: defaultClientName, + quit: make(chan struct{}), + interval: defaultSendInterval, + rtpClt: rtpClt, + log: l, } var err error @@ -107,11 +104,23 @@ func NewClient(clientAddress, serverAddress, name string, sendInterval time.Dura return c, nil } +// SetSendInterval sets a custom receiver report send interval (default is 5 seconds.) +func (c *Client) SetSendInterval(d time.Duration) { + c.interval = d +} + +// SetName sets a custom client name for use in receiver report source description. +// Default is "Client". +func (c *Client) SetName(name string) { + c.name = name +} + // Start starts the listen and send routines. This will start the process of // receiving and parsing sender reports, and the process of sending receiver // reports to the server. func (c *Client) Start() { c.log(logger.Debug, pkg+"Client is starting") + c.err = make(chan error) c.wg.Add(2) go c.recv() go c.send() @@ -124,6 +133,7 @@ func (c *Client) Stop() { close(c.quit) c.conn.Close() c.wg.Wait() + close(c.err) } // Err provides read access to the Client err channel. This must be checked @@ -171,13 +181,13 @@ func (c *Client) send() { ReportCount: 1, Type: typeReceiverReport, }, - SenderSSRC: senderSSRC, + SenderSSRC: clientSSRC, Blocks: []ReportBlock{ ReportBlock{ - SourceIdentifier: c.sourceSSRC, + SourceIdentifier: c.rtpClt.SSRC(), FractionLost: 0, PacketsLost: math.MaxUint32, - HighestSequence: c.sequence(), + HighestSequence: uint32((c.rtpClt.Cycles() << 16) | c.rtpClt.Sequence()), Jitter: c.jitter(), SenderReportTs: c.lastSenderTs(), SenderReportDelay: c.delay(), @@ -195,7 +205,7 @@ func (c *Client) send() { }, Chunks: []Chunk{ Chunk{ - SSRC: senderSSRC, + SSRC: clientSSRC, Items: []SDESItem{ SDESItem{ Type: typeCName, @@ -238,22 +248,6 @@ func (c *Client) parse(buf []byte) { c.setSenderTs(t) } -// SetSequence will allow updating of the highest sequence number received -// through an RTP stream. -func (c *Client) SetSequence(s uint32) { - c.mu.Lock() - c.seq = s - c.mu.Unlock() -} - -// sequence will return the highest sequence number received through RTP. -func (c *Client) sequence() uint32 { - c.mu.Lock() - s := c.seq - c.mu.Unlock() - return s -} - // jitter returns the interarrival jitter as described by RTCP specifications: // https://tools.ietf.org/html/rfc3550 // TODO(saxon): complete this. diff --git a/protocol/rtcp/client_test.go b/protocol/rtcp/client_test.go index 64a4d685..6c95c75d 100644 --- a/protocol/rtcp/client_test.go +++ b/protocol/rtcp/client_test.go @@ -37,6 +37,7 @@ import ( "testing" "time" + "bitbucket.org/ausocean/av/protocol/rtp" "bitbucket.org/ausocean/utils/logger" ) @@ -148,12 +149,15 @@ func (dl *dummyLogger) log(lvl int8, msg string, args ...interface{}) { // respond with sender reports. func TestReceiveAndSend(t *testing.T) { const clientAddr, serverAddr = "localhost:8000", "localhost:8001" + rtpClt, err := rtp.NewClient("localhost:8002") + if err != nil { + t.Fatalf("unexpected error when creating RTP client: %v", err) + } + c, err := NewClient( clientAddr, serverAddr, - "testClient", - 10*time.Millisecond, - 12345, + rtpClt, (*dummyLogger)(t).log, ) if err != nil { @@ -162,14 +166,14 @@ func TestReceiveAndSend(t *testing.T) { go func() { for { - select { - case err := <-c.Err(): + err, ok := <-c.Err() + if ok { const errConnClosed = "use of closed network connection" if !strings.Contains(err.Error(), errConnClosed) { t.Fatalf("error received from client error chan: %v\n", err) } - - default: + } else { + return } } }() @@ -197,8 +201,6 @@ func TestReceiveAndSend(t *testing.T) { n, _, _ := conn.ReadFromUDP(buf) t.Logf("SERVER: receiver report received: \n%v\n", buf[:n]) - c.SetSequence(uint32(i)) - now := time.Now().Second() var time [8]byte binary.BigEndian.PutUint64(time[:], uint64(now)) diff --git a/protocol/rtcp/parse.go b/protocol/rtcp/parse.go index 2f756f4b..2007a26d 100644 --- a/protocol/rtcp/parse.go +++ b/protocol/rtcp/parse.go @@ -38,9 +38,9 @@ type Timestamp struct { Fraction uint32 } -// Timestamp gets the timestamp from a receiver report and returns it as the most -// significant word, and the least significant word. If the given bytes do not -// represent a valid receiver report, an error is returned. +// ParseTimestamp gets the timestamp from a receiver report and returns it as +// a Timestamp as defined above. If the given bytes do not represent a valid +// receiver report, an error is returned. func ParseTimestamp(buf []byte) (Timestamp, error) { if len(buf) < 4 { return Timestamp{}, errors.New("bad RTCP packet, not of sufficient length") diff --git a/protocol/rtmp/rtmp_test.go b/protocol/rtmp/rtmp_test.go index 1cf056cb..e1e79796 100644 --- a/protocol/rtmp/rtmp_test.go +++ b/protocol/rtmp/rtmp_test.go @@ -38,7 +38,7 @@ import ( "testing" "time" - "bitbucket.org/ausocean/av/codec/lex" + "bitbucket.org/ausocean/av/codec/h264" "bitbucket.org/ausocean/av/container/flv" ) @@ -199,7 +199,7 @@ func TestFromFrame(t *testing.T) { if err != nil { t.Errorf("Failed to create flv encoder with error: %v", err) } - err = lex.H264(flvEncoder, bytes.NewReader(videoData), time.Second/time.Duration(frameRate)) + err = h264.Lex(flvEncoder, bytes.NewReader(videoData), time.Second/time.Duration(frameRate)) if err != nil { t.Errorf("Lexing failed with error: %v", err) } @@ -251,7 +251,7 @@ func TestFromFile(t *testing.T) { if err != nil { t.Fatalf("failed to create encoder: %v", err) } - err = lex.H264(flvEncoder, f, time.Second/time.Duration(25)) + err = h264.Lex(flvEncoder, f, time.Second/time.Duration(25)) if err != nil { t.Errorf("Lexing and encoding failed with error: %v", err) } diff --git a/protocol/rtp/client.go b/protocol/rtp/client.go index f47f9baf..e8418b0d 100644 --- a/protocol/rtp/client.go +++ b/protocol/rtp/client.go @@ -29,12 +29,17 @@ package rtp import ( "net" + "sync" ) // Client describes an RTP client that can receive an RTP stream and implements // io.Reader. type Client struct { - conn *net.UDPConn + r *PacketReader + ssrc uint32 + mu sync.Mutex + sequence uint16 + cycles uint16 } // NewClient returns a pointer to a new Client. @@ -42,22 +47,72 @@ type Client struct { // addr is the address of form : that we expect to receive // RTP at. func NewClient(addr string) (*Client, error) { - c := &Client{} + c := &Client{r: &PacketReader{}} a, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } - c.conn, err = net.ListenUDP("udp", a) + c.r.PacketConn, err = net.ListenUDP("udp", a) if err != nil { return nil, err } - return c, nil } -// Read implements io.Reader. This wraps the Read for the connection. -func (c *Client) Read(p []byte) (int, error) { - return c.conn.Read(p) +// SSRC returns the identified for the source from which the RTP packets being +// received are coming from. +func (c *Client) SSRC() uint32 { + return c.ssrc +} + +// Read implements io.Reader. +func (c *Client) Read(p []byte) (int, error) { + n, err := c.r.Read(p) + if err != nil { + return n, err + } + if c.ssrc == 0 { + c.ssrc, _ = SSRC(p[:n]) + } + s, _ := Sequence(p[:n]) + c.setSequence(s) + return n, err +} + +// setSequence sets the most recently received sequence number, and updates the +// cycles count if the sequence number has rolled over. +func (c *Client) setSequence(s uint16) { + c.mu.Lock() + if s < c.sequence { + c.cycles++ + } + c.sequence = s + c.mu.Unlock() +} + +// Sequence returns the most recent RTP packet sequence number received. +func (c *Client) Sequence() uint16 { + c.mu.Lock() + defer c.mu.Unlock() + return c.sequence +} + +// Cycles returns the number of RTP sequence number cycles that have been received. +func (c *Client) Cycles() uint16 { + c.mu.Lock() + defer c.mu.Unlock() + return c.cycles +} + +// PacketReader provides an io.Reader interface to an underlying UDP PacketConn. +type PacketReader struct { + net.PacketConn +} + +// Read implements io.Reader. +func (r PacketReader) Read(b []byte) (int, error) { + n, _, err := r.PacketConn.ReadFrom(b) + return n, err } diff --git a/protocol/rtp/parse.go b/protocol/rtp/parse.go index d658aa20..16e64c5d 100644 --- a/protocol/rtp/parse.go +++ b/protocol/rtp/parse.go @@ -34,14 +34,25 @@ import ( const badVer = "incompatible RTP version" -// Payload returns the payload from an RTP packet provided the version is -// compatible, otherwise an error is returned. -func Payload(d []byte) ([]byte, error) { +// Marker returns the state of the RTP marker bit, and an error if parsing fails. +func Marker(d []byte) (bool, error) { if len(d) < defaultHeadSize { panic("invalid RTP packet length") } + if version(d) != rtpVer { - return nil, errors.New(badVer) + return false, errors.New(badVer) + } + + return d[1]&0x80 != 0, nil +} + +// Payload returns the payload from an RTP packet provided the version is +// compatible, otherwise an error is returned. +func Payload(d []byte) ([]byte, error) { + err := checkPacket(d) + if err != nil { + return nil, err } extLen := 0 if hasExt(d) { @@ -51,6 +62,38 @@ func Payload(d []byte) ([]byte, error) { return d[payloadIdx:], nil } +// SSRC returns the source identifier from an RTP packet. An error is return if +// the packet is not valid. +func SSRC(d []byte) (uint32, error) { + err := checkPacket(d) + if err != nil { + return 0, err + } + return binary.BigEndian.Uint32(d[8:]), nil +} + +// Sequence returns the sequence number of an RTP packet. An error is returned +// if the packet is not valid. +func Sequence(d []byte) (uint16, error) { + err := checkPacket(d) + if err != nil { + return 0, err + } + return binary.BigEndian.Uint16(d[2:]), nil +} + +// checkPacket checks the validity of the packet, firstly by checking size and +// then also checking that version is compatible with these utilities. +func checkPacket(d []byte) error { + if len(d) < defaultHeadSize { + return errors.New("invalid RTP packet length") + } + if version(d) != rtpVer { + return errors.New(badVer) + } + return nil +} + // hasExt returns true if an extension is present in the RTP packet. func hasExt(d []byte) bool { return (d[0] & 0x10 >> 4) == 1 diff --git a/protocol/rtsp/client.go b/protocol/rtsp/client.go new file mode 100644 index 00000000..f6c9d0eb --- /dev/null +++ b/protocol/rtsp/client.go @@ -0,0 +1,141 @@ +/* +NAME + client.go + +DESCRIPTION + client.go provides a Client type providing functionality to send RTSP requests + of methods DESCRIBE, OPTIONS, SETUP and PLAY to an RTSP server. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtsp + +import ( + "net" + "net/url" + "strconv" +) + +// Client describes an RTSP Client. +type Client struct { + cSeq int + addr string + url *url.URL + conn net.Conn + sessionID string +} + +// NewClient returns a pointer to a new Client and the local address of the +// RTSP connection. The address addr will be parsed and a connection to the +// RTSP server will be made. +func NewClient(addr string) (c *Client, local, remote *net.TCPAddr, err error) { + c = &Client{addr: addr} + c.url, err = url.Parse(addr) + if err != nil { + return nil, nil,nil, err + } + c.conn, err = net.Dial("tcp", c.url.Host) + if err != nil { + return nil, nil, nil, err + } + local = c.conn.LocalAddr().(*net.TCPAddr) + remote = c.conn.RemoteAddr().(*net.TCPAddr) + return +} + +// Close closes the RTSP connection. +func (c *Client) Close() error { + return c.conn.Close() +} + +// Describe forms and sends an RTSP request of method DESCRIBE to the RTSP server. +func (c *Client) Describe() (*Response, error) { + req, err := NewRequest("DESCRIBE", c.nextCSeq(), c.url, nil) + if err != nil { + return nil, err + } + req.Header.Add("Accept", "application/sdp") + return c.Do(req) +} + +// Options forms and sends an RTSP request of method OPTIONS to the RTSP server. +func (c *Client) Options() (*Response, error) { + req, err := NewRequest("OPTIONS", c.nextCSeq(), c.url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} + +// Setup forms and sends an RTSP request of method SETUP to the RTSP server. +func (c *Client) Setup(track, transport string) (*Response, error) { + u, err := url.Parse(c.addr + "/" + track) + if err != nil { + return nil, err + } + + req, err := NewRequest("SETUP", c.nextCSeq(), u, nil) + if err != nil { + return nil, err + } + req.Header.Add("Transport", transport) + + resp, err := c.Do(req) + if err != nil { + return nil, err + } + c.sessionID = resp.Header.Get("Session") + + return resp, err +} + +// Play forms and sends an RTSP request of method PLAY to the RTSP server +func (c *Client) Play() (*Response, error) { + req, err := NewRequest("PLAY", c.nextCSeq(), c.url, nil) + if err != nil { + return nil, err + } + req.Header.Add("Session", c.sessionID) + + return c.Do(req) +} + +// Do sends the given RTSP request req, reads any responses and returns the response +// and any errors. +func (c *Client) Do(req *Request) (*Response, error) { + err := req.Write(c.conn) + if err != nil { + return nil, err + } + + resp, err := ReadResponse(c.conn) + if err != nil { + return nil, err + } + + return resp, nil +} + +// nextCSeq provides the next CSeq number for the next RTSP request. +func (c *Client) nextCSeq() string { + c.cSeq++ + return strconv.Itoa(c.cSeq) +} diff --git a/protocol/rtsp/rtsp.go b/protocol/rtsp/rtsp.go new file mode 100644 index 00000000..9e0995c8 --- /dev/null +++ b/protocol/rtsp/rtsp.go @@ -0,0 +1,182 @@ +/* +NAME + rtsp.go + +DESCRIPTION + rtsp.go provides functionality for forming and sending RTSP requests for + methods, DESCRIBE, OPTIONS, SETUP and PLAY, as described by + the RTSP standards, see https://tools.ietf.org/html/rfc7826 + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +// Package rtsp provides an RTSP client implementation and methods for +// communication with an RTSP server to request video. +package rtsp + +import ( + "bufio" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "strconv" + "strings" +) + +// Minimum response size to be considered valid in bytes. +const minResponse = 12 + +var errInvalidResponse = errors.New("invalid response") + +// Request describes an RTSP request. +type Request struct { + Method string + URL *url.URL + Proto string + ProtoMajor int + ProtoMinor int + Header http.Header + ContentLength int + Body io.ReadCloser +} + +// NewRequest returns a pointer to a new Request. +func NewRequest(method, cSeq string, u *url.URL, body io.ReadCloser) (*Request, error) { + req := &Request{ + Method: method, + URL: u, + Proto: "RTSP", + ProtoMajor: 1, + ProtoMinor: 0, + Header: map[string][]string{"CSeq": []string{cSeq}}, + Body: body, + } + return req, nil +} + +// Write writes the request r to the given io.Writer w. +func (r *Request) Write(w io.Writer) error { + _, err := w.Write([]byte(r.String())) + return err +} + +// String returns a formatted string of the Request. +func (r Request) String() string { + var b strings.Builder + fmt.Fprintf(&b, "%s %s %s/%d.%d\r\n", r.Method, r.URL.String(), r.Proto, r.ProtoMajor, r.ProtoMinor) + for k, v := range r.Header { + for _, v := range v { + fmt.Fprintf(&b, "%s: %s\r\n", k, v) + } + } + b.WriteString("\r\n") + if r.Body != nil { + s, _ := ioutil.ReadAll(r.Body) + b.WriteString(string(s)) + } + return b.String() +} + +// Response describes an RTSP response. +type Response struct { + Proto string + ProtoMajor int + ProtoMinor int + StatusCode int + ContentLength int + Header http.Header + Body io.ReadCloser +} + +// String returns a formatted string of the Response. +func (r Response) String() string { + var b strings.Builder + fmt.Fprintf(&b, "%s/%d.%d %d\n", r.Proto, r.ProtoMajor, r.ProtoMinor, r.StatusCode) + for k, v := range r.Header { + for _, v := range v { + fmt.Fprintf(&b, "%s: %s", k, v) + } + } + return b.String() +} + +// ReadResponse will read the response of the RTSP request from the connection, +// and return a pointer to a new Response. +func ReadResponse(r io.Reader) (*Response, error) { + resp := &Response{Header: make(map[string][]string)} + + scanner := bufio.NewScanner(r) + + // Read the first line. + scanner.Scan() + err := scanner.Err() + if err != nil { + return nil, err + } + s := scanner.Text() + + if len(s) < minResponse || !strings.HasPrefix(s, "RTSP/") { + return nil, errInvalidResponse + } + resp.Proto = "RTSP" + + n, err := fmt.Sscanf(s[5:], "%d.%d %d", &resp.ProtoMajor, &resp.ProtoMinor, &resp.StatusCode) + if err != nil || n != 3 { + return nil, fmt.Errorf("could not Sscanf response, error: %v", err) + } + + // Read headers. + for scanner.Scan() { + err = scanner.Err() + if err != nil { + return nil, err + } + parts := strings.SplitN(scanner.Text(), ":", 2) + if len(parts) < 2 { + break + } + resp.Header.Add(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])) + } + // Get the content length from the header. + resp.ContentLength, _ = strconv.Atoi(resp.Header.Get("Content-Length")) + + resp.Body = closer{r} + return resp, nil +} + +type closer struct { + io.Reader +} + +func (c closer) Close() error { + if c.Reader == nil { + return nil + } + defer func() { + c.Reader = nil + }() + if r, ok := c.Reader.(io.ReadCloser); ok { + return r.Close() + } + return nil +} diff --git a/protocol/rtsp/rtsp_test.go b/protocol/rtsp/rtsp_test.go new file mode 100644 index 00000000..4157cfba --- /dev/null +++ b/protocol/rtsp/rtsp_test.go @@ -0,0 +1,344 @@ +/* +NAME + 0x r,tsp_test.go + +DESCRIPTION + rtsp_test.go provides a test to check functionality provided in rtsp.go. + +AUTHORS + Saxon A. Nelson-Milton + +LICENSE + This is Copyright (C) 2019 the Australian Ocean Lab (AusOcean). + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package rtsp + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "net/url" + "strings" + "testing" + "time" + "unicode" +) + +// The max request size we should get in bytes. +const maxRequest = 1024 + +// TestMethods checks that we can correctly form requests for each of the RTSP +// methods supported in the rtsp pkg. This test also checks that communication +// over a TCP connection is performed correctly. +func TestMethods(t *testing.T) { + const dummyURL = "rtsp://admin:admin@192.168.0.50:8554/CH001.sdp" + url, err := url.Parse(dummyURL) + if err != nil { + t.Fatalf("could not parse dummy address, failed with err: %v", err) + } + + // tests holds tests which consist of a function used to create and write a + // request of a particular method, and also the expected request bytes + // to be received on the server side. The bytes in these tests have been + // obtained from a valid RTSP communication cltion.. + tests := []struct { + method func(c *Client) (*Response, error) + expected []byte + }{ + { + method: func(c *Client) (*Response, error) { + req, err := NewRequest("DESCRIBE", c.nextCSeq(), url, nil) + if err != nil { + return nil, err + } + req.Header.Add("Accept", "application/sdp") + return c.Do(req) + }, + expected: []byte{ + 0x44, 0x45, 0x53, 0x43, 0x52, 0x49, 0x42, 0x45, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, 0x2f, // |DESCRIBE rtsp://| + 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, // |admin:admin@192.| + 0x31, 0x36, 0x38, 0x2e, 0x30, 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, // |168.0.50:8554/CH| + 0x30, 0x30, 0x31, 0x2e, 0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, // |001.sdp RTSP/1.0| + 0x0d, 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x32, 0x0d, 0x0a, 0x41, 0x63, 0x63, 0x65, 0x70, // |..CSeq: 2..Accep| + 0x74, 0x3a, 0x20, 0x61, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x73, // |t: application/s| + 0x64, 0x70, 0x0d, 0x0a, 0x0d, 0x0a, /* */ // |dp....| + }, + }, + { + method: func(c *Client) (*Response, error) { + req, err := NewRequest("OPTIONS", c.nextCSeq(), url, nil) + if err != nil { + return nil, err + } + return c.Do(req) + }, + expected: []byte{ + 0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x53, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, 0x2f, 0x61, // |OPTIONS rtsp://a| + 0x64, 0x6d, 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, // |dmin:admin@192.1| + 0x36, 0x38, 0x2e, 0x30, 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, // |68.0.50:8554/CH0| + 0x30, 0x31, 0x2e, 0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, 0x0d, // |01.sdp RTSP/1.0.| + 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x31, 0x0d, 0x0a, 0x0d, 0x0a, /* */ // |.CSeq: 1....| + }, + }, + { + method: func(c *Client) (*Response, error) { + u, err := url.Parse(dummyURL + "/track1") + if err != nil { + return nil, err + } + + req, err := NewRequest("SETUP", c.nextCSeq(), u, nil) + if err != nil { + return nil, err + } + req.Header.Add("Transport", fmt.Sprintf("RTP/AVP;unicast;client_port=%d-%d", 6870, 6871)) + + return c.Do(req) + }, + expected: []byte{ + 0x53, 0x45, 0x54, 0x55, 0x50, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, 0x2f, 0x61, 0x64, 0x6d, // |SETUP rtsp://adm| + 0x69, 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, // |in:admin@192.168| + 0x2e, 0x30, 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, 0x30, 0x31, // |.0.50:8554/CH001| + 0x2e, 0x73, 0x64, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x63, 0x6b, 0x31, 0x20, 0x52, 0x54, 0x53, 0x50, // |.sdp/track1 RTSP| + 0x2f, 0x31, 0x2e, 0x30, 0x0d, 0x0a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x3a, // |/1.0..Transport:| + 0x20, 0x52, 0x54, 0x50, 0x2f, 0x41, 0x56, 0x50, 0x3b, 0x75, 0x6e, 0x69, 0x63, 0x61, 0x73, 0x74, // | RTP/AVP;unicast| + 0x3b, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x3d, 0x36, 0x38, 0x37, // |;client_port=687| + 0x30, 0x2d, 0x36, 0x38, 0x37, 0x31, 0x0d, 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x33, 0x0d, // |0-6871..CSeq: 3.| + 0x0a, 0x0d, 0x0a, /* */ // |...| + }, + }, + { + method: func(c *Client) (*Response, error) { + req, err := NewRequest("PLAY", c.nextCSeq(), url, nil) + if err != nil { + return nil, err + } + req.Header.Add("Session", "00000021") + + return c.Do(req) + }, + expected: []byte{ + 0x50, 0x4c, 0x41, 0x59, 0x20, 0x72, 0x74, 0x73, 0x70, 0x3a, 0x2f, 0x2f, 0x61, 0x64, 0x6d, 0x69, // |PLAY rtsp://admi| + 0x6e, 0x3a, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x40, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36, 0x38, 0x2e, // |n:admin@192.168.| + 0x30, 0x2e, 0x35, 0x30, 0x3a, 0x38, 0x35, 0x35, 0x34, 0x2f, 0x43, 0x48, 0x30, 0x30, 0x31, 0x2e, // |0.50:8554/CH001.| + 0x73, 0x64, 0x70, 0x20, 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, 0x0d, 0x0a, 0x43, 0x53, // |sdp RTSP/1.0..CS| + 0x65, 0x71, 0x3a, 0x20, 0x34, 0x0d, 0x0a, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x3a, 0x20, // |eq: 4..Session: | + 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x32, 0x31, 0x0d, 0x0a, 0x0d, 0x0a, /* */ // |00000021....| + }, + }, + } + + const serverAddr = "rtsp://localhost:8005" + const retries = 10 + + clientErr := make(chan error) + serverErr := make(chan error) + done := make(chan struct{}) + + // This routine acts as the server. + go func() { + l, err := net.Listen("tcp", strings.TrimLeft(serverAddr, "rtsp://")) + if err != nil { + serverErr <- errors.New(fmt.Sprintf("server could not listen, error: %v", err)) + } + + conn, err := l.Accept() + if err != nil { + serverErr <- errors.New(fmt.Sprintf("server could not accept connection, error: %v", err)) + } + + buf := make([]byte, maxRequest) + var n int + for i, test := range tests { + loop: + for { + n, err = conn.Read(buf) + err, ok := err.(net.Error) + + switch { + case err == nil: + break loop + case err == io.EOF: + case ok && err.Timeout(): + default: + serverErr <- errors.New(fmt.Sprintf("server could not read conn, error: %v", err)) + return + } + } + + // Write a dummy response, client won't care. + conn.Write([]byte{'\n'}) + + want := test.expected + got := buf[:n] + if !equal(got, want) { + serverErr <- errors.New(fmt.Sprintf("unexpected result for test: %v. \nGot: %v\n Want: %v\n", i, got, want)) + } + } + close(done) + }() + + // This routine acts as the client. + go func() { + var clt *Client + var err error + + // Keep trying to connect to server. + // TODO: use generalised retry utility when available. + for retry := 0; ; retry++ { + clt, _, _, err = NewClient(serverAddr) + if err == nil { + break + } + + if retry > retries { + clientErr <- errors.New(fmt.Sprintf("client could not connect to server, error: %v", err)) + } + time.Sleep(10 * time.Millisecond) + } + + for i, test := range tests { + _, err = test.method(clt) + if err != nil && err != io.EOF && err != errInvalidResponse { + clientErr <- errors.New(fmt.Sprintf("error request for: %v err: %v", i, err)) + } + } + }() + + // We check for errors or a done signal from the server and client routines. + for { + select { + case err := <-clientErr: + t.Fatalf("client error: %v", err) + case err := <-serverErr: + t.Fatalf("server error: %v", err) + case <-done: + return + default: + } + } +} + +// equal checks that the got slice is considered equivalent to the want slice, +// neglecting unimportant differences such as order of items in header and the +// CSeq number. +func equal(got, want []byte) bool { + const eol = "\r\n" + gotParts := strings.Split(strings.TrimRight(string(got), eol), eol) + wantParts := strings.Split(strings.TrimRight(string(want), eol), eol) + gotParts, ok := rmSeqNum(gotParts) + if !ok { + return false + } + wantParts, ok = rmSeqNum(wantParts) + if !ok { + return false + } + for _, gotStr := range gotParts { + for i, wantStr := range wantParts { + if gotStr == wantStr { + wantParts = append(wantParts[:i], wantParts[i+1:]...) + } + } + } + return len(wantParts) == 0 +} + +// rmSeqNum removes the CSeq number from a string in []string that contains it. +// If a CSeq field is not found nil and false is returned. +func rmSeqNum(s []string) ([]string, bool) { + for i, _s := range s { + if strings.Contains(_s, "CSeq") { + s[i] = strings.TrimFunc(s[i], func(r rune) bool { return unicode.IsNumber(r) }) + return s, true + } + } + return nil, false +} + +// TestReadResponse checks that ReadResponse behaves as expected. +func TestReadResponse(t *testing.T) { + // input has been obtained from a valid RTSP response. + input := []byte{ + 0x52, 0x54, 0x53, 0x50, 0x2f, 0x31, 0x2e, 0x30, 0x20, 0x32, 0x30, 0x30, 0x20, 0x4f, 0x4b, 0x0d, // |RTSP/1.0 200 OK.| + 0x0a, 0x43, 0x53, 0x65, 0x71, 0x3a, 0x20, 0x32, 0x0d, 0x0a, 0x44, 0x61, 0x74, 0x65, 0x3a, 0x20, // |.CSeq: 2..Date: | + 0x57, 0x65, 0x64, 0x2c, 0x20, 0x4a, 0x61, 0x6e, 0x20, 0x32, 0x31, 0x20, 0x31, 0x39, 0x37, 0x30, // |Wed, Jan 21 1970| + 0x20, 0x30, 0x32, 0x3a, 0x33, 0x37, 0x3a, 0x31, 0x34, 0x20, 0x47, 0x4d, 0x54, 0x0d, 0x0a, 0x50, // | 02:37:14 GMT..P| + 0x75, 0x62, 0x6c, 0x69, 0x63, 0x3a, 0x20, 0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x53, 0x2c, 0x20, // |ublic: OPTIONS, | + 0x44, 0x45, 0x53, 0x43, 0x52, 0x49, 0x42, 0x45, 0x2c, 0x20, 0x53, 0x45, 0x54, 0x55, 0x50, 0x2c, // |DESCRIBE, SETUP,| + 0x20, 0x54, 0x45, 0x41, 0x52, 0x44, 0x4f, 0x57, 0x4e, 0x2c, 0x20, 0x50, 0x4c, 0x41, 0x59, 0x2c, // | TEARDOWN, PLAY,| + 0x20, 0x47, 0x45, 0x54, 0x5f, 0x50, 0x41, 0x52, 0x41, 0x4d, 0x45, 0x54, 0x45, 0x52, 0x2c, 0x20, // | GET_PARAMETER, | + 0x53, 0x45, 0x54, 0x5f, 0x50, 0x41, 0x52, 0x41, 0x4d, 0x45, 0x54, 0x45, 0x52, 0x0d, 0x0a, 0x0d, // |SET_PARAMETER...| + 0x0a, + } + + expect := Response{ + Proto: "RTSP", + ProtoMajor: 1, + ProtoMinor: 0, + StatusCode: 200, + ContentLength: 0, + Header: map[string][]string{ + "Cseq": []string{"2"}, + "Date": []string{"Wed, Jan 21 1970 02:37:14 GMT"}, + "Public": []string{"OPTIONS, DESCRIBE, SETUP, TEARDOWN, PLAY, GET_PARAMETER, SET_PARAMETER"}, + }, + } + + got, err := ReadResponse(bytes.NewReader(input)) + if err != nil { + t.Fatalf("should not have got error: %v", err) + } + + if !respEqual(*got, expect) { + t.Errorf("did not get expected result.\nGot: %+v\n Want: %+v\n", got, expect) + } +} + +// respEqual checks the equality of two Responses. +func respEqual(got, want Response) bool { + for _, f := range [][2]interface{}{ + {got.Proto, want.Proto}, + {got.ProtoMajor, want.ProtoMajor}, + {got.ProtoMinor, want.ProtoMinor}, + {got.StatusCode, want.StatusCode}, + {got.ContentLength, want.ContentLength}, + } { + if f[0] != f[1] { + return false + } + } + + if len(got.Header) != len(want.Header) { + return false + } + + for k, v := range got.Header { + if len(v) != len(want.Header[k]) { + return false + } + + for i, _v := range v { + if _v != want.Header[k][i] { + return false + } + } + } + return true +} diff --git a/revid/config.go b/revid/config.go index b33eb7ad..f2d04193 100644 --- a/revid/config.go +++ b/revid/config.go @@ -74,6 +74,7 @@ const ( // Inputs. Raspivid V4L + RTSP // Outputs. RTMP @@ -83,6 +84,7 @@ const ( // Codecs. H264 + H265 MJPEG ) @@ -127,6 +129,8 @@ type Config struct { // Read from webcam. // File: // Location must be specified in InputPath field. + // RTSP: + // RTSPURL must also be defined. Input uint8 // InputCodec defines the input codec we wish to use, and therefore define the @@ -164,6 +168,10 @@ type Config struct { // RTMP is to be used as an output. RTMPURL string + // RTSPURL specifies the RTSP server URL for RTSP input. This must be defined + // when Input is RTSP. + RTSPURL string + // OutputPath defines the output destination for File output. This must be // defined if File output is to be used. OutputPath string @@ -233,7 +241,7 @@ func (c *Config) Validate(r *Revid) error { } switch c.Input { - case Raspivid, V4L, File: + case Raspivid, V4L, File, RTSP: case NothingDefined: c.Logger.Log(logger.Info, pkg+"no input type defined, defaulting", "input", defaultInput) c.Input = defaultInput diff --git a/revid/revid.go b/revid/revid.go index 88657a2d..836a9045 100644 --- a/revid/revid.go +++ b/revid/revid.go @@ -33,6 +33,7 @@ import ( "errors" "fmt" "io" + "net" "os" "os/exec" "strconv" @@ -40,9 +41,13 @@ import ( "sync" "time" - "bitbucket.org/ausocean/av/codec/lex" + "bitbucket.org/ausocean/av/codec/h264" + "bitbucket.org/ausocean/av/codec/h265" "bitbucket.org/ausocean/av/container/flv" "bitbucket.org/ausocean/av/container/mts" + "bitbucket.org/ausocean/av/protocol/rtcp" + "bitbucket.org/ausocean/av/protocol/rtp" + "bitbucket.org/ausocean/av/protocol/rtsp" "bitbucket.org/ausocean/iot/pi/netsender" "bitbucket.org/ausocean/utils/ioext" "bitbucket.org/ausocean/utils/logger" @@ -60,6 +65,12 @@ const ( rtmpConnectionTimeout = 10 ) +const ( + rtpPort = 60000 + rtcpPort = 60001 + defaultServerRTCPPort = 17301 +) + const pkg = "revid:" type Logger interface { @@ -163,7 +174,14 @@ func (r *Revid) reset(config Config) error { err = r.setupPipeline( func(dst io.WriteCloser, fps int) (io.WriteCloser, error) { - e := mts.NewEncoder(dst, float64(fps), mts.Video) + var st int + switch r.config.Input { + case Raspivid, File, V4L: + st = mts.EncodeH264 + case RTSP: + st = mts.EncodeH265 + } + e := mts.NewEncoder(dst, float64(fps), st) return e, nil }, func(dst io.WriteCloser, fps int) (io.WriteCloser, error) { @@ -262,20 +280,18 @@ func (r *Revid) setupPipeline(mtsEnc, flvEnc func(dst io.WriteCloser, rate int) switch r.config.Input { case Raspivid: r.setupInput = r.startRaspivid + r.lexTo = h264.Lex case V4L: r.setupInput = r.startV4L + r.lexTo = h264.Lex case File: r.setupInput = r.setupInputForFile + r.lexTo = h264.Lex + case RTSP: + r.setupInput = r.startRTSPCamera + r.lexTo = h265.NewLexer(false).Lex } - switch r.config.InputCodec { - case H264: - r.config.Logger.Log(logger.Info, pkg+"using H264 lexer") - r.lexTo = lex.H264 - case MJPEG: - r.config.Logger.Log(logger.Info, pkg+"using MJPEG lexer") - r.lexTo = lex.MJPEG - } return nil } @@ -604,6 +620,108 @@ func (r *Revid) setupInputForFile() (func() error, error) { return func() error { return f.Close() }, nil } +// startRTSPCamera uses RTSP to request an RTP stream from an IP camera. An RTP +// client is created from which RTP packets containing either h264/h265 can read +// by the selected lexer. +func (r *Revid) startRTSPCamera() (func() error, error) { + rtspClt, local, remote, err := rtsp.NewClient(r.config.RTSPURL) + if err != nil { + return nil, err + } + + resp, err := rtspClt.Options() + if err != nil { + return nil, err + } + r.config.Logger.Log(logger.Info, pkg+"RTSP OPTIONS response", "response", resp.String()) + + resp, err = rtspClt.Describe() + if err != nil { + return nil, err + } + r.config.Logger.Log(logger.Info, pkg+"RTSP DESCRIBE response", "response", resp.String()) + + resp, err = rtspClt.Setup("track1", fmt.Sprintf("RTP/AVP;unicast;client_port=%d-%d", rtpPort, rtcpPort)) + if err != nil { + return nil, err + } + r.config.Logger.Log(logger.Info, pkg+"RTSP SETUP response", "response", resp.String()) + rtpCltAddr, rtcpCltAddr, rtcpSvrAddr, err := formAddrs(local, remote, *resp) + if err != nil { + return nil, err + } + + resp, err = rtspClt.Play() + if err != nil { + return nil, err + } + r.config.Logger.Log(logger.Info, pkg+"RTSP server PLAY response", "response", resp.String()) + + rtpClt, err := rtp.NewClient(rtpCltAddr) + if err != nil { + return nil, err + } + + rtcpClt, err := rtcp.NewClient(rtcpCltAddr, rtcpSvrAddr, rtpClt, r.config.Logger.Log) + if err != nil { + return nil, err + } + + // Check errors from RTCP client until it has stopped running. + go func() { + for { + err, ok := <-rtcpClt.Err() + if ok { + r.config.Logger.Log(logger.Warning, pkg+"RTCP error", "error", err.Error()) + } else { + return + } + } + }() + + // Start the RTCP client. + rtcpClt.Start() + + // Start reading data from the RTP client. + r.wg.Add(1) + go r.processFrom(rtpClt, time.Second/time.Duration(r.config.FrameRate)) + + return func() error { + rtspClt.Close() + rtcpClt.Stop() + return nil + }, nil +} + +// formAddrs is a helper function to form the addresses for the RTP client, +// RTCP client, and the RTSP server's RTCP addr using the local, remote addresses +// of the RTSP conn, and the SETUP method response. +func formAddrs(local, remote *net.TCPAddr, setupResp rtsp.Response) (rtpCltAddr, rtcpCltAddr, rtcpSvrAddr string, err error) { + svrRTCPPort, err := parseSvrRTCPPort(setupResp) + if err != nil { + return "", "", "", err + } + rtpCltAddr = strings.Split(local.String(), ":")[0] + ":" + strconv.Itoa(rtpPort) + rtcpCltAddr = strings.Split(local.String(), ":")[0] + ":" + strconv.Itoa(rtcpPort) + rtcpSvrAddr = strings.Split(remote.String(), ":")[0] + ":" + strconv.Itoa(svrRTCPPort) + return +} + +// parseServerRTCPPort is a helper function to get the RTSP server's RTCP port. +func parseSvrRTCPPort(resp rtsp.Response) (int, error) { + transport := resp.Header.Get("Transport") + for _, p := range strings.Split(transport, ";") { + if strings.Contains(p, "server_port") { + port, err := strconv.Atoi(strings.Split(p, "-")[1]) + if err != nil { + return 0, err + } + return port, nil + } + } + return 0, errors.New("SETUP response did not provide RTCP port") +} + func (r *Revid) processFrom(read io.Reader, delay time.Duration) { r.config.Logger.Log(logger.Info, pkg+"reading input data") r.err <- r.lexTo(r.encoders, read, delay) diff --git a/revid/senders_test.go b/revid/senders_test.go index 80293759..b7f67959 100644 --- a/revid/senders_test.go +++ b/revid/senders_test.go @@ -134,7 +134,7 @@ func TestMtsSenderSegment(t *testing.T) { const numberOfClips = 11 dst := &destination{t: t, done: make(chan struct{}), doneAt: numberOfClips} sender := newMtsSender(dst, (*dummyLogger)(t).log, rbSize, rbElementSize, 0) - encoder := mts.NewEncoder(sender, 25, mts.Video) + encoder := mts.NewEncoder(sender, 25, mts.EncodeH264) // Turn time based PSI writing off for encoder. const psiSendCount = 10 @@ -212,7 +212,7 @@ func TestMtsSenderFailedSend(t *testing.T) { const clipToFailAt = 3 dst := &destination{t: t, testFails: true, failAt: clipToFailAt, done: make(chan struct{})} sender := newMtsSender(dst, (*dummyLogger)(t).log, rbSize, rbElementSize, 0) - encoder := mts.NewEncoder(sender, 25, mts.Video) + encoder := mts.NewEncoder(sender, 25, mts.EncodeH264) // Turn time based PSI writing off for encoder and send PSI every 10 packets. const psiSendCount = 10 @@ -292,7 +292,7 @@ func TestMtsSenderDiscontinuity(t *testing.T) { const clipToDelay = 3 dst := &destination{t: t, sendDelay: 10 * time.Millisecond, delayAt: clipToDelay, done: make(chan struct{})} sender := newMtsSender(dst, (*dummyLogger)(t).log, 1, rbElementSize, 0) - encoder := mts.NewEncoder(sender, 25, mts.Video) + encoder := mts.NewEncoder(sender, 25, mts.EncodeH264) // Turn time based PSI writing off for encoder. const psiSendCount = 10