diff --git a/.circleci/config.yml b/.circleci/config.yml index af93f767..b1ae37f2 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -13,6 +13,8 @@ jobs: steps: - checkout + - run: git clone --depth=1 https://bitbucket.org/ausocean/test.git ${GOPATH}/src/bitbucket.org/ausocean/test + - restore_cache: keys: - v1-pkg-cache diff --git a/cmd/revid-cli/main.go b/cmd/revid-cli/main.go index 5b826b91..056687de 100644 --- a/cmd/revid-cli/main.go +++ b/cmd/revid-cli/main.go @@ -37,7 +37,10 @@ import ( "time" "bitbucket.org/ausocean/av/revid" + "bitbucket.org/ausocean/av/stream/mts" + "bitbucket.org/ausocean/av/stream/mts/meta" "bitbucket.org/ausocean/iot/pi/netsender" + "bitbucket.org/ausocean/iot/pi/sds" "bitbucket.org/ausocean/iot/pi/smartlogger" "bitbucket.org/ausocean/utils/logger" ) @@ -47,7 +50,14 @@ const ( progName = "revid-cli" // Logging is set to INFO level. - defaultLogVerbosity = logger.Debug + defaultLogVerbosity = logger.Info +) + +// Revid modes +const ( + normal = "Normal" + paused = "Paused" + burst = "Burst" ) // Other misc consts @@ -65,30 +75,33 @@ var canProfile = true // The logger that will be used throughout var log *logger.Logger +const ( + metaPreambleKey = "copyright" + metaPreambleData = "ausocean.org/license/content2019" +) + func main() { + mts.Meta = meta.NewWith([][2]string{{metaPreambleKey, metaPreambleData}}) + useNetsender := flag.Bool("NetSender", false, "Are we checking vars through netsender?") runDurationPtr := flag.Duration("runDuration", defaultRunDuration, "How long do you want revid to run for?") cfg := handleFlags() - if !*useNetsender { - // run revid for the specified duration - rv, _, err := startRevid(nil, cfg) + rv, err := revid.New(cfg, nil) if err != nil { + cfg.Logger.Log(logger.Fatal, pkg+"failed to initialiase revid", "error", err.Error()) + } + if err = rv.Start(); err != nil { cfg.Logger.Log(logger.Fatal, pkg+"failed to start revid", "error", err.Error()) } time.Sleep(*runDurationPtr) - err = stopRevid(rv) - if err != nil { - cfg.Logger.Log(logger.Error, pkg+"failed to stop revid before program termination", "error", err.Error()) - } + rv.Stop() return } - err := run(nil, cfg) - if err != nil { + if err := run(cfg); err != nil { log.Log(logger.Fatal, pkg+"failed to run revid", "error", err.Error()) - os.Exit(1) } } @@ -105,12 +118,12 @@ func handleFlags() revid.Config { rtmpMethodPtr = flag.String("RtmpMethod", "", "The method used to send over rtmp: Ffmpeg, Librtmp") packetizationPtr = flag.String("Packetization", "", "The method of data packetisation: Flv, Mpegts, None") quantizePtr = flag.Bool("Quantize", false, "Quantize input (non-variable bitrate)") - verbosityPtr = flag.String("Verbosity", "", "Verbosity: Info, Warning, Error, Fatal") + verbosityPtr = flag.String("Verbosity", "Info", "Verbosity: Info, Warning, Error, Fatal") framesPerClipPtr = flag.Uint("FramesPerClip", 0, "Number of frames per clip sent") rtmpUrlPtr = flag.String("RtmpUrl", "", "Url of rtmp endpoint") bitratePtr = flag.Uint("Bitrate", 0, "Bitrate of recorded video") - outputFileNamePtr = flag.String("OutputFileName", "", "The directory of the output file") - inputFileNamePtr = flag.String("InputFileName", "", "The directory of the input file") + outputPathPtr = flag.String("OutputPath", "", "The directory of the output file") + inputFilePtr = flag.String("InputPath", "", "The directory of the input file") heightPtr = flag.Uint("Height", 0, "Height in pixels") widthPtr = flag.Uint("Width", 0, "Width in pixels") frameRatePtr = flag.Uint("FrameRate", 0, "Frame rate of captured video") @@ -122,6 +135,7 @@ func handleFlags() revid.Config { rtpAddrPtr = flag.String("RtpAddr", "", "Rtp destination address: : (port is generally 6970-6999)") logPathPtr = flag.String("LogPath", defaultLogPath, "The log path") configFilePtr = flag.String("ConfigFile", "", "NetSender config file") + sendRetryPtr = flag.Bool("retry", false, "Specify whether a failed send should be retried.") ) var outputs flagStrings @@ -129,7 +143,22 @@ func handleFlags() revid.Config { flag.Parse() - log = logger.New(defaultLogVerbosity, &smartlogger.New(*logPathPtr).LogRoller) + switch *verbosityPtr { + case "Debug": + cfg.LogLevel = logger.Debug + case "Info": + cfg.LogLevel = logger.Info + case "Warning": + cfg.LogLevel = logger.Warning + case "Error": + cfg.LogLevel = logger.Error + case "Fatal": + cfg.LogLevel = logger.Fatal + default: + cfg.LogLevel = defaultLogVerbosity + } + + log = logger.New(cfg.LogLevel, &smartlogger.New(*logPathPtr).LogRoller) cfg.Logger = log @@ -168,6 +197,10 @@ func handleFlags() revid.Config { log.Log(logger.Error, pkg+"bad input codec argument") } + if len(outputs) == 0 { + cfg.Outputs = make([]uint8, 1) + } + for _, o := range outputs { switch o { case "File": @@ -209,17 +242,6 @@ func handleFlags() revid.Config { log.Log(logger.Error, pkg+"bad packetization argument") } - switch *verbosityPtr { - case "No": - cfg.LogLevel = logger.Fatal - case "Debug": - cfg.LogLevel = logger.Debug - //logger.SetLevel(logger.Debug) - case "": - default: - log.Log(logger.Error, pkg+"bad verbosity argument") - } - if *configFilePtr != "" { netsender.ConfigFile = *configFilePtr } @@ -230,8 +252,8 @@ func handleFlags() revid.Config { cfg.FramesPerClip = *framesPerClipPtr cfg.RtmpUrl = *rtmpUrlPtr cfg.Bitrate = *bitratePtr - cfg.OutputFileName = *outputFileNamePtr - cfg.InputFileName = *inputFileNamePtr + cfg.OutputPath = *outputPathPtr + cfg.InputPath = *inputFilePtr cfg.Height = *heightPtr cfg.Width = *widthPtr cfg.FrameRate = *frameRatePtr @@ -239,231 +261,115 @@ func handleFlags() revid.Config { cfg.Quantization = *quantizationPtr cfg.IntraRefreshPeriod = *intraRefreshPeriodPtr cfg.RtpAddress = *rtpAddrPtr + cfg.SendRetry = *sendRetryPtr return cfg } // initialize then run the main NetSender client -func run(rv *revid.Revid, cfg revid.Config) error { - // initialize NetSender and use NetSender's logger - //config.Logger = netsender.Logger() +func run(cfg revid.Config) error { log.Log(logger.Info, pkg+"running in NetSender mode") - var ns netsender.Sender - err := ns.Init(log, nil, nil, nil) + var vars map[string]string + + var rv *revid.Revid + + readPin := func(pin *netsender.Pin) error { + switch { + case pin.Name == "X23": + pin.Value = rv.Bitrate() + case pin.Name[0] == 'X': + return sds.ReadSystem(pin) + default: + pin.Value = -1 + } + return nil // Return error only if we want NetSender to generate an error + } + + ns, err := netsender.New(log, nil, readPin, nil) if err != nil { return err } - vars, _ := ns.Vars() - vs := ns.VarSum() - paused := false - if vars["mode"] == "Paused" { - paused = true + + rv, err = revid.New(cfg, ns) + if err != nil { + log.Log(logger.Fatal, pkg+"could not initialise revid", "error", err.Error()) } - if !paused { - rv, cfg, err = updateRevid(&ns, rv, cfg, vars, false) + + vars, _ = ns.Vars() + vs := ns.VarSum() + + // Update revid to get latest config settings from netreceiver. + err = rv.Update(vars) + if err != nil { + return err + } + + // If mode on netreceiver isn't paused then we can start revid. + if ns.Mode() != paused && ns.Mode() != burst { + err = rv.Start() if err != nil { return err } } + if ns.Mode() == burst { + ns.SetMode(paused, &vs) + } + for { - if err := send(&ns, rv); err != nil { - log.Log(logger.Error, pkg+"polling failed", "error", err.Error()) + err = ns.Run() + if err != nil { + log.Log(logger.Error, pkg+"Run Failed. Retrying...", "error", err.Error()) time.Sleep(netSendRetryTime) continue } - if vs != ns.VarSum() { - // vars changed - vars, err := ns.Vars() - if err != nil { - log.Log(logger.Error, pkg+"netSender failed to get vars", "error", err.Error()) - time.Sleep(netSendRetryTime) - continue - } - vs = ns.VarSum() - if vars["mode"] == "Paused" { - if !paused { - log.Log(logger.Info, pkg+"pausing revid") - err = stopRevid(rv) - if err != nil { - log.Log(logger.Error, pkg+"failed to stop revide", "error", err.Error()) - continue - } - paused = true - } - } else { - rv, cfg, err = updateRevid(&ns, rv, cfg, vars, !paused) - if err != nil { - return err - } - if paused { - paused = false - } - } + // If var sum hasn't change we continue + if vs == ns.VarSum() { + goto sleep + } + + vars, err = ns.Vars() + if err != nil { + log.Log(logger.Error, pkg+"netSender failed to get vars", "error", err.Error()) + time.Sleep(netSendRetryTime) + continue + } + vs = ns.VarSum() + + err = rv.Update(vars) + if err != nil { + return err + } + + switch ns.Mode() { + case paused: + case normal: + err = rv.Start() + if err != nil { + return err + } + case burst: + log.Log(logger.Info, pkg+"Starting burst...") + err = rv.Start() + if err != nil { + return err + } + time.Sleep(time.Duration(rv.Config().BurstPeriod) * time.Second) + log.Log(logger.Info, pkg+"Stopping burst...") + rv.Stop() + ns.SetMode(paused, &vs) + } + sleep: + sleepTime, err := strconv.Atoi(ns.Param("mp")) + if err != nil { + return err } - sleepTime, _ := strconv.Atoi(ns.Param("mp")) time.Sleep(time.Duration(sleepTime) * time.Second) } } -// send implements our main NetSender client and handles NetReceiver configuration -// (as distinct from httpSender which just sends video data). -func send(ns *netsender.Sender, rv *revid.Revid) error { - // populate input values, if any - inputs := netsender.MakePins(ns.Param("ip"), "X") - if rv != nil { - for i, pin := range inputs { - if pin.Name == "X23" { - inputs[i].Value = rv.Bitrate() - } - } - } - - _, reconfig, err := ns.Send(netsender.RequestPoll, inputs) - if err != nil { - return err - } - if reconfig { - return ns.Config() - } - return nil -} - -// wrappers for stopping and starting revid -func startRevid(ns *netsender.Sender, cfg revid.Config) (*revid.Revid, revid.Config, error) { - rv, err := revid.New(cfg, ns) - if err != nil { - return nil, cfg, err - } - err = rv.Start() - return rv, cfg, err -} - -func stopRevid(rv *revid.Revid) error { - err := rv.Stop() - if err != nil { - return err - } - - // FIXME(kortschak): Is this waiting on completion of work? - // Use a wait group and Wait method if it is. - time.Sleep(revidStopTime) - return nil -} - -func updateRevid(ns *netsender.Sender, rv *revid.Revid, cfg revid.Config, vars map[string]string, stop bool) (*revid.Revid, revid.Config, error) { - if stop { - err := stopRevid(rv) - if err != nil { - return nil, cfg, err - } - } - - //look through the vars and update revid where needed - for key, value := range vars { - switch key { - case "Output": - // FIXME(kortschak): There can be only one! - // How do we specify outputs after the first? - // - // Maybe we shouldn't be doing this! - switch value { - case "File": - cfg.Outputs[0] = revid.File - case "Http": - cfg.Outputs[0] = revid.Http - case "Rtmp": - cfg.Outputs[0] = revid.Rtmp - case "FfmpegRtmp": - cfg.Outputs[0] = revid.FfmpegRtmp - default: - log.Log(logger.Warning, pkg+"invalid Output1 param", "value", value) - continue - } - case "FramesPerClip": - f, err := strconv.ParseUint(value, 10, 0) - if err != nil { - log.Log(logger.Warning, pkg+"invalid framesperclip param", "value", value) - break - } - cfg.FramesPerClip = uint(f) - case "RtmpUrl": - cfg.RtmpUrl = value - case "Bitrate": - r, err := strconv.ParseUint(value, 10, 0) - if err != nil { - log.Log(logger.Warning, pkg+"invalid framerate param", "value", value) - break - } - cfg.Bitrate = uint(r) - case "OutputFileName": - cfg.OutputFileName = value - case "InputFileName": - cfg.InputFileName = value - case "Height": - h, err := strconv.ParseUint(value, 10, 0) - if err != nil { - log.Log(logger.Warning, pkg+"invalid height param", "value", value) - break - } - cfg.Height = uint(h) - case "Width": - w, err := strconv.ParseUint(value, 10, 0) - if err != nil { - log.Log(logger.Warning, pkg+"invalid width param", "value", value) - break - } - cfg.Width = uint(w) - case "FrameRate": - r, err := strconv.ParseUint(value, 10, 0) - if err != nil { - log.Log(logger.Warning, pkg+"invalid framerate param", "value", value) - break - } - cfg.FrameRate = uint(r) - case "HttpAddress": - cfg.HttpAddress = value - case "Quantization": - q, err := strconv.ParseUint(value, 10, 0) - if err != nil { - log.Log(logger.Warning, pkg+"invalid quantization param", "value", value) - break - } - cfg.Quantization = uint(q) - case "IntraRefreshPeriod": - p, err := strconv.ParseUint(value, 10, 0) - if err != nil { - log.Log(logger.Warning, pkg+"invalid intrarefreshperiod param", "value", value) - break - } - cfg.IntraRefreshPeriod = uint(p) - case "HorizontalFlip": - switch strings.ToLower(value) { - case "true": - cfg.FlipHorizontal = true - case "false": - cfg.FlipHorizontal = false - default: - log.Log(logger.Warning, pkg+"invalid HorizontalFlip param", "value", value) - } - case "VerticalFlip": - switch strings.ToLower(value) { - case "true": - cfg.FlipVertical = true - case "false": - cfg.FlipVertical = false - default: - log.Log(logger.Warning, pkg+"invalid VerticalFlip param", "value", value) - } - default: - } - } - - return startRevid(ns, cfg) -} - // flagStrings implements an appending string set flag. type flagStrings []string diff --git a/exp/flac/decode.go b/exp/flac/decode.go new file mode 100644 index 00000000..34d42057 --- /dev/null +++ b/exp/flac/decode.go @@ -0,0 +1,144 @@ +/* +NAME + decode.go + +DESCRIPTION + decode.go provides functionality for the decoding of FLAC compressed audio + +AUTHOR + Saxon Nelson-Milton + +LICENSE + decode.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ +package flac + +import ( + "bytes" + "errors" + "io" + + "github.com/go-audio/audio" + "github.com/go-audio/wav" + "github.com/mewkiz/flac" +) + +const wavFormat = 1 + +// writeSeeker implements a memory based io.WriteSeeker. +type writeSeeker struct { + buf []byte + pos int +} + +// Bytes returns the bytes contained in the writeSeekers buffer. +func (ws *writeSeeker) Bytes() []byte { + return ws.buf +} + +// Write writes len(p) bytes from p to the writeSeeker's buf and returns the number +// of bytes written. If less than len(p) bytes are written, an error is returned. +func (ws *writeSeeker) Write(p []byte) (n int, err error) { + minCap := ws.pos + len(p) + if minCap > cap(ws.buf) { // Make sure buf has enough capacity: + buf2 := make([]byte, len(ws.buf), minCap+len(p)) // add some extra + copy(buf2, ws.buf) + ws.buf = buf2 + } + if minCap > len(ws.buf) { + ws.buf = ws.buf[:minCap] + } + copy(ws.buf[ws.pos:], p) + ws.pos += len(p) + return len(p), nil +} + +// Seek sets the offset for the next Read or Write to offset, interpreted according +// to whence: SeekStart means relative to the start of the file, SeekCurrent means +// relative to the current offset, and SeekEnd means relative to the end. Seek returns +// the new offset relative to the start of the file and an error, if any. +func (ws *writeSeeker) Seek(offset int64, whence int) (int64, error) { + newPos, offs := 0, int(offset) + switch whence { + case io.SeekStart: + newPos = offs + case io.SeekCurrent: + newPos = ws.pos + offs + case io.SeekEnd: + newPos = len(ws.buf) + offs + } + if newPos < 0 { + return 0, errors.New("negative result pos") + } + ws.pos = newPos + return int64(newPos), nil +} + +// Decode takes buf, a slice of FLAC, and decodes to WAV. If complete decoding +// fails, an error is returned. +func Decode(buf []byte) ([]byte, error) { + + // Lex the FLAC into a stream to hold audio and it's properties. + r := bytes.NewReader(buf) + stream, err := flac.Parse(r) + if err != nil { + return nil, errors.New("Could not parse FLAC") + } + + // Create WAV encoder and pass writeSeeker that will store output WAV. + ws := &writeSeeker{} + sr := int(stream.Info.SampleRate) + bps := int(stream.Info.BitsPerSample) + nc := int(stream.Info.NChannels) + enc := wav.NewEncoder(ws, sr, bps, nc, wavFormat) + defer enc.Close() + + // Decode FLAC into frames of samples + intBuf := &audio.IntBuffer{ + Format: &audio.Format{NumChannels: nc, SampleRate: sr}, + SourceBitDepth: bps, + } + return decodeFrames(stream, intBuf, enc, ws) +} + +// decodeFrames parses frames from the stream and encodes them into WAV until +// the end of the stream is reached. The bytes from writeSeeker buffer are then +// returned. If any errors occur during encodeing, nil bytes and the error is returned. +func decodeFrames(s *flac.Stream, intBuf *audio.IntBuffer, e *wav.Encoder, ws *writeSeeker) ([]byte, error) { + var data []int + for { + frame, err := s.ParseNext() + + // If we've reached the end of the stream then we can output the writeSeeker's buffer. + if err == io.EOF { + return ws.Bytes(), nil + } else if err != nil { + return nil, err + } + + // Encode WAV audio samples. + data = data[:0] + for i := 0; i < frame.Subframes[0].NSamples; i++ { + for _, subframe := range frame.Subframes { + data = append(data, int(subframe.Samples[i])) + } + } + intBuf.Data = data + if err := e.Write(intBuf); err != nil { + return nil, err + } + } +} diff --git a/exp/flac/flac_test.go b/exp/flac/flac_test.go new file mode 100644 index 00000000..1f8019e5 --- /dev/null +++ b/exp/flac/flac_test.go @@ -0,0 +1,121 @@ +/* +NAME + flac_test.go + +DESCRIPTION + flac_test.go provides utilities to test FLAC audio decoding + +AUTHOR + Saxon Nelson-Milton + +LICENSE + flac_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ +package flac + +import ( + "io" + "io/ioutil" + "os" + "testing" +) + +const ( + testFile = "../../../test/test-data/av/input/robot.flac" + outFile = "testOut.wav" +) + +// TestWriteSeekerWrite checks that basic writing to the ws works as expected. +func TestWriteSeekerWrite(t *testing.T) { + ws := &writeSeeker{} + + const tstStr1 = "hello" + ws.Write([]byte(tstStr1)) + got := string(ws.buf) + if got != tstStr1 { + t.Errorf("Write failed, got: %v, want: %v", got, tstStr1) + } + + const tstStr2 = " world" + const want = "hello world" + ws.Write([]byte(tstStr2)) + got = string(ws.buf) + if got != want { + t.Errorf("Second write failed, got: %v, want: %v", got, want) + } +} + +// TestWriteSeekerSeek checks that writing and seeking works as expected, i.e. we +// can write, then seek to a knew place in the buf, and write again, either replacing +// bytes, or appending bytes. +func TestWriteSeekerSeek(t *testing.T) { + ws := &writeSeeker{} + + const tstStr1 = "hello" + want1 := tstStr1 + ws.Write([]byte(tstStr1)) + got := string(ws.buf) + if got != tstStr1 { + t.Errorf("Unexpected output, got: %v, want: %v", got, want1) + } + + const tstStr2 = " world" + const want2 = tstStr1 + tstStr2 + ws.Write([]byte(tstStr2)) + got = string(ws.buf) + if got != want2 { + t.Errorf("Unexpected output, got: %v, want: %v", got, want2) + } + + const tstStr3 = "k!" + const want3 = "hello work!" + ws.Seek(-2, io.SeekEnd) + ws.Write([]byte(tstStr3)) + got = string(ws.buf) + if got != want3 { + t.Errorf("Unexpected output, got: %v, want: %v", got, want3) + } + + const tstStr4 = "gopher" + const want4 = "hello gopher" + ws.Seek(6, io.SeekStart) + ws.Write([]byte(tstStr4)) + got = string(ws.buf) + if got != want4 { + t.Errorf("Unexpected output, got: %v, want: %v", got, want4) + } +} + +// TestDecodeFlac checks that we can load a flac file and decode to wav, writing +// to a wav file. +func TestDecodeFlac(t *testing.T) { + b, err := ioutil.ReadFile(testFile) + if err != nil { + t.Fatalf("Could not read test file, failed with err: %v", err.Error()) + } + out, err := Decode(b) + if err != nil { + t.Errorf("Could not decode, failed with err: %v", err.Error()) + } + f, err := os.Create(outFile) + if err != nil { + t.Fatalf("Could not create output file, failed with err: %v", err.Error()) + } + _, err = f.Write(out) + if err != nil { + t.Fatalf("Could not write to output file, failed with err: %v", err.Error()) + } +} diff --git a/exp/ts-repair/main.go b/exp/ts-repair/main.go new file mode 100644 index 00000000..97c350f6 --- /dev/null +++ b/exp/ts-repair/main.go @@ -0,0 +1,262 @@ +/* +NAME + ts-repair/main.go + +DESCRIPTION + This program attempts to repair mpegts discontinuities using one of two methods + as selected by the mode flag. Setting the mode flag to 0 will result in repair + by shifting all CCs such that they are continuous. Setting the mode flag to 1 + will result in repair through setting the discontinuity indicator to true at + packets where a discontinuity exists. + + Specify the input file with the in flag, and the output file with out flag. + +AUTHOR + Saxon A. Nelson-Milton + +LICENSE + mpegts.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see [GNU licenses](http://www.gnu.org/licenses). +*/ + +package main + +import ( + "errors" + "flag" + "fmt" + "io" + "os" + + "bitbucket.org/ausocean/av/stream/mts" + "github.com/Comcast/gots/packet" +) + +const ( + PatPid = 0 + PmtPid = 4096 + VideoPid = 256 + HeadSize = 4 + DefaultAdaptationSize = 2 + AdaptationIdx = 4 + AdaptationControlIdx = 3 + AdaptationBodyIdx = AdaptationIdx + 1 + AdaptationControlMask = 0x30 + DefaultAdaptationBodySize = 1 + DiscontinuityIndicatorMask = 0x80 + DiscontinuityIndicatorIdx = AdaptationIdx + 1 +) + +// Various errors that we can encounter. +const ( + errBadInPath = "No file path provided, or file does not exist" + errCantCreateOut = "Can't create output file" + errCantGetPid = "Can't get pid from packet" + errReadFail = "Read failed" + errWriteFail = "Write to file failed" + errBadMode = "Bad fix mode" + errAdaptationPresent = "Adaptation field is already present in packet" + errNoAdaptationField = "No adaptation field in this packet" +) + +// Consts describing flag usage. +const ( + inUsage = "The path to the file to be repaired" + outUsage = "Output file path" + modeUsage = "Fix mode: 0 = cc-shift, 1 = di-update" +) + +// Repair modes. +const ( + ccShift = iota + diUpdate +) + +var ccMap = map[int]byte{ + PatPid: 16, + PmtPid: 16, + VideoPid: 16, +} + +// packetNo will keep track of the ts packet number for reference. +var packetNo int + +// Option defines a func that performs an action on p in order to change a ts option. +type Option func(p *Packet) + +// Packet is a byte array of size PacketSize i.e. 188 bytes. We define this +// to allow us to write receiver funcs for the [PacketSize]byte type. +type Packet [mts.PacketSize]byte + +// CC returns the CC of p. +func (p *Packet) CC() byte { + return (*p)[3] & 0x0f +} + +// setCC sets the CC of p. +func (p *Packet) setCC(cc byte) { + (*p)[3] |= cc & 0xf +} + +// setDI sets the discontinuity counter of p. +func (p *Packet) setDI(di bool) { + if di { + p[5] |= 0x80 + } else { + p[5] &= 0x7f + } +} + +// addAdaptationField adds an adaptation field to p, and applys the passed options to this field. +// TODO: this will probably break if we already have adaptation field. +func (p *Packet) addAdaptationField(options ...Option) error { + if p.hasAdaptation() { + return errors.New(errAdaptationPresent) + } + // Create space for adaptation field. + copy(p[HeadSize+DefaultAdaptationSize:], p[HeadSize:len(p)-DefaultAdaptationSize]) + + // TODO: seperate into own function + // Update adaptation field control. + p[AdaptationControlIdx] &= 0xff ^ AdaptationControlMask + p[AdaptationControlIdx] |= AdaptationControlMask + // Default the adaptationfield. + p.resetAdaptation() + + // Apply and options that have bee passed. + for _, option := range options { + option(p) + } + return nil +} + +// resetAdaptation sets fields in ps adaptation field to 0 if the adaptation field +// exists, otherwise an error is returned. +func (p *Packet) resetAdaptation() error { + if !p.hasAdaptation() { + return errors.New(errNoAdaptationField) + } + p[AdaptationIdx] = DefaultAdaptationBodySize + p[AdaptationBodyIdx] = 0x00 + return nil +} + +// hasAdaptation returns true if p has an adaptation field and false otherwise. +func (p *Packet) hasAdaptation() bool { + afc := p[AdaptationControlIdx] & AdaptationControlMask + if afc == 0x20 || afc == 0x30 { + return true + } else { + return false + } +} + +// DiscontinuityIndicator returns and Option that will set p's discontinuity +// indicator according to f. +func DiscontinuityIndicator(f bool) Option { + return func(p *Packet) { + set := byte(DiscontinuityIndicatorMask) + if !f { + set = 0x00 + } + p[DiscontinuityIndicatorIdx] &= 0xff ^ DiscontinuityIndicatorMask + p[DiscontinuityIndicatorIdx] |= DiscontinuityIndicatorMask & set + } +} + +func main() { + // Deal with input flags + inPtr := flag.String("in", "", inUsage) + outPtr := flag.String("out", "out.ts", outUsage) + modePtr := flag.Int("mode", diUpdate, modeUsage) + flag.Parse() + + // Try and open the given input file, otherwise panic - we can't do anything + inFile, err := os.Open(*inPtr) + defer inFile.Close() + if err != nil { + panic(errBadInPath) + } + + // Try and create output file, otherwise panic - we can't do anything + outFile, err := os.Create(*outPtr) + defer outFile.Close() + if err != nil { + panic(errCantCreateOut) + } + + // Read each packet from the input file reader + var p Packet + for { + // If we get an end of file then return, otherwise we panic - can't do anything else + if _, err := inFile.Read(p[:mts.PacketSize]); err == io.EOF { + return + } else if err != nil { + panic(errReadFail + ": " + err.Error()) + } + packetNo++ + + // Get the pid from the packet + pid := packet.Pid((*packet.Packet)(&p)) + + // Get the cc from the packet and also the expected cc (if exists) + cc := p.CC() + expect, exists := expectedCC(int(pid)) + if !exists { + updateCCMap(int(pid), cc) + } else { + switch *modePtr { + // ccShift mode shifts all CC regardless of presence of Discontinuities or not + case ccShift: + p.setCC(expect) + // diUpdate mode finds discontinuities and sets the discontinuity indicator to true. + // If we have a pat or pmt then we need to add an adaptation field and then set the DI. + case diUpdate: + if cc != expect { + fmt.Printf("***** Discontinuity found (packetNo: %v pid: %v, cc: %v, expect: %v)\n", packetNo, pid, cc, expect) + if p.hasAdaptation() { + p.setDI(true) + } else { + p.addAdaptationField(DiscontinuityIndicator(true)) + } + updateCCMap(int(pid), p.CC()) + } + default: + panic(errBadMode) + } + } + + // Write this packet to the output file. + if _, err := outFile.Write(p[:]); err != nil { + panic(errWriteFail + ": " + err.Error()) + } + } +} + +// expectedCC returns the expected cc for the given pid. If the cc hasn't been +// used yet, then 16 and false is returned. +func expectedCC(pid int) (byte, bool) { + cc := ccMap[pid] + if cc == 16 { + return 16, false + } + ccMap[pid] = (cc + 1) & 0xf + return cc, true +} + +// updateCCMap updates the cc for the passed pid. +func updateCCMap(pid int, cc byte) { + ccMap[pid] = (cc + 1) & 0xf +} diff --git a/revid/cmd/h264-file-to-flv-rtmp/main.go b/revid/cmd/h264-file-to-flv-rtmp/main.go deleted file mode 100644 index 4f7c9d7c..00000000 --- a/revid/cmd/h264-file-to-flv-rtmp/main.go +++ /dev/null @@ -1,75 +0,0 @@ -/* -NAME - main.go - -DESCRIPTION - See Readme.md - -AUTHOR - Saxon Nelson-Milton - -LICENSE - main.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) - - It is free software: you can redistribute it and/or modify them - under the terms of the GNU General Public License as published by the - Free Software Foundation, either version 3 of the License, or (at your - option) any later version. - - It is distributed in the hope that it will be useful, but WITHOUT - ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - for more details. - - You should have received a copy of the GNU General Public License - along with revid in gpl.txt. If not, see [GNU licenses](http://www.gnu.org/licenses). -*/ - -package main - -import ( - "flag" - "log" - "time" - - "bitbucket.org/ausocean/av/revid" - "bitbucket.org/ausocean/iot/pi/smartlogger" - "bitbucket.org/ausocean/utils/logger" -) - -const ( - inputFile = "../../../../test/test-data/av/input/betterInput.h264" - frameRate = "25" - runDuration = 120 * time.Second - logPath = "/var/log" -) - -// Test h264 inputfile to flv format into rtmp using librtmp c wrapper -func main() { - // Get the rtmp url from a cmd flag - rtmpUrlPtr := flag.String("rtmpUrl", "", "The rtmp url you would like to stream to.") - flag.Parse() - if *rtmpUrlPtr == "" { - log.Println("No RTMP url passed!") - return - } - - config := revid.Config{ - Input: revid.File, - InputFileName: inputFile, - InputCodec: revid.H264, - Outputs: []byte{revid.Rtmp}, - RtmpMethod: revid.LibRtmp, - RtmpUrl: *rtmpUrlPtr, - Packetization: revid.Flv, - Logger: logger.New(logger.Info, &smartlogger.New(logPath).LogRoller), - } - revidInst, err := revid.New(config, nil) - if err != nil { - config.Logger.Log(logger.Error, "Should not have got an error!: ", err.Error()) - return - } - revidInst.Start() - time.Sleep(runDuration) - revidInst.Stop() -} diff --git a/revid/cmd/h264-file-to-mpgets-file/main.go b/revid/cmd/h264-file-to-mpgets-file/main.go deleted file mode 100644 index 768f560b..00000000 --- a/revid/cmd/h264-file-to-mpgets-file/main.go +++ /dev/null @@ -1,66 +0,0 @@ -/* -NAME - main.go - -DESCRIPTION - See Readme.md - -AUTHOR - Saxon Nelson-Milton - -LICENSE - main.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean) - - It is free software: you can redistribute it and/or modify them - under the terms of the GNU General Public License as published by the - Free Software Foundation, either version 3 of the License, or (at your - option) any later version. - - It is distributed in the hope that it will be useful, but WITHOUT - ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or - FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License - for more details. - - You should have received a copy of the GNU General Public License - along with revid in gpl.txt. If not, see [GNU licenses](http://www.gnu.org/licenses). -*/ - -package main - -import ( - "time" - - "bitbucket.org/ausocean/av/revid" - "bitbucket.org/ausocean/iot/pi/smartlogger" - "bitbucket.org/ausocean/utils/logger" -) - -const ( - inputFile = "../../../../test/test-data/av/input/betterInput.h264" - outputFile = "output.ts" - frameRate = "25" - runDuration = 120 * time.Second - logPath = "/var/log" -) - -// Test h264 inputfile to flv format into rtmp using librtmp c wrapper -func main() { - - config := revid.Config{ - Input: revid.File, - InputFileName: inputFile, - InputCodec: revid.H264, - Outputs: []byte{revid.File}, - OutputFileName: outputFile, - Packetization: revid.Mpegts, - Logger: logger.New(logger.Info, &smartlogger.New(logPath).LogRoller), - } - revidInst, err := revid.New(config, nil) - if err != nil { - config.Logger.Log(logger.Error, "Should not have got an error!:", err.Error()) - return - } - revidInst.Start() - time.Sleep(runDuration) - revidInst.Stop() -} diff --git a/revid/config.go b/revid/config.go index dc9c5a8d..2cafc52d 100644 --- a/revid/config.go +++ b/revid/config.go @@ -57,8 +57,8 @@ type Config struct { FramesPerClip uint RtmpUrl string Bitrate uint - OutputFileName string - InputFileName string + OutputPath string + InputPath string Height uint Width uint FrameRate uint @@ -68,6 +68,7 @@ type Config struct { RtpAddress string Logger Logger SendRetry bool + BurstPeriod uint } // Enums for config struct @@ -114,6 +115,7 @@ const ( defaultInputCodec = H264 defaultVerbosity = No // FIXME(kortschak): This makes no sense whatsoever. No is currently 15. defaultRtpAddr = "localhost:6970" + defaultBurstPeriod = 10 // Seconds ) // Validate checks for any errors in the config fields and defaults settings @@ -146,11 +148,9 @@ func (c *Config) Validate(r *Revid) error { // Configuration really needs to be rethought here. if c.Quantize && c.Quantization == 0 { c.Quantization = defaultQuantization - } else { - c.Bitrate = defaultBitrate } - if (c.Bitrate > 0 && c.Quantization > 0) || (c.Bitrate == 0 && c.Quantization == 0) { + if (c.Bitrate > 0 && c.Quantize) || (c.Bitrate == 0 && !c.Quantize) { return errors.New("bad bitrate and quantization combination for H264 input") } @@ -200,6 +200,11 @@ func (c *Config) Validate(r *Revid) error { } } + if c.BurstPeriod == 0 { + c.Logger.Log(logger.Warning, pkg+"no burst period defined, defaulting", "burstPeriod", defaultBurstPeriod) + c.BurstPeriod = defaultBurstPeriod + } + if c.FramesPerClip < 1 { c.Logger.Log(logger.Warning, pkg+"no FramesPerClip defined, defaulting", "framesPerClip", defaultFramesPerClip) diff --git a/revid/revid.go b/revid/revid.go index 6833ec2b..474280ef 100644 --- a/revid/revid.go +++ b/revid/revid.go @@ -51,7 +51,7 @@ import ( // Ring buffer sizes and read/write timeouts. const ( - ringBufferSize = 10000 + ringBufferSize = 1000 ringBufferElementSize = 150000 writeTimeout = 10 * time.Millisecond readTimeout = 10 * time.Millisecond @@ -119,10 +119,11 @@ type Revid struct { // bitrate hold the last send bitrate calculation result. bitrate int - // isRunning is a loaded and cocked foot-gun. mu sync.Mutex isRunning bool + wg sync.WaitGroup + err chan error } @@ -144,22 +145,27 @@ func (p *packer) Write(frame []byte) (int, error) { p.owner.config.Logger.Log(logger.Warning, pkg+"frame was too big", "frame size", len(frame)) return len(frame), nil } - n, err := p.owner.buffer.Write(frame) + + if len(p.owner.destination) != 0 { + n, err := p.owner.buffer.Write(frame) + if err != nil { + if err == ring.ErrDropped { + p.owner.config.Logger.Log(logger.Warning, pkg+"dropped frame", "frame size", len(frame)) + return len(frame), nil + } + p.owner.config.Logger.Log(logger.Error, pkg+"unexpected ring buffer write error", "error", err.Error()) + return n, err + } + } + // If we have an rtp sender bypass ringbuffer and give straight to sender if p.owner.rtpSender != nil { - err = p.owner.rtpSender.send(frame) + err := p.owner.rtpSender.send(frame) if err != nil { p.owner.config.Logger.Log(logger.Error, pkg+"rtp send failed with error", "error", err.Error()) } } - if err != nil { - if err == ring.ErrDropped { - p.owner.config.Logger.Log(logger.Warning, pkg+"dropped frame", "frame size", len(frame)) - return len(frame), nil - } - p.owner.config.Logger.Log(logger.Error, pkg+"unexpected ring buffer write error", "error", err.Error()) - return n, err - } + p.packetCount++ var hasRtmp bool for _, d := range p.owner.config.Outputs { @@ -181,7 +187,6 @@ func (p *packer) Write(frame []byte) (int, error) { // an error if construction of the new instance was not successful. func New(c Config, ns *netsender.Sender) (*Revid, error) { r := Revid{ns: ns, err: make(chan error)} - r.buffer = ring.NewBuffer(ringBufferSize, ringBufferElementSize, writeTimeout) r.packer.owner = &r err := r.reset(c) if err != nil { @@ -197,13 +202,10 @@ func (r *Revid) handleErrors() { err := <-r.err if err != nil { r.config.Logger.Log(logger.Error, pkg+"async error", "error", err.Error()) - err = r.Stop() - if err != nil { - r.config.Logger.Log(logger.Fatal, pkg+"failed to stop", "error", err.Error()) - } + r.Stop() err = r.Start() if err != nil { - r.config.Logger.Log(logger.Fatal, pkg+"failed to restart", "error", err.Error()) + r.config.Logger.Log(logger.Error, pkg+"failed to restart revid", "error", err.Error()) } } } @@ -224,20 +226,13 @@ func (r *Revid) reset(config Config) error { } r.config = config - for _, dest := range r.destination { - if dest != nil { - err = dest.close() - if err != nil { - return err - } - } - } + r.buffer = ring.NewBuffer(ringBufferSize, ringBufferElementSize, writeTimeout) r.destination = r.destination[:0] for _, typ := range r.config.Outputs { switch typ { case File: - s, err := newFileSender(config.OutputFileName) + s, err := newFileSender(config.OutputPath) if err != nil { return err } @@ -326,7 +321,14 @@ func (r *Revid) IsRunning() bool { return ret } -// setIsRunning sets revid.isRunning using b. +func (r *Revid) Config() Config { + r.mu.Lock() + cfg := r.config + r.mu.Unlock() + return cfg +} + +// setIsRunning sets r.isRunning using b. func (r *Revid) setIsRunning(b bool) { r.mu.Lock() r.isRunning = b @@ -337,12 +339,15 @@ func (r *Revid) setIsRunning(b bool) { // and packetising (if theres packetization) to a defined output. func (r *Revid) Start() error { if r.IsRunning() { - return errors.New(pkg + "start called but revid is already running") + r.config.Logger.Log(logger.Warning, pkg+"start called, but revid already running") + return nil } r.config.Logger.Log(logger.Info, pkg+"starting Revid") + // TODO: this doesn't need to be here r.config.Logger.Log(logger.Debug, pkg+"setting up output") r.setIsRunning(true) r.config.Logger.Log(logger.Info, pkg+"starting output routine") + r.wg.Add(1) go r.outputClips() r.config.Logger.Log(logger.Info, pkg+"setting up input and receiving content") err := r.setupInput() @@ -350,9 +355,10 @@ func (r *Revid) Start() error { } // Stop halts any processing of video data from a camera or file -func (r *Revid) Stop() error { +func (r *Revid) Stop() { if !r.IsRunning() { - return errors.New(pkg + "stop called but revid is already stopped") + r.config.Logger.Log(logger.Warning, pkg+"stop called but revid isn't running") + return } r.config.Logger.Log(logger.Info, pkg+"stopping revid") @@ -363,12 +369,138 @@ func (r *Revid) Stop() error { if r.cmd != nil && r.cmd.Process != nil { r.cmd.Process.Kill() } - return nil + r.wg.Wait() +} + +func (r *Revid) Update(vars map[string]string) error { + if r.IsRunning() { + r.Stop() + } + //look through the vars and update revid where needed + for key, value := range vars { + switch key { + case "Output": + r.config.Outputs = make([]uint8, 1) + // FIXME(kortschak): There can be only one! + // How do we specify outputs after the first? + // + // Maybe we shouldn't be doing this! + switch value { + case "File": + r.config.Outputs[0] = File + case "Http": + r.config.Outputs[0] = Http + case "Rtmp": + r.config.Outputs[0] = Rtmp + case "FfmpegRtmp": + r.config.Outputs[0] = FfmpegRtmp + default: + r.config.Logger.Log(logger.Warning, pkg+"invalid Output1 param", "value", value) + continue + } + + case "Packetization": + switch value { + case "Mpegts": + r.config.Packetization = Mpegts + case "Flv": + r.config.Packetization = Flv + default: + r.config.Logger.Log(logger.Warning, pkg+"invalid packetization param", "value", value) + continue + } + case "FramesPerClip": + f, err := strconv.ParseUint(value, 10, 0) + if err != nil { + r.config.Logger.Log(logger.Warning, pkg+"invalid framesperclip param", "value", value) + break + } + r.config.FramesPerClip = uint(f) + case "RtmpUrl": + r.config.RtmpUrl = value + case "Bitrate": + v, err := strconv.ParseUint(value, 10, 0) + if err != nil { + r.config.Logger.Log(logger.Warning, pkg+"invalid framerate param", "value", value) + break + } + r.config.Bitrate = uint(v) + case "OutputPath": + r.config.OutputPath = value + case "InputPath": + r.config.InputPath = value + case "Height": + h, err := strconv.ParseUint(value, 10, 0) + if err != nil { + r.config.Logger.Log(logger.Warning, pkg+"invalid height param", "value", value) + break + } + r.config.Height = uint(h) + case "Width": + w, err := strconv.ParseUint(value, 10, 0) + if err != nil { + r.config.Logger.Log(logger.Warning, pkg+"invalid width param", "value", value) + break + } + r.config.Width = uint(w) + case "FrameRate": + v, err := strconv.ParseUint(value, 10, 0) + if err != nil { + r.config.Logger.Log(logger.Warning, pkg+"invalid framerate param", "value", value) + break + } + r.config.FrameRate = uint(v) + case "HttpAddress": + r.config.HttpAddress = value + case "Quantization": + q, err := strconv.ParseUint(value, 10, 0) + if err != nil { + r.config.Logger.Log(logger.Warning, pkg+"invalid quantization param", "value", value) + break + } + r.config.Quantization = uint(q) + case "IntraRefreshPeriod": + p, err := strconv.ParseUint(value, 10, 0) + if err != nil { + r.config.Logger.Log(logger.Warning, pkg+"invalid intrarefreshperiod param", "value", value) + break + } + r.config.IntraRefreshPeriod = uint(p) + case "HorizontalFlip": + switch strings.ToLower(value) { + case "true": + r.config.FlipHorizontal = true + case "false": + r.config.FlipHorizontal = false + default: + r.config.Logger.Log(logger.Warning, pkg+"invalid HorizontalFlip param", "value", value) + } + case "VerticalFlip": + switch strings.ToLower(value) { + case "true": + r.config.FlipVertical = true + case "false": + r.config.FlipVertical = false + default: + r.config.Logger.Log(logger.Warning, pkg+"invalid VerticalFlip param", "value", value) + } + case "BurstPeriod": + v, err := strconv.ParseUint(value, 10, 0) + if err != nil { + r.config.Logger.Log(logger.Warning, pkg+"invalid BurstPeriod param", "value", value) + break + } + r.config.BurstPeriod = uint(v) + } + } + r.config.Logger.Log(logger.Info, pkg+"revid config changed", "config", fmt.Sprintf("%+v", r.config)) + return r.reset(r.config) } // outputClips takes the clips produced in the packClips method and outputs them // to the desired output defined in the revid config func (r *Revid) outputClips() { + defer r.wg.Done() lastTime := time.Now() var count int loop: @@ -379,7 +511,7 @@ loop: case nil: // Do nothing. case ring.ErrTimeout: - r.config.Logger.Log(logger.Warning, pkg+"ring buffer read timeout") + r.config.Logger.Log(logger.Debug, pkg+"ring buffer read timeout") continue default: r.config.Logger.Log(logger.Error, pkg+"unexpected error", "error", err.Error()) @@ -507,6 +639,7 @@ func (r *Revid) startRaspivid() error { r.config.Logger.Log(logger.Fatal, pkg+"cannot start raspivid", "error", err.Error()) } + r.wg.Add(1) go r.processFrom(stdout, 0) return nil } @@ -515,13 +648,13 @@ func (r *Revid) startV4L() error { const defaultVideo = "/dev/video0" r.config.Logger.Log(logger.Info, pkg+"starting webcam") - if r.config.InputFileName == "" { + if r.config.InputPath == "" { r.config.Logger.Log(logger.Info, pkg+"using default video device", "device", defaultVideo) - r.config.InputFileName = defaultVideo + r.config.InputPath = defaultVideo } args := []string{ - "-i", r.config.InputFileName, + "-i", r.config.InputPath, "-f", "h264", "-r", fmt.Sprint(r.config.FrameRate), } @@ -555,13 +688,14 @@ func (r *Revid) startV4L() error { return err } + r.wg.Add(1) go r.processFrom(stdout, time.Duration(0)) return nil } // setupInputForFile sets things up for getting input from a file func (r *Revid) setupInputForFile() error { - f, err := os.Open(r.config.InputFileName) + f, err := os.Open(r.config.InputPath) if err != nil { r.config.Logger.Log(logger.Error, err.Error()) r.Stop() @@ -570,6 +704,7 @@ func (r *Revid) setupInputForFile() error { defer f.Close() // TODO(kortschak): Maybe we want a context.Context-aware parser that we can stop. + r.wg.Add(1) go r.processFrom(f, time.Second/time.Duration(r.config.FrameRate)) return nil } @@ -578,4 +713,5 @@ func (r *Revid) processFrom(read io.Reader, delay time.Duration) { r.config.Logger.Log(logger.Info, pkg+"reading input data") r.err <- r.lexTo(r.encoder, read, delay) r.config.Logger.Log(logger.Info, pkg+"finished reading input data") + r.wg.Done() } diff --git a/revid/senders.go b/revid/senders.go index a73523d5..9d326471 100644 --- a/revid/senders.go +++ b/revid/senders.go @@ -35,6 +35,7 @@ import ( "net" "os" "os/exec" + "strconv" "bitbucket.org/ausocean/av/rtmp" "bitbucket.org/ausocean/av/stream/mts" @@ -172,7 +173,7 @@ func (s *httpSender) extractMeta(r string) error { s.log(logger.Warning, pkg+"No timestamp in reply") } else { s.log(logger.Debug, fmt.Sprintf("%v got timestamp: %v", pkg, t)) - mts.MetaData.SetTimeStamp(uint64(t)) + mts.Meta.Add("ts", strconv.Itoa(t)) } // Extract location from reply @@ -181,7 +182,7 @@ func (s *httpSender) extractMeta(r string) error { s.log(logger.Warning, pkg+"No location in reply") } else { s.log(logger.Debug, fmt.Sprintf("%v got location: %v", pkg, g)) - mts.MetaData.SetLocation(g) + mts.Meta.Add("loc", g) } return nil @@ -272,7 +273,6 @@ func newRtmpSender(url string, timeout uint, retries int, log func(lvl int8, msg break } log(logger.Error, err.Error()) - conn.Close() if n < retries-1 { log(logger.Info, pkg+"retry rtmp connection") } diff --git a/rtmp/conn.go b/rtmp/conn.go index 7e1b3b15..4249a637 100644 --- a/rtmp/conn.go +++ b/rtmp/conn.go @@ -134,10 +134,10 @@ func Dial(url string, timeout uint, log Log) (*Conn, error) { // Close terminates the RTMP connection. // NB: Close is idempotent and the connection value is cleared completely. func (c *Conn) Close() error { - c.log(DebugLevel, pkg+"Conn.Close") if !c.isConnected() { return errNotConnected } + c.log(DebugLevel, pkg+"Conn.Close") if c.streamID > 0 { if c.link.protocol&featureWrite != 0 { sendFCUnpublish(c) diff --git a/stream/mts/encoder.go b/stream/mts/encoder.go index 60393285..6ac11cc3 100644 --- a/stream/mts/encoder.go +++ b/stream/mts/encoder.go @@ -30,9 +30,9 @@ package mts import ( "io" - "sync" "time" + "bitbucket.org/ausocean/av/stream/mts/meta" "bitbucket.org/ausocean/av/stream/mts/pes" "bitbucket.org/ausocean/av/stream/mts/psi" ) @@ -82,93 +82,21 @@ var ( }, }, } - - // standardPmtTimeLocation is a standard PMT with time and location - // descriptors, but time and location fields zeroed out. - standardPmtTimeLocation = psi.PSI{ - Pf: 0x00, - Tid: 0x02, - Ssi: true, - Sl: 0x3e, - Tss: &psi.TSS{ - Tide: 0x01, - V: 0, - Cni: true, - Sn: 0, - Lsn: 0, - Sd: &psi.PMT{ - Pcrpid: 0x0100, - Pil: psi.PmtTimeLocationPil, - Pd: []psi.Desc{ - { - Dt: psi.TimeDescTag, - Dl: psi.TimeDataSize, - Dd: make([]byte, psi.TimeDataSize), - }, - { - Dt: psi.LocationDescTag, - Dl: psi.LocationDataSize, - Dd: make([]byte, psi.LocationDataSize), - }, - }, - Essd: &psi.ESSD{ - St: 0x1b, - Epid: 0x0100, - Esil: 0x00, - }, - }, - }, - } ) const ( - psiSndCnt = 7 + psiInterval = 1 * time.Second ) -// timeLocation holds time and location data -type timeLocation struct { - mu sync.RWMutex - time uint64 - location string -} - -// SetTimeStamp sets the time field of a TimeLocation. -func (tl *timeLocation) SetTimeStamp(t uint64) { - tl.mu.Lock() - tl.time = t - tl.mu.Unlock() -} - -// GetTimeStamp returns the location of a TimeLocation. -func (tl *timeLocation) TimeStamp() uint64 { - tl.mu.RLock() - t := tl.time - tl.mu.RUnlock() - return t -} - -// SetLocation sets the location of a TimeLocation. -func (tl *timeLocation) SetLocation(l string) { - tl.mu.Lock() - tl.location = l - tl.mu.Unlock() -} - -// GetLocation returns the location of a TimeLocation. -func (tl *timeLocation) Location() string { - tl.mu.RLock() - l := tl.location - tl.mu.RUnlock() - return l -} - -// MetData will hold time and location data which may be set externally if -// this data is available. It is then inserted into mpegts packets outputted. -var MetaData timeLocation +// Meta allows addition of metadata to encoded mts from outside of this pkg. +// See meta pkg for usage. +// +// TODO: make this not global. +var Meta *meta.Data var ( patTable = standardPat.Bytes() - pmtTable = standardPmtTimeLocation.Bytes() + pmtTable = standardPmt.Bytes() ) const ( @@ -194,14 +122,15 @@ type Encoder struct { dst io.Writer clock time.Duration + lastTime time.Time frameInterval time.Duration ptsOffset time.Duration tsSpace [PacketSize]byte pesSpace [pes.MaxPesSize]byte - psiCount int - continuity map[int]byte + + psiLastTime time.Time } // NewEncoder returns an Encoder with the specified frame rate. @@ -233,12 +162,15 @@ const ( // generate handles the incoming data and generates equivalent mpegts packets - // sending them to the output channel. func (e *Encoder) Encode(nalu []byte) error { - if e.psiCount <= 0 { + now := time.Now() + if now.Sub(e.psiLastTime) > psiInterval { err := e.writePSI() if err != nil { return err } + e.psiLastTime = now } + // Prepare PES data. pesPkt := pes.Packet{ StreamID: streamID, @@ -269,7 +201,6 @@ func (e *Encoder) Encode(nalu []byte) error { pusi = false } _, err := e.dst.Write(pkt.Bytes(e.tsSpace[:PacketSize])) - e.psiCount-- if err != nil { return err } @@ -286,39 +217,32 @@ func (e *Encoder) writePSI() error { // Write PAT. patPkt := Packet{ PUSI: true, - PID: patPid, - CC: e.ccFor(patPid), - AFC: hasPayload, - Payload: patTable, + PID: PatPid, + CC: e.ccFor(PatPid), + AFC: HasPayload, + Payload: psi.AddPadding(patTable), } _, err := e.dst.Write(patPkt.Bytes(e.tsSpace[:PacketSize])) if err != nil { return err } - - // Update pmt table time and location. - err = psi.UpdateTime(pmtTable, MetaData.TimeStamp()) + pmtTable, err = updateMeta(pmtTable) if err != nil { return err } - err = psi.UpdateLocation(pmtTable, MetaData.Location()) - if err != nil { - return nil - } // Create mts packet from pmt table. pmtPkt := Packet{ PUSI: true, - PID: pmtPid, - CC: e.ccFor(pmtPid), - AFC: hasPayload, - Payload: pmtTable, + PID: PmtPid, + CC: e.ccFor(PmtPid), + AFC: HasPayload, + Payload: psi.AddPadding(pmtTable), } _, err = e.dst.Write(pmtPkt.Bytes(e.tsSpace[:PacketSize])) if err != nil { return err } - e.psiCount = psiSndCnt return nil } @@ -344,3 +268,11 @@ func (e *Encoder) ccFor(pid int) byte { e.continuity[pid] = (cc + 1) & continuityCounterMask return cc } + +// updateMeta adds/updates a metaData descriptor in the given psi bytes using data +// contained in the global Meta struct. +func updateMeta(b []byte) ([]byte, error) { + p := psi.PSIBytes(b) + err := p.AddDescriptor(psi.MetadataTag, Meta.Encode()) + return []byte(p), err +} diff --git a/stream/mts/meta/meta.go b/stream/mts/meta/meta.go new file mode 100644 index 00000000..481b5ae5 --- /dev/null +++ b/stream/mts/meta/meta.go @@ -0,0 +1,222 @@ +/* +NAME + meta.go + +DESCRIPTION + See Readme.md + +AUTHOR + Saxon Nelson-Milton + +LICENSE + meta.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package meta + +import ( + "encoding/binary" + "errors" + "strings" + "sync" +) + +// This is the headsize of our metadata string, +// which is encoded int the data body of a pmt descriptor. +const headSize = 4 + +const ( + majVer = 1 + minVer = 0 +) + +// Indices of bytes for uint16 metadata length. +const ( + dataLenIdx = 2 +) + +var ( + errKeyAbsent = errors.New("Key does not exist in map") + errInvalidMeta = errors.New("Invalid metadata given") + errUnexpectedMetaFormat = errors.New("Unexpected meta format") +) + +// Metadata provides functionality for the storage and encoding of metadata +// using a map. +type Data struct { + mu sync.RWMutex + data map[string]string + order []string + enc []byte +} + +// New returns a pointer to a new Metadata. +func New() *Data { + return &Data{ + data: make(map[string]string), + enc: []byte{ + 0x00, // Reserved byte + (majVer << 4) | minVer, // MS and LS versions + 0x00, // Data len byte1 + 0x00, // Data len byte2 + }, + } +} + +// NewWith creates a meta.Data and fills map with initial data given. If there +// is repeated key, then the latter overwrites the prior. +func NewWith(data [][2]string) *Data { + m := New() + m.order = make([]string, 0, len(data)) + for _, d := range data { + if _, exists := m.data[d[0]]; !exists { + m.order = append(m.order, d[0]) + } + m.data[d[0]] = d[1] + } + return m +} + +// Add adds metadata with key and val. +func (m *Data) Add(key, val string) { + m.mu.Lock() + defer m.mu.Unlock() + m.data[key] = val + for _, k := range m.order { + if k == key { + return + } + } + m.order = append(m.order, key) + return +} + +// All returns the a copy of the map containing the meta data. +func (m *Data) All() map[string]string { + m.mu.Lock() + cpy := make(map[string]string) + for k, v := range m.data { + cpy[k] = v + } + m.mu.Unlock() + return cpy +} + +// Get returns the meta data for the passed key. +func (m *Data) Get(key string) (val string, ok bool) { + m.mu.Lock() + val, ok = m.data[key] + m.mu.Unlock() + return +} + +// Delete deletes a meta entry in the map and returns error if it doesn’t exist. +func (m *Data) Delete(key string) { + m.mu.Lock() + defer m.mu.Unlock() + if _, ok := m.data[key]; ok { + delete(m.data, key) + for i, k := range m.order { + if k == key { + copy(m.order[:i], m.order[i+1:]) + m.order = m.order[:len(m.order)-1] + break + } + } + return + } + return +} + +// Encode takes the meta data map and encodes into a byte slice with header +// describing the version, length of data and data in TSV format. +func (m *Data) Encode() []byte { + m.enc = m.enc[:headSize] + + // Iterate over map and append entries, only adding tab if we're not on the + // last entry. + var entry string + for i, k := range m.order { + v := m.data[k] + entry += k + "=" + v + if i+1 < len(m.data) { + entry += "\t" + } + } + m.enc = append(m.enc, []byte(entry)...) + + // Calculate and set data length in encoded meta header. + dataLen := len(m.enc[headSize:]) + binary.BigEndian.PutUint16(m.enc[dataLenIdx:dataLenIdx+2], uint16(dataLen)) + return m.enc +} + +// Keys returns all keys in a slice of metadata d. +func Keys(d []byte) ([]string, error) { + m, err := GetAll(d) + if err != nil { + return nil, err + } + k := make([]string, len(m)) + for i, kv := range m { + k[i] = kv[0] + } + return k, nil +} + +// Get returns the value for the given key in d. +func Get(key string, d []byte) (string, error) { + err := checkMeta(d) + if err != nil { + return "", err + } + d = d[headSize:] + entries := strings.Split(string(d), "\t") + for _, entry := range entries { + kv := strings.Split(entry, "=") + if kv[0] == key { + return kv[1], nil + } + } + return "", errKeyAbsent +} + +// GetAll returns metadata keys and values from d. +func GetAll(d []byte) ([][2]string, error) { + err := checkMeta(d) + if err != nil { + return nil, err + } + d = d[headSize:] + entries := strings.Split(string(d), "\t") + all := make([][2]string, len(entries)) + for i, entry := range entries { + kv := strings.Split(entry, "=") + if len(kv) != 2 { + return nil, errUnexpectedMetaFormat + } + copy(all[i][:], kv) + } + return all, nil +} + +// checkHeader checks that a valid metadata header exists in the given data. +func checkMeta(d []byte) error { + if len(d) == 0 || d[0] != 0 || binary.BigEndian.Uint16(d[2:headSize]) != uint16(len(d[headSize:])) { + return errInvalidMeta + } + return nil +} diff --git a/stream/mts/meta/meta_test.go b/stream/mts/meta/meta_test.go new file mode 100644 index 00000000..38e4dbb6 --- /dev/null +++ b/stream/mts/meta/meta_test.go @@ -0,0 +1,204 @@ +/* +NAME + meta_test.go + +DESCRIPTION + See Readme.md + +AUTHOR + Saxon Nelson-Milton + +LICENSE + meta_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package meta + +import ( + "bytes" + "encoding/binary" + "reflect" + "testing" +) + +const ( + tstKey1 = "loc" + tstData1 = "a,b,c" + tstKey2 = "ts" + tstData2 = "12345678" + tstData3 = "d,e,f" +) + +// TestAddAndGet ensures that we can add metadata and then successfully get it. +func TestAddAndGet(t *testing.T) { + meta := New() + meta.Add(tstKey1, tstData1) + meta.Add(tstKey2, tstData2) + if data, ok := meta.Get(tstKey1); !ok { + t.Errorf("Could not get data for key: %v\n", tstKey1) + if data != tstData1 { + t.Error("Did not get expected data") + } + } + + if data, ok := meta.Get(tstKey2); !ok { + t.Errorf("Could not get data for key: %v", tstKey2) + if data != tstData2 { + t.Error("Did not get expected data") + } + } +} + +// TestUpdate checks that we can use Meta.Add to actually update metadata +// if it already exists in the Meta map. +func TestUpdate(t *testing.T) { + meta := New() + meta.Add(tstKey1, tstData1) + meta.Add(tstKey1, tstData3) + + if data, ok := meta.Get(tstKey1); !ok { + t.Errorf("Could not get data for key: %v\n", tstKey1) + if data != tstData2 { + t.Error(`Data did not correctly update for key "loc"`) + } + } +} + +// TestAll ensures we can get a correct map using Meta.All() after adding some data +func TestAll(t *testing.T) { + meta := New() + tstMap := map[string]string{ + tstKey1: tstData1, + tstKey2: tstData2, + } + + meta.Add(tstKey1, tstData1) + meta.Add(tstKey2, tstData2) + metaMap := meta.All() + + if !reflect.DeepEqual(metaMap, tstMap) { + t.Errorf("Map not correct. Got: %v, want: %v", metaMap, tstMap) + } +} + +// TestGetAbsentKey ensures that we get the expected error when we try to get with +// key that does not yet exist in the Meta map. +func TestGetAbsentKey(t *testing.T) { + meta := New() + + if _, ok := meta.Get(tstKey1); ok { + t.Error("Get for absent key incorrectly returned'ok'") + } +} + +// TestDelete ensures we can remove a data entry in the Meta map. +func TestDelete(t *testing.T) { + meta := New() + meta.Add(tstKey1, tstData1) + meta.Delete(tstKey1) + if _, ok := meta.Get(tstKey1); ok { + t.Error("Get incorrectly returned okay for absent key") + } +} + +// TestEncode checks that we're getting the correct byte slice from Meta.Encode(). +func TestEncode(t *testing.T) { + meta := New() + meta.Add(tstKey1, tstData1) + meta.Add(tstKey2, tstData2) + + dataLen := len(tstKey1+tstData1+tstKey2+tstData2) + 3 + header := [4]byte{ + 0x00, + 0x10, + } + binary.BigEndian.PutUint16(header[2:4], uint16(dataLen)) + expectedOut := append(header[:], []byte( + tstKey1+"="+tstData1+"\t"+ + tstKey2+"="+tstData2)...) + + got := meta.Encode() + if !bytes.Equal(expectedOut, got) { + t.Errorf("Did not get expected out. \nGot : %v, \nwant: %v\n", got, expectedOut) + } +} + +// TestGetFrom checks that we can correctly obtain a value for a partiular key +// from a string of metadata using the ReadFrom func. +func TestGetFrom(t *testing.T) { + tstMeta := append([]byte{0x00, 0x10, 0x00, 0x12}, "loc=a,b,c\tts=12345"...) + + tests := []struct { + key string + want string + }{ + { + "loc", + "a,b,c", + }, + { + "ts", + "12345", + }, + } + + for _, test := range tests { + got, err := Get(test.key, []byte(tstMeta)) + if err != nil { + t.Errorf("Unexpected err: %v\n", err) + } + if got != test.want { + t.Errorf("Did not get expected out. \nGot : %v, \nwant: %v\n", got, test.want) + } + } +} + +// TestGetAll checks that meta.GetAll can correctly get all metadata +// from descriptor data. +func TestGetAll(t *testing.T) { + tstMeta := append([]byte{0x00, 0x10, 0x00, 0x12}, "loc=a,b,c\tts=12345"...) + want := [][2]string{ + { + "loc", + "a,b,c", + }, + { + "ts", + "12345", + }, + } + got, err := GetAll(tstMeta) + if err != nil { + t.Errorf("Unexpected error: %v\n", err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Did not get expected out. \nGot : %v, \nWant: %v\n", got, want) + } +} + +// TestKeys checks that we can successfully get keys from some metadata using +// the meta.Keys method. +func TestKeys(t *testing.T) { + tstMeta := append([]byte{0x00, 0x10, 0x00, 0x12}, "loc=a,b,c\tts=12345"...) + want := []string{"loc", "ts"} + got, err := Keys(tstMeta) + if err != nil { + t.Errorf("Unexpected error: %v\n", err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Did not get expected out. \nGot : %v, \nWant: %v\n", got, want) + } +} diff --git a/stream/mts/metaEncode_test.go b/stream/mts/metaEncode_test.go new file mode 100644 index 00000000..e970b7c8 --- /dev/null +++ b/stream/mts/metaEncode_test.go @@ -0,0 +1,102 @@ +/* +NAME + metaEncode_test.go + +DESCRIPTION + See Readme.md + +AUTHOR + Saxon Nelson-Milton + +LICENSE + metaEncode_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package mts + +import ( + "bytes" + "testing" + + "bitbucket.org/ausocean/av/stream/mts/meta" + "bitbucket.org/ausocean/av/stream/mts/psi" +) + +const ( + errNotExpectedOut = "Unexpected output. \n Got : %v\n, Want: %v\n" + errUnexpectedErr = "Unexpected error: %v\n" +) + +const fps = 25 + +// TestMetaEncode1 checks that we can externally add a single metadata entry to +// the mts global Meta meta.Data struct and then successfully have the mts encoder +// write this to psi. +func TestMetaEncode1(t *testing.T) { + Meta = meta.New() + var b []byte + buf := bytes.NewBuffer(b) + e := NewEncoder(buf, fps) + Meta.Add("ts", "12345678") + if err := e.writePSI(); err != nil { + t.Errorf(errUnexpectedErr, err.Error()) + } + out := buf.Bytes() + got := out[PacketSize+4:] + + want := []byte{ + 0x00, 0x02, 0xb0, 0x23, 0x00, 0x01, 0xc1, 0x00, 0x00, 0xe1, 0x00, 0xf0, 0x11, + psi.MetadataTag, // Descriptor tag + 0x0f, // Length of bytes to follow + 0x00, 0x10, 0x00, 0x0b, 't', 's', '=', '1', '2', '3', '4', '5', '6', '7', '8', // timestamp + 0x1b, 0xe1, 0x00, 0xf0, 0x00, + } + want = psi.AddCrc(want) + want = psi.AddPadding(want) + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestMetaEncode2 checks that we can externally add two metadata entries to the +// Meta meta.Data global and then have the mts encoder successfully encode this +// into psi. +func TestMetaEncode2(t *testing.T) { + Meta = meta.New() + var b []byte + buf := bytes.NewBuffer(b) + e := NewEncoder(buf, fps) + Meta.Add("ts", "12345678") + Meta.Add("loc", "1234,4321,1234") + if err := e.writePSI(); err != nil { + t.Errorf(errUnexpectedErr, err.Error()) + } + out := buf.Bytes() + got := out[PacketSize+4:] + want := []byte{ + 0x00, 0x02, 0xb0, 0x36, 0x00, 0x01, 0xc1, 0x00, 0x00, 0xe1, 0x00, 0xf0, 0x24, + psi.MetadataTag, // Descriptor tag + 0x22, // Length of bytes to follow + 0x00, 0x10, 0x00, 0x1e, 't', 's', '=', '1', '2', '3', '4', '5', '6', '7', '8', '\t', // timestamp + 'l', 'o', 'c', '=', '1', '2', '3', '4', ',', '4', '3', '2', '1', ',', '1', '2', '3', '4', // location + 0x1b, 0xe1, 0x00, 0xf0, 0x00, + } + want = psi.AddCrc(want) + want = psi.AddPadding(want) + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} diff --git a/stream/mts/mpegts.go b/stream/mts/mpegts.go index 0bef80d2..705f88de 100644 --- a/stream/mts/mpegts.go +++ b/stream/mts/mpegts.go @@ -30,13 +30,44 @@ package mts import ( "errors" + "fmt" ) +// General mpegts packet properties. const ( PacketSize = 188 PayloadSize = 176 ) +// Program ID for various types of ts packets. +const ( + SdtPid = 17 + PatPid = 0 + PmtPid = 4096 + VideoPid = 256 +) + +// StreamID is the id of the first stream. +const StreamID = 0xe0 + +// HeadSize is the size of an mpegts packet header. +const HeadSize = 4 + +// Consts relating to adaptation field. +const ( + AdaptationIdx = 4 // Index to the adaptation field (index of AFL). + AdaptationControlIdx = 3 // Index to octet with adaptation field control. + AdaptationFieldsIdx = AdaptationIdx + 1 // Adaptation field index is the index of the adaptation fields. + DefaultAdaptationSize = 2 // Default size of the adaptation field. + AdaptationControlMask = 0x30 // Mask for the adaptation field control in octet 3. +) + +// TODO: make this better - currently doesn't make sense. +const ( + HasPayload = 0x1 + HasAdaptationField = 0x2 +) + /* The below data struct encapsulates the fields of an MPEG-TS packet. Below is the formatting of an MPEG-TS packet for reference! @@ -127,20 +158,32 @@ type Packet struct { Payload []byte // Mpeg ts Payload } -// FindPMT will take a clip of mpegts and try to find a PMT table - if one +// FindPmt will take a clip of mpegts and try to find a PMT table - if one // is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. -func FindPMT(d []byte) (p []byte, i int, err error) { +func FindPmt(d []byte) ([]byte, int, error) { + return FindPid(d, PmtPid) +} + +// FindPat will take a clip of mpegts and try to find a PAT table - if one +// is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. +func FindPat(d []byte) ([]byte, int, error) { + return FindPid(d, PatPid) +} + +// FindPid will take a clip of mpegts and try to find a packet with given PID - if one +// is found, then it is returned along with its index, otherwise nil, -1 and an error is returned. +func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) { if len(d) < PacketSize { return nil, -1, errors.New("Mmpegts data not of valid length") } for i = 0; i < len(d); i += PacketSize { - pid := (uint16(d[i+1]&0x1f) << 8) | uint16(d[i+2]) - if pid == pmtPid { - p = d[i+4 : i+PacketSize] + p := (uint16(d[i+1]&0x1f) << 8) | uint16(d[i+2]) + if p == pid { + pkt = d[i+4 : i+PacketSize] return } } - return nil, -1, errors.New("Could not find pmt table in mpegts data") + return nil, -1, fmt.Errorf("could not find packet with pid: %d", pid) } // FillPayload takes a channel and fills the packets Payload field until the diff --git a/stream/mts/psi/crc.go b/stream/mts/psi/crc.go index a361307d..e0ac7deb 100644 --- a/stream/mts/psi/crc.go +++ b/stream/mts/psi/crc.go @@ -34,16 +34,16 @@ import ( ) // addCrc appends a crc table to a given psi table in bytes -func addCrc(out []byte) []byte { +func AddCrc(out []byte) []byte { t := make([]byte, len(out)+4) copy(t, out) - updateCrc(t) + UpdateCrc(t[1:]) return t } // updateCrc updates the crc of bytes slice, writing the checksum into the last four bytes. -func updateCrc(b []byte) { - crc32 := crc32_Update(0xffffffff, crc32_MakeTable(bits.Reverse32(crc32.IEEE)), b[1:len(b)-4]) +func UpdateCrc(b []byte) { + crc32 := crc32_Update(0xffffffff, crc32_MakeTable(bits.Reverse32(crc32.IEEE)), b[:len(b)-4]) binary.BigEndian.PutUint32(b[len(b)-4:], crc32) } diff --git a/stream/mts/psi/descriptor_test.go b/stream/mts/psi/descriptor_test.go new file mode 100644 index 00000000..94441277 --- /dev/null +++ b/stream/mts/psi/descriptor_test.go @@ -0,0 +1,322 @@ +/* +NAME + descriptor_test.go + +DESCRIPTION + See Readme.md + +AUTHOR + Saxon Nelson-Milton + +LICENSE + descriptor_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean) + + It is free software: you can redistribute it and/or modify them + under the terms of the GNU General Public License as published by the + Free Software Foundation, either version 3 of the License, or (at your + option) any later version. + + It is distributed in the hope that it will be useful, but WITHOUT + ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + for more details. + + You should have received a copy of the GNU General Public License + along with revid in gpl.txt. If not, see http://www.gnu.org/licenses. +*/ + +package psi + +import ( + "bytes" + "testing" +) + +const ( + errNotExpectedOut = "Did not get expected output: \ngot : %v, \nwant: %v" + errUnexpectedErr = "Unexpected error: %v\n" +) + +var ( + tstPsi1 = PSI{ + Pf: 0x00, + Tid: 0x02, + Ssi: true, + Sl: 0x1c, + Tss: &TSS{ + Tide: 0x01, + V: 0, + Cni: true, + Sn: 0, + Lsn: 0, + Sd: &PMT{ + Pcrpid: 0x0100, // wrong + Pil: 10, + Pd: []Desc{ + { + Dt: TimeDescTag, + Dl: TimeDataSize, + Dd: make([]byte, TimeDataSize), + }, + }, + Essd: &ESSD{ + St: 0x1b, + Epid: 0x0100, + Esil: 0x00, + }, + }, + }, + } + + tstPsi2 = PSI{ + Pf: 0x00, + Tid: 0x02, + Ssi: true, + Sl: 0x12, + Tss: &TSS{ + Tide: 0x01, + V: 0, + Cni: true, + Sn: 0, + Lsn: 0, + Sd: &PMT{ + Pcrpid: 0x0100, + Pil: 0, + Essd: &ESSD{ + St: 0x1b, + Epid: 0x0100, + Esil: 0x00, + }, + }, + }, + } + + tstPsi3 = PSI{ + Pf: 0x00, + Tid: 0x02, + Ssi: true, + Sl: 0x3e, + Tss: &TSS{ + Tide: 0x01, + V: 0, + Cni: true, + Sn: 0, + Lsn: 0, + Sd: &PMT{ + Pcrpid: 0x0100, + Pil: PmtTimeLocationPil, + Pd: []Desc{ + { + Dt: TimeDescTag, + Dl: TimeDataSize, + Dd: make([]byte, TimeDataSize), + }, + { + Dt: LocationDescTag, + Dl: LocationDataSize, + Dd: make([]byte, LocationDataSize), + }, + }, + Essd: &ESSD{ + St: 0x1b, + Epid: 0x0100, + Esil: 0x00, + }, + }, + }, + } +) + +var ( + pmtTimeBytesResizedBigger = []byte{ + 0x00, 0x02, 0xb0, 0x1e, 0x00, 0x01, 0xc1, 0x00, 0x00, 0xe1, 0x00, 0xf0, 0x0c, + TimeDescTag, // Descriptor tag + 0x0a, // Length of bytes to follow + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, // timestamp + 0x1b, 0xe1, 0x00, 0xf0, 0x00, + } + + pmtTimeBytesResizedSmaller = []byte{ + 0x00, 0x02, 0xb0, 0x1a, 0x00, 0x01, 0xc1, 0x00, 0x00, 0xe1, 0x00, 0xf0, 0x08, + TimeDescTag, // Descriptor tag + 0x06, // Length of bytes to follow + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, // timestamp + 0x1b, 0xe1, 0x00, 0xf0, 0x00, + } +) + +// TestHasDescriptorExists checks that PSIBytes.HasDescriptor performs as expected +// when the PSI we're interested in has the descriptor of interest. HasDescriptor +// should return the descriptor bytes. +// TODO: HasDescriptor also returns index of descriptor - we should check this. +func TestHasDescriptorExists(t *testing.T) { + p := PSIBytes(tstPsi3.Bytes()) + _, got := p.HasDescriptor(LocationDescTag) + want := []byte{ + LocationDescTag, + LocationDataSize, + } + want = append(want, make([]byte, LocationDataSize)...) + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestHasDescriptorAbsent checks that PSIBytes.HasDescriptor performs as expected +// when the PSI does not have the descriptor of interest. HasDescriptor should +// return a nil slice and a negative index. +// TODO: check index here as well. +func TestHasDescriptorAbsent(t *testing.T) { + p := PSIBytes(tstPsi3.Bytes()) + const fakeTag = 236 + _, got := p.HasDescriptor(fakeTag) + var want []byte + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestHasDescriptorNone checks that PSIBytes.HasDescriptor behaves as expected +// when the PSI does not have any descriptors. HasDescriptor should return a nil +// slice. +// TODO: again check index here. +func TestHasDescriptorNone(t *testing.T) { + p := PSIBytes(tstPsi2.Bytes()) + _, got := p.HasDescriptor(LocationDescTag) + var want []byte + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestProgramInfoLen checks that PSIBytes.ProgramInfoLen correctly extracts +// the program info length from a PSI. +func TestProgramInfoLen(t *testing.T) { + p := PSIBytes(tstPsi1.Bytes()) + got := p.ProgramInfoLen() + want := 10 + if got != want { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestDescriptors checks that PSIBytes.descriptors correctly returns the descriptors +// from a PSI when descriptors exist. +func TestDescriptors(t *testing.T) { + p := PSIBytes(tstPsi1.Bytes()) + got := p.descriptors() + want := []byte{ + TimeDescTag, + TimeDataSize, + } + want = append(want, make([]byte, TimeDataSize)...) + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestDescriptors checks that PSIBYtes.desriptors correctly returns nil when +// we try to get descriptors from a psi without any descriptors. +func TestDescriptorsNone(t *testing.T) { + p := PSIBytes(tstPsi2.Bytes()) + got := p.descriptors() + var want []byte + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestCreateDescriptorEmpty checks that PSIBytes.createDescriptor correctly adds +// a descriptor to the descriptors list in a PSI when it has no descriptors already. +func TestCreateDescriptorEmpty(t *testing.T) { + got := PSIBytes(tstPsi2.Bytes()) + got.createDescriptor(TimeDescTag, make([]byte, TimeDataSize)) + UpdateCrc(got[1:]) + want := PSIBytes(tstPsi1.Bytes()) + if !bytes.Equal(want, got) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestCreateDescriptorNotEmpty checks that PSIBytes.createDescriptor correctly adds +// a descriptor to the descriptors list in a PSI when it already has one with +// a different tag. +func TestCreateDescriptorNotEmpty(t *testing.T) { + got := PSIBytes(tstPsi1.Bytes()) + got.createDescriptor(LocationDescTag, make([]byte, LocationDataSize)) + UpdateCrc(got[1:]) + want := PSIBytes(tstPsi3.Bytes()) + if !bytes.Equal(want, got) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestAddDescriptorEmpty checks that PSIBytes.AddDescriptor correctly adds a descriptor +// when there are no other descriptors present in the PSI. +func TestAddDescriptorEmpty(t *testing.T) { + got := PSIBytes(tstPsi2.Bytes()) + if err := got.AddDescriptor(TimeDescTag, make([]byte, TimeDataSize)); err != nil { + t.Errorf(errUnexpectedErr, err.Error()) + } + want := PSIBytes(tstPsi1.Bytes()) + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestAddDescriptorNonEmpty checks that PSIBytes.AddDescriptor correctly adds a +// descriptor when there is already a descriptor of a differing type in a PSI. +func TestAddDescriptorNonEmpty(t *testing.T) { + got := PSIBytes(tstPsi1.Bytes()) + if err := got.AddDescriptor(LocationDescTag, make([]byte, LocationDataSize)); err != nil { + t.Errorf(errUnexpectedErr, err.Error()) + } + want := PSIBytes(tstPsi3.Bytes()) + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestAddDescriptorUpdateSame checks that PSIBytes.AddDescriptor correctly updates data in a descriptor +// with the same given tag, with data being the same size. AddDescriptor should just copy new data into +// the descriptors data field. +func TestAddDescriptorUpdateSame(t *testing.T) { + newData := [8]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} + want := PSIBytes(tstPsi2.Bytes()) + want.createDescriptor(TimeDescTag, newData[:]) + got := PSIBytes(tstPsi1.Bytes()) + if err := got.AddDescriptor(TimeDescTag, newData[:]); err != nil { + t.Errorf(errUnexpectedErr, err.Error()) + } + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestAddDescriptorUpdateBigger checks that PSIBytes.AddDescriptor correctly resizes descriptor with same given tag +// to a bigger size and copies in new data. AddDescriptor should find descriptor with same tag, increase size of psi, +// shift data to make room for update descriptor, and then copy in the new data. +func TestAddDescriptorUpdateBigger(t *testing.T) { + got := PSIBytes(tstPsi1.Bytes()) + if err := got.AddDescriptor(TimeDescTag, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a}); err != nil { + t.Errorf(errUnexpectedErr, err.Error()) + } + want := AddCrc(pmtTimeBytesResizedBigger) + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} + +// TestAddDescriptorUpdateSmaller checks that PSIBytes.AddDescriptor correctly resizes descriptor with same given tag +// in a psi to a smaller size and copies in new data. AddDescriptor should find tag with same descrtiptor, shift data +// after descriptor upwards, trim the psi to new size, and then copy in new data. +func TestAddDescriptorUpdateSmaller(t *testing.T) { + got := PSIBytes(tstPsi1.Bytes()) + if err := got.AddDescriptor(TimeDescTag, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}); err != nil { + t.Errorf(errUnexpectedErr, err.Error()) + } + want := AddCrc(pmtTimeBytesResizedSmaller) + if !bytes.Equal(got, want) { + t.Errorf(errNotExpectedOut, got, want) + } +} diff --git a/stream/mts/psi/helpers.go b/stream/mts/psi/helpers.go index 82d7ce72..b8bab6b5 100644 --- a/stream/mts/psi/helpers.go +++ b/stream/mts/psi/helpers.go @@ -64,15 +64,14 @@ func UpdateTime(dst []byte, t uint64) error { for i := range dst[TimeDataIndx : TimeDataIndx+TimeDataSize] { dst[i+TimeDataIndx] = ts[i] } - updateCrc(dst) + UpdateCrc(dst[1:]) return nil } // SyntaxSecLenFrom takes a byte slice representation of a psi and extracts // it's syntax section length -func SyntaxSecLenFrom(p []byte) (l uint8) { - l = uint8(p[syntaxSecLenIndx]) - crcSize - return +func SyntaxSecLenFrom(p []byte) int { + return int(((p[SyntaxSecLenIdx1] & SyntaxSecLenMask1) << 8) | p[SyntaxSecLenIdx2]) } // TimeFrom takes a byte slice representation of a psi-pmt and extracts it's @@ -112,7 +111,7 @@ func UpdateLocation(d []byte, s string) error { for i := range loc { loc[i] = 0 } - updateCrc(d) + UpdateCrc(d[1:]) return nil } @@ -127,7 +126,7 @@ func trimTo(d []byte, t byte) []byte { // addPadding adds an appropriate amount of padding to a pat or pmt table for // addition to an mpegts packet -func addPadding(d []byte) []byte { +func AddPadding(d []byte) []byte { t := make([]byte, PacketSize) copy(t, d) padding := t[len(d):] diff --git a/stream/mts/psi/psi.go b/stream/mts/psi/psi.go index 09028ba6..c93d3011 100644 --- a/stream/mts/psi/psi.go +++ b/stream/mts/psi/psi.go @@ -26,11 +26,16 @@ LICENSE package psi -const ( - PacketSize = 184 // packet size of a psi. +import ( + "errors" + + "github.com/Comcast/gots/psi" ) -// Lengths of section definitions +// PacketSize of psi (without mpegts header) +const PacketSize = 184 + +// Lengths of section definitions. const ( ESSDDefLen = 5 DescDefLen = 2 @@ -40,13 +45,14 @@ const ( PSIDefLen = 3 ) -// Table Type IDs +// Table Type IDs. const ( patID = 0x00 pmtID = 0x02 ) // Consts relating to time description +// TODO: remove this, we don't do metadata like this anymore. const ( TimeDescTag = 234 TimeTagIndx = 13 @@ -55,6 +61,7 @@ const ( ) // Consts relating to location description +// TODO: remove this, we don't do metadata like this anymore. const ( LocationDescTag = 235 LocationTagIndx = 23 @@ -62,10 +69,35 @@ const ( LocationDataSize = 32 // bytes ) -// Other misc consts +// crc hassh Size +const crcSize = 4 + +// Consts relating to syntax section. const ( - syntaxSecLenIndx = 3 - crcSize = 4 + TotalSyntaxSecLen = 180 + SyntaxSecLenIdx1 = 2 + SyntaxSecLenIdx2 = 3 + SyntaxSecLenMask1 = 0x03 + SectionLenMask1 = 0x03 +) + +// Consts relating to program info len. +const ( + ProgramInfoLenIdx1 = 11 + ProgramInfoLenIdx2 = 12 + ProgramInfoLenMask1 = 0x03 +) + +// DescriptorsIdx is the index that the descriptors start at. +const DescriptorsIdx = ProgramInfoLenIdx2 + 1 + +// MetadataTag is the descriptor tag used for metadata. +const MetadataTag = 0x26 + +// TODO: get rid of these - not a good idea. +type ( + PSIBytes []byte + Descriptor []byte ) // Program specific information @@ -135,8 +167,7 @@ func (p *PSI) Bytes() []byte { out[2] = 0x80 | 0x30 | (0x03 & byte(p.Sl>>8)) out[3] = byte(p.Sl) out = append(out, p.Tss.Bytes()...) - out = addCrc(out) - out = addPadding(out) + out = AddCrc(out) return out } @@ -205,3 +236,135 @@ func asByte(b bool) byte { } return 0x00 } + +// AddDescriptor adds or updates a descriptor in a PSI given a descriptor tag +// and data. If the psi is not a pmt, then an error is returned. If a descriptor +// with the given tag is not found in the psi, room is made and a descriptor with +// given tag and data is created. If a descriptor with the tag is found, the +// descriptor is resized as required and the new data is copied in. +func (p *PSIBytes) AddDescriptor(tag int, data []byte) error { + if psi.TableID(*p) != pmtID { + return errors.New("trying to add descriptor, but not pmt") + } + + i, desc := p.HasDescriptor(tag) + if desc == nil { + err := p.createDescriptor(tag, data) + return err + } + + oldDescLen := desc.len() + oldDataLen := int(desc[1]) + newDataLen := len(data) + newDescLen := 2 + newDataLen + delta := newDescLen - oldDescLen + + // If the old data length is more than the new data length, we need shift data + // after descriptor up, and then trim the psi. If the oldDataLen is less than + // new data then we need reseize psi and shift data down. If same do nothing. + switch { + case oldDataLen > newDataLen: + copy((*p)[i+newDescLen:], (*p)[i+oldDescLen:]) + *p = (*p)[:len(*p)+delta] + case oldDataLen < newDataLen: + tmp := make([]byte, len(*p)+delta) + copy(tmp, *p) + *p = tmp + copy((*p)[i+newDescLen:], (*p)[i+oldDescLen:]) + } + + // Copy in new data + (*p)[i+1] = byte(newDataLen) + copy((*p)[i+2:], data) + + newProgInfoLen := p.ProgramInfoLen() + delta + p.setProgInfoLen(newProgInfoLen) + newSectionLen := int(psi.SectionLength(*p)) + delta + p.setSectionLen(newSectionLen) + UpdateCrc((*p)[1:]) + return nil +} + +// HasDescriptor checks if a descriptor of the given tag exists in a PSI. If the descriptor +// of the given tag exists, an index of this descriptor, as well as the Descriptor is returned. +// If the descriptor of the given tag cannot be found, -1 and a nil slice is returned. +// +// TODO: check if pmt, return error if not ? +func (p *PSIBytes) HasDescriptor(tag int) (int, Descriptor) { + descs := p.descriptors() + if descs == nil { + return -1, nil + } + for i := 0; i < len(descs); i += 2 + int(descs[i+1]) { + if int(descs[i]) == tag { + return i + DescriptorsIdx, descs[i : i+2+int(descs[i+1])] + } + } + return -1, nil +} + +// createDescriptor creates a descriptor in a psi given a tag and data. It does so +// by resizing the psi, shifting existing data down and copying in new descriptor +// in new space. +func (p *PSIBytes) createDescriptor(tag int, data []byte) error { + curProgLen := p.ProgramInfoLen() + oldSyntaxSectionLen := SyntaxSecLenFrom(*p) + if TotalSyntaxSecLen-(oldSyntaxSectionLen+2+len(data)) <= 0 { + return errors.New("Not enough space in psi to create descriptor.") + } + dataLen := len(data) + newDescIdx := DescriptorsIdx + curProgLen + newDescLen := dataLen + 2 + + // Increase size of psi and copy data down to make room for new descriptor. + tmp := make([]byte, len(*p)+newDescLen) + copy(tmp, *p) + *p = tmp + copy((*p)[newDescIdx+newDescLen:], (*p)[newDescIdx:newDescIdx+newDescLen]) + // Set the tag, data len and data of the new desriptor. + (*p)[newDescIdx] = byte(tag) + (*p)[newDescIdx+1] = byte(dataLen) + copy((*p)[newDescIdx+2:newDescIdx+2+dataLen], data) + + // Set length fields and update the psi crc. + addedLen := dataLen + 2 + newProgInfoLen := curProgLen + addedLen + p.setProgInfoLen(newProgInfoLen) + newSyntaxSectionLen := int(oldSyntaxSectionLen) + addedLen + p.setSectionLen(newSyntaxSectionLen) + UpdateCrc((*p)[1:]) + + return nil +} + +// setProgInfoLen sets the program information length in a psi with a pmt. +func (p *PSIBytes) setProgInfoLen(l int) { + (*p)[ProgramInfoLenIdx1] &= 0xff ^ ProgramInfoLenMask1 + (*p)[ProgramInfoLenIdx1] |= byte(l>>8) & ProgramInfoLenMask1 + (*p)[ProgramInfoLenIdx2] = byte(l) +} + +// setSectionLen sets section length in a psi. +func (p *PSIBytes) setSectionLen(l int) { + (*p)[SyntaxSecLenIdx1] &= 0xff ^ SyntaxSecLenMask1 + (*p)[SyntaxSecLenIdx1] |= byte(l>>8) & SyntaxSecLenMask1 + (*p)[SyntaxSecLenIdx2] = byte(l) +} + +// descriptors returns the descriptors in a psi if they exist, otherwise +// a nil slice is returned. +func (p *PSIBytes) descriptors() []byte { + return (*p)[DescriptorsIdx : DescriptorsIdx+p.ProgramInfoLen()] +} + +// len returns the length of a descriptor in bytes. +func (d *Descriptor) len() int { + return int(2 + (*d)[1]) +} + +// ProgramInfoLen returns the program info length of a PSI. +// +// TODO: check if pmt - if not return 0 ? or -1 ? +func (p *PSIBytes) ProgramInfoLen() int { + return int((((*p)[ProgramInfoLenIdx1] & ProgramInfoLenMask1) << 8) | (*p)[ProgramInfoLenIdx2]) +} diff --git a/stream/mts/psi/psi_test.go b/stream/mts/psi/psi_test.go index e437066e..7e3a3104 100644 --- a/stream/mts/psi/psi_test.go +++ b/stream/mts/psi/psi_test.go @@ -282,7 +282,7 @@ var bytesTests = []struct { func TestBytes(t *testing.T) { for _, test := range bytesTests { got := test.input.Bytes() - if !bytes.Equal(got, addPadding(addCrc(test.want))) { + if !bytes.Equal(got, AddCrc(test.want)) { t.Errorf("unexpected error for test %v: got:%v want:%v", test.name, got, test.want) } @@ -301,7 +301,7 @@ func TestTimestampToBytes(t *testing.T) { func TestTimeUpdate(t *testing.T) { cpy := make([]byte, len(pmtTimeBytes1)) copy(cpy, pmtTimeBytes1) - cpy = addCrc(cpy) + cpy = AddCrc(cpy) err := UpdateTime(cpy, tstTime2) cpy = cpy[:len(cpy)-4] if err != nil { @@ -343,7 +343,7 @@ func TestLocationGet(t *testing.T) { func TestLocationUpdate(t *testing.T) { cpy := make([]byte, len(pmtWithMetaTst1)) copy(cpy, pmtWithMetaTst1) - cpy = addCrc(cpy) + cpy = AddCrc(cpy) err := UpdateLocation(cpy, locationTstStr2) cpy = cpy[:len(cpy)-4] if err != nil { diff --git a/stream/rtp/encoder.go b/stream/rtp/encoder.go index 20df9434..26016a23 100644 --- a/stream/rtp/encoder.go +++ b/stream/rtp/encoder.go @@ -40,7 +40,7 @@ const ( timestampFreq = 90000 // Hz mtsSize = 188 bufferSize = 1000 - sendLen = 7 * 188 + sendSize = 7 * 188 ) // Encoder implements io writer and provides functionality to wrap data into @@ -51,6 +51,7 @@ type Encoder struct { seqNo uint16 clock time.Duration frameInterval time.Duration + lastTime time.Time fps int buffer []byte pktSpace [defPktSize]byte @@ -72,13 +73,29 @@ func NewEncoder(dst io.Writer, fps int) *Encoder { // so that multiple layers of packetization can occur. func (e *Encoder) Write(data []byte) (int, error) { e.buffer = append(e.buffer, data...) - for len(e.buffer) >= sendLen { - e.Encode(e.buffer[:sendLen]) - e.buffer = e.buffer[sendLen:] + if len(e.buffer) < sendSize { + return len(data), nil } + buf := e.buffer + for len(buf) != 0 { + l := min(sendSize, len(buf)) + err := e.Encode(buf[:l]) + if err != nil { + return len(data), err + } + buf = buf[l:] + } + e.buffer = e.buffer[:0] return len(data), nil } +func min(a, b int) int { + if a < b { + return a + } + return b +} + // Encode takes a nalu unit and encodes it into an rtp packet and // writes to the io.Writer given in NewEncoder func (e *Encoder) Encode(payload []byte) error { diff --git a/stream/rtp/rtp.go b/stream/rtp/rtp.go index 47f4a91b..92192294 100644 --- a/stream/rtp/rtp.go +++ b/stream/rtp/rtp.go @@ -35,7 +35,7 @@ package rtp const ( rtpVer = 2 headSize = 3 * 4 // Header size of an rtp packet. - defPayloadSize = sendLen // Default payload size for the rtp packet. + defPayloadSize = sendSize // Default payload size for the rtp packet. defPktSize = headSize + defPayloadSize // Default packet size is header size + payload size. )