/*
NAME
  discontinuity.go

DESCRIPTION
  discontinuity.go provides functionality for detecting discontinuities in
	mpegts and accounting for using the discontinuity indicator in the adaptation
	field.

AUTHOR
  Saxon A. Nelson-Milton <saxon@ausocean.org>

LICENSE
  discontinuity.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean)

  It is free software: you can redistribute it and/or modify them
  under the terms of the GNU General Public License as published by the
  Free Software Foundation, either version 3 of the License, or (at your
  option) any later version.

  It is distributed in the hope that it will be useful, but WITHOUT
  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
 for more details.

  You should have received a copy of the GNU General Public License
  in gpl.txt.  If not, see [GNU licenses](http://www.gnu.org/licenses).
*/

package mts

import (
	"github.com/Comcast/gots/packet"
)

// discontinuityRepairer provides function to detect discontinuities in mpegts
// and set the discontinuity indicator as appropriate.
type DiscontinuityRepairer struct {
	expCC map[int]int
}

// NewDiscontinuityRepairer returns a pointer to a new discontinuityRepairer.
func NewDiscontinuityRepairer() *DiscontinuityRepairer {
	return &DiscontinuityRepairer{
		expCC: map[int]int{
			PatPid:   16,
			PmtPid:   16,
			VideoPid: 16,
		},
	}
}

// Failed is to be called in the case of a failed send. This will decrement the
// expectedCC so that it aligns with the failed chunks cc.
func (dr *DiscontinuityRepairer) Failed() {
	dr.decExpectedCC(PatPid)
}

// Repair takes a clip of mpegts and checks that the first packet, which should
// be a PAT, contains a cc that is expected, otherwise the discontinuity indicator
// is set to true.
func (dr *DiscontinuityRepairer) Repair(d []byte) error {
	var pkt packet.Packet
	copy(pkt[:], d[:PacketSize])
	pid := pkt.PID()
	if pid != PatPid {
		panic("Clip to repair must have PAT first")
	}
	cc := pkt.ContinuityCounter()
	expect, _ := dr.ExpectedCC(pid)
	if cc != int(expect) {
		if packet.ContainsAdaptationField(&pkt) {
			(*packet.AdaptationField)(&pkt).SetDiscontinuity(true)
		} else {
			err := addAdaptationField(&pkt, DiscontinuityIndicator(true))
			if err != nil {
				return err
			}
		}
		dr.SetExpectedCC(pid, cc)
		copy(d[:PacketSize], pkt[:])
	}
	dr.IncExpectedCC(pid)
	return nil
}

// expectedCC returns the expected cc. If the cc hasn't been used yet, then 16
// and false is returned.
func (dr *DiscontinuityRepairer) ExpectedCC(pid int) (int, bool) {
	if dr.expCC[pid] == 16 {
		return 16, false
	}
	return dr.expCC[pid], true
}

// incExpectedCC increments the expected cc.
func (dr *DiscontinuityRepairer) IncExpectedCC(pid int) {
	dr.expCC[pid] = (dr.expCC[pid] + 1) & 0xf
}

// decExpectedCC decrements the expected cc.
func (dr *DiscontinuityRepairer) decExpectedCC(pid int) {
	dr.expCC[pid] = (dr.expCC[pid] - 1) & 0xf
}

// setExpectedCC sets the expected cc.
func (dr *DiscontinuityRepairer) SetExpectedCC(pid, cc int) {
	dr.expCC[pid] = cc
}