/*
NAME
  rtmp.go

DESCRIPTION
  RTMP command functionality.

AUTHORS
  Saxon Nelson-Milton <saxon@ausocean.org>
  Dan Kortschak <dan@ausocean.org>
  Alan Noble <alan@ausocean.org>

LICENSE
  rtmp.go is Copyright (C) 2017-2019 the Australian Ocean Lab (AusOcean)

  It is free software: you can redistribute it and/or modify them
  under the terms of the GNU General Public License as published by the
  Free Software Foundation, either version 3 of the License, or (at your
  option) any later version.

  It is distributed in the hope that it will be useful, but WITHOUT
  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
  for more details.

  You should have received a copy of the GNU General Public License
  along with revid in gpl.txt. If not, see http://www.gnu.org/licenses.

  Derived from librtmp under the GNU Lesser General Public License 2.1
  Copyright (C) 2005-2008 Team XBMC http://www.xbmc.org
  Copyright (C) 2008-2009 Andrej Stepanchuk
  Copyright (C) 2009-2010 Howard Chu
*/

package rtmp

import (
	"bytes"
	"encoding/binary"
	"errors"
	"math/rand"
	"net"
	"strconv"
	"time"

	"bitbucket.org/ausocean/av/protocol/rtmp/amf"
)

const (
	pkg            = "rtmp:"
	signatureSize  = 1536
	fullHeaderSize = 12
)

// Link flags.
const (
	linkAuth     = 0x0001 // using auth param
	linkLive     = 0x0002 // stream is live
	linkSWF      = 0x0004 // do SWF verification - not implemented
	linkPlaylist = 0x0008 // send playlist before play - not implemented
	linkBufx     = 0x0010 // toggle stream on BufferEmpty msg - not implemented
)

// Protocol features.
const (
	featureHTTP   = 0x01 // not implemented
	featureEncode = 0x02 // not implemented
	featureSSL    = 0x04 // not implemented
	featureMFP    = 0x08 // not implemented
	featureWrite  = 0x10 // publish, not play
	featureHTTP2  = 0x20 // server-side RTMPT - not implemented
)

// RTMP protocols.
const (
	protoRTMP   = 0
	protoRTMPE  = featureEncode
	protoRTMPT  = featureHTTP
	protoRTMPS  = featureSSL
	protoRTMPTE = (featureHTTP | featureEncode)
	protoRTMPTS = (featureHTTP | featureSSL)
	protoRTMFP  = featureMFP
)

// RTMP tokens (lexemes).
// NB: Underscores are deliberately preserved in const names where they exist in the corresponding tokens.
const (
	av_checkbw                       = "_checkbw"
	av_onbwcheck                     = "_onbwcheck"
	av_onbwdone                      = "_onbwdone"
	av_result                        = "_result"
	avApp                            = "app"
	avAudioCodecs                    = "audioCodecs"
	avCapabilities                   = "capabilities"
	avClose                          = "close"
	avCode                           = "code"
	avConnect                        = "connect"
	avCreatestream                   = "createStream"
	avDeletestream                   = "deleteStream"
	avFCPublish                      = "FCPublish"
	avFCUnpublish                    = "FCUnpublish"
	avFlashver                       = "flashVer"
	avFpad                           = "fpad"
	avLevel                          = "level"
	avLive                           = "live"
	avNetConnectionConnectInvalidApp = "NetConnection.Connect.InvalidApp"
	avNetStreamFailed                = "NetStream.Failed"
	avNetStreamPauseNotify           = "NetStream.Pause.Notify"
	avNetStreamPlayComplete          = "NetStream.Play.Complete"
	avNetStreamPlayFailed            = "NetStream.Play.Failed"
	avNetStreamPlayPublishNotify     = "NetStream.Play.PublishNotify"
	avNetStreamPlayStart             = "NetStream.Play.Start"
	avNetStreamPlayStop              = "NetStream.Play.Stop"
	avNetStreamPlayStreamNotFound    = "NetStream.Play.StreamNotFound"
	avNetStreamPlayUnpublishNotify   = "NetStream.Play.UnpublishNotify"
	avNetStreamPublish_Start         = "NetStream.Publish.Start"
	avNetStreamSeekNotify            = "NetStream.Seek.Notify"
	avNonprivate                     = "nonprivate"
	avObjectEncoding                 = "objectEncoding"
	avOnBWDone                       = "onBWDone"
	avOnFCSubscribe                  = "onFCSubscribe"
	avOnFCUnsubscribe                = "onFCUnsubscribe"
	avOnStatus                       = "onStatus"
	avPageUrl                        = "pageUrl"
	avPing                           = "ping"
	avPlay                           = "play"
	avPlaylist_ready                 = "playlist_ready"
	avPublish                        = "publish"
	avReleasestream                  = "releaseStream"
	avSecureToken                    = "secureToken"
	avSet_playlist                   = "set_playlist"
	avSwfUrl                         = "swfUrl"
	avTcUrl                          = "tcUrl"
	avType                           = "type"
	avVideoCodecs                    = "videoCodecs"
	avVideoFunction                  = "videoFunction"
)

// RTMP protocol strings.
var rtmpProtocolStrings = [...]string{
	"rtmp",
	"rtmpt",
	"rtmpe",
	"rtmpte",
	"rtmps",
	"rtmpts",
	"",
	"",
	"rtmfp",
}

// RTMP errors.
var (
	errUnknownScheme = errors.New("rtmp: unknown scheme")
	errInvalidURL    = errors.New("rtmp: invalid URL")
	errConnected     = errors.New("rtmp: already connected")
	errNotConnected  = errors.New("rtmp: not connected")
	errNotWritable   = errors.New("rtmp: connection not writable")
	errInvalidHeader = errors.New("rtmp: invalid header")
	errInvalidBody   = errors.New("rtmp: invalid body")
	ErrInvalidFlvTag = errors.New("rtmp: invalid FLV tag")
	errUnimplemented = errors.New("rtmp: unimplemented feature")
)

// connect establishes an RTMP connection.
func connect(c *Conn) error {
	addr, err := net.ResolveTCPAddr("tcp4", c.link.host+":"+strconv.Itoa(int(c.link.port)))
	if err != nil {
		return err
	}
	c.link.conn, err = net.DialTCP("tcp4", nil, addr)
	if err != nil {
		c.log(WarnLevel, pkg+"dial failed", "error", err.Error())
		return err
	}
	c.log(DebugLevel, pkg+"connected")

	defer func() {
		if err != nil {
			c.link.conn.Close()
		}
	}()

	err = handshake(c)
	if err != nil {
		c.log(WarnLevel, pkg+"handshake failed", "error", err.Error())
		return err
	}
	c.log(DebugLevel, pkg+"handshaked")
	err = sendConnectPacket(c)
	if err != nil {
		c.log(WarnLevel, pkg+"sendConnect failed", "error", err.Error())
		return err
	}

	c.log(DebugLevel, pkg+"negotiating")
	var buf [256]byte
	for !c.isPlaying {
		pkt := packet{buf: buf[:]}
		err = pkt.readFrom(c)
		if err != nil {
			return err
		}

		switch pkt.packetType {
		case packetTypeAudio, packetTypeVideo, packetTypeInfo:
			c.log(WarnLevel, pkg+"got packet before play; ignoring", "type", pkt.packetType)
		default:
			err = handlePacket(c, &pkt)
			if err != nil {
				return err
			}
		}
	}
	return nil
}

// handlePacket handles a packet that the client has received.
// NB: Unsupported packet types are logged fatally.
func handlePacket(c *Conn, pkt *packet) error {
	if pkt.bodySize < 4 {
		return errInvalidBody
	}

	switch pkt.packetType {
	case packetTypeChunkSize:
		c.inChunkSize = amf.DecodeInt32(pkt.body[:4])
		c.log(DebugLevel, pkg+"set inChunkSize", "size", int(c.inChunkSize))

	case packetTypeBytesReadReport:
		c.log(DebugLevel, pkg+"received packetTypeBytesReadReport")

	case packetTypeServerBW:
		c.serverBW = amf.DecodeInt32(pkt.body[:4])
		c.log(DebugLevel, pkg+"set serverBW", "size", int(c.serverBW))

	case packetTypeClientBW:
		c.clientBW = amf.DecodeInt32(pkt.body[:4])
		c.log(DebugLevel, pkg+"set clientBW", "size", int(c.clientBW))
		if pkt.bodySize > 4 {
			c.clientBW2 = pkt.body[4]
			c.log(DebugLevel, pkg+"set clientBW2", "size", int(c.clientBW2))
		} else {
			c.clientBW2 = 0xff
		}

	case packetTypeInvoke:
		err := handleInvoke(c, pkt.body[:pkt.bodySize])
		if err != nil {
			c.log(WarnLevel, pkg+"unexpected error from handleInvoke", "error", err.Error())
			return err
		}

	case packetTypeControl, packetTypeAudio, packetTypeVideo, packetTypeFlashVideo, packetTypeFlexMessage, packetTypeInfo:
		c.log(FatalLevel, pkg+"unsupported packet type "+strconv.Itoa(int(pkt.packetType)))

	default:
		c.log(WarnLevel, pkg+"unknown packet type", "type", pkt.packetType)
	}
	return nil
}

func sendConnectPacket(c *Conn) error {
	var pbuf [4096]byte
	pkt := packet{
		channel:    chanControl,
		headerType: headerSizeLarge,
		packetType: packetTypeInvoke,
		buf:        pbuf[:],
		body:       pbuf[fullHeaderSize:],
	}
	enc := pkt.body

	enc, err := amf.EncodeString(enc, avConnect)
	if err != nil {
		return err
	}
	c.numInvokes += 1
	enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
	if err != nil {
		return err
	}

	// required link info
	info := amf.Object{Properties: []amf.Property{
		amf.Property{Type: amf.TypeString, Name: avApp, String: c.link.app},
		amf.Property{Type: amf.TypeString, Name: avType, String: avNonprivate},
		amf.Property{Type: amf.TypeString, Name: avTcUrl, String: c.link.url}},
	}
	enc, err = amf.Encode(&info, enc)
	if err != nil {
		return err
	}

	// optional link auth info
	if c.link.auth != "" {
		enc, err = amf.EncodeBoolean(enc, c.link.flags&linkAuth != 0)
		if err != nil {
			return err
		}
		enc, err = amf.EncodeString(enc, c.link.auth)
		if err != nil {
			return err
		}
	}

	pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))

	return pkt.writeTo(c, true) // response expected
}

func sendCreateStream(c *Conn) error {
	var pbuf [256]byte
	pkt := packet{
		channel:    chanControl,
		headerType: headerSizeMedium,
		packetType: packetTypeInvoke,
		buf:        pbuf[:],
		body:       pbuf[fullHeaderSize:],
	}
	enc := pkt.body

	enc, err := amf.EncodeString(enc, avCreatestream)
	if err != nil {
		return err
	}
	c.numInvokes++
	enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
	if err != nil {
		return err
	}
	enc[0] = amf.TypeNull
	enc = enc[1:]

	pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))

	return pkt.writeTo(c, true) // response expected
}

func sendReleaseStream(c *Conn) error {
	var pbuf [1024]byte
	pkt := packet{
		channel:    chanControl,
		headerType: headerSizeMedium,
		packetType: packetTypeInvoke,
		buf:        pbuf[:],
		body:       pbuf[fullHeaderSize:],
	}
	enc := pkt.body

	enc, err := amf.EncodeString(enc, avReleasestream)
	if err != nil {
		return err
	}
	c.numInvokes++
	enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
	if err != nil {
		return err
	}
	enc[0] = amf.TypeNull
	enc = enc[1:]
	enc, err = amf.EncodeString(enc, c.link.playpath)
	if err != nil {
		return err
	}
	pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))

	return pkt.writeTo(c, false)
}

func sendFCPublish(c *Conn) error {
	var pbuf [1024]byte
	pkt := packet{
		channel:    chanControl,
		headerType: headerSizeMedium,
		packetType: packetTypeInvoke,
		buf:        pbuf[:],
		body:       pbuf[fullHeaderSize:],
	}
	enc := pkt.body

	enc, err := amf.EncodeString(enc, avFCPublish)
	if err != nil {
		return err
	}
	c.numInvokes++
	enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
	if err != nil {
		return err
	}
	enc[0] = amf.TypeNull
	enc = enc[1:]
	enc, err = amf.EncodeString(enc, c.link.playpath)
	if err != nil {
		return err
	}

	pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))

	return pkt.writeTo(c, false)
}

func sendFCUnpublish(c *Conn) error {
	var pbuf [1024]byte
	pkt := packet{
		channel:    chanControl,
		headerType: headerSizeMedium,
		packetType: packetTypeInvoke,
		buf:        pbuf[:],
		body:       pbuf[fullHeaderSize:],
	}
	enc := pkt.body

	enc, err := amf.EncodeString(enc, avFCUnpublish)
	if err != nil {
		return err
	}
	c.numInvokes++
	enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
	if err != nil {
		return err
	}
	enc[0] = amf.TypeNull
	enc = enc[1:]
	enc, err = amf.EncodeString(enc, c.link.playpath)
	if err != nil {
		return err
	}

	pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))

	return pkt.writeTo(c, false)
}

func sendPublish(c *Conn) error {
	var pbuf [1024]byte
	pkt := packet{
		channel:    chanSource,
		headerType: headerSizeLarge,
		packetType: packetTypeInvoke,
		buf:        pbuf[:],
		body:       pbuf[fullHeaderSize:],
	}
	enc := pkt.body

	enc, err := amf.EncodeString(enc, avPublish)
	if err != nil {
		return err
	}
	c.numInvokes++
	enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
	if err != nil {
		return err
	}
	enc[0] = amf.TypeNull
	enc = enc[1:]
	enc, err = amf.EncodeString(enc, c.link.playpath)
	if err != nil {
		return err
	}
	enc, err = amf.EncodeString(enc, avLive)
	if err != nil {
		return err
	}

	pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))

	return pkt.writeTo(c, true) // response expected
}

func sendDeleteStream(c *Conn, streamID float64) error {
	var pbuf [256]byte
	pkt := packet{
		channel:    chanControl,
		headerType: headerSizeMedium,
		packetType: packetTypeInvoke,
		buf:        pbuf[:],
		body:       pbuf[fullHeaderSize:],
	}
	enc := pkt.body

	enc, err := amf.EncodeString(enc, avDeletestream)
	if err != nil {
		return err
	}
	c.numInvokes++
	enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
	if err != nil {
		return err
	}
	enc[0] = amf.TypeNull
	enc = enc[1:]
	enc, err = amf.EncodeNumber(enc, streamID)
	if err != nil {
		return err
	}
	pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))

	return pkt.writeTo(c, false)
}

// sendBytesReceived tells the server how many bytes the client has received.
func sendBytesReceived(c *Conn) error {
	var pbuf [256]byte
	pkt := packet{
		channel:    chanBytesRead,
		headerType: headerSizeMedium,
		packetType: packetTypeBytesReadReport,
		buf:        pbuf[:],
		body:       pbuf[fullHeaderSize:],
	}
	enc := pkt.body

	c.nBytesInSent = c.nBytesIn

	enc, err := amf.EncodeInt32(enc, c.nBytesIn)
	if err != nil {
		return err
	}
	pkt.bodySize = 4

	return pkt.writeTo(c, false)
}

func sendCheckBW(c *Conn) error {
	var pbuf [256]byte
	pkt := packet{
		channel:    chanControl,
		headerType: headerSizeLarge,
		packetType: packetTypeInvoke,
		buf:        pbuf[:],
		body:       pbuf[fullHeaderSize:],
	}
	enc := pkt.body

	enc, err := amf.EncodeString(enc, av_checkbw)
	if err != nil {
		return err
	}
	c.numInvokes++
	enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
	if err != nil {
		return err
	}
	enc[0] = amf.TypeNull
	enc = enc[1:]

	pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))

	return pkt.writeTo(c, false)
}

func eraseMethod(m []method, i int) []method {
	copy(m[i:], m[i+1:])
	m[len(m)-1] = method{}
	return m[:len(m)-1]
}

// int handleInvoke handles a packet invoke request
// Side effects: c.isPlaying set to true upon avNetStreamPublish_Start
func handleInvoke(c *Conn, body []byte) error {
	if body[0] != 0x02 {
		return errInvalidBody
	}
	var obj amf.Object
	_, err := amf.Decode(&obj, body, false)
	if err != nil {
		return err
	}

	meth, err := obj.StringProperty("", 0)
	if err != nil {
		return err
	}
	txn, err := obj.NumberProperty("", 1)
	if err != nil {
		return err
	}

	c.log(DebugLevel, pkg+"invoking method "+meth)
	switch meth {
	case av_result:
		if (c.link.protocol & featureWrite) == 0 {
			return errNotWritable
		}
		var methodInvoked string
		for i, m := range c.methodCalls {
			if float64(m.num) == txn {
				methodInvoked = m.name
				c.methodCalls = eraseMethod(c.methodCalls, i)
				break
			}
		}
		if methodInvoked == "" {
			c.log(WarnLevel, pkg+"received result without matching request", "id", txn)
			return nil
		}
		c.log(DebugLevel, pkg+"received result for "+methodInvoked)

		switch methodInvoked {
		case avConnect:
			err := sendReleaseStream(c)
			if err != nil {
				return err
			}
			err = sendFCPublish(c)
			if err != nil {
				return err
			}
			err = sendCreateStream(c)
			if err != nil {
				return err
			}

		case avCreatestream:
			n, err := obj.NumberProperty("", 3)
			if err != nil {
				return err
			}
			c.streamID = uint32(n)
			err = sendPublish(c)
			if err != nil {
				return err
			}

		default:
			c.log(FatalLevel, pkg+"unexpected method invoked"+methodInvoked)
		}

	case avOnBWDone:
		err := sendCheckBW(c)
		if err != nil {
			return err
		}

	case avOnStatus:
		obj2, err := obj.ObjectProperty("", 3)
		if err != nil {
			return err
		}
		code, err := obj2.StringProperty(avCode, -1)
		if err != nil {
			return err
		}
		level, err := obj2.StringProperty(avLevel, -1)
		if err != nil {
			return err
		}
		c.log(DebugLevel, pkg+"onStatus", "code", code, "level", level)

		if code != avNetStreamPublish_Start {
			c.log(ErrorLevel, pkg+"unexpected response "+code)
			return errUnimplemented
		}
		c.log(DebugLevel, pkg+"playing")
		c.isPlaying = true
		for i, m := range c.methodCalls {
			if m.name == avPublish {
				c.methodCalls = eraseMethod(c.methodCalls, i)
			}
		}

	default:
		c.log(FatalLevel, pkg+"unsuppoted method "+meth)
	}
	return nil
}

func handshake(c *Conn) error {
	var clientbuf [signatureSize + 1]byte
	clientsig := clientbuf[1:]

	var serversig [signatureSize]byte
	clientbuf[0] = chanControl
	binary.BigEndian.PutUint32(clientsig, uint32(time.Now().UnixNano()/1000000))
	copy(clientsig[4:8], []byte{0, 0, 0, 0})

	for i := 8; i < signatureSize; i++ {
		clientsig[i] = byte(rand.Intn(256))
	}

	_, err := c.write(clientbuf[:])
	if err != nil {
		return err
	}
	c.log(DebugLevel, pkg+"handshake sent")

	var typ [1]byte
	_, err = c.read(typ[:])
	if err != nil {
		return err
	}
	c.log(DebugLevel, pkg+"handshake received")

	if typ[0] != clientbuf[0] {
		c.log(WarnLevel, pkg+"handshake type mismatch", "sent", clientbuf[0], "received", typ)
	}
	_, err = c.read(serversig[:])
	if err != nil {
		return err
	}

	// decode server response
	suptime := binary.BigEndian.Uint32(serversig[:4])
	c.log(DebugLevel, pkg+"server uptime", "uptime", suptime)

	// 2nd part of handshake
	_, err = c.write(serversig[:])
	if err != nil {
		return err
	}

	_, err = c.read(serversig[:])
	if err != nil {
		return err
	}

	if !bytes.Equal(serversig[:signatureSize], clientbuf[1:signatureSize+1]) {
		c.log(WarnLevel, pkg+"signature mismatch", "serversig", serversig[:signatureSize], "clientsig", clientbuf[1:signatureSize+1])
	}
	return nil
}