From 998d41c96fe140e6c2b6dd9412ee64cf099d2271 Mon Sep 17 00:00:00 2001 From: scruzin Date: Sat, 19 Jan 2019 13:12:24 +1030 Subject: [PATCH] Session now Conn, init() moved into Dial(), and connectStream() merged into connect(). --- rtmp/rtmp.go | 248 ++++++++++++++++++++++----------------------------- 1 file changed, 109 insertions(+), 139 deletions(-) diff --git a/rtmp/rtmp.go b/rtmp/rtmp.go index 37460f07..fe4c04d6 100644 --- a/rtmp/rtmp.go +++ b/rtmp/rtmp.go @@ -162,79 +162,49 @@ var ( errUnimplemented = errors.New("rtmp: unimplemented feature") ) -// init initialises the Session link -func (s *Session) init() (err error) { - s.link.protocol, s.link.host, s.link.port, s.link.app, s.link.playpath, err = parseURL(s.url) - if err != nil { - return err - } - if s.link.app == "" { - return errInvalidURL - } - if s.link.port == 0 { - switch { - case (s.link.protocol & featureSSL) != 0: - s.link.port = 433 - s.log(FatalLevel, pkg+"SSL not supported") - case (s.link.protocol & featureHTTP) != 0: - s.link.port = 80 - default: - s.link.port = 1935 - } - } - s.link.url = rtmpProtocolStrings[s.link.protocol] + "://" + s.link.host + ":" + strconv.Itoa(int(s.link.port)) + "/" + s.link.app - s.link.protocol |= featureWrite - return nil -} - // connect establishes an RTMP connection. -func connect(s *Session) error { - addr, err := net.ResolveTCPAddr("tcp4", s.link.host+":"+strconv.Itoa(int(s.link.port))) +func connect(c *Conn) error { + addr, err := net.ResolveTCPAddr("tcp4", c.link.host+":"+strconv.Itoa(int(c.link.port))) if err != nil { return err } - s.link.conn, err = net.DialTCP("tcp4", nil, addr) + c.link.conn, err = net.DialTCP("tcp4", nil, addr) if err != nil { - s.log(WarnLevel, pkg+"dial failed", "error", err.Error()) + c.log(WarnLevel, pkg+"dial failed", "error", err.Error()) return err } - s.log(DebugLevel, pkg+"connected") - err = handshake(s) + c.log(DebugLevel, pkg+"connected") + err = handshake(c) if err != nil { - s.log(WarnLevel, pkg+"handshake failed", "error", err.Error()) + c.log(WarnLevel, pkg+"handshake failed", "error", err.Error()) return err } - s.log(DebugLevel, pkg+"handshaked") - err = sendConnectPacket(s) + c.log(DebugLevel, pkg+"handshaked") + err = sendConnectPacket(c) if err != nil { - s.log(WarnLevel, pkg+"sendConnect failed", "error", err.Error()) + c.log(WarnLevel, pkg+"sendConnect failed", "error", err.Error()) return err } - return nil -} - -// connectStream reads a packet and handles it -func connectStream(s *Session) error { - var err error - for !s.isPlaying { + c.log(DebugLevel, pkg+"negotiating") + for !c.isPlaying { pkt := packet{} - err = pkt.readFrom(s) + err = pkt.readFrom(c) if err != nil { break } switch pkt.packetType { case packetTypeAudio, packetTypeVideo, packetTypeInfo: - s.log(WarnLevel, pkg+"got packet before play; ignoring", "type", pkt.packetType) + c.log(WarnLevel, pkg+"got packet before play; ignoring", "type", pkt.packetType) default: - err = handlePacket(s, &pkt) + err = handlePacket(c, &pkt) if err != nil { break } } } - if !s.isPlaying { + if !c.isPlaying { return err } return nil @@ -242,50 +212,50 @@ func connectStream(s *Session) error { // handlePacket handles a packet that the client has received. // NB: Unsupported packet types are logged fatally. -func handlePacket(s *Session, pkt *packet) error { +func handlePacket(c *Conn, pkt *packet) error { if pkt.bodySize < 4 { return errInvalidBody } switch pkt.packetType { case packetTypeChunkSize: - s.inChunkSize = amf.DecodeInt32(pkt.body[:4]) - s.log(DebugLevel, pkg+"set inChunkSize", "size", int(s.inChunkSize)) + c.inChunkSize = amf.DecodeInt32(pkt.body[:4]) + c.log(DebugLevel, pkg+"set inChunkSize", "size", int(c.inChunkSize)) case packetTypeBytesReadReport: - s.log(DebugLevel, pkg+"received packetTypeBytesReadReport") + c.log(DebugLevel, pkg+"received packetTypeBytesReadReport") case packetTypeServerBW: - s.serverBW = amf.DecodeInt32(pkt.body[:4]) - s.log(DebugLevel, pkg+"set serverBW", "size", int(s.serverBW)) + c.serverBW = amf.DecodeInt32(pkt.body[:4]) + c.log(DebugLevel, pkg+"set serverBW", "size", int(c.serverBW)) case packetTypeClientBW: - s.clientBW = amf.DecodeInt32(pkt.body[:4]) - s.log(DebugLevel, pkg+"set clientBW", "size", int(s.clientBW)) + c.clientBW = amf.DecodeInt32(pkt.body[:4]) + c.log(DebugLevel, pkg+"set clientBW", "size", int(c.clientBW)) if pkt.bodySize > 4 { - s.clientBW2 = pkt.body[4] - s.log(DebugLevel, pkg+"set clientBW2", "size", int(s.clientBW2)) + c.clientBW2 = pkt.body[4] + c.log(DebugLevel, pkg+"set clientBW2", "size", int(c.clientBW2)) } else { - s.clientBW2 = 0xff + c.clientBW2 = 0xff } case packetTypeInvoke: - err := handleInvoke(s, pkt.body[:pkt.bodySize]) + err := handleInvoke(c, pkt.body[:pkt.bodySize]) if err != nil { - s.log(WarnLevel, pkg+"unexpected error from handleInvoke", "error", err.Error()) + c.log(WarnLevel, pkg+"unexpected error from handleInvoke", "error", err.Error()) return err } case packetTypeControl, packetTypeAudio, packetTypeVideo, packetTypeFlashVideo, packetTypeFlexMessage, packetTypeInfo: - s.log(FatalLevel, pkg+"unsupported packet type "+strconv.Itoa(int(pkt.packetType))) + c.log(FatalLevel, pkg+"unsupported packet type "+strconv.Itoa(int(pkt.packetType))) default: - s.log(WarnLevel, pkg+"unknown packet type", "type", pkt.packetType) + c.log(WarnLevel, pkg+"unknown packet type", "type", pkt.packetType) } return nil } -func sendConnectPacket(s *Session) error { +func sendConnectPacket(c *Conn) error { var pbuf [4096]byte pkt := packet{ channel: chanControl, @@ -300,15 +270,15 @@ func sendConnectPacket(s *Session) error { if err != nil { return err } - s.numInvokes += 1 - enc, err = amf.EncodeNumber(enc, float64(s.numInvokes)) + c.numInvokes += 1 + enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { return err } enc[0] = amf.TypeObject enc = enc[1:] - enc, err = amf.EncodeNamedString(enc, avApp, s.link.app) + enc, err = amf.EncodeNamedString(enc, avApp, c.link.app) if err != nil { return err } @@ -316,7 +286,7 @@ func sendConnectPacket(s *Session) error { if err != nil { return err } - enc, err = amf.EncodeNamedString(enc, avTcUrl, s.link.url) + enc, err = amf.EncodeNamedString(enc, avTcUrl, c.link.url) if err != nil { return err } @@ -326,12 +296,12 @@ func sendConnectPacket(s *Session) error { } // add auth string, if any - if s.link.auth != "" { - enc, err = amf.EncodeBoolean(enc, s.link.flags&linkAuth != 0) + if c.link.auth != "" { + enc, err = amf.EncodeBoolean(enc, c.link.flags&linkAuth != 0) if err != nil { return err } - enc, err = amf.EncodeString(enc, s.link.auth) + enc, err = amf.EncodeString(enc, c.link.auth) if err != nil { return err } @@ -339,10 +309,10 @@ func sendConnectPacket(s *Session) error { pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(s, true) // response expected + return pkt.writeTo(c, true) // response expected } -func sendCreateStream(s *Session) error { +func sendCreateStream(c *Conn) error { var pbuf [256]byte pkt := packet{ channel: chanControl, @@ -357,8 +327,8 @@ func sendCreateStream(s *Session) error { if err != nil { return err } - s.numInvokes++ - enc, err = amf.EncodeNumber(enc, float64(s.numInvokes)) + c.numInvokes++ + enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { return err } @@ -367,10 +337,10 @@ func sendCreateStream(s *Session) error { pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(s, true) // response expected + return pkt.writeTo(c, true) // response expected } -func sendReleaseStream(s *Session) error { +func sendReleaseStream(c *Conn) error { var pbuf [1024]byte pkt := packet{ channel: chanControl, @@ -385,23 +355,23 @@ func sendReleaseStream(s *Session) error { if err != nil { return err } - s.numInvokes++ - enc, err = amf.EncodeNumber(enc, float64(s.numInvokes)) + c.numInvokes++ + enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { return err } enc[0] = amf.TypeNull enc = enc[1:] - enc, err = amf.EncodeString(enc, s.link.playpath) + enc, err = amf.EncodeString(enc, c.link.playpath) if err != nil { return err } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(s, false) + return pkt.writeTo(c, false) } -func sendFCPublish(s *Session) error { +func sendFCPublish(c *Conn) error { var pbuf [1024]byte pkt := packet{ channel: chanControl, @@ -416,24 +386,24 @@ func sendFCPublish(s *Session) error { if err != nil { return err } - s.numInvokes++ - enc, err = amf.EncodeNumber(enc, float64(s.numInvokes)) + c.numInvokes++ + enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { return err } enc[0] = amf.TypeNull enc = enc[1:] - enc, err = amf.EncodeString(enc, s.link.playpath) + enc, err = amf.EncodeString(enc, c.link.playpath) if err != nil { return err } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(s, false) + return pkt.writeTo(c, false) } -func sendFCUnpublish(s *Session) error { +func sendFCUnpublish(c *Conn) error { var pbuf [1024]byte pkt := packet{ channel: chanControl, @@ -448,24 +418,24 @@ func sendFCUnpublish(s *Session) error { if err != nil { return err } - s.numInvokes++ - enc, err = amf.EncodeNumber(enc, float64(s.numInvokes)) + c.numInvokes++ + enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { return err } enc[0] = amf.TypeNull enc = enc[1:] - enc, err = amf.EncodeString(enc, s.link.playpath) + enc, err = amf.EncodeString(enc, c.link.playpath) if err != nil { return err } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(s, false) + return pkt.writeTo(c, false) } -func sendPublish(s *Session) error { +func sendPublish(c *Conn) error { var pbuf [1024]byte pkt := packet{ channel: chanSource, @@ -480,14 +450,14 @@ func sendPublish(s *Session) error { if err != nil { return err } - s.numInvokes++ - enc, err = amf.EncodeNumber(enc, float64(s.numInvokes)) + c.numInvokes++ + enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { return err } enc[0] = amf.TypeNull enc = enc[1:] - enc, err = amf.EncodeString(enc, s.link.playpath) + enc, err = amf.EncodeString(enc, c.link.playpath) if err != nil { return err } @@ -498,10 +468,10 @@ func sendPublish(s *Session) error { pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(s, true) // response expected + return pkt.writeTo(c, true) // response expected } -func sendDeleteStream(s *Session, dStreamId float64) error { +func sendDeleteStream(c *Conn, dStreamId float64) error { var pbuf [256]byte pkt := packet{ channel: chanControl, @@ -516,8 +486,8 @@ func sendDeleteStream(s *Session, dStreamId float64) error { if err != nil { return err } - s.numInvokes++ - enc, err = amf.EncodeNumber(enc, float64(s.numInvokes)) + c.numInvokes++ + enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { return err } @@ -529,11 +499,11 @@ func sendDeleteStream(s *Session, dStreamId float64) error { } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(s, false) + return pkt.writeTo(c, false) } // sendBytesReceived tells the server how many bytes the client has received. -func sendBytesReceived(s *Session) error { +func sendBytesReceived(c *Conn) error { var pbuf [256]byte pkt := packet{ channel: chanBytesRead, @@ -544,18 +514,18 @@ func sendBytesReceived(s *Session) error { } enc := pkt.body - s.nBytesInSent = s.nBytesIn + c.nBytesInSent = c.nBytesIn - enc, err := amf.EncodeInt32(enc, s.nBytesIn) + enc, err := amf.EncodeInt32(enc, c.nBytesIn) if err != nil { return err } pkt.bodySize = 4 - return pkt.writeTo(s, false) + return pkt.writeTo(c, false) } -func sendCheckBW(s *Session) error { +func sendCheckBW(c *Conn) error { var pbuf [256]byte pkt := packet{ channel: chanControl, @@ -570,8 +540,8 @@ func sendCheckBW(s *Session) error { if err != nil { return err } - s.numInvokes++ - enc, err = amf.EncodeNumber(enc, float64(s.numInvokes)) + c.numInvokes++ + enc, err = amf.EncodeNumber(enc, float64(c.numInvokes)) if err != nil { return err } @@ -580,7 +550,7 @@ func sendCheckBW(s *Session) error { pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - return pkt.writeTo(s, false) + return pkt.writeTo(c, false) } func eraseMethod(m []method, i int) []method { @@ -590,8 +560,8 @@ func eraseMethod(m []method, i int) []method { } // int handleInvoke handles a packet invoke request -// Side effects: s.isPlaying set to true upon avNetStreamPublish_Start -func handleInvoke(s *Session, body []byte) error { +// Side effects: c.isPlaying set to true upon avNetStreamPublish_Start +func handleInvoke(c *Conn, body []byte) error { if body[0] != 0x02 { return errInvalidBody } @@ -610,37 +580,37 @@ func handleInvoke(s *Session, body []byte) error { return err } - s.log(DebugLevel, pkg+"invoking method "+meth) + c.log(DebugLevel, pkg+"invoking method "+meth) switch meth { case av_result: - if (s.link.protocol & featureWrite) == 0 { + if (c.link.protocol & featureWrite) == 0 { return errNotWritable } var methodInvoked string - for i, m := range s.methodCalls { + for i, m := range c.methodCalls { if float64(m.num) == txn { methodInvoked = m.name - s.methodCalls = eraseMethod(s.methodCalls, i) + c.methodCalls = eraseMethod(c.methodCalls, i) break } } if methodInvoked == "" { - s.log(WarnLevel, pkg+"received result without matching request", "id", txn) + c.log(WarnLevel, pkg+"received result without matching request", "id", txn) return nil } - s.log(DebugLevel, pkg+"received result for "+methodInvoked) + c.log(DebugLevel, pkg+"received result for "+methodInvoked) switch methodInvoked { case avConnect: - err := sendReleaseStream(s) + err := sendReleaseStream(c) if err != nil { return err } - err = sendFCPublish(s) + err = sendFCPublish(c) if err != nil { return err } - err = sendCreateStream(s) + err = sendCreateStream(c) if err != nil { return err } @@ -650,18 +620,18 @@ func handleInvoke(s *Session, body []byte) error { if err != nil { return err } - s.streamID = int32(n) - err = sendPublish(s) + c.streamID = int32(n) + err = sendPublish(c) if err != nil { return err } default: - s.log(FatalLevel, pkg+"unexpected method invoked"+methodInvoked) + c.log(FatalLevel, pkg+"unexpected method invoked"+methodInvoked) } case avOnBWDone: - err := sendCheckBW(s) + err := sendCheckBW(c) if err != nil { return err } @@ -679,27 +649,27 @@ func handleInvoke(s *Session, body []byte) error { if err != nil { return err } - s.log(DebugLevel, pkg+"onStatus", "code", code, "level", level) + c.log(DebugLevel, pkg+"onStatus", "code", code, "level", level) if code != avNetStreamPublish_Start { - s.log(ErrorLevel, pkg+"unexpected response "+code) + c.log(ErrorLevel, pkg+"unexpected response "+code) return errUnimplemented } - s.log(DebugLevel, pkg+"playing") - s.isPlaying = true - for i, m := range s.methodCalls { + c.log(DebugLevel, pkg+"playing") + c.isPlaying = true + for i, m := range c.methodCalls { if m.name == avPublish { - s.methodCalls = eraseMethod(s.methodCalls, i) + c.methodCalls = eraseMethod(c.methodCalls, i) } } default: - s.log(FatalLevel, pkg+"unsuppoted method "+meth) + c.log(FatalLevel, pkg+"unsuppoted method "+meth) } return nil } -func handshake(s *Session) error { +func handshake(c *Conn) error { var clientbuf [signatureSize + 1]byte clientsig := clientbuf[1:] @@ -712,44 +682,44 @@ func handshake(s *Session) error { clientsig[i] = byte(rand.Intn(256)) } - _, err := s.write(clientbuf[:]) + _, err := c.write(clientbuf[:]) if err != nil { return err } - s.log(DebugLevel, pkg+"handshake sent") + c.log(DebugLevel, pkg+"handshake sent") var typ [1]byte - _, err = s.read(typ[:]) + _, err = c.read(typ[:]) if err != nil { return err } - s.log(DebugLevel, pkg+"handshake received") + c.log(DebugLevel, pkg+"handshake received") if typ[0] != clientbuf[0] { - s.log(WarnLevel, pkg+"handshake type mismatch", "sent", clientbuf[0], "received", typ) + c.log(WarnLevel, pkg+"handshake type mismatch", "sent", clientbuf[0], "received", typ) } - _, err = s.read(serversig[:]) + _, err = c.read(serversig[:]) if err != nil { return err } // decode server response suptime := binary.BigEndian.Uint32(serversig[:4]) - s.log(DebugLevel, pkg+"server uptime", "uptime", suptime) + c.log(DebugLevel, pkg+"server uptime", "uptime", suptime) // 2nd part of handshake - _, err = s.write(serversig[:]) + _, err = c.write(serversig[:]) if err != nil { return err } - _, err = s.read(serversig[:]) + _, err = c.read(serversig[:]) if err != nil { return err } if !bytes.Equal(serversig[:signatureSize], clientbuf[1:signatureSize+1]) { - s.log(WarnLevel, pkg+"signature mismatch", "serversig", serversig[:signatureSize], "clientsig", clientbuf[1:signatureSize+1]) + c.log(WarnLevel, pkg+"signature mismatch", "serversig", serversig[:signatureSize], "clientsig", clientbuf[1:signatureSize+1]) } return nil }