From 22b76b5bda7d95c73c097f03d61749a23cc20047 Mon Sep 17 00:00:00 2001 From: scruzin Date: Fri, 11 Jan 2019 10:35:20 +1030 Subject: [PATCH] Session.read()/write() both now return (int, error). --- rtmp/packet.go | 19 ++++++++++--------- rtmp/rtmp.go | 11 +++++------ rtmp/session.go | 20 ++++++++++---------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/rtmp/packet.go b/rtmp/packet.go index 3c50b59d..dac3d803 100644 --- a/rtmp/packet.go +++ b/rtmp/packet.go @@ -106,7 +106,7 @@ func (pkt *packet) read(s *Session) error { var hbuf [fullHeaderSize]byte header := hbuf[:] - err := s.read(header[:1]) + _, err := s.read(header[:1]) if err != nil { s.log(DebugLevel, pkg+"failed to read packet header 1st byte", "error", err.Error()) if err == io.EOF { @@ -120,7 +120,7 @@ func (pkt *packet) read(s *Session) error { switch { case pkt.channel == 0: - err = s.read(header[:1]) + _, err = s.read(header[:1]) if err != nil { s.log(DebugLevel, pkg+"failed to read packet header 2nd byte", "error", err.Error()) return err @@ -129,7 +129,7 @@ func (pkt *packet) read(s *Session) error { pkt.channel = int32(header[0]) + 64 case pkt.channel == 1: - err = s.read(header[:2]) + _, err = s.read(header[:2]) if err != nil { s.log(DebugLevel, pkg+"failed to read packet header 3rd byte", "error", err.Error()) return err @@ -173,7 +173,7 @@ func (pkt *packet) read(s *Session) error { size-- if size > 0 { - err = s.read(header[:size]) + _, err = s.read(header[:size]) if err != nil { s.log(DebugLevel, pkg+"failed to read packet header", "error", err.Error()) return err @@ -199,7 +199,7 @@ func (pkt *packet) read(s *Session) error { extendedTimestamp := pkt.timestamp == 0xffffff if extendedTimestamp { - err = s.read(header[size : size+4]) + _, err = s.read(header[size : size+4]) if err != nil { s.log(DebugLevel, pkg+"failed to read extended timestamp", "error", err.Error()) return err @@ -227,7 +227,7 @@ func (pkt *packet) read(s *Session) error { pkt.chunk.data = pkt.body[pkt.bytesRead : pkt.bytesRead+uint32(chunkSize)] } - err = s.read(pkt.body[pkt.bytesRead:][:chunkSize]) + _, err = s.read(pkt.body[pkt.bytesRead:][:chunkSize]) if err != nil { s.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error()) return err @@ -261,7 +261,7 @@ func (pkt *packet) read(s *Session) error { return nil } -// resize adjusts the packet's storage to accommodate a body of the given size. +// resize adjusts the packet's storage to accommodate a body of the given size and header type. func (pkt *packet) resize(size uint32, ht uint8) { buf := make([]byte, fullHeaderSize+size) pkt.header = buf @@ -286,6 +286,7 @@ func (pkt *packet) resize(size uint32, ht uint8) { } // write sends a packet. +// When queue is true, we expect a response to this request and cache the method on s.methodCalls. func (pkt *packet) write(s *Session, queue bool) error { if pkt.body == nil { return errInvalidBody @@ -426,7 +427,7 @@ func (pkt *packet) write(s *Session, queue bool) error { // Send previously deferrd packet if combining it with the next one would exceed the chunk size. if len(s.deferred)+size+hSize > chunkSize { s.log(DebugLevel, pkg+"sending deferred packet separately", "size", len(s.deferred)) - err := s.write(s.deferred) + _, err := s.write(s.deferred) if err != nil { return err } @@ -447,7 +448,7 @@ func (pkt *packet) write(s *Session, queue bool) error { s.log(DebugLevel, pkg+"combining deferred packet", "size", len(s.deferred)) bytes = append(s.deferred, bytes...) } - err := s.write(bytes) + _, err := s.write(bytes) if err != nil { return err } diff --git a/rtmp/rtmp.go b/rtmp/rtmp.go index f48e9e53..4b334587 100644 --- a/rtmp/rtmp.go +++ b/rtmp/rtmp.go @@ -606,7 +606,6 @@ func sendDeleteStream(s *Session, dStreamId float64) error { } pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) - /* no response expected */ return pkt.write(s, false) } @@ -816,14 +815,14 @@ func handshake(s *Session) error { clientsig[i] = byte(rand.Intn(256)) } - err := s.write(clientbuf[:]) + _, err := s.write(clientbuf[:]) if err != nil { return err } s.log(DebugLevel, pkg+"handshake sent") var typ [1]byte - err = s.read(typ[:]) + _, err = s.read(typ[:]) if err != nil { return err } @@ -832,7 +831,7 @@ func handshake(s *Session) error { if typ[0] != clientbuf[0] { s.log(WarnLevel, pkg+"handshake type mismatch", "sent", clientbuf[0], "received", typ) } - err = s.read(serversig[:]) + _, err = s.read(serversig[:]) if err != nil { return err } @@ -842,12 +841,12 @@ func handshake(s *Session) error { s.log(DebugLevel, pkg+"server uptime", "uptime", suptime) // 2nd part of handshake - err = s.write(serversig[:]) + _, err = s.write(serversig[:]) if err != nil { return err } - err = s.read(serversig[:]) + _, err = s.read(serversig[:]) if err != nil { return err } diff --git a/rtmp/session.go b/rtmp/session.go index 8ba30d1c..eb97c938 100644 --- a/rtmp/session.go +++ b/rtmp/session.go @@ -203,39 +203,39 @@ func (s *Session) Write(data []byte) (int, error) { // read from an RTMP connection. Sends a bytes received message if the // number of bytes received (nBytesIn) is greater than the number sent // (nBytesInSent) by 10% of the bandwidth. -func (s *Session) read(buf []byte) error { +func (s *Session) read(buf []byte) (int, error) { err := s.link.conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(s.link.timeout))) if err != nil { - return err + return 0, err } n, err := io.ReadFull(s.link.conn, buf) if err != nil { s.log(DebugLevel, pkg+"read failed", "error", err.Error()) - return err + return 0, err } s.nBytesIn += int32(n) if s.nBytesIn > (s.nBytesInSent + s.clientBW/10) { err := sendBytesReceived(s) if err != nil { - return err + return n, err // NB: we still read n bytes, even though send bytes failed } } - return nil + return n, nil } // write to an RTMP connection. -func (s *Session) write(buf []byte) error { +func (s *Session) write(buf []byte) (int, error) { //ToDo: consider using a different timeout for writes than for reads err := s.link.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(s.link.timeout))) if err != nil { - return err + return 0, err } - _, err = s.link.conn.Write(buf) + n, err := s.link.conn.Write(buf) if err != nil { s.log(WarnLevel, pkg+"write failed", "error", err.Error()) - return err + return 0, err } - return nil + return n, nil } // isConnected returns true if the RTMP connection is up.