Merged in rtmp-refactoring-4 (pull request #121)

Rtmp refactoring 4

Approved-by: Saxon Milton <saxon.milton@gmail.com>
Approved-by: Alan Noble <anoble@gmail.com>
This commit is contained in:
Alan Noble 2019-01-31 21:26:22 +00:00
commit 0f346ea523
7 changed files with 59 additions and 71 deletions

View File

@ -51,7 +51,7 @@ import (
const ( const (
typeNumber = 0x00 typeNumber = 0x00
typeBoolean = 0x01 typeBoolean = 0x01
typeString = 0x02 TypeString = 0x02
TypeObject = 0x03 TypeObject = 0x03
typeMovieClip = 0x04 typeMovieClip = 0x04
TypeNull = 0x05 TypeNull = 0x05
@ -93,7 +93,7 @@ type Property struct {
var ( var (
ErrShortBuffer = errors.New("amf: short buffer") // The supplied buffer was too short. ErrShortBuffer = errors.New("amf: short buffer") // The supplied buffer was too short.
ErrInvalidType = errors.New("amf: invalid type") // An invalid type was supplied to the encoder. ErrInvalidType = errors.New("amf: invalid type") // An invalid type was supplied to the encoder.
ErrUnexpectedType = errors.New("amf: unexpected end") // An unexpected type was encountered while decoding. ErrUnexpectedType = errors.New("amf: unexpected type") // An unexpected type was encountered while decoding.
ErrPropertyNotFound = errors.New("amf: property not found") // The requested property was not found. ErrPropertyNotFound = errors.New("amf: property not found") // The requested property was not found.
) )
@ -160,6 +160,7 @@ func EncodeInt32(buf []byte, val uint32) ([]byte, error) {
} }
// EncodeString encodes a string. // EncodeString encodes a string.
// Strings less than 65536 in length are encoded as TypeString, while longer strings are ecodeded as typeLongString.
func EncodeString(buf []byte, val string) ([]byte, error) { func EncodeString(buf []byte, val string) ([]byte, error) {
const typeSize = 1 const typeSize = 1
if len(val) < 65536 && len(val)+typeSize+binary.Size(int16(0)) > len(buf) { if len(val) < 65536 && len(val)+typeSize+binary.Size(int16(0)) > len(buf) {
@ -171,7 +172,7 @@ func EncodeString(buf []byte, val string) ([]byte, error) {
} }
if len(val) < 65536 { if len(val) < 65536 {
buf[0] = typeString buf[0] = TypeString
buf = buf[1:] buf = buf[1:]
binary.BigEndian.PutUint16(buf[:2], uint16(len(val))) binary.BigEndian.PutUint16(buf[:2], uint16(len(val)))
buf = buf[2:] buf = buf[2:]
@ -263,7 +264,7 @@ func EncodeProperty(prop *Property, buf []byte) ([]byte, error) {
return EncodeNumber(buf, prop.Number) return EncodeNumber(buf, prop.Number)
case typeBoolean: case typeBoolean:
return EncodeBoolean(buf, prop.Number != 0) return EncodeBoolean(buf, prop.Number != 0)
case typeString: case TypeString:
return EncodeString(buf, prop.String) return EncodeString(buf, prop.String)
case TypeNull: case TypeNull:
if len(buf) < 2 { if len(buf) < 2 {
@ -320,7 +321,7 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) {
prop.Number = float64(buf[0]) prop.Number = float64(buf[0])
buf = buf[1:] buf = buf[1:]
case typeString: case TypeString:
n := DecodeInt16(buf[:2]) n := DecodeInt16(buf[:2])
if len(buf) < int(n+2) { if len(buf) < int(n+2) {
return 0, ErrShortBuffer return 0, ErrShortBuffer
@ -354,7 +355,6 @@ func DecodeProperty(prop *Property, buf []byte, decodeName bool) (int, error) {
} }
// Encode encodes an Object into its AMF representation. // Encode encodes an Object into its AMF representation.
// This is the top-level encoding function and is typically the only function callers will need to use.
func Encode(obj *Object, buf []byte) ([]byte, error) { func Encode(obj *Object, buf []byte) ([]byte, error) {
if len(buf) < 5 { if len(buf) < 5 {
return nil, ErrShortBuffer return nil, ErrShortBuffer
@ -481,7 +481,7 @@ func (obj *Object) NumberProperty(name string, idx int) (float64, error) {
// StringProperty is a wrapper for Property that returns a String property's value, if any. // StringProperty is a wrapper for Property that returns a String property's value, if any.
func (obj *Object) StringProperty(name string, idx int) (string, error) { func (obj *Object) StringProperty(name string, idx int) (string, error) {
prop, err := obj.Property(name, idx, typeString) prop, err := obj.Property(name, idx, TypeString)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -58,7 +58,7 @@ func TestSanity(t *testing.T) {
// TestStrings tests string encoding and decoding. // TestStrings tests string encoding and decoding.
func TestStrings(t *testing.T) { func TestStrings(t *testing.T) {
// Short string encoding is as follows: // Short string encoding is as follows:
// enc[0] = data type (typeString) // enc[0] = data type (TypeString)
// end[1:3] = size // end[1:3] = size
// enc[3:] = data // enc[3:] = data
for _, s := range testStrings { for _, s := range testStrings {
@ -67,8 +67,8 @@ func TestStrings(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("EncodeString failed") t.Errorf("EncodeString failed")
} }
if buf[0] != typeString { if buf[0] != TypeString {
t.Errorf("Expected typeString, got %v", buf[0]) t.Errorf("Expected TypeString, got %v", buf[0])
} }
ds := DecodeString(buf[1:]) ds := DecodeString(buf[1:])
if s != ds { if s != ds {
@ -76,7 +76,7 @@ func TestStrings(t *testing.T) {
} }
} }
// Long string encoding is as follows: // Long string encoding is as follows:
// enc[0] = data type (typeString) // enc[0] = data type (TypeString)
// end[1:5] = size // end[1:5] = size
// enc[5:] = data // enc[5:] = data
s := string(make([]byte, 65536)) s := string(make([]byte, 65536))
@ -148,7 +148,7 @@ func TestProperties(t *testing.T) {
// Encode/decode string properties. // Encode/decode string properties.
enc = buf[:] enc = buf[:]
for i := range testStrings { for i := range testStrings {
enc, err = EncodeProperty(&Property{Type: typeString, String: testStrings[i]}, enc) enc, err = EncodeProperty(&Property{Type: TypeString, String: testStrings[i]}, enc)
if err != nil { if err != nil {
t.Errorf("EncodeProperty of string failed") t.Errorf("EncodeProperty of string failed")
} }
@ -235,7 +235,7 @@ func TestObject(t *testing.T) {
// Construct a more complicated object that includes a nested object. // Construct a more complicated object that includes a nested object.
var obj2 Object var obj2 Object
for i := range testStrings { for i := range testStrings {
obj2.Properties = append(obj2.Properties, Property{Type: typeString, String: testStrings[i]}) obj2.Properties = append(obj2.Properties, Property{Type: TypeString, String: testStrings[i]})
obj2.Properties = append(obj2.Properties, Property{Type: typeNumber, Number: float64(testNumbers[i])}) obj2.Properties = append(obj2.Properties, Property{Type: typeNumber, Number: float64(testNumbers[i])})
} }
obj2.Properties = append(obj2.Properties, Property{Type: TypeObject, Object: obj1}) obj2.Properties = append(obj2.Properties, Property{Type: TypeObject, Object: obj1})

View File

@ -121,20 +121,6 @@ func Dial(url string, timeout uint, log Log) (*Conn, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if c.link.app == "" {
return nil, errInvalidURL
}
if c.link.port == 0 {
switch {
case (c.link.protocol & featureSSL) != 0:
c.link.port = 433
c.log(FatalLevel, pkg+"SSL not supported")
case (c.link.protocol & featureHTTP) != 0:
c.link.port = 80
default:
c.link.port = 1935
}
}
c.link.url = rtmpProtocolStrings[c.link.protocol] + "://" + c.link.host + ":" + strconv.Itoa(int(c.link.port)) + "/" + c.link.app c.link.url = rtmpProtocolStrings[c.link.protocol] + "://" + c.link.host + ":" + strconv.Itoa(int(c.link.port)) + "/" + c.link.app
c.link.protocol |= featureWrite c.link.protocol |= featureWrite

View File

@ -81,7 +81,7 @@ const (
// 3: basic header (chunk type and stream ID) (1 byte) // 3: basic header (chunk type and stream ID) (1 byte)
var headerSizes = [...]int{12, 8, 4, 1} var headerSizes = [...]int{12, 8, 4, 1}
// packet defines an RTMP packet. // packet represents an RTMP packet.
type packet struct { type packet struct {
headerType uint8 headerType uint8
packetType uint8 packetType uint8
@ -90,7 +90,6 @@ type packet struct {
timestamp uint32 timestamp uint32
streamID uint32 streamID uint32
bodySize uint32 bodySize uint32
bytesRead uint32
buf []byte buf []byte
body []byte body []byte
} }
@ -179,7 +178,6 @@ func (pkt *packet) readFrom(c *Conn) error {
pkt.timestamp = amf.DecodeInt24(header[:3]) pkt.timestamp = amf.DecodeInt24(header[:3])
if size >= 6 { if size >= 6 {
pkt.bodySize = amf.DecodeInt24(header[3:6]) pkt.bodySize = amf.DecodeInt24(header[3:6])
pkt.bytesRead = 0
if size > 6 { if size > 6 {
pkt.packetType = header[6] pkt.packetType = header[6]
@ -201,25 +199,18 @@ func (pkt *packet) readFrom(c *Conn) error {
hSize += 4 hSize += 4
} }
if pkt.bodySize > 0 && pkt.body == nil { pkt.resize(pkt.bodySize, pkt.headerType)
pkt.resize(pkt.bodySize, (hbuf[0]&0xc0)>>6)
if pkt.bodySize > c.inChunkSize {
c.log(WarnLevel, pkg+"reading large packet", "size", int(pkt.bodySize))
} }
toRead := pkt.bodySize - pkt.bytesRead _, err = c.read(pkt.body[:pkt.bodySize])
chunkSize := c.inChunkSize
if toRead < chunkSize {
chunkSize = toRead
}
_, err = c.read(pkt.body[pkt.bytesRead:][:chunkSize])
if err != nil { if err != nil {
c.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error()) c.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error())
return err return err
} }
pkt.bytesRead += uint32(chunkSize)
// Keep the packet as a reference for other packets on this channel. // Keep the packet as a reference for other packets on this channel.
if c.channelsIn[pkt.channel] == nil { if c.channelsIn[pkt.channel] == nil {
c.channelsIn[pkt.channel] = &packet{} c.channelsIn[pkt.channel] = &packet{}
@ -237,15 +228,16 @@ func (pkt *packet) readFrom(c *Conn) error {
c.channelTimestamp[pkt.channel] = int32(pkt.timestamp) c.channelTimestamp[pkt.channel] = int32(pkt.timestamp)
c.channelsIn[pkt.channel].body = nil c.channelsIn[pkt.channel].body = nil
c.channelsIn[pkt.channel].bytesRead = 0
c.channelsIn[pkt.channel].hasAbsTimestamp = false c.channelsIn[pkt.channel].hasAbsTimestamp = false
return nil return nil
} }
// resize adjusts the packet's storage to accommodate a body of the given size and header type. // resize adjusts the packet's storage (if necessary) to accommodate a body of the given size and header type.
// When headerSizeAuto is specified, the header type is computed based on packet type. // When headerSizeAuto is specified, the header type is computed based on packet type.
func (pkt *packet) resize(size uint32, ht uint8) { func (pkt *packet) resize(size uint32, ht uint8) {
pkt.buf = make([]byte, fullHeaderSize+size) if cap(pkt.buf) < fullHeaderSize+int(size) {
pkt.buf = make([]byte, fullHeaderSize+size)
}
pkt.body = pkt.buf[fullHeaderSize:] pkt.body = pkt.buf[fullHeaderSize:]
if ht != headerSizeAuto { if ht != headerSizeAuto {
pkt.headerType = ht pkt.headerType = ht
@ -407,7 +399,7 @@ func (pkt *packet) writeTo(c *Conn, queue bool) error {
return nil return nil
} }
} else { } else {
// Send previously deferrd packet if combining it with the next one would exceed the chunk size. // Send previously deferred packet if combining it with the next one would exceed the chunk size.
if len(c.deferred)+size+hSize > chunkSize { if len(c.deferred)+size+hSize > chunkSize {
c.log(DebugLevel, pkg+"sending deferred packet separately", "size", len(c.deferred)) c.log(DebugLevel, pkg+"sending deferred packet separately", "size", len(c.deferred))
_, err := c.write(c.deferred) _, err := c.write(c.deferred)
@ -419,7 +411,7 @@ func (pkt *packet) writeTo(c *Conn, queue bool) error {
} }
// TODO(kortschak): Rewrite this horrific peice of premature optimisation. // TODO(kortschak): Rewrite this horrific peice of premature optimisation.
c.log(DebugLevel, pkg+"sending packet", "la", c.link.conn.LocalAddr(), "ra", c.link.conn.RemoteAddr(), "size", size) c.log(DebugLevel, pkg+"sending packet", "size", size, "la", c.link.conn.LocalAddr(), "ra", c.link.conn.RemoteAddr())
for size+hSize != 0 { for size+hSize != 0 {
if chunkSize > size { if chunkSize > size {
chunkSize = size chunkSize = size

View File

@ -41,7 +41,6 @@ import (
) )
// parseURL parses an RTMP URL (ok, technically it is lexing). // parseURL parses an RTMP URL (ok, technically it is lexing).
//
func parseURL(addr string) (protocol int32, host string, port uint16, app, playpath string, err error) { func parseURL(addr string) (protocol int32, host string, port uint16, app, playpath string, err error) {
u, err := url.Parse(addr) u, err := url.Parse(addr)
if err != nil { if err != nil {
@ -81,6 +80,9 @@ func parseURL(addr string) (protocol int32, host string, port uint16, app, playp
} }
elems := strings.SplitN(u.Path[1:], "/", 3) elems := strings.SplitN(u.Path[1:], "/", 3)
app = elems[0] app = elems[0]
if app == "" {
return protocol, host, port, app, playpath, errInvalidURL
}
playpath = elems[1] playpath = elems[1]
if len(elems) == 3 && len(elems[2]) != 0 { if len(elems) == 3 && len(elems[2]) != 0 {
playpath = path.Join(elems[1:]...) playpath = path.Join(elems[1:]...)
@ -97,5 +99,15 @@ func parseURL(addr string) (protocol int32, host string, port uint16, app, playp
} }
} }
switch {
case port != 0:
case (protocol & featureSSL) != 0:
return protocol, host, port, app, playpath, errUnimplemented // port = 433
case (protocol & featureHTTP) != 0:
port = 80
default:
port = 1935
}
return protocol, host, port, app, playpath, nil return protocol, host, port, app, playpath, nil
} }

View File

@ -174,6 +174,13 @@ func connect(c *Conn) error {
return err return err
} }
c.log(DebugLevel, pkg+"connected") c.log(DebugLevel, pkg+"connected")
defer func() {
if err != nil {
c.link.conn.Close()
}
}()
err = handshake(c) err = handshake(c)
if err != nil { if err != nil {
c.log(WarnLevel, pkg+"handshake failed", "error", err.Error()) c.log(WarnLevel, pkg+"handshake failed", "error", err.Error())
@ -185,12 +192,14 @@ func connect(c *Conn) error {
c.log(WarnLevel, pkg+"sendConnect failed", "error", err.Error()) c.log(WarnLevel, pkg+"sendConnect failed", "error", err.Error())
return err return err
} }
c.log(DebugLevel, pkg+"negotiating") c.log(DebugLevel, pkg+"negotiating")
var buf [256]byte
for !c.isPlaying { for !c.isPlaying {
pkt := packet{} pkt := packet{buf: buf[:]}
err = pkt.readFrom(c) err = pkt.readFrom(c)
if err != nil { if err != nil {
break return err
} }
switch pkt.packetType { switch pkt.packetType {
@ -199,14 +208,10 @@ func connect(c *Conn) error {
default: default:
err = handlePacket(c, &pkt) err = handlePacket(c, &pkt)
if err != nil { if err != nil {
break return err
} }
} }
} }
if !c.isPlaying {
return err
}
return nil return nil
} }
@ -276,26 +281,18 @@ func sendConnectPacket(c *Conn) error {
return err return err
} }
enc[0] = amf.TypeObject // required link info
enc = enc[1:] info := amf.Object{Properties: []amf.Property{
enc, err = amf.EncodeNamedString(enc, avApp, c.link.app) amf.Property{Type: amf.TypeString, Name: avApp, String: c.link.app},
if err != nil { amf.Property{Type: amf.TypeString, Name: avType, String: avNonprivate},
return err amf.Property{Type: amf.TypeString, Name: avTcUrl, String: c.link.url}},
} }
enc, err = amf.EncodeNamedString(enc, avType, avNonprivate) enc, err = amf.Encode(&info, enc)
if err != nil {
return err
}
enc, err = amf.EncodeNamedString(enc, avTcUrl, c.link.url)
if err != nil {
return err
}
enc, err = amf.EncodeInt24(enc, amf.TypeObjectEnd)
if err != nil { if err != nil {
return err return err
} }
// add auth string, if any // optional link auth info
if c.link.auth != "" { if c.link.auth != "" {
enc, err = amf.EncodeBoolean(enc, c.link.flags&linkAuth != 0) enc, err = amf.EncodeBoolean(enc, c.link.flags&linkAuth != 0)
if err != nil { if err != nil {

View File

@ -243,8 +243,9 @@ func TestFromFile(t *testing.T) {
} }
defer f.Close() defer f.Close()
rs := &rtmpSender{conn: c}
// Pass RTMP session, true for audio, true for video, and 25 FPS // Pass RTMP session, true for audio, true for video, and 25 FPS
flvEncoder, err := flv.NewEncoder(c, true, true, 25) flvEncoder, err := flv.NewEncoder(rs, true, true, 25)
if err != nil { if err != nil {
t.Fatalf("failed to create encoder: %v", err) t.Fatalf("failed to create encoder: %v", err)
} }