/*
NAME
  extract.go

DESCRIPTION
  extract.go provides an extracter to get access units from an RTP stream.

AUTHOR
  Saxon Nelson-Milton <saxon@ausocean.org>

LICENSE
  Copyright (C) 2017 the Australian Ocean Lab (AusOcean)

  It is free software: you can redistribute it and/or modify them
  under the terms of the GNU General Public License as published by the
  Free Software Foundation, either version 3 of the License, or (at your
  option) any later version.

  It is distributed in the hope that it will be useful, but WITHOUT
  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
  for more details.

  You should have received a copy of the GNU General Public License
  along with revid in gpl.txt. If not, see http://www.gnu.org/licenses.
*/
package h264

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"io"
	"time"

	"bitbucket.org/ausocean/av/protocol/rtp"
)

// NAL types (from https://tools.ietf.org/html/rfc6184#page-13)
const (
	// Single nal units bounds.
	typeSingleNALULowBound  = 1
	typeSingleNALUHighBound = 23

	// Single-time aggregation packets.
	typeSTAPA = 24
	typeSTAPB = 25

	// Multi-time aggregation packets.
	typeMTAP16 = 26
	typeMTAP24 = 27

	// Fragmentation packets.
	typeFUA = 28
	typeFUB = 29
)

// Min NAL lengths.
const (
	minSingleNALLen = 1
	minSTAPALen     = 4
	minFUALen       = 2
)

// Buffer sizes.
const (
	maxAUSize  = 100000 // Max access unit size in bytes.
	maxRTPSize = 1500   // Max ethernet transmission unit in bytes.
)

// Extracter is an extracter for extracting H264 access units from RTP stream.
type Extracter struct {
	buf  *bytes.Buffer // Holds the current access unit.
	frag bool          // Indicates if we're currently dealing with a fragmentation packet.
}

// NewExtracter returns a new Extracter.
func NewExtracter() *Extracter {
	return &Extracter{
		buf: bytes.NewBuffer(make([]byte, 0, maxAUSize))}
}

// Extract extracts H264 access units from an RTP stream. This function
// expects that each read from src will provide a single RTP packet.
func (e *Extracter) Extract(dst io.Writer, src io.Reader, delay time.Duration) error {
	buf := make([]byte, maxRTPSize)
	for {
		n, err := src.Read(buf)
		switch err {
		case nil: // Do nothing.
		case io.EOF:
			return nil
		default:
			return fmt.Errorf("source read error: %v\n", err)
		}

		// Get payload from RTP packet.
		payload, err := rtp.Payload(buf[:n])
		if err != nil {
			return fmt.Errorf("could not get RTP payload, failed with err: %v\n", err)
		}

		nalType := payload[0] & 0x1f

		// If not currently fragmented then we ignore current write.
		if e.frag && nalType != typeFUA {
			e.buf.Reset()
			e.frag = false
			continue
		}

		if typeSingleNALULowBound <= nalType && nalType <= typeSingleNALUHighBound {
			// If len too small, ignore.
			if len(payload) < minSingleNALLen {
				continue
			}
			e.writeWithPrefix(payload)
		} else {
			switch nalType {
			case typeSTAPA:
				e.handleSTAPA(payload)
			case typeFUA:
				e.handleFUA(payload)
			case typeSTAPB:
				panic("STAP-B type unsupported")
			case typeMTAP16:
				panic("MTAP16 type unsupported")
			case typeMTAP24:
				panic("MTAP24 type unsupported")
			case typeFUB:
				panic("FU-B type unsupported")
			default:
				panic("unsupported type")
			}
		}

		markerIsSet, err := rtp.Marker(buf[:n])
		if err != nil {
			return fmt.Errorf("could not get marker bit, failed with err: %v\n", err)
		}

		if markerIsSet {
			e.buf.WriteTo(dst)
			e.buf.Reset()
		}
	}
	return nil
}

// handleSTAPA parses NAL units from an aggregation packet and writes
// them to the Extracter's buffer buf.
func (e *Extracter) handleSTAPA(d []byte) {
	// If the length is too small, ignore.
	if len(d) < minSTAPALen {
		return
	}

	for i := 1; i < len(d); {
		size := int(binary.BigEndian.Uint16(d[i:]))

		// Skip over NAL unit size.
		const sizeOfFieldLen = 2
		i += sizeOfFieldLen

		// Get the NALU.
		nalu := d[i : i+size]
		i += size
		e.writeWithPrefix(nalu)
	}
}

// handleFUA parses NAL units from fragmentation packets and writes
// them to the Extracter's buf.
func (e *Extracter) handleFUA(d []byte) {
	// If length is too small, ignore.
	if len(d) < minFUALen {
		return
	}

	// Get start and end indiciators from FU header.
	const FUHeadIdx = 1
	start := d[FUHeadIdx]&0x80 != 0
	end := d[FUHeadIdx]&0x40 != 0

	// If start, form new header, skip FU indicator only and set first byte to
	// new header. Otherwise, skip over both FU indicator and FU header.
	if start {
		newHead := (d[0] & 0xe0) | (d[1] & 0x1f)
		d = d[1:]
		d[0] = newHead
		if end {
			panic("bad fragmentation packet")
		}
		e.frag = true
		e.writeWithPrefix(d)
	} else {
		d = d[2:]
		if end {
			e.frag = false
		}
		e.writeNoPrefix(d)
	}
}

// write writes a NAL unit to the Extracter's buf in byte stream format using the
// start code.
func (e *Extracter) writeWithPrefix(d []byte) {
	const prefix = "\x00\x00\x00\x01"
	e.buf.Write([]byte(prefix))
	e.buf.Write(d)
}

// writeNoPrefix writes data to the Extracter's buf. This is used for non start
// fragmentations of a NALU.
func (e *Extracter) writeNoPrefix(d []byte) {
	e.buf.Write(d)
}