diff --git a/protocol/rtmp/amf/amf.go b/protocol/rtmp/amf/amf.go index 02a9d005..e1145c30 100644 --- a/protocol/rtmp/amf/amf.go +++ b/protocol/rtmp/amf/amf.go @@ -43,6 +43,7 @@ package amf import ( "encoding/binary" "errors" + "fmt" "math" ) @@ -163,11 +164,7 @@ func EncodeInt32(buf []byte, val uint32) ([]byte, error) { // Strings less than 65536 in length are encoded as TypeString, while longer strings are ecodeded as typeLongString. func EncodeString(buf []byte, val string) ([]byte, error) { const typeSize = 1 - if len(val) < 65536 && len(val)+typeSize+binary.Size(int16(0)) > len(buf) { - return nil, ErrShortBuffer - } - - if len(val)+typeSize+binary.Size(uint32(0)) > len(buf) { + if len(val) < 65536 && len(val)+typeSize+binary.Size(int16(0)) > len(buf) || len(val)+typeSize+binary.Size(uint32(0)) > len(buf) { return nil, ErrShortBuffer } @@ -222,7 +219,11 @@ func EncodeNamedString(buf []byte, key, val string) ([]byte, error) { binary.BigEndian.PutUint16(buf[:2], uint16(len(key))) buf = buf[2:] copy(buf, key) - return EncodeString(buf[len(key):], val) + b, err := EncodeString(buf[len(key):], val) + if err != nil { + return nil, fmt.Errorf("could not encode string: %w", err) + } + return b, nil } // EncodeNamedNumber encodes a named number, where key is the name and val is the number value. @@ -233,7 +234,11 @@ func EncodeNamedNumber(buf []byte, key string, val float64) ([]byte, error) { binary.BigEndian.PutUint16(buf[:2], uint16(len(key))) buf = buf[2:] copy(buf, key) - return EncodeNumber(buf[len(key):], val) + b, err := EncodeNumber(buf[len(key):], val) + if err != nil { + return nil, fmt.Errorf("could not encode number: %w", err) + } + return b, nil } // EncodeNamedNumber encodes a named boolean, where key is the name and val is the boolean value. @@ -244,14 +249,18 @@ func EncodeNamedBoolean(buf []byte, key string, val bool) ([]byte, error) { binary.BigEndian.PutUint16(buf[:2], uint16(len(key))) buf = buf[2:] copy(buf, key) - return EncodeBoolean(buf[len(key):], val) + b, err := EncodeBoolean(buf[len(key):], val) + if err != nil { + return nil, fmt.Errorf("could not encode boolean: %w", err) + } + return b, nil } // EncodeProperty encodes a property. func EncodeProperty(prop *Property, buf []byte) ([]byte, error) { if prop.Type != TypeNull && prop.Name != "" { if len(buf) < 2+len(prop.Name) { - return nil, ErrShortBuffer + return nil, fmt.Errorf("not type null, short buffer: %w", ErrShortBuffer) } binary.BigEndian.PutUint16(buf[:2], uint16(len(prop.Name))) buf = buf[2:] @@ -261,23 +270,47 @@ func EncodeProperty(prop *Property, buf []byte) ([]byte, error) { switch prop.Type { case typeNumber: - return EncodeNumber(buf, prop.Number) + b, err := EncodeNumber(buf, prop.Number) + if err != nil { + return nil, fmt.Errorf("could not encode number: %w", err) + } + return b, nil case typeBoolean: - return EncodeBoolean(buf, prop.Number != 0) + b, err := EncodeBoolean(buf, prop.Number != 0) + if err != nil { + return nil, fmt.Errorf("could not encode boolean: %w", err) + } + return b, nil case TypeString: - return EncodeString(buf, prop.String) + b, err := EncodeString(buf, prop.String) + if err != nil { + return nil, fmt.Errorf("could not encode string: %w", err) + } + return b, nil case TypeNull: if len(buf) < 2 { - return nil, ErrShortBuffer + return nil, fmt.Errorf("type null, short buffer: %w", ErrShortBuffer) } buf[0] = TypeNull buf = buf[1:] case TypeObject: - return Encode(&prop.Object, buf) + b, err := Encode(&prop.Object, buf) + if err != nil { + return nil, fmt.Errorf("could not encode: %w", err) + } + return b, nil case typeEcmaArray: - return EncodeEcmaArray(&prop.Object, buf) + b, err := EncodeEcmaArray(&prop.Object, buf) + if err != nil { + return nil, fmt.Errorf("could not encode ecma array: %w", err) + } + return b, nil case typeStrictArray: - return EncodeArray(&prop.Object, buf) + b, err := EncodeArray(&prop.Object, buf) + if err != nil { + return nil, fmt.Errorf("could not encode array: %w", err) + } + return b, nil default: return nil, ErrInvalidType } @@ -294,7 +327,7 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) { } n := DecodeInt16(buf[:2]) if int(n) > len(buf)-2 { - return 0, ErrShortBuffer + return 0, fmt.Errorf("short buffer after decode of int 16: %w", ErrShortBuffer) } prop.Name = DecodeString(buf) @@ -309,14 +342,14 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) { switch prop.Type { case typeNumber: if len(buf) < 8 { - return 0, ErrShortBuffer + return 0, fmt.Errorf("type number short buffer: %w", ErrShortBuffer) } prop.Number = DecodeNumber(buf[:8]) buf = buf[8:] case typeBoolean: if len(buf) < 1 { - return 0, ErrShortBuffer + return 0, fmt.Errorf("type boolean short buffer: %w", ErrShortBuffer) } prop.Number = float64(buf[0]) buf = buf[1:] @@ -324,7 +357,7 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) { case TypeString: n := DecodeInt16(buf[:2]) if len(buf) < int(n+2) { - return 0, ErrShortBuffer + return 0, fmt.Errorf("type string: %w", ErrShortBuffer) } prop.String = DecodeString(buf) buf = buf[2+n:] @@ -332,7 +365,7 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) { case TypeObject: n, err := Decode(&prop.Object, buf, true) if err != nil { - return 0, err + return 0, fmt.Errorf("could not decode type object: %w", err) } buf = buf[n:] @@ -343,7 +376,7 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) { buf = buf[4:] n, err := Decode(&prop.Object, buf, true) if err != nil { - return 0, err + return 0, fmt.Errorf("could not decode type ecma array: %w", err) } buf = buf[n:] @@ -367,14 +400,19 @@ func Encode(obj *Object, buf []byte) ([]byte, error) { var err error buf, err = EncodeProperty(&obj.Properties[i], buf) if err != nil { - return nil, err + return nil, fmt.Errorf("could not encode property no. %d: %w", i, err) } } if len(buf) < 3 { - return nil, ErrShortBuffer + return nil, fmt.Errorf("short buffer after property encoding: %w", ErrShortBuffer) } - return EncodeInt24(buf, TypeObjectEnd) + + b, err := EncodeInt24(buf, TypeObjectEnd) + if err != nil { + return nil, fmt.Errorf("could not encode int 24: %w", err) + } + return b, err } // EncodeEcmaArray encodes an ECMA array. @@ -392,14 +430,20 @@ func EncodeEcmaArray(obj *Object, buf []byte) ([]byte, error) { var err error buf, err = EncodeProperty(&obj.Properties[i], buf) if err != nil { - return nil, err + return nil, fmt.Errorf("could not encode property no. %d: %w", i, err) } } if len(buf) < 3 { - return nil, ErrShortBuffer + return nil, fmt.Errorf("short buffer after property encoding: %w", ErrShortBuffer) } - return EncodeInt24(buf, TypeObjectEnd) + + b, err := EncodeInt24(buf, TypeObjectEnd) + if err != nil { + return nil, fmt.Errorf("could not encode int 24: %w", err) + } + + return b, nil } // EncodeArray encodes an array. @@ -417,7 +461,7 @@ func EncodeArray(obj *Object, buf []byte) ([]byte, error) { var err error buf, err = EncodeProperty(&obj.Properties[i], buf) if err != nil { - return nil, err + return nil, fmt.Errorf("could not encode property no. %d: %w", i, err) } } @@ -437,7 +481,7 @@ func Decode(obj *Object, buf []byte, decodeName bool) (int, error) { var prop Property n, err := DecodeProperty(&prop, buf, decodeName) if err != nil { - return 0, err + return 0, fmt.Errorf("could not decode property: %w", err) } buf = buf[n:] obj.Properties = append(obj.Properties, prop) diff --git a/protocol/rtmp/conn.go b/protocol/rtmp/conn.go index 56464578..c0077ebe 100644 --- a/protocol/rtmp/conn.go +++ b/protocol/rtmp/conn.go @@ -138,14 +138,14 @@ func Dial(url string, log Log, options ...func(*Conn) error) (*Conn, error) { var err error c.link.protocol, c.link.host, c.link.port, c.link.app, c.link.playpath, err = parseURL(url) if err != nil { - return nil, err + return nil, fmt.Errorf("could not parse url: %w",err) } c.link.url = rtmpProtocolStrings[c.link.protocol] + "://" + c.link.host + ":" + strconv.Itoa(int(c.link.port)) + "/" + c.link.app c.link.protocol |= featureWrite err = connect(&c) if err != nil { - return nil, err + return nil, fmt.Errorf("could not connect: %w",err) } return &c, nil } @@ -159,11 +159,20 @@ func (c *Conn) Close() error { c.log(DebugLevel, pkg+"Conn.Close") if c.streamID > 0 { if c.link.protocol&featureWrite != 0 { - sendFCUnpublish(c) + err := sendFCUnpublish(c) + if err != nil { + return fmt.Errorf("could not send fc unpublish: %w",err) + } + } + err := sendDeleteStream(c, float64(c.streamID)) + if err != nil { + return fmt.Errorf("could not send delete stream: %w",err) } - sendDeleteStream(c, float64(c.streamID)) } - c.link.conn.Close() + err := c.link.conn.Close() + if err != nil { + return fmt.Errorf("could not close link conn: %w",err) + } *c = Conn{} return nil } @@ -192,7 +201,7 @@ func (c *Conn) Write(data []byte) (int, error) { copy(pkt.body, data[flvTagheaderSize:flvTagheaderSize+pkt.bodySize]) err := pkt.writeTo(c, false) if err != nil { - return 0, err + return 0, fmt.Errorf("could not write packet to connection: %w",err) } return len(data), nil } @@ -205,18 +214,18 @@ func (c *Conn) Write(data []byte) (int, error) { func (c *Conn) read(buf []byte) (int, error) { err := c.link.conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(c.link.timeout))) if err != nil { - return 0, err + return 0, fmt.Errorf("could not set read deadline: %w",err) } n, err := io.ReadFull(c.link.conn, buf) if err != nil { c.log(DebugLevel, pkg+"read failed", "error", err.Error()) - return 0, err + return 0, fmt.Errorf("could not read conn: %w",err) } c.nBytesIn += uint32(n) if c.nBytesIn > (c.nBytesInSent + c.clientBW/10) { err := sendBytesReceived(c) if err != nil { - return n, err // NB: we still read n bytes, even though send bytes failed + return n, fmt.Errorf("could not send bytes received: %w",err) // NB: we still read n bytes, even though send bytes failed } } return n, nil @@ -227,12 +236,12 @@ func (c *Conn) write(buf []byte) (int, error) { //ToDo: consider using a different timeout for writes than for reads err := c.link.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(c.link.timeout))) if err != nil { - return 0, err + return 0, fmt.Errorf("could not set write deadline: %w",err) } n, err := c.link.conn.Write(buf) if err != nil { c.log(WarnLevel, pkg+"write failed", "error", err.Error()) - return 0, err + return 0, fmt.Errorf("could not write to conn: %w",err) } return n, nil } diff --git a/protocol/rtmp/packet.go b/protocol/rtmp/packet.go index dda53d35..128751a7 100644 --- a/protocol/rtmp/packet.go +++ b/protocol/rtmp/packet.go @@ -36,6 +36,7 @@ package rtmp import ( "encoding/binary" + "fmt" "io" "bitbucket.org/ausocean/av/protocol/rtmp/amf" @@ -90,10 +91,15 @@ type packet struct { timestamp uint32 streamID uint32 bodySize uint32 + bytesRead uint32 buf []byte body []byte } +func (pkt *packet) isReady() bool { + return pkt.bytesRead == pkt.bodySize +} + // readFrom reads a packet from the RTMP connection. func (pkt *packet) readFrom(c *Conn) error { var hbuf [fullHeaderSize]byte @@ -105,7 +111,7 @@ func (pkt *packet) readFrom(c *Conn) error { if err == io.EOF { c.log(WarnLevel, pkg+"EOF error; connection likely terminated") } - return err + return fmt.Errorf("failed to read packet header 1st byte: %w", err) } pkt.headerType = (header[0] & 0xc0) >> 6 pkt.channel = int32(header[0] & 0x3f) @@ -116,7 +122,7 @@ func (pkt *packet) readFrom(c *Conn) error { _, err = c.read(header[:1]) if err != nil { c.log(DebugLevel, pkg+"failed to read packet header 2nd byte", "error", err.Error()) - return err + return fmt.Errorf("failed to read packet header second byte: %w", err) } header = header[1:] pkt.channel = int32(header[0]) + 64 @@ -125,7 +131,7 @@ func (pkt *packet) readFrom(c *Conn) error { _, err = c.read(header[:2]) if err != nil { c.log(DebugLevel, pkg+"failed to read packet header 3rd byte", "error", err.Error()) - return err + return fmt.Errorf("failed to read packet header 3rd byte: %w", err) } header = header[2:] pkt.channel = int32(binary.BigEndian.Uint16(header[:2])) + 64 @@ -169,13 +175,14 @@ func (pkt *packet) readFrom(c *Conn) error { _, err = c.read(header[:size]) if err != nil { c.log(DebugLevel, pkg+"failed to read packet header", "error", err.Error()) - return err + return fmt.Errorf("failed to read packet header: %w", err) } } hSize := len(hbuf) - len(header) + size if size >= 3 { pkt.timestamp = amf.DecodeInt24(header[:3]) + pkt.bytesRead = 0 if size >= 6 { pkt.bodySize = amf.DecodeInt24(header[3:6]) @@ -193,7 +200,7 @@ func (pkt *packet) readFrom(c *Conn) error { _, err = c.read(header[size : size+4]) if err != nil { c.log(DebugLevel, pkg+"failed to read extended timestamp", "error", err.Error()) - return err + return fmt.Errorf("failed to read extended timestamp: %w", err) } pkt.timestamp = amf.DecodeInt32(header[size : size+4]) hSize += 4 @@ -205,12 +212,24 @@ func (pkt *packet) readFrom(c *Conn) error { c.log(WarnLevel, pkg+"reading large packet", "size", int(pkt.bodySize)) } - _, err = c.read(pkt.body[:pkt.bodySize]) + nToRead := pkt.bodySize - pkt.bytesRead + nChunk := c.inChunkSize + if nToRead < nChunk { + nChunk = nToRead + } + + n, err := c.read(pkt.body[pkt.bytesRead : pkt.bytesRead+nChunk]) if err != nil { c.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error()) - return err + return fmt.Errorf("failed to read packet body: %w", err) } + if uint32(n) != nChunk { + return fmt.Errorf("did not read correct number of bytes, read: %d, expected: %d", n, nChunk) + } + + pkt.bytesRead += nChunk + // Keep the packet as a reference for other packets on this channel. if c.channelsIn[pkt.channel] == nil { c.channelsIn[pkt.channel] = &packet{} @@ -221,14 +240,17 @@ func (pkt *packet) readFrom(c *Conn) error { c.channelsIn[pkt.channel].timestamp = 0xffffff } - if !pkt.hasAbsTimestamp { - // Timestamps seem to always be relative. - pkt.timestamp += uint32(c.channelTimestamp[pkt.channel]) - } - c.channelTimestamp[pkt.channel] = int32(pkt.timestamp) + if pkt.isReady() { + if !pkt.hasAbsTimestamp { + // Timestamps seem to always be relative. + pkt.timestamp += uint32(c.channelTimestamp[pkt.channel]) + } + c.channelTimestamp[pkt.channel] = int32(pkt.timestamp) + + c.channelsIn[pkt.channel].body = nil + c.channelsIn[pkt.channel].hasAbsTimestamp = false + } - c.channelsIn[pkt.channel].body = nil - c.channelsIn[pkt.channel].hasAbsTimestamp = false return nil } @@ -404,7 +426,7 @@ func (pkt *packet) writeTo(c *Conn, queue bool) error { c.log(DebugLevel, pkg+"sending deferred packet separately", "size", len(c.deferred)) _, err := c.write(c.deferred) if err != nil { - return err + return fmt.Errorf("could not write deferred packet: %w", err) } c.deferred = nil } @@ -424,7 +446,7 @@ func (pkt *packet) writeTo(c *Conn, queue bool) error { } _, err := c.write(bytes) if err != nil { - return err + return fmt.Errorf("could not write combined packet: %w", err) } c.deferred = nil diff --git a/protocol/rtmp/parseurl.go b/protocol/rtmp/parseurl.go index 132c9acd..855d5ff4 100644 --- a/protocol/rtmp/parseurl.go +++ b/protocol/rtmp/parseurl.go @@ -34,17 +34,25 @@ LICENSE package rtmp import ( + "errors" + "fmt" "net/url" "path" "strconv" "strings" ) +// Errors. +var ( + errInvalidPath = errors.New("invalid url path") + errInvalidElements = errors.New("invalid url elements") +) + // parseURL parses an RTMP URL (ok, technically it is lexing). func parseURL(addr string) (protocol int32, host string, port uint16, app, playpath string, err error) { u, err := url.Parse(addr) if err != nil { - return protocol, host, port, app, playpath, err + return protocol, host, port, app, playpath, fmt.Errorf("could not parse to url value: %w", err) } switch u.Scheme { @@ -63,24 +71,24 @@ func parseURL(addr string) (protocol int32, host string, port uint16, app, playp case "rtmpts": protocol = protoRTMPTS default: - return protocol, host, port, app, playpath, errUnknownScheme + return protocol, host, port, app, playpath, fmt.Errorf("unknown scheme: %s", u.Scheme) } host = u.Host if p := u.Port(); p != "" { pi, err := strconv.Atoi(p) if err != nil { - return protocol, host, port, app, playpath, err + return protocol, host, port, app, playpath, fmt.Errorf("could convert port to integer: %w", err) } port = uint16(pi) } if len(u.Path) < 1 || !path.IsAbs(u.Path) { - return protocol, host, port, app, playpath, errInvalidURL + return protocol, host, port, app, playpath, errInvalidPath } elems := strings.SplitN(u.Path[1:], "/", 3) if len(elems) < 2 || elems[0] == "" || elems[1] == "" { - return protocol, host, port, app, playpath, errInvalidURL + return protocol, host, port, app, playpath, errInvalidElements } app = elems[0] playpath = path.Join(elems[1:]...) @@ -106,7 +114,7 @@ func parseURL(addr string) (protocol int32, host string, port uint16, app, playp switch { case port != 0: case (protocol & featureSSL) != 0: - return protocol, host, port, app, playpath, errUnimplemented // port = 433 + return protocol, host, port, app, playpath, errors.New("ssl not implemented") case (protocol & featureHTTP) != 0: port = 80 default: diff --git a/protocol/rtmp/parseurl_test.go b/protocol/rtmp/parseurl_test.go index f6501901..90bfd099 100644 --- a/protocol/rtmp/parseurl_test.go +++ b/protocol/rtmp/parseurl_test.go @@ -41,19 +41,19 @@ var parseURLTests = []struct { }{ { url: "rtmp://addr", - wantErr: errInvalidURL, + wantErr: errInvalidPath, }, { url: "rtmp://addr/", - wantErr: errInvalidURL, + wantErr: errInvalidElements, }, { url: "rtmp://addr/live2", - wantErr: errInvalidURL, + wantErr: errInvalidElements, }, { url: "rtmp://addr/live2/", - wantErr: errInvalidURL, + wantErr: errInvalidElements, }, { url: "rtmp://addr/appname/key", diff --git a/protocol/rtmp/rtmp.go b/protocol/rtmp/rtmp.go index dcaefe42..c7e4efc0 100644 --- a/protocol/rtmp/rtmp.go +++ b/protocol/rtmp/rtmp.go @@ -38,6 +38,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "math/rand" "net" "strconv" @@ -164,14 +165,15 @@ var ( // connect establishes an RTMP connection. func connect(c *Conn) error { - addr, err := net.ResolveTCPAddr("tcp4", c.link.host+":"+strconv.Itoa(int(c.link.port))) + addrStr := c.link.host + ":" + strconv.Itoa(int(c.link.port)) + addr, err := net.ResolveTCPAddr("tcp4", addrStr) if err != nil { - return err + return fmt.Errorf("could not resolve tcp address (%s):%w", addrStr, err) } c.link.conn, err = net.DialTCP("tcp4", nil, addr) if err != nil { c.log(WarnLevel, pkg+"dial failed", "error", err.Error()) - return err + return fmt.Errorf("could not dial tcp: %w", err) } c.log(DebugLevel, pkg+"connected") @@ -184,33 +186,42 @@ func connect(c *Conn) error { err = handshake(c) if err != nil { c.log(WarnLevel, pkg+"handshake failed", "error", err.Error()) - return err + return fmt.Errorf("could not handshake: %w", err) } c.log(DebugLevel, pkg+"handshaked") err = sendConnectPacket(c) if err != nil { c.log(WarnLevel, pkg+"sendConnect failed", "error", err.Error()) - return err + return fmt.Errorf("could not send connect packet: %w", err) } c.log(DebugLevel, pkg+"negotiating") var buf [256]byte + pkt := packet{buf: buf[:]} for !c.isPlaying { - pkt := packet{buf: buf[:]} err = pkt.readFrom(c) if err != nil { - return err + return fmt.Errorf("could not read from packet: %w", err) } - switch pkt.packetType { - case packetTypeAudio, packetTypeVideo, packetTypeInfo: - c.log(WarnLevel, pkg+"got packet before play; ignoring", "type", pkt.packetType) - default: - err = handlePacket(c, &pkt) - if err != nil { - return err + if pkt.isReady() { + if pkt.bodySize == 0 { + continue } + + switch pkt.packetType { + case packetTypeAudio, packetTypeVideo, packetTypeInfo: + c.log(WarnLevel, pkg+"got packet before play; ignoring", "type", pkt.packetType) + default: + err = handlePacket(c, &pkt) + if err != nil { + return fmt.Errorf("could not handle packet: %w", err) + } + } + pkt = packet{buf: buf[:]} } + + } return nil } @@ -248,7 +259,7 @@ func handlePacket(c *Conn, pkt *packet) error { err := handleInvoke(c, pkt.body[:pkt.bodySize]) if err != nil { c.log(WarnLevel, pkg+"unexpected error from handleInvoke", "error", err.Error()) - return err + return fmt.Errorf("could not handle invoke: %w", err) } case packetTypeControl, packetTypeAudio, packetTypeVideo, packetTypeFlashVideo, packetTypeFlexMessage, packetTypeInfo: @@ -273,12 +284,12 @@ func sendConnectPacket(c *Conn) error { enc, err := amf.EncodeString(enc, avConnect) if err != nil { - return err + return fmt.Errorf("could not encode string: %w", err) } c.numInvokes += 1 enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { - return err + return fmt.Errorf("could not encode number of invokes: %w", err) } // required link info @@ -289,24 +300,29 @@ func sendConnectPacket(c *Conn) error { } enc, err = amf.Encode(&info, enc) if err != nil { - return err + return fmt.Errorf("could not encode info: %w", err) } // optional link auth info if c.link.auth != "" { enc, err = amf.EncodeBoolean(enc, c.link.flags&linkAuth != 0) if err != nil { - return err + return fmt.Errorf("could not encode link auth bool: %w", err) } enc, err = amf.EncodeString(enc, c.link.auth) if err != nil { - return err + return fmt.Errorf("could not encode link auth string: %w", err) } } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(c, true) // response expected + err = pkt.writeTo(c, true) + if err != nil { + return fmt.Errorf("could not write packet: %w", err) + } + + return nil } func sendCreateStream(c *Conn) error { @@ -322,19 +338,23 @@ func sendCreateStream(c *Conn) error { enc, err := amf.EncodeString(enc, avCreatestream) if err != nil { - return err + return fmt.Errorf("could not encode av create stream token: %w", err) } c.numInvokes++ enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { - return err + return fmt.Errorf("could not encode number of invokes: %w", err) } enc[0] = amf.TypeNull enc = enc[1:] pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(c, true) // response expected + err = pkt.writeTo(c, true) + if err != nil { + return fmt.Errorf("could not write packet: %w", err) + } + return nil } func sendReleaseStream(c *Conn) error { @@ -350,22 +370,27 @@ func sendReleaseStream(c *Conn) error { enc, err := amf.EncodeString(enc, avReleasestream) if err != nil { - return err + return fmt.Errorf("could not encode av release stream token: %w", err) } c.numInvokes++ enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { - return err + return fmt.Errorf("could not encode number of invokes: %w", err) } enc[0] = amf.TypeNull enc = enc[1:] enc, err = amf.EncodeString(enc, c.link.playpath) if err != nil { - return err + return fmt.Errorf("could not encode playpath: %w", err) } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(c, false) + err = pkt.writeTo(c, false) + if err != nil { + return fmt.Errorf("could not write packet: %w", err) + } + + return nil } func sendFCPublish(c *Conn) error { @@ -381,23 +406,28 @@ func sendFCPublish(c *Conn) error { enc, err := amf.EncodeString(enc, avFCPublish) if err != nil { - return err + return fmt.Errorf("could not encode av fc publish token: %w", err) } c.numInvokes++ enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { - return err + return fmt.Errorf("could not encode number of invokes: %w", err) } enc[0] = amf.TypeNull enc = enc[1:] enc, err = amf.EncodeString(enc, c.link.playpath) if err != nil { - return err + return fmt.Errorf("could not encode playpath: %w", err) } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(c, false) + err = pkt.writeTo(c, false) + if err != nil { + return fmt.Errorf("could not write packet: %w", err) + } + + return nil } func sendFCUnpublish(c *Conn) error { @@ -413,23 +443,28 @@ func sendFCUnpublish(c *Conn) error { enc, err := amf.EncodeString(enc, avFCUnpublish) if err != nil { - return err + return fmt.Errorf("could not encode av fc unpublish token: %w", err) } c.numInvokes++ enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { - return err + return fmt.Errorf("could not encode number of invokes: %w", err) } enc[0] = amf.TypeNull enc = enc[1:] enc, err = amf.EncodeString(enc, c.link.playpath) if err != nil { - return err + return fmt.Errorf("could not encode link playpath: %w", err) } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(c, false) + err = pkt.writeTo(c, false) + if err != nil { + return fmt.Errorf("could not write packet: %w", err) + } + + return nil } func sendPublish(c *Conn) error { @@ -445,27 +480,32 @@ func sendPublish(c *Conn) error { enc, err := amf.EncodeString(enc, avPublish) if err != nil { - return err + return fmt.Errorf("could not encode av publish token: %w", err) } c.numInvokes++ enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { - return err + return fmt.Errorf("could not encode number of invokes: %w", err) } enc[0] = amf.TypeNull enc = enc[1:] enc, err = amf.EncodeString(enc, c.link.playpath) if err != nil { - return err + return fmt.Errorf("could not encode link playpath: %w", err) } enc, err = amf.EncodeString(enc, avLive) if err != nil { - return err + return fmt.Errorf("could not encode av live token: %w", err) } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(c, true) // response expected + err = pkt.writeTo(c, true) + if err != nil { + return fmt.Errorf("could not write packet: %w", err) + } + + return nil } func sendDeleteStream(c *Conn, streamID float64) error { @@ -481,22 +521,27 @@ func sendDeleteStream(c *Conn, streamID float64) error { enc, err := amf.EncodeString(enc, avDeletestream) if err != nil { - return err + return fmt.Errorf("could not encode av delete stream token: %w", err) } c.numInvokes++ enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { - return err + return fmt.Errorf("could not encode number of invokes: %w", err) } enc[0] = amf.TypeNull enc = enc[1:] enc, err = amf.EncodeNumber(enc, streamID) if err != nil { - return err + return fmt.Errorf("could not encode stream id: %w", err) } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(c, false) + err = pkt.writeTo(c, false) + if err != nil { + return fmt.Errorf("could not write packet: %w", err) + } + + return nil } // sendBytesReceived tells the server how many bytes the client has received. @@ -515,11 +560,16 @@ func sendBytesReceived(c *Conn) error { enc, err := amf.EncodeInt32(enc, c.nBytesIn) if err != nil { - return err + return fmt.Errorf("could not encode number of bytes in: %w", err) } pkt.bodySize = 4 - return pkt.writeTo(c, false) + err = pkt.writeTo(c, false) + if err != nil { + return fmt.Errorf("could not write packet: %w", err) + } + + return nil } func sendCheckBW(c *Conn) error { @@ -535,19 +585,24 @@ func sendCheckBW(c *Conn) error { enc, err := amf.EncodeString(enc, av_checkbw) if err != nil { - return err + return fmt.Errorf("could not encode av check bw token: %w", err) } c.numInvokes++ enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { - return err + return fmt.Errorf("could not encode number of invokes: %w", err) } enc[0] = amf.TypeNull enc = enc[1:] pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(c, false) + err = pkt.writeTo(c, false) + if err != nil { + return fmt.Errorf("could not write packet: %w", err) + } + + return nil } func eraseMethod(m []method, i int) []method { @@ -565,16 +620,16 @@ func handleInvoke(c *Conn, body []byte) error { var obj amf.Object _, err := amf.Decode(&obj, body, false) if err != nil { - return err + return fmt.Errorf("could not decode: %w", err) } meth, err := obj.StringProperty("", 0) if err != nil { - return err + return fmt.Errorf("could not get value of string property meth: %w", err) } txn, err := obj.NumberProperty("", 1) if err != nil { - return err + return fmt.Errorf("could not get value of number property txn: %w", err) } c.log(DebugLevel, pkg+"invoking method "+meth) @@ -601,26 +656,26 @@ func handleInvoke(c *Conn, body []byte) error { case avConnect: err := sendReleaseStream(c) if err != nil { - return err + return fmt.Errorf("could not send release stream: %w", err) } err = sendFCPublish(c) if err != nil { - return err + return fmt.Errorf("could not send fc publish: %w", err) } err = sendCreateStream(c) if err != nil { - return err + return fmt.Errorf("could not send create stream: %w", err) } case avCreatestream: n, err := obj.NumberProperty("", 3) if err != nil { - return err + return fmt.Errorf("could not get value for stream id number property: %w",err) } c.streamID = uint32(n) err = sendPublish(c) if err != nil { - return err + return fmt.Errorf("could not send publish: %w",err) } default: @@ -630,27 +685,27 @@ func handleInvoke(c *Conn, body []byte) error { case avOnBWDone: err := sendCheckBW(c) if err != nil { - return err + return fmt.Errorf("could not send check bw: %w", err) } case avOnStatus: obj2, err := obj.ObjectProperty("", 3) if err != nil { - return err + return fmt.Errorf("could not get object property value for obj2: %w", err) } code, err := obj2.StringProperty(avCode, -1) if err != nil { - return err + return fmt.Errorf("could not get string property value for code: %w", err) } level, err := obj2.StringProperty(avLevel, -1) if err != nil { - return err + return fmt.Errorf("could not get string property value for level: %w", err) } c.log(DebugLevel, pkg+"onStatus", "code", code, "level", level) if code != avNetStreamPublish_Start { c.log(ErrorLevel, pkg+"unexpected response "+code) - return errUnimplemented + return fmt.Errorf("unimplemented code: %v", code) } c.log(DebugLevel, pkg+"playing") c.isPlaying = true @@ -681,14 +736,14 @@ func handshake(c *Conn) error { _, err := c.write(clientbuf[:]) if err != nil { - return err + return fmt.Errorf("could not write handshake: %w", err) } c.log(DebugLevel, pkg+"handshake sent") var typ [1]byte _, err = c.read(typ[:]) if err != nil { - return err + return fmt.Errorf("could not read handshake: %w", err) } c.log(DebugLevel, pkg+"handshake received") @@ -697,7 +752,7 @@ func handshake(c *Conn) error { } _, err = c.read(serversig[:]) if err != nil { - return err + return fmt.Errorf("could not read server signal: %w", err) } // decode server response @@ -707,12 +762,12 @@ func handshake(c *Conn) error { // 2nd part of handshake _, err = c.write(serversig[:]) if err != nil { - return err + return fmt.Errorf("could not write part 2 of handshake: %w", err) } _, err = c.read(serversig[:]) if err != nil { - return err + return fmt.Errorf("could not read part 2 of handshake: %w", err) } if !bytes.Equal(serversig[:signatureSize], clientbuf[1:signatureSize+1]) {