/*
NAME
  mtsSender_test.go

DESCRIPTION
  mtsSender_test.go contains tests that validate the functionalilty of the
  mtsSender under senders.go. Tests include checks that the mtsSender is
  segmenting sends correctly, and also that it can correct discontinuities.

AUTHORS
  Saxon A. Nelson-Milton <saxon@ausocean.org>

LICENSE
  mtsSender_test.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean)

  It is free software: you can redistribute it and/or modify them
  under the terms of the GNU General Public License as published by the
  Free Software Foundation, either version 3 of the License, or (at your
  option) any later version.

  It is distributed in the hope that it will be useful, but WITHOUT
  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
 for more details.

  You should have received a copy of the GNU General Public License
  in gpl.txt.  If not, see http://www.gnu.org/licenses.
*/
package revid

import (
	"errors"
	"testing"
	"time"

	"github.com/Comcast/gots/packet"
	"github.com/Comcast/gots/pes"

	"bitbucket.org/ausocean/av/container/mts"
	"bitbucket.org/ausocean/av/container/mts/meta"
	"bitbucket.org/ausocean/utils/logger"
)

// Ring buffer sizes and read/write timeouts.
const (
	rbSize        = 100
	rbElementSize = 150000
	wTimeout      = 10 * time.Millisecond
	rTimeout      = 10 * time.Millisecond
)

var (
	errSendFailed = errors.New("send failed")
)

// destination simulates a destination for the mtsSender. It allows for the
// emulation of failed and delayed sends.
type destination struct {
	buf        [][]byte
	testFails  bool
	failAt     int
	currentPkt int
	t          *testing.T
	sendDelay  time.Duration
	delayAt    int
}

func (ts *destination) Write(d []byte) (int, error) {
	ts.t.Log("writing clip to destination")
	if ts.delayAt != 0 && ts.currentPkt == ts.delayAt {
		time.Sleep(ts.sendDelay)
	}
	if ts.testFails && ts.currentPkt == ts.failAt {
		ts.t.Log("failed send")
		ts.currentPkt++
		return 0, errSendFailed
	}
	cpy := make([]byte, len(d))
	copy(cpy, d)
	ts.buf = append(ts.buf, cpy)
	ts.currentPkt++
	return len(d), nil
}

func (ts *destination) Close() error { return nil }

// dummyLogger will allow logging to be done by the testing pkg.
type dummyLogger testing.T

func (dl *dummyLogger) log(lvl int8, msg string, args ...interface{}) {
	var l string
	switch lvl {
	case logger.Warning:
		l = "warning"
	case logger.Debug:
		l = "debug"
	case logger.Info:
		l = "info"
	case logger.Error:
		l = "error"
	case logger.Fatal:
		l = "fatal"
	}
	msg = l + ": " + msg
	for i := 0; i < len(args); i++ {
		msg += " %v"
	}
	if len(args) == 0 {
		dl.Log(msg + "\n")
		return
	}
	dl.Logf(msg+"\n", args)
}

// TestSegment ensures that the mtsSender correctly segments data into clips
// based on positioning of PSI in the mtsEncoder's output stream.
func TestMtsSenderSegment(t *testing.T) {
	mts.Meta = meta.New()

	// Create ringBuffer, sender, sender and the MPEGTS encoder.
	tstDst := &destination{t: t}
	sender := newMtsSender(tstDst, (*dummyLogger)(t).log, ringBufferSize, ringBufferElementSize, writeTimeout)
	encoder := mts.NewEncoder(sender, 25, mts.Video)

	// Turn time based PSI writing off for encoder.
	const psiSendCount = 10
	encoder.TimeBasedPsi(false, psiSendCount)

	// Write the packets to the encoder, which will in turn write to the mtsSender.
	// Payload will just be packet number.
	t.Log("writing packets")
	const noOfPacketsToWrite = 100
	for i := 0; i < noOfPacketsToWrite; i++ {
		encoder.Write([]byte{byte(i)})
	}

	// Give the mtsSender some time to finish up and then Close it.
	time.Sleep(10 * time.Millisecond)
	sender.Close()

	// Check the data.
	result := tstDst.buf
	expectData := 0
	for clipNo, clip := range result {
		t.Logf("Checking clip: %v\n", clipNo)

		// Check that the clip is of expected length.
		clipLen := len(clip)
		if clipLen != psiSendCount*mts.PacketSize {
			t.Fatalf("Clip %v is not correct length. Got: %v Want: %v\n Clip: %v\n", clipNo, clipLen, psiSendCount*mts.PacketSize, clip)
		}

		// Also check that the first packet is a PAT.
		firstPkt := clip[:mts.PacketSize]
		var pkt packet.Packet
		copy(pkt[:], firstPkt)
		pid := pkt.PID()
		if pid != mts.PatPid {
			t.Fatalf("First packet of clip %v is not pat, but rather: %v\n", clipNo, pid)
		}

		// Check that the clip data is okay.
		t.Log("checking clip data")
		for i := 0; i < len(clip); i += mts.PacketSize {
			copy(pkt[:], clip[i:i+mts.PacketSize])
			if pkt.PID() == mts.VideoPid {
				t.Log("got video PID")
				payload, err := pkt.Payload()
				if err != nil {
					t.Fatalf("Unexpected err: %v\n", err)
				}

				// Parse PES from the MTS payload.
				pes, err := pes.NewPESHeader(payload)
				if err != nil {
					t.Fatalf("Unexpected err: %v\n", err)
				}

				// Get the data from the PES packet and convert to an int.
				data := int8(pes.Data()[0])

				// Calc expected data in the PES and then check.
				if data != int8(expectData) {
					t.Errorf("Did not get expected pkt data. ClipNo: %v, pktNoInClip: %v, Got: %v, want: %v\n", clipNo, i/mts.PacketSize, data, expectData)
				}
				expectData++
			}
		}
	}
}

// TestMtsSenderFailedSend checks that a failed send is correctly handled by
// the mtsSender. The mtsSender should try to send the same clip again.
func TestMtsSenderFailedSend(t *testing.T) {
	mts.Meta = meta.New()

	// Create destination, the mtsSender and the mtsEncoder
	const clipToFailAt = 3
	tstDst := &destination{t: t, testFails: true, failAt: clipToFailAt}
	sender := newMtsSender(tstDst, (*dummyLogger)(t).log, ringBufferSize, ringBufferElementSize, writeTimeout)
	encoder := mts.NewEncoder(sender, 25, mts.Video)

	// Turn time based PSI writing off for encoder and send PSI every 10 packets.
	const psiSendCount = 10
	encoder.TimeBasedPsi(false, psiSendCount)

	// Write the packets to the encoder, which will in turn write to the mtsSender.
	// Payload will just be packet number.
	t.Log("writing packets")
	const noOfPacketsToWrite = 100
	for i := 0; i < noOfPacketsToWrite; i++ {
		encoder.Write([]byte{byte(i)})
	}

	// Give the mtsSender some time to finish up and then Close it.
	time.Sleep(10 * time.Millisecond)
	sender.Close()

	// Check that we have data as expected.
	result := tstDst.buf
	expectData := 0
	for clipNo, clip := range result {
		t.Logf("Checking clip: %v\n", clipNo)

		// Check that the clip is of expected length.
		clipLen := len(clip)
		if clipLen != psiSendCount*mts.PacketSize {
			t.Fatalf("Clip %v is not correct length. Got: %v Want: %v\n Clip: %v\n", clipNo, clipLen, psiSendCount*mts.PacketSize, clip)
		}

		// Also check that the first packet is a PAT.
		firstPkt := clip[:mts.PacketSize]
		var pkt packet.Packet
		copy(pkt[:], firstPkt)
		pid := pkt.PID()
		if pid != mts.PatPid {
			t.Fatalf("First packet of clip %v is not pat, but rather: %v\n", clipNo, pid)
		}

		// Check that the clip data is okay.
		t.Log("checking clip data")
		for i := 0; i < len(clip); i += mts.PacketSize {
			copy(pkt[:], clip[i:i+mts.PacketSize])
			if pkt.PID() == mts.VideoPid {
				t.Log("got video PID")
				payload, err := pkt.Payload()
				if err != nil {
					t.Fatalf("Unexpected err: %v\n", err)
				}

				// Parse PES from the MTS payload.
				pes, err := pes.NewPESHeader(payload)
				if err != nil {
					t.Fatalf("Unexpected err: %v\n", err)
				}

				// Get the data from the PES packet and convert to an int.
				data := int8(pes.Data()[0])

				// Calc expected data in the PES and then check.
				if data != int8(expectData) {
					t.Errorf("Did not get expected pkt data. ClipNo: %v, pktNoInClip: %v, Got: %v, want: %v\n", clipNo, i/mts.PacketSize, data, expectData)
				}
				expectData++
			}
		}
	}
}

// TestMtsSenderDiscontinuity checks that a discontinuity in a stream is
// correctly handled by the mtsSender. A discontinuity is caused by overflowing
// the mtsSender's ringBuffer. It is expected that the next clip seen has the
// disconinuity indicator applied.
func TestMtsSenderDiscontinuity(t *testing.T) {
	mts.Meta = meta.New()

	// Create destination, the mtsSender and the mtsEncoder.
	const clipToDelay = 3
	tstDst := &destination{t: t, sendDelay: 10 * time.Millisecond, delayAt: clipToDelay}
	sender := newMtsSender(tstDst, (*dummyLogger)(t).log, 1, ringBufferElementSize, writeTimeout)
	encoder := mts.NewEncoder(sender, 25, mts.Video)

	// Turn time based PSI writing off for encoder.
	const psiSendCount = 10
	encoder.TimeBasedPsi(false, psiSendCount)

	// Write the packets to the encoder, which will in turn write to the mtsSender.
	// Payload will just be packet number.
	const noOfPacketsToWrite = 100
	for i := 0; i < noOfPacketsToWrite; i++ {
		encoder.Write([]byte{byte(i)})
	}

	// Give mtsSender time to finish up then Close.
	time.Sleep(100 * time.Millisecond)
	sender.Close()

	// Check the data.
	result := tstDst.buf
	expectedCC := 0
	for clipNo, clip := range result {
		t.Logf("Checking clip: %v\n", clipNo)

		// Check that the clip is of expected length.
		clipLen := len(clip)
		if clipLen != psiSendCount*mts.PacketSize {
			t.Fatalf("Clip %v is not correct length. Got: %v Want: %v\n Clip: %v\n", clipNo, clipLen, psiSendCount*mts.PacketSize, clip)
		}

		// Also check that the first packet is a PAT.
		firstPkt := clip[:mts.PacketSize]
		var pkt packet.Packet
		copy(pkt[:], firstPkt)
		pid := pkt.PID()
		if pid != mts.PatPid {
			t.Fatalf("First packet of clip %v is not pat, but rather: %v\n", clipNo, pid)
		}

		// Get the discontinuity indicator
		discon, _ := (*packet.AdaptationField)(&pkt).Discontinuity()

		// Check the continuity counter.
		cc := pkt.ContinuityCounter()
		if cc != expectedCC {
			t.Log("discontinuity found")
			expectedCC = cc
			if !discon {
				t.Errorf("discontinuity indicator not set where expected for clip: %v", clipNo)
			}
		} else {
			if discon && clipNo != 0 {
				t.Errorf("did not expect discontinuity indicator to be set for clip: %v", clipNo)
			}
		}
		expectedCC = (expectedCC + 1) & 0xf
	}
}