readN()/writeN() now Session.read()/write() respectfully.

This commit is contained in:
scruzin 2019-01-10 23:16:20 +10:30
parent 26e8133a6e
commit 6a8e78a256
3 changed files with 56 additions and 49 deletions

View File

@ -35,8 +35,8 @@ LICENSE
package rtmp package rtmp
import ( import (
"io"
"encoding/binary" "encoding/binary"
"io"
) )
const ( const (
@ -104,7 +104,7 @@ func readPacket(s *Session, pkt *packet) error {
var hbuf [RTMP_MAX_HEADER_SIZE]byte var hbuf [RTMP_MAX_HEADER_SIZE]byte
header := hbuf[:] header := hbuf[:]
err := readN(s, header[:1]) err := s.read(header[:1])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read packet header 1st byte", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read packet header 1st byte", "error", err.Error())
if err == io.EOF { if err == io.EOF {
@ -118,7 +118,7 @@ func readPacket(s *Session, pkt *packet) error {
switch { switch {
case pkt.channel == 0: case pkt.channel == 0:
err = readN(s, header[:1]) err = s.read(header[:1])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read packet header 2nd byte", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read packet header 2nd byte", "error", err.Error())
return err return err
@ -127,7 +127,7 @@ func readPacket(s *Session, pkt *packet) error {
pkt.channel = int32(header[0]) + 64 pkt.channel = int32(header[0]) + 64
case pkt.channel == 1: case pkt.channel == 1:
err = readN(s, header[:2]) err = s.read(header[:2])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read packet header 3rd byte", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read packet header 3rd byte", "error", err.Error())
return err return err
@ -171,7 +171,7 @@ func readPacket(s *Session, pkt *packet) error {
size-- size--
if size > 0 { if size > 0 {
err = readN(s, header[:size]) err = s.read(header[:size])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read packet header", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read packet header", "error", err.Error())
return err return err
@ -197,7 +197,7 @@ func readPacket(s *Session, pkt *packet) error {
extendedTimestamp := pkt.timestamp == 0xffffff extendedTimestamp := pkt.timestamp == 0xffffff
if extendedTimestamp { if extendedTimestamp {
err = readN(s, header[size:size+4]) err = s.read(header[size : size+4])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read extended timestamp", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read extended timestamp", "error", err.Error())
return err return err
@ -224,7 +224,7 @@ func readPacket(s *Session, pkt *packet) error {
pkt.chunk.data = pkt.body[pkt.bytesRead : pkt.bytesRead+uint32(chunkSize)] pkt.chunk.data = pkt.body[pkt.bytesRead : pkt.bytesRead+uint32(chunkSize)]
} }
err = readN(s, pkt.body[pkt.bytesRead:][:chunkSize]) err = s.read(pkt.body[pkt.bytesRead:][:chunkSize])
if err != nil { if err != nil {
s.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error()) s.log(DebugLevel, pkg+"failed to read packet body", "error", err.Error())
return err return err
@ -418,7 +418,7 @@ func sendPacket(s *Session, pkt *packet, queue bool) error {
s.log(DebugLevel, pkg+"sending packet", "size", size, "la", s.link.conn.LocalAddr(), "ra", s.link.conn.RemoteAddr()) s.log(DebugLevel, pkg+"sending packet", "size", size, "la", s.link.conn.LocalAddr(), "ra", s.link.conn.RemoteAddr())
if s.deferred != nil && len(s.deferred)+size+hSize > chunkSize { if s.deferred != nil && len(s.deferred)+size+hSize > chunkSize {
err := writeN(s, s.deferred) err := s.write(s.deferred)
if err != nil { if err != nil {
return err return err
} }
@ -441,7 +441,7 @@ func sendPacket(s *Session, pkt *packet, queue bool) error {
// Prepend the previously deferred packet and write it with the current one. // Prepend the previously deferred packet and write it with the current one.
bytes = append(s.deferred, bytes...) bytes = append(s.deferred, bytes...)
} }
err := writeN(s, bytes) err := s.write(bytes)
if err != nil { if err != nil {
return err return err
} }

View File

@ -38,7 +38,6 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io"
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
@ -276,40 +275,6 @@ func handlePacket(s *Session, pkt *packet) error {
return nil return nil
} }
func readN(s *Session, buf []byte) error {
err := s.link.conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(s.link.timeout)))
if err != nil {
return err
}
n, err := io.ReadFull(s.link.conn, buf)
if err != nil {
s.log(DebugLevel, pkg+"read failed", "error", err.Error())
return err
}
s.nBytesIn += int32(n)
if s.nBytesIn > (s.nBytesInSent + s.clientBW/10) {
err := sendBytesReceived(s)
if err != nil {
return err
}
}
return nil
}
func writeN(s *Session, buf []byte) error {
//ToDo: consider using a different timeout for writes than for reads
err := s.link.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(s.link.timeout)))
if err != nil {
return err
}
_, err = s.link.conn.Write(buf)
if err != nil {
s.log(WarnLevel, pkg+"write failed", "error", err.Error())
return err
}
return nil
}
func sendConnectPacket(s *Session) error { func sendConnectPacket(s *Session) error {
var pbuf [4096]byte var pbuf [4096]byte
pkt := packet{ pkt := packet{
@ -817,14 +782,14 @@ func handshake(s *Session) error {
clientsig[i] = byte(rand.Intn(256)) clientsig[i] = byte(rand.Intn(256))
} }
err := writeN(s, clientbuf[:]) err := s.write(clientbuf[:])
if err != nil { if err != nil {
return err return err
} }
s.log(DebugLevel, pkg+"handshake sent") s.log(DebugLevel, pkg+"handshake sent")
var typ [1]byte var typ [1]byte
err = readN(s, typ[:]) err = s.read(typ[:])
if err != nil { if err != nil {
return err return err
} }
@ -833,7 +798,7 @@ func handshake(s *Session) error {
if typ[0] != clientbuf[0] { if typ[0] != clientbuf[0] {
s.log(WarnLevel, pkg+"handshake type mismatch", "sent", clientbuf[0], "received", typ) s.log(WarnLevel, pkg+"handshake type mismatch", "sent", clientbuf[0], "received", typ)
} }
err = readN(s, serversig[:]) err = s.read(serversig[:])
if err != nil { if err != nil {
return err return err
} }
@ -843,12 +808,12 @@ func handshake(s *Session) error {
s.log(DebugLevel, pkg+"server uptime", "uptime", suptime) s.log(DebugLevel, pkg+"server uptime", "uptime", suptime)
// 2nd part of handshake // 2nd part of handshake
err = writeN(s, serversig[:]) err = s.write(serversig[:])
if err != nil { if err != nil {
return err return err
} }
err = readN(s, serversig[:]) err = s.read(serversig[:])
if err != nil { if err != nil {
return err return err
} }

View File

@ -33,6 +33,11 @@ LICENSE
*/ */
package rtmp package rtmp
import (
"io"
"time"
)
// Session holds the state for an RTMP session. // Session holds the state for an RTMP session.
type Session struct { type Session struct {
url string url string
@ -183,6 +188,43 @@ func (s *Session) Write(data []byte) (int, error) {
return len(data), nil return len(data), nil
} }
// I/O functions
// read from an RTMP connection.
func (s *Session) read(buf []byte) error {
err := s.link.conn.SetReadDeadline(time.Now().Add(time.Second * time.Duration(s.link.timeout)))
if err != nil {
return err
}
n, err := io.ReadFull(s.link.conn, buf)
if err != nil {
s.log(DebugLevel, pkg+"read failed", "error", err.Error())
return err
}
s.nBytesIn += int32(n)
if s.nBytesIn > (s.nBytesInSent + s.clientBW/10) {
err := sendBytesReceived(s)
if err != nil {
return err
}
}
return nil
}
// write to an RTMP connection.
func (s *Session) write(buf []byte) error {
//ToDo: consider using a different timeout for writes than for reads
err := s.link.conn.SetWriteDeadline(time.Now().Add(time.Second * time.Duration(s.link.timeout)))
if err != nil {
return err
}
_, err = s.link.conn.Write(buf)
if err != nil {
s.log(WarnLevel, pkg+"write failed", "error", err.Error())
return err
}
return nil
}
// isConnected returns true if the RTMP connection is up. // isConnected returns true if the RTMP connection is up.
func (s *Session) isConnected() bool { func (s *Session) isConnected() bool {
return s.link.conn != nil return s.link.conn != nil