Session.read()/write() both now return (int, error).

This commit is contained in:
scruzin 2019-01-11 10:35:20 +10:30
parent 67cc591dd2
commit 22b76b5bda
3 changed files with 25 additions and 25 deletions

View File

@ -106,7 +106,7 @@ func (pkt *packet) read(s *Session) error {
var hbuf [fullHeaderSize]byte var hbuf [fullHeaderSize]byte
header := hbuf[:] header := hbuf[:]
err := s.read(header[:1]) _, err := s.read(header[:1])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read packet header 1st byte", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read packet header 1st byte", "error", err.Error())
if err == io.EOF { if err == io.EOF {
@ -120,7 +120,7 @@ func (pkt *packet) read(s *Session) error {
switch { switch {
case pkt.channel == 0: case pkt.channel == 0:
err = s.read(header[:1]) _, err = s.read(header[:1])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read packet header 2nd byte", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read packet header 2nd byte", "error", err.Error())
return err return err
@ -129,7 +129,7 @@ func (pkt *packet) read(s *Session) error {
pkt.channel = int32(header[0]) + 64 pkt.channel = int32(header[0]) + 64
case pkt.channel == 1: case pkt.channel == 1:
err = s.read(header[:2]) _, err = s.read(header[:2])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read packet header 3rd byte", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read packet header 3rd byte", "error", err.Error())
return err return err
@ -173,7 +173,7 @@ func (pkt *packet) read(s *Session) error {
size-- size--
if size > 0 { if size > 0 {
err = s.read(header[:size]) _, err = s.read(header[:size])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read packet header", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read packet header", "error", err.Error())
return err return err
@ -199,7 +199,7 @@ func (pkt *packet) read(s *Session) error {
extendedTimestamp := pkt.timestamp == 0xffffff extendedTimestamp := pkt.timestamp == 0xffffff
if extendedTimestamp { if extendedTimestamp {
err = s.read(header[size : size+4]) _, err = s.read(header[size : size+4])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read extended timestamp", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read extended timestamp", "error", err.Error())
return err return err
@ -227,7 +227,7 @@ func (pkt *packet) read(s *Session) error {
pkt.chunk.data = pkt.body[pkt.bytesRead : pkt.bytesRead+uint32(chunkSize)] pkt.chunk.data = pkt.body[pkt.bytesRead : pkt.bytesRead+uint32(chunkSize)]
} }
err = s.read(pkt.body[pkt.bytesRead:][:chunkSize]) _, err = s.read(pkt.body[pkt.bytesRead:][:chunkSize])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error())
return err return err
@ -261,7 +261,7 @@ func (pkt *packet) read(s *Session) error {
return nil return nil
} }
// resize adjusts the packet's storage to accommodate a body of the given size. // resize adjusts the packet's storage to accommodate a body of the given size and header type.
func (pkt *packet) resize(size uint32, ht uint8) { func (pkt *packet) resize(size uint32, ht uint8) {
buf := make([]byte, fullHeaderSize+size) buf := make([]byte, fullHeaderSize+size)
pkt.header = buf pkt.header = buf
@ -286,6 +286,7 @@ func (pkt *packet) resize(size uint32, ht uint8) {
} }
// write sends a packet. // write sends a packet.
// When queue is true, we expect a response to this request and cache the method on s.methodCalls.
func (pkt *packet) write(s *Session, queue bool) error { func (pkt *packet) write(s *Session, queue bool) error {
if pkt.body == nil { if pkt.body == nil {
return errInvalidBody return errInvalidBody
@ -426,7 +427,7 @@ func (pkt *packet) write(s *Session, queue bool) error {
// Send previously deferrd packet if combining it with the next one would exceed the chunk size. // Send previously deferrd packet if combining it with the next one would exceed the chunk size.
if len(s.deferred)+size+hSize > chunkSize { if len(s.deferred)+size+hSize > chunkSize {
s.log(DebugLevel, pkg+"sending deferred packet separately", "size", len(s.deferred)) s.log(DebugLevel, pkg+"sending deferred packet separately", "size", len(s.deferred))
err := s.write(s.deferred) _, err := s.write(s.deferred)
if err != nil { if err != nil {
return err return err
} }
@ -447,7 +448,7 @@ func (pkt *packet) write(s *Session, queue bool) error {
s.log(DebugLevel, pkg+"combining deferred packet", "size", len(s.deferred)) s.log(DebugLevel, pkg+"combining deferred packet", "size", len(s.deferred))
bytes = append(s.deferred, bytes...) bytes = append(s.deferred, bytes...)
} }
err := s.write(bytes) _, err := s.write(bytes)
if err != nil { if err != nil {
return err return err
} }

View File

@ -606,7 +606,6 @@ func sendDeleteStream(s *Session, dStreamId float64) error {
} }
pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc)) pkt.bodySize = uint32((len(pbuf) - fullHeaderSize) - len(enc))
/* no response expected */
return pkt.write(s, false) return pkt.write(s, false)
} }
@ -816,14 +815,14 @@ func handshake(s *Session) error {
clientsig[i] = byte(rand.Intn(256)) clientsig[i] = byte(rand.Intn(256))
} }
err := s.write(clientbuf[:]) _, err := s.write(clientbuf[:])
if err != nil { if err != nil {
return err return err
} }
s.log(DebugLevel, pkg+"handshake sent") s.log(DebugLevel, pkg+"handshake sent")
var typ [1]byte var typ [1]byte
err = s.read(typ[:]) _, err = s.read(typ[:])
if err != nil { if err != nil {
return err return err
} }
@ -832,7 +831,7 @@ func handshake(s *Session) error {
if typ[0] != clientbuf[0] { if typ[0] != clientbuf[0] {
s.log(WarnLevel, pkg+"handshake type mismatch", "sent", clientbuf[0], "received", typ) s.log(WarnLevel, pkg+"handshake type mismatch", "sent", clientbuf[0], "received", typ)
} }
err = s.read(serversig[:]) _, err = s.read(serversig[:])
if err != nil { if err != nil {
return err return err
} }
@ -842,12 +841,12 @@ func handshake(s *Session) error {
s.log(DebugLevel, pkg+"server uptime", "uptime", suptime) s.log(DebugLevel, pkg+"server uptime", "uptime", suptime)
// 2nd part of handshake // 2nd part of handshake
err = s.write(serversig[:]) _, err = s.write(serversig[:])
if err != nil { if err != nil {
return err return err
} }
err = s.read(serversig[:]) _, err = s.read(serversig[:])
if err != nil { if err != nil {
return err return err
} }

View File

@ -203,39 +203,39 @@ func (s *Session) Write(data []byte) (int, error) {
// read from an RTMP connection. Sends a bytes received message if the // read from an RTMP connection. Sends a bytes received message if the
// number of bytes received (nBytesIn) is greater than the number sent // number of bytes received (nBytesIn) is greater than the number sent
// (nBytesInSent) by 10% of the bandwidth. // (nBytesInSent) by 10% of the bandwidth.
func (s *Session) read(buf []byte) error { func (s *Session) read(buf []byte) (int, error) {
err := s.link.conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(s.link.timeout))) err := s.link.conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(s.link.timeout)))
if err != nil { if err != nil {
return err return 0, err
} }
n, err := io.ReadFull(s.link.conn, buf) n, err := io.ReadFull(s.link.conn, buf)
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"read failed", "error", err.Error()) s.log(DebugLevel, pkg+"read failed", "error", err.Error())
return err return 0, err
} }
s.nBytesIn += int32(n) s.nBytesIn += int32(n)
if s.nBytesIn > (s.nBytesInSent + s.clientBW/10) { if s.nBytesIn > (s.nBytesInSent + s.clientBW/10) {
err := sendBytesReceived(s) err := sendBytesReceived(s)
if err != nil { if err != nil {
return err return n, err // NB: we still read n bytes, even though send bytes failed
} }
} }
return nil return n, nil
} }
// write to an RTMP connection. // write to an RTMP connection.
func (s *Session) write(buf []byte) error { func (s *Session) write(buf []byte) (int, error) {
//ToDo: consider using a different timeout for writes than for reads //ToDo: consider using a different timeout for writes than for reads
err := s.link.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(s.link.timeout))) err := s.link.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(s.link.timeout)))
if err != nil { if err != nil {
return err return 0, err
} }
_, err = s.link.conn.Write(buf) n, err := s.link.conn.Write(buf)
if err != nil { if err != nil {
s.log(WarnLevel, pkg+"write failed", "error", err.Error()) s.log(WarnLevel, pkg+"write failed", "error", err.Error())
return err return 0, err
} }
return nil return n, nil
} }
// isConnected returns true if the RTMP connection is up. // isConnected returns true if the RTMP connection is up.