/*
NAME
  mpegts.go - provides a data structure intended to encapsulate the properties
  of an MpegTs packet and also functions to allow manipulation of these packets.

DESCRIPTION
  See Readme.md

AUTHOR
  Saxon A. Nelson-Milton <saxon.milton@gmail.com>

LICENSE
  mpegts.go is Copyright (C) 2017 the Australian Ocean Lab (AusOcean)

  It is free software: you can redistribute it and/or modify them
  under the terms of the GNU General Public License as published by the
  Free Software Foundation, either version 3 of the License, or (at your
  option) any later version.

  It is distributed in the hope that it will be useful, but WITHOUT
  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
 for more details.

  You should have received a copy of the GNU General Public License
  along with revid in gpl.txt.  If not, see [GNU licenses](http://www.gnu.org/licenses).
*/

package mts

import (
	"errors"
	"fmt"

	"github.com/Comcast/gots/packet"
)

// General mpegts packet properties.
const (
	PacketSize  = 188
	PayloadSize = 176
)

// Program ID for various types of ts packets.
const (
	SdtPid   = 17
	PatPid   = 0
	PmtPid   = 4096
	VideoPid = 256
)

// StreamID is the id of the first stream.
const StreamID = 0xe0

// HeadSize is the size of an mpegts packet header.
const HeadSize = 4

// Consts relating to adaptation field.
const (
	AdaptationIdx              = 4                 // Index to the adaptation field (index of AFL).
	AdaptationControlIdx       = 3                 // Index to octet with adaptation field control.
	AdaptationFieldsIdx        = AdaptationIdx + 1 // Adaptation field index is the index of the adaptation fields.
	DefaultAdaptationSize      = 2                 // Default size of the adaptation field.
	AdaptationControlMask      = 0x30              // Mask for the adaptation field control in octet 3.
	DefaultAdaptationBodySize  = 1                 // Default size of the adaptation field body.
	DiscontinuityIndicatorMask = 0x80              // Mask for the discontinuity indicator at the discontinuity indicator idk.
	DiscontinuityIndicatorIdx  = AdaptationIdx + 1 // The index at which the discontinuity indicator is found in an MTS packet.
)

// TODO: make this better - currently doesn't make sense.
const (
	HasPayload         = 0x1
	HasAdaptationField = 0x2
)

/*
The below data struct encapsulates the fields of an MPEG-TS packet. Below is
the formatting of an MPEG-TS packet for reference!

													MPEG-TS Packet Formatting
============================================================================
| octet no | bit 0 | bit 1 | bit 2 | bit 3 | bit 4 | bit 5 | bit 6 | bit 7 |
============================================================================
| octet 0  | sync byte (0x47)                                              |
----------------------------------------------------------------------------
| octet 1  | TEI   | PUSI  | Prior | PID                                   |
----------------------------------------------------------------------------
| octet 2  | PID cont.                                                     |
----------------------------------------------------------------------------
| octet 3  | TSC           | AFC           | CC                            |
----------------------------------------------------------------------------
| octet 4  | AFL                                                           |
----------------------------------------------------------------------------
| octet 5  | DI    | RAI   | ESPI  | PCRF  | OPCRF | SPF   | TPDF  | AFEF  |
----------------------------------------------------------------------------
| optional | PCR (48 bits => 6 bytes)                                      |
----------------------------------------------------------------------------
| -        | PCR cont.                                                     |
----------------------------------------------------------------------------
| -        | PCR cont.                                                     |
----------------------------------------------------------------------------
| -        | PCR cont.                                                     |
----------------------------------------------------------------------------
| -        | PCR cont.                                                     |
----------------------------------------------------------------------------
| -        | PCR cont.                                                     |
----------------------------------------------------------------------------
| optional | OPCR (48 bits => 6 bytes)                                     |
----------------------------------------------------------------------------
| -        | OPCR cont.                                                    |
----------------------------------------------------------------------------
| -        | OPCR cont.                                                    |
----------------------------------------------------------------------------
| -        | OPCR cont.                                                    |
----------------------------------------------------------------------------
| -        | OPCR cont.                                                    |
----------------------------------------------------------------------------
| -        | OPCR cont.                                                    |
----------------------------------------------------------------------------
| optional | SC                                                            |
----------------------------------------------------------------------------
| optional | TPDL                                                          |
----------------------------------------------------------------------------
| optional | TPD (variable length)                                         |
----------------------------------------------------------------------------
| -        | ...                                                           |
----------------------------------------------------------------------------
| optional | Extension (variable length)                                   |
----------------------------------------------------------------------------
| -        | ...                                                           |
----------------------------------------------------------------------------
| optional | Stuffing (variable length)                                    |
----------------------------------------------------------------------------
| -        | ...                                                           |
----------------------------------------------------------------------------
| optional | Payload (variable length)                                     |
----------------------------------------------------------------------------
| -        | ...                                                           |
----------------------------------------------------------------------------
*/
type Packet struct {
	TEI      bool   // Transport Error Indicator
	PUSI     bool   // Payload Unit Start Indicator
	Priority bool   // Tranposrt priority indicator
	PID      uint16 // Packet identifier
	TSC      byte   // Transport Scrambling Control
	AFC      byte   // Adaption Field Control
	CC       byte   // Continuity Counter
	DI       bool   // Discontinouty indicator
	RAI      bool   // random access indicator
	ESPI     bool   // Elementary stream priority indicator
	PCRF     bool   // PCR flag
	OPCRF    bool   // OPCR flag
	SPF      bool   // Splicing point flag
	TPDF     bool   // Transport private data flag
	AFEF     bool   // Adaptation field extension flag
	PCR      uint64 // Program clock reference
	OPCR     uint64 // Original program clock reference
	SC       byte   // Splice countdown
	TPDL     byte   // Tranposrt private data length
	TPD      []byte // Private data
	Ext      []byte // Adaptation field extension
	Payload  []byte // Mpeg ts Payload
}

// FindPmt will take a clip of mpegts and try to find a PMT table - if one
// is found, then it is returned along with its index, otherwise nil, -1 and an error is returned.
func FindPmt(d []byte) ([]byte, int, error) {
	return FindPid(d, PmtPid)
}

// FindPat will take a clip of mpegts and try to find a PAT table - if one
// is found, then it is returned along with its index, otherwise nil, -1 and an error is returned.
func FindPat(d []byte) ([]byte, int, error) {
	return FindPid(d, PatPid)
}

// FindPid will take a clip of mpegts and try to find a packet with given PID - if one
// is found, then it is returned along with its index, otherwise nil, -1 and an error is returned.
func FindPid(d []byte, pid uint16) (pkt []byte, i int, err error) {
	if len(d) < PacketSize {
		return nil, -1, errors.New("Mmpegts data not of valid length")
	}
	for i = 0; i < len(d); i += PacketSize {
		p := (uint16(d[i+1]&0x1f) << 8) | uint16(d[i+2])
		if p == pid {
			pkt = d[i+4 : i+PacketSize]
			return
		}
	}
	return nil, -1, fmt.Errorf("could not find packet with pid: %d", pid)
}

// FillPayload takes a channel and fills the packets Payload field until the
// channel is empty or we've the packet reaches capacity
func (p *Packet) FillPayload(data []byte) int {
	currentPktLen := 6 + asInt(p.PCRF)*6 + asInt(p.OPCRF)*6 +
		asInt(p.SPF)*1 + asInt(p.TPDF)*1 + len(p.TPD)
	p.Payload = make([]byte, PayloadSize-currentPktLen)
	return copy(p.Payload, data)
}

func asInt(b bool) int {
	if b {
		return 1
	}
	return 0
}

func asByte(b bool) byte {
	if b {
		return 1
	}
	return 0
}

// ToByteSlice interprets the fields of the ts packet instance and outputs a
// corresponding byte slice
func (p *Packet) Bytes(buf []byte) []byte {
	if buf == nil || cap(buf) != PacketSize {
		buf = make([]byte, 0, PacketSize)
	}
	buf = buf[:0]
	stuffingLength := 182 - len(p.Payload) - len(p.TPD) - asInt(p.PCRF)*6 -
		asInt(p.OPCRF)*6 - asInt(p.SPF)
	var stuffing []byte
	if stuffingLength > 0 {
		stuffing = make([]byte, stuffingLength)
	}
	for i := range stuffing {
		stuffing[i] = 0xFF
	}
	afl := 1 + asInt(p.PCRF)*6 + asInt(p.OPCRF)*6 + asInt(p.SPF) + asInt(p.TPDF) + len(p.TPD) + len(stuffing)
	buf = append(buf, []byte{
		0x47,
		(asByte(p.TEI)<<7 | asByte(p.PUSI)<<6 | asByte(p.Priority)<<5 | byte((p.PID&0xFF00)>>8)),
		byte(p.PID & 0x00FF),
		(p.TSC<<6 | p.AFC<<4 | p.CC),
	}...)

	if p.AFC == 3 || p.AFC == 2 {
		buf = append(buf, []byte{
			byte(afl), (asByte(p.DI)<<7 | asByte(p.RAI)<<6 | asByte(p.ESPI)<<5 |
				asByte(p.PCRF)<<4 | asByte(p.OPCRF)<<3 | asByte(p.SPF)<<2 |
				asByte(p.TPDF)<<1 | asByte(p.AFEF)),
		}...)
		for i := 40; p.PCRF && i >= 0; i -= 8 {
			buf = append(buf, byte((p.PCR<<15)>>uint(i)))
		}
		for i := 40; p.OPCRF && i >= 0; i -= 8 {
			buf = append(buf, byte(p.OPCR>>uint(i)))
		}
		if p.SPF {
			buf = append(buf, p.SC)
		}
		if p.TPDF {
			buf = append(buf, append([]byte{p.TPDL}, p.TPD...)...)
		}
		buf = append(buf, p.Ext...)
		buf = append(buf, stuffing...)
	}
	buf = append(buf, p.Payload...)
	return buf
}

type Option func(p *packet.Packet)

// addAdaptationField adds an adaptation field to p, and applys the passed options to this field.
// TODO: this will probably break if we already have adaptation field.
func addAdaptationField(p *packet.Packet, options ...Option) error {
	if packet.ContainsAdaptationField((*packet.Packet)(p)) {
		return errors.New("Adaptation field is already present in packet")
	}
	// Create space for adaptation field.
	copy(p[HeadSize+DefaultAdaptationSize:], p[HeadSize:len(p)-DefaultAdaptationSize])

	// TODO: seperate into own function
	// Update adaptation field control.
	p[AdaptationControlIdx] &= 0xff ^ AdaptationControlMask
	p[AdaptationControlIdx] |= AdaptationControlMask
	// Default the adaptationfield.
	resetAdaptation(p)

	// Apply and options that have bee passed.
	for _, option := range options {
		option(p)
	}
	return nil
}

// resetAdaptation sets fields in ps adaptation field to 0 if the adaptation field
// exists, otherwise an error is returned.
func resetAdaptation(p *packet.Packet) error {
	if !packet.ContainsAdaptationField((*packet.Packet)(p)) {
		return errors.New("No adaptation field in this packet")
	}
	p[AdaptationIdx] = DefaultAdaptationBodySize
	p[AdaptationIdx+1] = 0x00
	return nil
}

// DiscontinuityIndicator returns and Option that will set p's discontinuity
// indicator according to f.
func DiscontinuityIndicator(f bool) Option {
	return func(p *packet.Packet) {
		set := byte(DiscontinuityIndicatorMask)
		if !f {
			set = 0x00
		}
		p[DiscontinuityIndicatorIdx] &= 0xff ^ DiscontinuityIndicatorMask
		p[DiscontinuityIndicatorIdx] |= DiscontinuityIndicatorMask & set
	}
}