Merge branch 'master' into turbidity-probe

merge latest rtmp changes into turbidity probe
This commit is contained in:
Russell Stanley 2022-03-24 12:06:08 +10:30
commit 24d77b3b65
6 changed files with 273 additions and 135 deletions

View File

@ -43,6 +43,7 @@ package amf
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"math" "math"
) )
@ -163,11 +164,7 @@ func EncodeInt32(buf []byte, val uint32) ([]byte, error) {
// Strings less than 65536 in length are encoded as TypeString, while longer strings are ecodeded as typeLongString. // Strings less than 65536 in length are encoded as TypeString, while longer strings are ecodeded as typeLongString.
func EncodeString(buf []byte, val string) ([]byte, error) { func EncodeString(buf []byte, val string) ([]byte, error) {
const typeSize = 1 const typeSize = 1
if len(val) < 65536 && len(val)+typeSize+binary.Size(int16(0)) > len(buf) { if len(val) < 65536 && len(val)+typeSize+binary.Size(int16(0)) > len(buf) || len(val)+typeSize+binary.Size(uint32(0)) > len(buf) {
return nil, ErrShortBuffer
}
if len(val)+typeSize+binary.Size(uint32(0)) > len(buf) {
return nil, ErrShortBuffer return nil, ErrShortBuffer
} }
@ -222,7 +219,11 @@ func EncodeNamedString(buf []byte, key, val string) ([]byte, error) {
binary.BigEndian.PutUint16(buf[:2], uint16(len(key))) binary.BigEndian.PutUint16(buf[:2], uint16(len(key)))
buf = buf[2:] buf = buf[2:]
copy(buf, key) copy(buf, key)
return EncodeString(buf[len(key):], val) b, err := EncodeString(buf[len(key):], val)
if err != nil {
return nil, fmt.Errorf("could not encode string: %w", err)
}
return b, nil
} }
// EncodeNamedNumber encodes a named number, where key is the name and val is the number value. // EncodeNamedNumber encodes a named number, where key is the name and val is the number value.
@ -233,7 +234,11 @@ func EncodeNamedNumber(buf []byte, key string, val float64) ([]byte, error) {
binary.BigEndian.PutUint16(buf[:2], uint16(len(key))) binary.BigEndian.PutUint16(buf[:2], uint16(len(key)))
buf = buf[2:] buf = buf[2:]
copy(buf, key) copy(buf, key)
return EncodeNumber(buf[len(key):], val) b, err := EncodeNumber(buf[len(key):], val)
if err != nil {
return nil, fmt.Errorf("could not encode number: %w", err)
}
return b, nil
} }
// EncodeNamedNumber encodes a named boolean, where key is the name and val is the boolean value. // EncodeNamedNumber encodes a named boolean, where key is the name and val is the boolean value.
@ -244,14 +249,18 @@ func EncodeNamedBoolean(buf []byte, key string, val bool) ([]byte, error) {
binary.BigEndian.PutUint16(buf[:2], uint16(len(key))) binary.BigEndian.PutUint16(buf[:2], uint16(len(key)))
buf = buf[2:] buf = buf[2:]
copy(buf, key) copy(buf, key)
return EncodeBoolean(buf[len(key):], val) b, err := EncodeBoolean(buf[len(key):], val)
if err != nil {
return nil, fmt.Errorf("could not encode boolean: %w", err)
}
return b, nil
} }
// EncodeProperty encodes a property. // EncodeProperty encodes a property.
func EncodeProperty(prop *Property, buf []byte) ([]byte, error) { func EncodeProperty(prop *Property, buf []byte) ([]byte, error) {
if prop.Type != TypeNull && prop.Name != "" { if prop.Type != TypeNull && prop.Name != "" {
if len(buf) < 2+len(prop.Name) { if len(buf) < 2+len(prop.Name) {
return nil, ErrShortBuffer return nil, fmt.Errorf("not type null, short buffer: %w", ErrShortBuffer)
} }
binary.BigEndian.PutUint16(buf[:2], uint16(len(prop.Name))) binary.BigEndian.PutUint16(buf[:2], uint16(len(prop.Name)))
buf = buf[2:] buf = buf[2:]
@ -261,23 +270,47 @@ func EncodeProperty(prop *Property, buf []byte) ([]byte, error) {
switch prop.Type { switch prop.Type {
case typeNumber: case typeNumber:
return EncodeNumber(buf, prop.Number) b, err := EncodeNumber(buf, prop.Number)
if err != nil {
return nil, fmt.Errorf("could not encode number: %w", err)
}
return b, nil
case typeBoolean: case typeBoolean:
return EncodeBoolean(buf, prop.Number != 0) b, err := EncodeBoolean(buf, prop.Number != 0)
if err != nil {
return nil, fmt.Errorf("could not encode boolean: %w", err)
}
return b, nil
case TypeString: case TypeString:
return EncodeString(buf, prop.String) b, err := EncodeString(buf, prop.String)
if err != nil {
return nil, fmt.Errorf("could not encode string: %w", err)
}
return b, nil
case TypeNull: case TypeNull:
if len(buf) < 2 { if len(buf) < 2 {
return nil, ErrShortBuffer return nil, fmt.Errorf("type null, short buffer: %w", ErrShortBuffer)
} }
buf[0] = TypeNull buf[0] = TypeNull
buf = buf[1:] buf = buf[1:]
case TypeObject: case TypeObject:
return Encode(&prop.Object, buf) b, err := Encode(&prop.Object, buf)
if err != nil {
return nil, fmt.Errorf("could not encode: %w", err)
}
return b, nil
case typeEcmaArray: case typeEcmaArray:
return EncodeEcmaArray(&prop.Object, buf) b, err := EncodeEcmaArray(&prop.Object, buf)
if err != nil {
return nil, fmt.Errorf("could not encode ecma array: %w", err)
}
return b, nil
case typeStrictArray: case typeStrictArray:
return EncodeArray(&prop.Object, buf) b, err := EncodeArray(&prop.Object, buf)
if err != nil {
return nil, fmt.Errorf("could not encode array: %w", err)
}
return b, nil
default: default:
return nil, ErrInvalidType return nil, ErrInvalidType
} }
@ -294,7 +327,7 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) {
} }
n := DecodeInt16(buf[:2]) n := DecodeInt16(buf[:2])
if int(n) > len(buf)-2 { if int(n) > len(buf)-2 {
return 0, ErrShortBuffer return 0, fmt.Errorf("short buffer after decode of int 16: %w", ErrShortBuffer)
} }
prop.Name = DecodeString(buf) prop.Name = DecodeString(buf)
@ -309,14 +342,14 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) {
switch prop.Type { switch prop.Type {
case typeNumber: case typeNumber:
if len(buf) < 8 { if len(buf) < 8 {
return 0, ErrShortBuffer return 0, fmt.Errorf("type number short buffer: %w", ErrShortBuffer)
} }
prop.Number = DecodeNumber(buf[:8]) prop.Number = DecodeNumber(buf[:8])
buf = buf[8:] buf = buf[8:]
case typeBoolean: case typeBoolean:
if len(buf) < 1 { if len(buf) < 1 {
return 0, ErrShortBuffer return 0, fmt.Errorf("type boolean short buffer: %w", ErrShortBuffer)
} }
prop.Number = float64(buf[0]) prop.Number = float64(buf[0])
buf = buf[1:] buf = buf[1:]
@ -324,7 +357,7 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) {
case TypeString: case TypeString:
n := DecodeInt16(buf[:2]) n := DecodeInt16(buf[:2])
if len(buf) < int(n+2) { if len(buf) < int(n+2) {
return 0, ErrShortBuffer return 0, fmt.Errorf("type string: %w", ErrShortBuffer)
} }
prop.String = DecodeString(buf) prop.String = DecodeString(buf)
buf = buf[2+n:] buf = buf[2+n:]
@ -332,7 +365,7 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) {
case TypeObject: case TypeObject:
n, err := Decode(&prop.Object, buf, true) n, err := Decode(&prop.Object, buf, true)
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("could not decode type object: %w", err)
} }
buf = buf[n:] buf = buf[n:]
@ -343,7 +376,7 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) {
buf = buf[4:] buf = buf[4:]
n, err := Decode(&prop.Object, buf, true) n, err := Decode(&prop.Object, buf, true)
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("could not decode type ecma array: %w", err)
} }
buf = buf[n:] buf = buf[n:]
@ -367,14 +400,19 @@ func Encode(obj *Object, buf []byte) ([]byte, error) {
var err error var err error
buf, err = EncodeProperty(&obj.Properties[i], buf) buf, err = EncodeProperty(&obj.Properties[i], buf)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("could not encode property no. %d: %w", i, err)
} }
} }
if len(buf) < 3 { if len(buf) < 3 {
return nil, ErrShortBuffer return nil, fmt.Errorf("short buffer after property encoding: %w", ErrShortBuffer)
} }
return EncodeInt24(buf, TypeObjectEnd)
b, err := EncodeInt24(buf, TypeObjectEnd)
if err != nil {
return nil, fmt.Errorf("could not encode int 24: %w", err)
}
return b, err
} }
// EncodeEcmaArray encodes an ECMA array. // EncodeEcmaArray encodes an ECMA array.
@ -392,14 +430,20 @@ func EncodeEcmaArray(obj *Object, buf []byte) ([]byte, error) {
var err error var err error
buf, err = EncodeProperty(&obj.Properties[i], buf) buf, err = EncodeProperty(&obj.Properties[i], buf)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("could not encode property no. %d: %w", i, err)
} }
} }
if len(buf) < 3 { if len(buf) < 3 {
return nil, ErrShortBuffer return nil, fmt.Errorf("short buffer after property encoding: %w", ErrShortBuffer)
} }
return EncodeInt24(buf, TypeObjectEnd)
b, err := EncodeInt24(buf, TypeObjectEnd)
if err != nil {
return nil, fmt.Errorf("could not encode int 24: %w", err)
}
return b, nil
} }
// EncodeArray encodes an array. // EncodeArray encodes an array.
@ -417,7 +461,7 @@ func EncodeArray(obj *Object, buf []byte) ([]byte, error) {
var err error var err error
buf, err = EncodeProperty(&obj.Properties[i], buf) buf, err = EncodeProperty(&obj.Properties[i], buf)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("could not encode property no. %d: %w", i, err)
} }
} }
@ -437,7 +481,7 @@ func Decode(obj *Object, buf []byte, decodeName bool) (int, error) {
var prop Property var prop Property
n, err := DecodeProperty(&prop, buf, decodeName) n, err := DecodeProperty(&prop, buf, decodeName)
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("could not decode property: %w", err)
} }
buf = buf[n:] buf = buf[n:]
obj.Properties = append(obj.Properties, prop) obj.Properties = append(obj.Properties, prop)

View File

@ -138,14 +138,14 @@ func Dial(url string, log Log, options ...func(*Conn) error) (*Conn, error) {
var err error var err error
c.link.protocol, c.link.host, c.link.port, c.link.app, c.link.playpath, err = parseURL(url) c.link.protocol, c.link.host, c.link.port, c.link.app, c.link.playpath, err = parseURL(url)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("could not parse url: %w",err)
} }
c.link.url = rtmpProtocolStrings[c.link.protocol] + "://" + c.link.host + ":" + strconv.Itoa(int(c.link.port)) + "/" + c.link.app c.link.url = rtmpProtocolStrings[c.link.protocol] + "://" + c.link.host + ":" + strconv.Itoa(int(c.link.port)) + "/" + c.link.app
c.link.protocol |= featureWrite c.link.protocol |= featureWrite
err = connect(&c) err = connect(&c)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("could not connect: %w",err)
} }
return &c, nil return &c, nil
} }
@ -159,11 +159,20 @@ func (c *Conn) Close() error {
c.log(DebugLevel, pkg+"Conn.Close") c.log(DebugLevel, pkg+"Conn.Close")
if c.streamID > 0 { if c.streamID > 0 {
if c.link.protocol&featureWrite != 0 { if c.link.protocol&featureWrite != 0 {
sendFCUnpublish(c) err := sendFCUnpublish(c)
if err != nil {
return fmt.Errorf("could not send fc unpublish: %w",err)
}
}
err := sendDeleteStream(c, float64(c.streamID))
if err != nil {
return fmt.Errorf("could not send delete stream: %w",err)
} }
sendDeleteStream(c, float64(c.streamID))
} }
c.link.conn.Close() err := c.link.conn.Close()
if err != nil {
return fmt.Errorf("could not close link conn: %w",err)
}
*c = Conn{} *c = Conn{}
return nil return nil
} }
@ -192,7 +201,7 @@ func (c *Conn) Write(data []byte) (int, error) {
copy(pkt.body, data[flvTagheaderSize:flvTagheaderSize+pkt.bodySize]) copy(pkt.body, data[flvTagheaderSize:flvTagheaderSize+pkt.bodySize])
err := pkt.writeTo(c, false) err := pkt.writeTo(c, false)
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("could not write packet to connection: %w",err)
} }
return len(data), nil return len(data), nil
} }
@ -205,18 +214,18 @@ func (c *Conn) Write(data []byte) (int, error) {
func (c *Conn) read(buf []byte) (int, error) { func (c *Conn) read(buf []byte) (int, error) {
err := c.link.conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(c.link.timeout))) err := c.link.conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(c.link.timeout)))
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("could not set read deadline: %w",err)
} }
n, err := io.ReadFull(c.link.conn, buf) n, err := io.ReadFull(c.link.conn, buf)
if err != nil { if err != nil {
c.log(DebugLevel, pkg+"read failed", "error", err.Error()) c.log(DebugLevel, pkg+"read failed", "error", err.Error())
return 0, err return 0, fmt.Errorf("could not read conn: %w",err)
} }
c.nBytesIn += uint32(n) c.nBytesIn += uint32(n)
if c.nBytesIn > (c.nBytesInSent + c.clientBW/10) { if c.nBytesIn > (c.nBytesInSent + c.clientBW/10) {
err := sendBytesReceived(c) err := sendBytesReceived(c)
if err != nil { if err != nil {
return n, err // NB: we still read n bytes, even though send bytes failed return n, fmt.Errorf("could not send bytes received: %w",err) // NB: we still read n bytes, even though send bytes failed
} }
} }
return n, nil return n, nil
@ -227,12 +236,12 @@ func (c *Conn) write(buf []byte) (int, error) {
//ToDo: consider using a different timeout for writes than for reads //ToDo: consider using a different timeout for writes than for reads
err := c.link.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(c.link.timeout))) err := c.link.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(c.link.timeout)))
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("could not set write deadline: %w",err)
} }
n, err := c.link.conn.Write(buf) n, err := c.link.conn.Write(buf)
if err != nil { if err != nil {
c.log(WarnLevel, pkg+"write failed", "error", err.Error()) c.log(WarnLevel, pkg+"write failed", "error", err.Error())
return 0, err return 0, fmt.Errorf("could not write to conn: %w",err)
} }
return n, nil return n, nil
} }

View File

@ -36,6 +36,7 @@ package rtmp
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"bitbucket.org/ausocean/av/protocol/rtmp/amf" "bitbucket.org/ausocean/av/protocol/rtmp/amf"
@ -90,10 +91,15 @@ type packet struct {
timestamp uint32 timestamp uint32
streamID uint32 streamID uint32
bodySize uint32 bodySize uint32
bytesRead uint32
buf []byte buf []byte
body []byte body []byte
} }
func (pkt *packet) isReady() bool {
return pkt.bytesRead == pkt.bodySize
}
// readFrom reads a packet from the RTMP connection. // readFrom reads a packet from the RTMP connection.
func (pkt *packet) readFrom(c *Conn) error { func (pkt *packet) readFrom(c *Conn) error {
var hbuf [fullHeaderSize]byte var hbuf [fullHeaderSize]byte
@ -105,7 +111,7 @@ func (pkt *packet) readFrom(c *Conn) error {
if err == io.EOF { if err == io.EOF {
c.log(WarnLevel, pkg+"EOF error; connection likely terminated") c.log(WarnLevel, pkg+"EOF error; connection likely terminated")
} }
return err return fmt.Errorf("failed to read packet header 1st byte: %w", err)
} }
pkt.headerType = (header[0] & 0xc0) >> 6 pkt.headerType = (header[0] & 0xc0) >> 6
pkt.channel = int32(header[0] & 0x3f) pkt.channel = int32(header[0] & 0x3f)
@ -116,7 +122,7 @@ func (pkt *packet) readFrom(c *Conn) error {
_, err = c.read(header[:1]) _, err = c.read(header[:1])
if err != nil { if err != nil {
c.log(DebugLevel, pkg+"failed to read packet header 2nd byte", "error", err.Error()) c.log(DebugLevel, pkg+"failed to read packet header 2nd byte", "error", err.Error())
return err return fmt.Errorf("failed to read packet header second byte: %w", err)
} }
header = header[1:] header = header[1:]
pkt.channel = int32(header[0]) + 64 pkt.channel = int32(header[0]) + 64
@ -125,7 +131,7 @@ func (pkt *packet) readFrom(c *Conn) error {
_, err = c.read(header[:2]) _, err = c.read(header[:2])
if err != nil { if err != nil {
c.log(DebugLevel, pkg+"failed to read packet header 3rd byte", "error", err.Error()) c.log(DebugLevel, pkg+"failed to read packet header 3rd byte", "error", err.Error())
return err return fmt.Errorf("failed to read packet header 3rd byte: %w", err)
} }
header = header[2:] header = header[2:]
pkt.channel = int32(binary.BigEndian.Uint16(header[:2])) + 64 pkt.channel = int32(binary.BigEndian.Uint16(header[:2])) + 64
@ -169,13 +175,14 @@ func (pkt *packet) readFrom(c *Conn) error {
_, err = c.read(header[:size]) _, err = c.read(header[:size])
if err != nil { if err != nil {
c.log(DebugLevel, pkg+"failed to read packet header", "error", err.Error()) c.log(DebugLevel, pkg+"failed to read packet header", "error", err.Error())
return err return fmt.Errorf("failed to read packet header: %w", err)
} }
} }
hSize := len(hbuf) - len(header) + size hSize := len(hbuf) - len(header) + size
if size >= 3 { if size >= 3 {
pkt.timestamp = amf.DecodeInt24(header[:3]) pkt.timestamp = amf.DecodeInt24(header[:3])
pkt.bytesRead = 0
if size >= 6 { if size >= 6 {
pkt.bodySize = amf.DecodeInt24(header[3:6]) pkt.bodySize = amf.DecodeInt24(header[3:6])
@ -193,7 +200,7 @@ func (pkt *packet) readFrom(c *Conn) error {
_, err = c.read(header[size : size+4]) _, err = c.read(header[size : size+4])
if err != nil { if err != nil {
c.log(DebugLevel, pkg+"failed to read extended timestamp", "error", err.Error()) c.log(DebugLevel, pkg+"failed to read extended timestamp", "error", err.Error())
return err return fmt.Errorf("failed to read extended timestamp: %w", err)
} }
pkt.timestamp = amf.DecodeInt32(header[size : size+4]) pkt.timestamp = amf.DecodeInt32(header[size : size+4])
hSize += 4 hSize += 4
@ -205,12 +212,24 @@ func (pkt *packet) readFrom(c *Conn) error {
c.log(WarnLevel, pkg+"reading large packet", "size", int(pkt.bodySize)) c.log(WarnLevel, pkg+"reading large packet", "size", int(pkt.bodySize))
} }
_, err = c.read(pkt.body[:pkt.bodySize]) nToRead := pkt.bodySize - pkt.bytesRead
nChunk := c.inChunkSize
if nToRead < nChunk {
nChunk = nToRead
}
n, err := c.read(pkt.body[pkt.bytesRead : pkt.bytesRead+nChunk])
if err != nil { if err != nil {
c.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error()) c.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error())
return err return fmt.Errorf("failed to read packet body: %w", err)
} }
if uint32(n) != nChunk {
return fmt.Errorf("did not read correct number of bytes, read: %d, expected: %d", n, nChunk)
}
pkt.bytesRead += nChunk
// Keep the packet as a reference for other packets on this channel. // Keep the packet as a reference for other packets on this channel.
if c.channelsIn[pkt.channel] == nil { if c.channelsIn[pkt.channel] == nil {
c.channelsIn[pkt.channel] = &packet{} c.channelsIn[pkt.channel] = &packet{}
@ -221,14 +240,17 @@ func (pkt *packet) readFrom(c *Conn) error {
c.channelsIn[pkt.channel].timestamp = 0xffffff c.channelsIn[pkt.channel].timestamp = 0xffffff
} }
if !pkt.hasAbsTimestamp { if pkt.isReady() {
// Timestamps seem to always be relative. if !pkt.hasAbsTimestamp {
pkt.timestamp += uint32(c.channelTimestamp[pkt.channel]) // Timestamps seem to always be relative.
} pkt.timestamp += uint32(c.channelTimestamp[pkt.channel])
c.channelTimestamp[pkt.channel] = int32(pkt.timestamp) }
c.channelTimestamp[pkt.channel] = int32(pkt.timestamp)
c.channelsIn[pkt.channel].body = nil
c.channelsIn[pkt.channel].hasAbsTimestamp = false
}
c.channelsIn[pkt.channel].body = nil
c.channelsIn[pkt.channel].hasAbsTimestamp = false
return nil return nil
} }
@ -404,7 +426,7 @@ func (pkt *packet) writeTo(c *Conn, queue bool) error {
c.log(DebugLevel, pkg+"sending deferred packet separately", "size", len(c.deferred)) c.log(DebugLevel, pkg+"sending deferred packet separately", "size", len(c.deferred))
_, err := c.write(c.deferred) _, err := c.write(c.deferred)
if err != nil { if err != nil {
return err return fmt.Errorf("could not write deferred packet: %w", err)
} }
c.deferred = nil c.deferred = nil
} }
@ -424,7 +446,7 @@ func (pkt *packet) writeTo(c *Conn, queue bool) error {
} }
_, err := c.write(bytes) _, err := c.write(bytes)
if err != nil { if err != nil {
return err return fmt.Errorf("could not write combined packet: %w", err)
} }
c.deferred = nil c.deferred = nil

View File

@ -34,17 +34,25 @@ LICENSE
package rtmp package rtmp
import ( import (
"errors"
"fmt"
"net/url" "net/url"
"path" "path"
"strconv" "strconv"
"strings" "strings"
) )
// Errors.
var (
errInvalidPath = errors.New("invalid url path")
errInvalidElements = errors.New("invalid url elements")
)
// parseURL parses an RTMP URL (ok, technically it is lexing). // parseURL parses an RTMP URL (ok, technically it is lexing).
func parseURL(addr string) (protocol int32, host string, port uint16, app, playpath string, err error) { func parseURL(addr string) (protocol int32, host string, port uint16, app, playpath string, err error) {
u, err := url.Parse(addr) u, err := url.Parse(addr)
if err != nil { if err != nil {
return protocol, host, port, app, playpath, err return protocol, host, port, app, playpath, fmt.Errorf("could not parse to url value: %w", err)
} }
switch u.Scheme { switch u.Scheme {
@ -63,24 +71,24 @@ func parseURL(addr string) (protocol int32, host string, port uint16, app, playp
case "rtmpts": case "rtmpts":
protocol = protoRTMPTS protocol = protoRTMPTS
default: default:
return protocol, host, port, app, playpath, errUnknownScheme return protocol, host, port, app, playpath, fmt.Errorf("unknown scheme: %s", u.Scheme)
} }
host = u.Host host = u.Host
if p := u.Port(); p != "" { if p := u.Port(); p != "" {
pi, err := strconv.Atoi(p) pi, err := strconv.Atoi(p)
if err != nil { if err != nil {
return protocol, host, port, app, playpath, err return protocol, host, port, app, playpath, fmt.Errorf("could convert port to integer: %w", err)
} }
port = uint16(pi) port = uint16(pi)
} }
if len(u.Path) < 1 || !path.IsAbs(u.Path) { if len(u.Path) < 1 || !path.IsAbs(u.Path) {
return protocol, host, port, app, playpath, errInvalidURL return protocol, host, port, app, playpath, errInvalidPath
} }
elems := strings.SplitN(u.Path[1:], "/", 3) elems := strings.SplitN(u.Path[1:], "/", 3)
if len(elems) < 2 || elems[0] == "" || elems[1] == "" { if len(elems) < 2 || elems[0] == "" || elems[1] == "" {
return protocol, host, port, app, playpath, errInvalidURL return protocol, host, port, app, playpath, errInvalidElements
} }
app = elems[0] app = elems[0]
playpath = path.Join(elems[1:]...) playpath = path.Join(elems[1:]...)
@ -106,7 +114,7 @@ func parseURL(addr string) (protocol int32, host string, port uint16, app, playp
switch { switch {
case port != 0: case port != 0:
case (protocol & featureSSL) != 0: case (protocol & featureSSL) != 0:
return protocol, host, port, app, playpath, errUnimplemented // port = 433 return protocol, host, port, app, playpath, errors.New("ssl not implemented")
case (protocol & featureHTTP) != 0: case (protocol & featureHTTP) != 0:
port = 80 port = 80
default: default:

View File

@ -41,19 +41,19 @@ var parseURLTests = []struct {
}{ }{
{ {
url: "rtmp://addr", url: "rtmp://addr",
wantErr: errInvalidURL, wantErr: errInvalidPath,
}, },
{ {
url: "rtmp://addr/", url: "rtmp://addr/",
wantErr: errInvalidURL, wantErr: errInvalidElements,
}, },
{ {
url: "rtmp://addr/live2", url: "rtmp://addr/live2",
wantErr: errInvalidURL, wantErr: errInvalidElements,
}, },
{ {
url: "rtmp://addr/live2/", url: "rtmp://addr/live2/",
wantErr: errInvalidURL, wantErr: errInvalidElements,
}, },
{ {
url: "rtmp://addr/appname/key", url: "rtmp://addr/appname/key",

View File

@ -38,6 +38,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
@ -164,14 +165,15 @@ var (
// connect establishes an RTMP connection. // connect establishes an RTMP connection.
func connect(c *Conn) error { func connect(c *Conn) error {
addr, err := net.ResolveTCPAddr("tcp4", c.link.host+":"+strconv.Itoa(int(c.link.port))) addrStr := c.link.host + ":" + strconv.Itoa(int(c.link.port))
addr, err := net.ResolveTCPAddr("tcp4", addrStr)
if err != nil { if err != nil {
return err return fmt.Errorf("could not resolve tcp address (%s):%w", addrStr, err)
} }
c.link.conn, err = net.DialTCP("tcp4", nil, addr) c.link.conn, err = net.DialTCP("tcp4", nil, addr)
if err != nil { if err != nil {
c.log(WarnLevel, pkg+"dial failed", "error", err.Error()) c.log(WarnLevel, pkg+"dial failed", "error", err.Error())
return err return fmt.Errorf("could not dial tcp: %w", err)
} }
c.log(DebugLevel, pkg+"connected") c.log(DebugLevel, pkg+"connected")
@ -184,33 +186,42 @@ func connect(c *Conn) error {
err = handshake(c) err = handshake(c)
if err != nil { if err != nil {
c.log(WarnLevel, pkg+"handshake failed", "error", err.Error()) c.log(WarnLevel, pkg+"handshake failed", "error", err.Error())
return err return fmt.Errorf("could not handshake: %w", err)
} }
c.log(DebugLevel, pkg+"handshaked") c.log(DebugLevel, pkg+"handshaked")
err = sendConnectPacket(c) err = sendConnectPacket(c)
if err != nil { if err != nil {
c.log(WarnLevel, pkg+"sendConnect failed", "error", err.Error()) c.log(WarnLevel, pkg+"sendConnect failed", "error", err.Error())
return err return fmt.Errorf("could not send connect packet: %w", err)
} }
c.log(DebugLevel, pkg+"negotiating") c.log(DebugLevel, pkg+"negotiating")
var buf [256]byte var buf [256]byte
pkt := packet{buf: buf[:]}
for !c.isPlaying { for !c.isPlaying {
pkt := packet{buf: buf[:]}
err = pkt.readFrom(c) err = pkt.readFrom(c)
if err != nil { if err != nil {
return err return fmt.Errorf("could not read from packet: %w", err)
} }
switch pkt.packetType { if pkt.isReady() {
case packetTypeAudio, packetTypeVideo, packetTypeInfo: if pkt.bodySize == 0 {
c.log(WarnLevel, pkg+"got packet before play; ignoring", "type", pkt.packetType) continue
default:
err = handlePacket(c, &pkt)
if err != nil {
return err
} }
switch pkt.packetType {
case packetTypeAudio, packetTypeVideo, packetTypeInfo:
c.log(WarnLevel, pkg+"got packet before play; ignoring", "type", pkt.packetType)
default:
err = handlePacket(c, &pkt)
if err != nil {
return fmt.Errorf("could not handle packet: %w", err)
}
}
pkt = packet{buf: buf[:]}
} }
} }
return nil return nil
} }
@ -248,7 +259,7 @@ func handlePacket(c *Conn, pkt *packet) error {
err := handleInvoke(c, pkt.body[:pkt.bodySize]) err := handleInvoke(c, pkt.body[:pkt.bodySize])
if err != nil { if err != nil {
c.log(WarnLevel, pkg+"unexpected error from handleInvoke", "error", err.Error()) c.log(WarnLevel, pkg+"unexpected error from handleInvoke", "error", err.Error())
return err return fmt.Errorf("could not handle invoke: %w", err)
} }
case packetTypeControl, packetTypeAudio, packetTypeVideo, packetTypeFlashVideo, packetTypeFlexMessage, packetTypeInfo: case packetTypeControl, packetTypeAudio, packetTypeVideo, packetTypeFlashVideo, packetTypeFlexMessage, packetTypeInfo:
@ -273,12 +284,12 @@ func sendConnectPacket(c *Conn) error {
enc, err := amf.EncodeString(enc, avConnect) enc, err := amf.EncodeString(enc, avConnect)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode string: %w", err)
} }
c.numInvokes += 1 c.numInvokes += 1
enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode number of invokes: %w", err)
} }
// required link info // required link info
@ -289,24 +300,29 @@ func sendConnectPacket(c *Conn) error {
} }
enc, err = amf.Encode(&info, enc) enc, err = amf.Encode(&info, enc)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode info: %w", err)
} }
// optional link auth info // optional link auth info
if c.link.auth != "" { if c.link.auth != "" {
enc, err = amf.EncodeBoolean(enc, c.link.flags&linkAuth != 0) enc, err = amf.EncodeBoolean(enc, c.link.flags&linkAuth != 0)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode link auth bool: %w", err)
} }
enc, err = amf.EncodeString(enc, c.link.auth) enc, err = amf.EncodeString(enc, c.link.auth)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode link auth string: %w", err)
} }
} }
pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))
return pkt.writeTo(c, true) // response expected err = pkt.writeTo(c, true)
if err != nil {
return fmt.Errorf("could not write packet: %w", err)
}
return nil
} }
func sendCreateStream(c *Conn) error { func sendCreateStream(c *Conn) error {
@ -322,19 +338,23 @@ func sendCreateStream(c *Conn) error {
enc, err := amf.EncodeString(enc, avCreatestream) enc, err := amf.EncodeString(enc, avCreatestream)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode av create stream token: %w", err)
} }
c.numInvokes++ c.numInvokes++
enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode number of invokes: %w", err)
} }
enc[0] = amf.TypeNull enc[0] = amf.TypeNull
enc = enc[1:] enc = enc[1:]
pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))
return pkt.writeTo(c, true) // response expected err = pkt.writeTo(c, true)
if err != nil {
return fmt.Errorf("could not write packet: %w", err)
}
return nil
} }
func sendReleaseStream(c *Conn) error { func sendReleaseStream(c *Conn) error {
@ -350,22 +370,27 @@ func sendReleaseStream(c *Conn) error {
enc, err := amf.EncodeString(enc, avReleasestream) enc, err := amf.EncodeString(enc, avReleasestream)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode av release stream token: %w", err)
} }
c.numInvokes++ c.numInvokes++
enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode number of invokes: %w", err)
} }
enc[0] = amf.TypeNull enc[0] = amf.TypeNull
enc = enc[1:] enc = enc[1:]
enc, err = amf.EncodeString(enc, c.link.playpath) enc, err = amf.EncodeString(enc, c.link.playpath)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode playpath: %w", err)
} }
pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))
return pkt.writeTo(c, false) err = pkt.writeTo(c, false)
if err != nil {
return fmt.Errorf("could not write packet: %w", err)
}
return nil
} }
func sendFCPublish(c *Conn) error { func sendFCPublish(c *Conn) error {
@ -381,23 +406,28 @@ func sendFCPublish(c *Conn) error {
enc, err := amf.EncodeString(enc, avFCPublish) enc, err := amf.EncodeString(enc, avFCPublish)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode av fc publish token: %w", err)
} }
c.numInvokes++ c.numInvokes++
enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode number of invokes: %w", err)
} }
enc[0] = amf.TypeNull enc[0] = amf.TypeNull
enc = enc[1:] enc = enc[1:]
enc, err = amf.EncodeString(enc, c.link.playpath) enc, err = amf.EncodeString(enc, c.link.playpath)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode playpath: %w", err)
} }
pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))
return pkt.writeTo(c, false) err = pkt.writeTo(c, false)
if err != nil {
return fmt.Errorf("could not write packet: %w", err)
}
return nil
} }
func sendFCUnpublish(c *Conn) error { func sendFCUnpublish(c *Conn) error {
@ -413,23 +443,28 @@ func sendFCUnpublish(c *Conn) error {
enc, err := amf.EncodeString(enc, avFCUnpublish) enc, err := amf.EncodeString(enc, avFCUnpublish)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode av fc unpublish token: %w", err)
} }
c.numInvokes++ c.numInvokes++
enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode number of invokes: %w", err)
} }
enc[0] = amf.TypeNull enc[0] = amf.TypeNull
enc = enc[1:] enc = enc[1:]
enc, err = amf.EncodeString(enc, c.link.playpath) enc, err = amf.EncodeString(enc, c.link.playpath)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode link playpath: %w", err)
} }
pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))
return pkt.writeTo(c, false) err = pkt.writeTo(c, false)
if err != nil {
return fmt.Errorf("could not write packet: %w", err)
}
return nil
} }
func sendPublish(c *Conn) error { func sendPublish(c *Conn) error {
@ -445,27 +480,32 @@ func sendPublish(c *Conn) error {
enc, err := amf.EncodeString(enc, avPublish) enc, err := amf.EncodeString(enc, avPublish)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode av publish token: %w", err)
} }
c.numInvokes++ c.numInvokes++
enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode number of invokes: %w", err)
} }
enc[0] = amf.TypeNull enc[0] = amf.TypeNull
enc = enc[1:] enc = enc[1:]
enc, err = amf.EncodeString(enc, c.link.playpath) enc, err = amf.EncodeString(enc, c.link.playpath)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode link playpath: %w", err)
} }
enc, err = amf.EncodeString(enc, avLive) enc, err = amf.EncodeString(enc, avLive)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode av live token: %w", err)
} }
pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))
return pkt.writeTo(c, true) // response expected err = pkt.writeTo(c, true)
if err != nil {
return fmt.Errorf("could not write packet: %w", err)
}
return nil
} }
func sendDeleteStream(c *Conn, streamID float64) error { func sendDeleteStream(c *Conn, streamID float64) error {
@ -481,22 +521,27 @@ func sendDeleteStream(c *Conn, streamID float64) error {
enc, err := amf.EncodeString(enc, avDeletestream) enc, err := amf.EncodeString(enc, avDeletestream)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode av delete stream token: %w", err)
} }
c.numInvokes++ c.numInvokes++
enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode number of invokes: %w", err)
} }
enc[0] = amf.TypeNull enc[0] = amf.TypeNull
enc = enc[1:] enc = enc[1:]
enc, err = amf.EncodeNumber(enc, streamID) enc, err = amf.EncodeNumber(enc, streamID)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode stream id: %w", err)
} }
pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))
return pkt.writeTo(c, false) err = pkt.writeTo(c, false)
if err != nil {
return fmt.Errorf("could not write packet: %w", err)
}
return nil
} }
// sendBytesReceived tells the server how many bytes the client has received. // sendBytesReceived tells the server how many bytes the client has received.
@ -515,11 +560,16 @@ func sendBytesReceived(c *Conn) error {
enc, err := amf.EncodeInt32(enc, c.nBytesIn) enc, err := amf.EncodeInt32(enc, c.nBytesIn)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode number of bytes in: %w", err)
} }
pkt.bodySize = 4 pkt.bodySize = 4
return pkt.writeTo(c, false) err = pkt.writeTo(c, false)
if err != nil {
return fmt.Errorf("could not write packet: %w", err)
}
return nil
} }
func sendCheckBW(c *Conn) error { func sendCheckBW(c *Conn) error {
@ -535,19 +585,24 @@ func sendCheckBW(c *Conn) error {
enc, err := amf.EncodeString(enc, av_checkbw) enc, err := amf.EncodeString(enc, av_checkbw)
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode av check bw token: %w", err)
} }
c.numInvokes++ c.numInvokes++
enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) enc, err = amf.EncodeNumber(enc, float64(c.numInvokes))
if err != nil { if err != nil {
return err return fmt.Errorf("could not encode number of invokes: %w", err)
} }
enc[0] = amf.TypeNull enc[0] = amf.TypeNull
enc = enc[1:] enc = enc[1:]
pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))
return pkt.writeTo(c, false) err = pkt.writeTo(c, false)
if err != nil {
return fmt.Errorf("could not write packet: %w", err)
}
return nil
} }
func eraseMethod(m []method, i int) []method { func eraseMethod(m []method, i int) []method {
@ -565,16 +620,16 @@ func handleInvoke(c *Conn, body []byte) error {
var obj amf.Object var obj amf.Object
_, err := amf.Decode(&obj, body, false) _, err := amf.Decode(&obj, body, false)
if err != nil { if err != nil {
return err return fmt.Errorf("could not decode: %w", err)
} }
meth, err := obj.StringProperty("", 0) meth, err := obj.StringProperty("", 0)
if err != nil { if err != nil {
return err return fmt.Errorf("could not get value of string property meth: %w", err)
} }
txn, err := obj.NumberProperty("", 1) txn, err := obj.NumberProperty("", 1)
if err != nil { if err != nil {
return err return fmt.Errorf("could not get value of number property txn: %w", err)
} }
c.log(DebugLevel, pkg+"invoking method "+meth) c.log(DebugLevel, pkg+"invoking method "+meth)
@ -601,26 +656,26 @@ func handleInvoke(c *Conn, body []byte) error {
case avConnect: case avConnect:
err := sendReleaseStream(c) err := sendReleaseStream(c)
if err != nil { if err != nil {
return err return fmt.Errorf("could not send release stream: %w", err)
} }
err = sendFCPublish(c) err = sendFCPublish(c)
if err != nil { if err != nil {
return err return fmt.Errorf("could not send fc publish: %w", err)
} }
err = sendCreateStream(c) err = sendCreateStream(c)
if err != nil { if err != nil {
return err return fmt.Errorf("could not send create stream: %w", err)
} }
case avCreatestream: case avCreatestream:
n, err := obj.NumberProperty("", 3) n, err := obj.NumberProperty("", 3)
if err != nil { if err != nil {
return err return fmt.Errorf("could not get value for stream id number property: %w",err)
} }
c.streamID = uint32(n) c.streamID = uint32(n)
err = sendPublish(c) err = sendPublish(c)
if err != nil { if err != nil {
return err return fmt.Errorf("could not send publish: %w",err)
} }
default: default:
@ -630,27 +685,27 @@ func handleInvoke(c *Conn, body []byte) error {
case avOnBWDone: case avOnBWDone:
err := sendCheckBW(c) err := sendCheckBW(c)
if err != nil { if err != nil {
return err return fmt.Errorf("could not send check bw: %w", err)
} }
case avOnStatus: case avOnStatus:
obj2, err := obj.ObjectProperty("", 3) obj2, err := obj.ObjectProperty("", 3)
if err != nil { if err != nil {
return err return fmt.Errorf("could not get object property value for obj2: %w", err)
} }
code, err := obj2.StringProperty(avCode, -1) code, err := obj2.StringProperty(avCode, -1)
if err != nil { if err != nil {
return err return fmt.Errorf("could not get string property value for code: %w", err)
} }
level, err := obj2.StringProperty(avLevel, -1) level, err := obj2.StringProperty(avLevel, -1)
if err != nil { if err != nil {
return err return fmt.Errorf("could not get string property value for level: %w", err)
} }
c.log(DebugLevel, pkg+"onStatus", "code", code, "level", level) c.log(DebugLevel, pkg+"onStatus", "code", code, "level", level)
if code != avNetStreamPublish_Start { if code != avNetStreamPublish_Start {
c.log(ErrorLevel, pkg+"unexpected response "+code) c.log(ErrorLevel, pkg+"unexpected response "+code)
return errUnimplemented return fmt.Errorf("unimplemented code: %v", code)
} }
c.log(DebugLevel, pkg+"playing") c.log(DebugLevel, pkg+"playing")
c.isPlaying = true c.isPlaying = true
@ -681,14 +736,14 @@ func handshake(c *Conn) error {
_, err := c.write(clientbuf[:]) _, err := c.write(clientbuf[:])
if err != nil { if err != nil {
return err return fmt.Errorf("could not write handshake: %w", err)
} }
c.log(DebugLevel, pkg+"handshake sent") c.log(DebugLevel, pkg+"handshake sent")
var typ [1]byte var typ [1]byte
_, err = c.read(typ[:]) _, err = c.read(typ[:])
if err != nil { if err != nil {
return err return fmt.Errorf("could not read handshake: %w", err)
} }
c.log(DebugLevel, pkg+"handshake received") c.log(DebugLevel, pkg+"handshake received")
@ -697,7 +752,7 @@ func handshake(c *Conn) error {
} }
_, err = c.read(serversig[:]) _, err = c.read(serversig[:])
if err != nil { if err != nil {
return err return fmt.Errorf("could not read server signal: %w", err)
} }
// decode server response // decode server response
@ -707,12 +762,12 @@ func handshake(c *Conn) error {
// 2nd part of handshake // 2nd part of handshake
_, err = c.write(serversig[:]) _, err = c.write(serversig[:])
if err != nil { if err != nil {
return err return fmt.Errorf("could not write part 2 of handshake: %w", err)
} }
_, err = c.read(serversig[:]) _, err = c.read(serversig[:])
if err != nil { if err != nil {
return err return fmt.Errorf("could not read part 2 of handshake: %w", err)
} }
if !bytes.Equal(serversig[:signatureSize], clientbuf[1:signatureSize+1]) { if !bytes.Equal(serversig[:signatureSize], clientbuf[1:signatureSize+1]) {