/*
NAME
  senders.go

DESCRIPTION
  See Readme.md

AUTHORS
  Saxon A. Nelson-Milton <saxon@ausocean.org>
  Alan Noble <alan@ausocean.org>

LICENSE
  revid is Copyright (C) 2017-2018 the Australian Ocean Lab (AusOcean)

  It is free software: you can redistribute it and/or modify them
  under the terms of the GNU General Public License as published by the
  Free Software Foundation, either version 3 of the License, or (at your
  option) any later version.

  It is distributed in the hope that it will be useful, but WITHOUT
  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
 for more details.

  You should have received a copy of the GNU General Public License
  along with revid in gpl.txt.  If not, see http://www.gnu.org/licenses.
*/

package revid

import (
	"errors"
	"fmt"
	"io"
	"net"
	"os"
	"sync"
	"time"

	"github.com/Comcast/gots/packet"

	"bitbucket.org/ausocean/av/container/mts"
	"bitbucket.org/ausocean/av/protocol/rtmp"
	"bitbucket.org/ausocean/av/protocol/rtp"
	"bitbucket.org/ausocean/av/revid/config"
	"bitbucket.org/ausocean/iot/pi/netsender"
	"bitbucket.org/ausocean/utils/logger"
	"bitbucket.org/ausocean/utils/ring"
)

// Log is used by the multiSender.
type Log func(level int8, message string, params ...interface{})

// Sender ring buffer read timeouts.
const (
	rtmpRBReadTimeout = 1 * time.Second
	mtsRBReadTimeout  = 1 * time.Second
	maxBuffLen        = 50000000
)

var (
	adjustedRTMPRBElementSize int
	adjustedMTSRBElementSize  int
)

// httpSender provides an implemntation of io.Writer to perform sends to a http
// destination.
type httpSender struct {
	client *netsender.Sender
	log    func(lvl int8, msg string, args ...interface{})
	report func(sent int)
}

// newHttpSender returns a pointer to a new httpSender.
func newHTTPSender(ns *netsender.Sender, log func(lvl int8, msg string, args ...interface{}), report func(sent int)) *httpSender {
	return &httpSender{
		client: ns,
		log:    log,
		report: report,
	}
}

// Write implements io.Writer.
func (s *httpSender) Write(d []byte) (int, error) {
	s.log(logger.Debug, "HTTP sending")
	err := httpSend(d, s.client, s.log)
	if err == nil {
		s.log(logger.Debug, "good send", "len", len(d))
		s.report(len(d))
	} else {
		s.log(logger.Debug, "bad send", "error", err)
	}
	return len(d), err
}

func (s *httpSender) Close() error { return nil }

func httpSend(d []byte, client *netsender.Sender, log func(lvl int8, msg string, args ...interface{})) error {
	// Only send if "V0" or "S0" are configured as input.
	send := false
	ip := client.Param("ip")
	log(logger.Debug, "making pins, and sending recv request", "ip", ip)
	pins := netsender.MakePins(ip, "V,S")
	for i, pin := range pins {
		switch pin.Name {
		case "V0":
			pins[i].MimeType = "video/mp2t"
		case "S0":
			pins[i].MimeType = "audio/x-wav"
		default:
			continue
		}
		pins[i].Value = len(d)
		pins[i].Data = d
		send = true
		break
	}

	if !send {
		return nil
	}
	var err error
	var reply string
	reply, _, err = client.Send(netsender.RequestRecv, pins)
	if err != nil {
		return err
	}
	log(logger.Debug, "good request", "reply", reply)
	return extractMeta(reply, log)
}

// extractMeta looks at a reply at extracts any time or location data - then used
// to update time and location information in the mpegts encoder.
func extractMeta(r string, log func(lvl int8, msg string, args ...interface{})) error {
	dec, err := netsender.NewJSONDecoder(r)
	if err != nil {
		return nil
	}
	// Extract time from reply if mts.Realtime has not been set.
	if !mts.RealTime.IsSet() {
		t, err := dec.Int("ts")
		if err != nil {
			log(logger.Warning, "No timestamp in reply")
		} else {
			log(logger.Debug, "got timestamp", "ts", t)
			mts.RealTime.Set(time.Unix(int64(t), 0))
		}
	}

	// Extract location from reply
	g, err := dec.String("ll")
	if err != nil {
		log(logger.Debug, "No location in reply")
	} else {
		log(logger.Debug, fmt.Sprintf("got location: %v", g))
		mts.Meta.Add(mts.LocationKey, g)
	}

	return nil
}

// fileSender implements loadSender for a local file destination.
type fileSender struct {
	file      *os.File
	data      []byte
	multiFile bool
	path      string
	init      bool
	log       config.Logger
}

// newFileSender returns a new fileSender. Setting multi true will write a new
// file for each write to this sender.
func newFileSender(l config.Logger, path string, multiFile bool) (*fileSender, error) {
	return &fileSender{
		path:      path,
		log:       l,
		multiFile: multiFile,
		init:      true,
	}, nil
}

// Write implements io.Writer.
func (s *fileSender) Write(d []byte) (int, error) {
	if s.init || s.multiFile {
		fileName := s.path + time.Now().String()
		s.log.Debug("creating new output file", "init", s.init, "multiFile", s.multiFile, "fileName", fileName)
		f, err := os.Create(fileName)
		if err != nil {
			return 0, fmt.Errorf("could not create file to write media to: %w", err)
		}
		s.file = f
		s.init = false
	}
	s.log.Debug("writing output file", "len(d)", len(d))
	return s.file.Write(d)
}

func (s *fileSender) Close() error { return s.file.Close() }

// mtsSender implements io.WriteCloser and provides sending capability specifically
// for use with MPEGTS packetization. It handles the construction of appropriately
// lengthed clips based on clip duration and PSI. It also accounts for
// discontinuities by setting the discontinuity indicator for the first packet of a clip.
type mtsSender struct {
	dst      io.WriteCloser
	buf      []byte
	ring     *ring.Buffer
	next     []byte
	pkt      packet.Packet
	repairer *mts.DiscontinuityRepairer
	curPid   int
	clipDur  time.Duration
	prev     time.Time
	done     chan struct{}
	log      func(lvl int8, msg string, args ...interface{})
	wg       sync.WaitGroup
}

// newMtsSender returns a new mtsSender.
func newMTSSender(dst io.WriteCloser, log func(lvl int8, msg string, args ...interface{}), rb *ring.Buffer, clipDur time.Duration) *mtsSender {
	log(logger.Debug, "setting up mtsSender", "clip duration", int(clipDur))
	s := &mtsSender{
		dst:      dst,
		repairer: mts.NewDiscontinuityRepairer(),
		log:      log,
		ring:     rb,
		done:     make(chan struct{}),
		clipDur:  clipDur,
	}
	s.wg.Add(1)
	go s.output()
	return s
}

// output starts an mtsSender's data handling routine.
func (s *mtsSender) output() {
	var chunk *ring.Chunk
	for {
		select {
		case <-s.done:
			s.log(logger.Info, "terminating sender output routine")
			defer s.wg.Done()
			return
		default:
			// If chunk is nil then we're ready to get another from the ringBuffer.
			if chunk == nil {
				var err error
				chunk, err = s.ring.Next(mtsRBReadTimeout)
				switch err {
				case nil, io.EOF:
					continue
				case ring.ErrTimeout:
					s.log(logger.Debug, "mtsSender: ring buffer read timeout")
					continue
				default:
					s.log(logger.Error, "unexpected error", "error", err.Error())
					continue
				}
			}
			err := s.repairer.Repair(chunk.Bytes())
			if err != nil {
				chunk.Close()
				chunk = nil
				continue
			}
			s.log(logger.Debug, "mtsSender: writing")
			_, err = s.dst.Write(chunk.Bytes())
			if err != nil {
				s.log(logger.Debug, "failed write, repairing MTS", "error", err)
				s.repairer.Failed()
				continue
			} else {
				s.log(logger.Debug, "good write")
			}
			chunk.Close()
			chunk = nil
		}
	}
}

// Write implements io.Writer.
func (s *mtsSender) Write(d []byte) (int, error) {
	if len(d) < mts.PacketSize {
		return 0, errors.New("do not have full MTS packet")
	}

	if s.next != nil {
		s.log(logger.Debug, "appending packet to clip")
		s.buf = append(s.buf, s.next...)
	}
	bytes := make([]byte, len(d))
	copy(bytes, d)
	s.next = bytes
	p, _ := mts.PID(bytes)
	s.curPid = int(p)
	curDur := time.Now().Sub(s.prev)
	s.log(logger.Debug, "checking send conditions", "curDuration", int(curDur), "sendDur", int(s.clipDur), "curPID", s.curPid, "len", len(s.buf))
	if curDur >= s.clipDur && s.curPid == mts.PatPid && len(s.buf) > 0 {
		s.log(logger.Debug, "writing clip to ring buffer for sending", "size", len(s.buf))
		s.prev = time.Now()
		n, err := s.ring.Write(s.buf)
		if err == nil {
			s.ring.Flush()
		}
		if err != nil {
			s.log(logger.Warning, "ringBuffer write error", "error", err.Error(), "n", n, "size", len(s.buf), "rb element size", adjustedMTSRBElementSize)
			if err == ring.ErrTooLong {
				adjustedMTSRBElementSize = len(s.buf) * 2
				numElements := maxBuffLen / adjustedMTSRBElementSize
				s.ring = ring.NewBuffer(maxBuffLen/adjustedMTSRBElementSize, adjustedMTSRBElementSize, 5*time.Second)
				s.log(logger.Info, "adjusted MTS ring buffer element size", "new size", adjustedMTSRBElementSize, "num elements", numElements, "size(MB)", numElements*adjustedMTSRBElementSize)
			}
		}
		s.buf = s.buf[:0]
	}
	return len(d), nil
}

// Close implements io.Closer.
func (s *mtsSender) Close() error {
	s.log(logger.Debug, "closing sender output routine")
	close(s.done)
	s.wg.Wait()
	s.log(logger.Info, "sender output routine closed")
	return nil
}

// rtmpSender implements loadSender for a native RTMP destination.
type rtmpSender struct {
	conn    *rtmp.Conn
	url     string
	retries int
	log     func(lvl int8, msg string, args ...interface{})
	ring    *ring.Buffer
	done    chan struct{}
	wg      sync.WaitGroup
	report  func(sent int)
}

func newRtmpSender(url string, retries int, rb *ring.Buffer, log func(lvl int8, msg string, args ...interface{}), report func(sent int)) (*rtmpSender, error) {
	var conn *rtmp.Conn
	var err error
	for n := 0; n < retries; n++ {
		conn, err = rtmp.Dial(url, log)
		if err == nil {
			break
		}
		log(logger.Error, "dial error", "error", err)
		if n < retries-1 {
			log(logger.Info, "retrying dial")
		}
	}
	s := &rtmpSender{
		conn:    conn,
		url:     url,
		retries: retries,
		log:     log,
		ring:    rb,
		done:    make(chan struct{}),
		report:  report,
	}
	s.wg.Add(1)
	go s.output()
	return s, err
}

// output starts an mtsSender's data handling routine.
func (s *rtmpSender) output() {
	var chunk *ring.Chunk
	for {
		select {
		case <-s.done:
			s.log(logger.Info, "terminating sender output routine")
			defer s.wg.Done()
			return
		default:
			// If chunk is nil then we're ready to get another from the ring buffer.
			if chunk == nil {
				var err error
				chunk, err = s.ring.Next(rtmpRBReadTimeout)
				switch err {
				case nil, io.EOF:
					continue
				case ring.ErrTimeout:
					s.log(logger.Debug, "rtmpSender: ring buffer read timeout")
					continue
				default:
					s.log(logger.Error, "unexpected error", "error", err.Error())
					continue
				}
			}
			if s.conn == nil {
				s.log(logger.Warning, "no rtmp connection, re-dialing")
				err := s.restart()
				if err != nil {
					s.log(logger.Warning, "could not restart connection", "error", err)
					continue
				}
			}
			_, err := s.conn.Write(chunk.Bytes())
			switch err {
			case nil, rtmp.ErrInvalidFlvTag:
				s.log(logger.Debug, "good write to conn")
			default:
				s.log(logger.Warning, "send error, re-dialing", "error", err)
				err = s.restart()
				if err != nil {
					s.log(logger.Warning, "could not restart connection", "error", err)
				}
				continue
			}
			chunk.Close()
			chunk = nil
		}
	}
}

// Write implements io.Writer.
func (s *rtmpSender) Write(d []byte) (int, error) {
	s.log(logger.Debug, "writing to ring buffer")
	_, err := s.ring.Write(d)
	if err == nil {
		s.ring.Flush()
		s.log(logger.Debug, "good ring buffer write", "len", len(d))
	} else {
		s.log(logger.Warning, "ring buffer write error", "error", err.Error())
		if err == ring.ErrTooLong {
			adjustedRTMPRBElementSize = len(d) * 2
			numElements := maxBuffLen / adjustedRTMPRBElementSize
			s.ring = ring.NewBuffer(numElements, adjustedRTMPRBElementSize, 5*time.Second)
			s.log(logger.Info, "adjusted RTMP ring buffer element size", "new size", adjustedRTMPRBElementSize, "num elements", numElements, "size(MB)", numElements*adjustedRTMPRBElementSize)
		}
	}
	s.report(len(d))
	return len(d), nil
}

func (s *rtmpSender) restart() error {
	s.close()
	var err error
	for n := 0; n < s.retries; n++ {
		s.log(logger.Debug, "dialing", "dials", n)
		s.conn, err = rtmp.Dial(s.url, s.log)
		if err == nil {
			break
		}
		s.log(logger.Error, "dial error", "error", err)
		if n < s.retries-1 {
			s.log(logger.Info, "retry rtmp connection")
		}
	}
	return err
}

func (s *rtmpSender) Close() error {
	s.log(logger.Debug, "closing output routine")
	if s.done != nil {
		close(s.done)
	}
	s.wg.Wait()
	s.log(logger.Info, "output routine closed")
	return s.close()
}

func (s *rtmpSender) close() error {
	s.log(logger.Debug, "closing connection")
	if s.conn == nil {
		return nil
	}
	return s.conn.Close()
}

// TODO: Write restart func for rtpSender
// rtpSender implements loadSender for a native udp destination with rtp packetization.
type rtpSender struct {
	log     func(lvl int8, msg string, args ...interface{})
	encoder *rtp.Encoder
	data    []byte
	report  func(sent int)
}

func newRtpSender(addr string, log func(lvl int8, msg string, args ...interface{}), fps uint, report func(sent int)) (*rtpSender, error) {
	conn, err := net.Dial("udp", addr)
	if err != nil {
		return nil, err
	}
	s := &rtpSender{
		log:     log,
		encoder: rtp.NewEncoder(conn, int(fps)),
		report:  report,
	}
	return s, nil
}

// Write implements io.Writer.
func (s *rtpSender) Write(d []byte) (int, error) {
	s.data = make([]byte, len(d))
	copy(s.data, d)
	_, err := s.encoder.Write(s.data)
	if err != nil {
		s.log(logger.Warning, "rtpSender: write error", err.Error())
	}
	s.report(len(d))
	return len(d), nil
}

func (s *rtpSender) Close() error { return nil }