diff --git a/stream/mts/discontinuity.go b/stream/mts/discontinuity.go index d1c6165a..0c6f6071 100644 --- a/stream/mts/discontinuity.go +++ b/stream/mts/discontinuity.go @@ -34,18 +34,24 @@ import "github.com/Comcast/gots/packet" // discontinuityRepairer provides function to detect discontinuities in mpegts // and set the discontinuity indicator as appropriate. type DiscontinuityRepairer struct { - expCC uint8 + expCC map[int]int } // NewDiscontinuityRepairer returns a pointer to a new discontinuityRepairer. func NewDiscontinuityRepairer() *DiscontinuityRepairer { - return &DiscontinuityRepairer{expCC: 16} + return &DiscontinuityRepairer{ + expCC: map[int]int{ + PatPid: 16, + PmtPid: 16, + VideoPid: 16, + }, + } } // Failed is to be called in the case of a failed send. This will decrement the // expectedCC so that it aligns with the failed chunks cc. func (dr *DiscontinuityRepairer) Failed() { - dr.decExpectedCC() + dr.decExpectedCC(PatPid) } // Repair takes a clip of mpegts and checks that the first packet, which should @@ -55,14 +61,15 @@ func (dr *DiscontinuityRepairer) Repair(d []byte) error { var pkt [PacketSize]byte copy(pkt[:], d[:PacketSize]) p := (*packet.Packet)(&pkt) - if p.PID() != PatPid { + pid := p.PID() + if pid != PatPid { panic("Clip to repair must have PAT first") } cc := p.ContinuityCounter() - expect, exists := dr.expectedCC(3) - dr.incExpectedCC() + expect, exists := dr.expectedCC(pid) + dr.incExpectedCC(pid) if !exists { - dr.setExpectedCC(uint8(cc)) + dr.setExpectedCC(pid, cc) } else if cc != int(expect) { if packet.ContainsAdaptationField(p) { (*packet.AdaptationField)(p).SetDiscontinuity(true) @@ -72,7 +79,7 @@ func (dr *DiscontinuityRepairer) Repair(d []byte) error { return err } } - dr.setExpectedCC(uint8(cc)) + dr.setExpectedCC(pid, cc) copy(d[:PacketSize], pkt[:]) } return nil @@ -80,24 +87,24 @@ func (dr *DiscontinuityRepairer) Repair(d []byte) error { // expectedCC returns the expected cc. If the cc hasn't been used yet, then 16 // and false is returned. -func (dr *DiscontinuityRepairer) expectedCC(pid int) (byte, bool) { - if dr.expCC == 16 { +func (dr *DiscontinuityRepairer) expectedCC(pid int) (int, bool) { + if dr.expCC[pid] == 16 { return 16, false } - return dr.expCC, true + return dr.expCC[pid], true } // incExpectedCC increments the expected cc. -func (dr *DiscontinuityRepairer) incExpectedCC() { - dr.expCC = (dr.expCC + 1) & 0xf +func (dr *DiscontinuityRepairer) incExpectedCC(pid int) { + dr.expCC[pid] = (dr.expCC[pid] + 1) & 0xf } // decExpectedCC decrements the expected cc. -func (dr *DiscontinuityRepairer) decExpectedCC() { - dr.expCC = (dr.expCC - 1) & 0xf +func (dr *DiscontinuityRepairer) decExpectedCC(pid int) { + dr.expCC[pid] = (dr.expCC[pid] - 1) & 0xf } // setExpectedCC sets the expected cc. -func (dr *DiscontinuityRepairer) setExpectedCC(cc uint8) { - dr.expCC = cc +func (dr *DiscontinuityRepairer) setExpectedCC(pid, cc int) { + dr.expCC[pid] = cc }