Merged in remove-load-and-send (pull request #178)

revid: remove load and send methods for senders

Approved-by: kortschak <dan@kortschak.io>
This commit is contained in:
Saxon Milton 2019-04-03 04:05:10 +00:00
commit 2d15e98445
4 changed files with 121 additions and 397 deletions

View File

@ -177,10 +177,11 @@ func (r *Revid) setConfig(config Config) error {
return nil return nil
} }
func (r *Revid) setupPipeline(mtsEnc func(io.Writer, int) io.Writer, flvEnc func(io.Writer, int) (io.Writer, error)) error { // setupPipeline constructs a data pipeline.
func (r *Revid) setupPipeline(mtsEnc, flvEnc func(dst io.Writer, rate int) (io.Writer, error), multiWriter func(...io.Writer) io.Writer) error {
r.buffer = (*buffer)(ring.NewBuffer(ringBufferSize, ringBufferElementSize, writeTimeout)) r.buffer = (*buffer)(ring.NewBuffer(ringBufferSize, ringBufferElementSize, writeTimeout))
r.encoder = make([]io.Writer, 0) r.encoder = r.encoder[:0]
// mtsSenders will hold the senders the require MPEGTS encoding, and flvSenders // mtsSenders will hold the senders the require MPEGTS encoding, and flvSenders
// will hold senders that require FLV encoding. // will hold senders that require FLV encoding.
@ -193,7 +194,7 @@ func (r *Revid) setupPipeline(mtsEnc func(io.Writer, int) io.Writer, flvEnc func
for _, out := range r.config.Outputs { for _, out := range r.config.Outputs {
switch out { switch out {
case Http: case Http:
w = newMtsSender(newMinimalHttpSender(r.ns, r.config.Logger.Log), nil) w = newMtsSender(newHttpSender(r.ns, r.config.Logger.Log), nil)
mtsSenders = append(mtsSenders, w) mtsSenders = append(mtsSenders, w)
case Rtp: case Rtp:
w, err := newRtpSender(r.config.RtpAddress, r.config.Logger.Log, r.config.FrameRate) w, err := newRtpSender(r.config.RtpAddress, r.config.Logger.Log, r.config.FrameRate)
@ -220,8 +221,8 @@ func (r *Revid) setupPipeline(mtsEnc func(io.Writer, int) io.Writer, flvEnc func
// encoder to revid's encoder slice, and give this encoder the mtsSenders // encoder to revid's encoder slice, and give this encoder the mtsSenders
// as a destination. // as a destination.
if len(mtsSenders) != 0 { if len(mtsSenders) != 0 {
ms := newMultiSender(mtsSenders, r.config.Logger.Log) mw := multiWriter(mtsSenders...)
e := mtsEnc(ms, int(r.config.FrameRate)) e, _ := mtsEnc(mw, int(r.config.FrameRate))
r.encoder = append(r.encoder, e) r.encoder = append(r.encoder, e)
} }
@ -229,8 +230,8 @@ func (r *Revid) setupPipeline(mtsEnc func(io.Writer, int) io.Writer, flvEnc func
// encoder to revid's encoder slice, and give this encoder the flvSenders // encoder to revid's encoder slice, and give this encoder the flvSenders
// as a destination. // as a destination.
if len(flvSenders) != 0 { if len(flvSenders) != 0 {
ms := newMultiSender(flvSenders, r.config.Logger.Log) mw := multiWriter(flvSenders...)
e, err := flvEnc(ms, int(r.config.FrameRate)) e, err := flvEnc(mw, int(r.config.FrameRate))
if err != nil { if err != nil {
return err return err
} }
@ -257,9 +258,9 @@ func (r *Revid) setupPipeline(mtsEnc func(io.Writer, int) io.Writer, flvEnc func
return nil return nil
} }
func newMtsEncoder(dst io.Writer, fps int) io.Writer { func newMtsEncoder(dst io.Writer, fps int) (io.Writer, error) {
e := mts.NewEncoder(dst, float64(fps)) e := mts.NewEncoder(dst, float64(fps))
return e return e, nil
} }
func newFlvEncoder(dst io.Writer, fps int) (io.Writer, error) { func newFlvEncoder(dst io.Writer, fps int) (io.Writer, error) {
@ -278,7 +279,7 @@ func (r *Revid) reset(config Config) error {
return err return err
} }
err = r.setupPipeline(newMtsEncoder, newFlvEncoder) err = r.setupPipeline(newMtsEncoder, newFlvEncoder, io.MultiWriter)
if err != nil { if err != nil {
return err return err
} }

View File

@ -75,8 +75,8 @@ type tstMtsEncoder struct {
} }
// newTstMtsEncoder returns a pointer to a newTsMtsEncoder. // newTstMtsEncoder returns a pointer to a newTsMtsEncoder.
func newTstMtsEncoder(dst io.Writer, fps int) io.Writer { func newTstMtsEncoder(dst io.Writer, fps int) (io.Writer, error) {
return &tstMtsEncoder{dst: dst} return &tstMtsEncoder{dst: dst}, nil
} }
func (e *tstMtsEncoder) Write(d []byte) (int, error) { return 0, nil } func (e *tstMtsEncoder) Write(d []byte) (int, error) { return 0, nil }
@ -92,7 +92,21 @@ func newTstFlvEncoder(dst io.Writer, fps int) (io.Writer, error) {
return &tstFlvEncoder{dst: dst}, nil return &tstFlvEncoder{dst: dst}, nil
} }
func (e *tstFlvEncoder) Write(d []byte) (int, error) { return 0, nil } func (e *tstFlvEncoder) Write(d []byte) (int, error) { return len(d), nil }
// dummyMultiWriter emulates the MultiWriter provided by std lib, so that we
// can access the destinations.
type dummyMultiWriter struct {
dst []io.Writer
}
func newDummyMultiWriter(dst ...io.Writer) io.Writer {
return &dummyMultiWriter{
dst: dst,
}
}
func (w *dummyMultiWriter) Write(d []byte) (int, error) { return len(d), nil }
// TestResetEncoderSenderSetup checks that revid.reset() correctly sets up the // TestResetEncoderSenderSetup checks that revid.reset() correctly sets up the
// revid.encoder slice and the senders the encoders write to. // revid.encoder slice and the senders the encoders write to.
@ -200,7 +214,7 @@ func TestResetEncoderSenderSetup(t *testing.T) {
} }
// This logic is what we want to check. // This logic is what we want to check.
err = rv.setupPipeline(newTstMtsEncoder, newTstFlvEncoder) err = rv.setupPipeline(newTstMtsEncoder, newTstFlvEncoder, newDummyMultiWriter)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v for test %v", err, testNum) t.Fatalf("unexpected error: %v for test %v", err, testNum)
} }
@ -237,7 +251,7 @@ func TestResetEncoderSenderSetup(t *testing.T) {
ms = e.(*tstFlvEncoder).dst ms = e.(*tstFlvEncoder).dst
} }
senders := ms.(*multiSender).dst senders := ms.(*dummyMultiWriter).dst
got = len(senders) got = len(senders)
want = len(test.encoders[idx].destinations) want = len(test.encoders[idx].destinations)
if got != want { if got != want {

View File

@ -45,228 +45,17 @@ import (
"bitbucket.org/ausocean/utils/logger" "bitbucket.org/ausocean/utils/logger"
) )
// Sender is intended to provided functionality for the sending of a byte slice
// to an implemented destination.
type Sender interface {
// send takes the bytes slice d and sends to a particular destination as
// implemented.
send(d []byte) error
}
// Log is used by the multiSender. // Log is used by the multiSender.
type Log func(level int8, message string, params ...interface{}) type Log func(level int8, message string, params ...interface{})
// multiSender implements io.Writer. It provides the capacity to send to multiple // httpSender provides an implemntation of io.Writer to perform sends to a http
// senders from a single Write call. // destination.
type multiSender struct { type httpSender struct {
dst []io.Writer
log Log
}
// newMultiSender returns a pointer to a new multiSender.
func newMultiSender(senders []io.Writer, log Log) *multiSender {
return &multiSender{
dst: senders,
log: log,
}
}
// Write implements io.Writer. This will call load (with the passed slice), send
// and release on all senders of multiSender.
func (s *multiSender) Write(d []byte) (int, error) {
for i, sender := range s.dst {
_, err := sender.Write(d)
if err != nil {
s.log(logger.Warning, pkg+"send failed", "sender", i, "error", err)
}
}
for _, sender := range s.dst {
s, ok := sender.(loadSender)
if !ok {
panic("sender is not a loadSender")
}
s.release()
}
return len(d), nil
}
// minimalHttpSender implements Sender for posting HTTP to netreceiver or vidgrind.
type minimalHttpSender struct {
client *netsender.Sender client *netsender.Sender
log func(lvl int8, msg string, args ...interface{}) log func(lvl int8, msg string, args ...interface{})
} }
// newMinimalHttpSender returns a pointer to a new minimalHttpSender. // newMinimalHttpSender returns a pointer to a new minimalHttpSender.
func newMinimalHttpSender(ns *netsender.Sender, log func(lvl int8, msg string, args ...interface{})) *minimalHttpSender {
return &minimalHttpSender{
client: ns,
log: log,
}
}
// send takes a bytes slice d and sends to http using s' http client.
func (s *minimalHttpSender) send(d []byte) error {
return httpSend(d, s.client, s.log)
}
// loadSender is a destination to send a *ring.Chunk to.
// When a loadSender has finished using the *ring.Chunk
// it must be Closed.
type loadSender interface {
// load assigns the *ring.Chunk to the loadSender.
// The load call may fail, but must not mutate the
// the chunk.
load(d []byte) error
// send performs a destination-specific send
// operation. It must not mutate the chunk.
send() error
// release releases the *ring.Chunk.
release()
// close cleans up after use of the loadSender.
close() error
}
// restart is an optional interface for loadSenders that
// can restart their connection.
type restarter interface {
restart() error
}
// fileSender implements loadSender for a local file destination.
type fileSender struct {
file *os.File
data []byte
}
// Write implements io.Writer.
func (s *fileSender) Write(d []byte) (int, error) {
err := s.load(d)
if err != nil {
return 0, err
}
err = s.send()
if err != nil {
return len(d), err
}
return len(d), nil
}
func newFileSender(path string) (io.Writer, error) {
f, err := os.Create(path)
if err != nil {
return nil, err
}
return &fileSender{file: f}, nil
}
func (s *fileSender) load(d []byte) error {
s.data = d
return nil
}
func (s *fileSender) send() error {
_, err := s.file.Write(s.data)
return err
}
func (s *fileSender) release() {}
func (s *fileSender) close() error { return s.file.Close() }
// mtsSender implements loadSender and provides sending capability specifically
// for use with MPEGTS packetization. It handles the construction of appropriately
// lengthed clips based on PSI. It also fixes accounts for discontinuities by
// setting the discontinuity indicator for the first packet of a clip.
type mtsSender struct {
sender Sender
buf []byte
next []byte
pkt packet.Packet
failed bool
discarded bool
repairer *mts.DiscontinuityRepairer
curPid int
}
// Write implements io.Writer.
func (s *mtsSender) Write(d []byte) (int, error) {
return write(s, d)
}
// newMtsSender returns a new mtsSender.
func newMtsSender(s Sender, log func(lvl int8, msg string, args ...interface{})) *mtsSender {
return &mtsSender{
sender: s,
repairer: mts.NewDiscontinuityRepairer(),
}
}
// load takes a *ring.Chunk and assigns to s.next, also grabbing it's pid and
// assigning to s.curPid. s.next if exists is also appended to the sender buf.
func (s *mtsSender) load(d []byte) error {
if s.next != nil {
s.buf = append(s.buf, s.next...)
}
bytes := make([]byte, len(d))
copy(bytes, d)
s.next = bytes
copy(s.pkt[:], bytes)
s.curPid = s.pkt.PID()
return nil
}
// send checks the currently loaded paackets PID; if it is a PAT then what is in
// the mtsSenders buffer is fixed and sent.
func (ms *mtsSender) send() error {
if ms.curPid == mts.PatPid && len(ms.buf) > 0 {
err := ms.fixAndSend()
if err != nil {
return err
}
ms.buf = ms.buf[:0]
}
return nil
}
// fixAndSend checks for discontinuities in the senders buffer and then sends.
// If a discontinuity is found the PAT packet at the start of the clip has it's
// discontintuity indicator set to true.
func (ms *mtsSender) fixAndSend() error {
err := ms.repairer.Repair(ms.buf)
if err == nil {
err = ms.sender.send(ms.buf)
if err == nil {
return nil
}
}
ms.failed = true
ms.repairer.Failed()
return err
}
func (s *mtsSender) close() error { return nil }
// release will set the s.fail flag to false and clear the buffer if
// the previous send was a fail. The currently loaded chunk is also closed.
func (s *mtsSender) release() {
if s.failed {
s.buf = s.buf[:0]
s.failed = false
}
}
// httpSender implements loadSender for posting HTTP to NetReceiver
type httpSender struct {
client *netsender.Sender
log func(lvl int8, msg string, args ...interface{})
data []byte
}
func newHttpSender(ns *netsender.Sender, log func(lvl int8, msg string, args ...interface{})) *httpSender { func newHttpSender(ns *netsender.Sender, log func(lvl int8, msg string, args ...interface{})) *httpSender {
return &httpSender{ return &httpSender{
client: ns, client: ns,
@ -274,13 +63,9 @@ func newHttpSender(ns *netsender.Sender, log func(lvl int8, msg string, args ...
} }
} }
func (s *httpSender) load(d []byte) error { // Write implements io.Writer.
s.data = d func (s *httpSender) Write(d []byte) (int, error) {
return nil return len(d), httpSend(d, s.client, s.log)
}
func (s *httpSender) send() error {
return httpSend(s.data, s.client, s.log)
} }
func httpSend(d []byte, client *netsender.Sender, log func(lvl int8, msg string, args ...interface{})) error { func httpSend(d []byte, client *netsender.Sender, log func(lvl int8, msg string, args ...interface{})) error {
@ -338,9 +123,72 @@ func extractMeta(r string, log func(lvl int8, msg string, args ...interface{}))
return nil return nil
} }
func (s *httpSender) release() {} // fileSender implements loadSender for a local file destination.
type fileSender struct {
file *os.File
data []byte
}
func (s *httpSender) close() error { return nil } func newFileSender(path string) (io.Writer, error) {
f, err := os.Create(path)
if err != nil {
return nil, err
}
return &fileSender{file: f}, nil
}
// Write implements io.Writer.
func (s *fileSender) Write(d []byte) (int, error) {
return s.file.Write(d)
}
func (s *fileSender) close() error { return s.file.Close() }
// mtsSender implements loadSender and provides sending capability specifically
// for use with MPEGTS packetization. It handles the construction of appropriately
// lengthed clips based on PSI. It also fixes accounts for discontinuities by
// setting the discontinuity indicator for the first packet of a clip.
type mtsSender struct {
dst io.Writer
buf []byte
next []byte
pkt packet.Packet
repairer *mts.DiscontinuityRepairer
curPid int
}
// newMtsSender returns a new mtsSender.
func newMtsSender(dst io.Writer, log func(lvl int8, msg string, args ...interface{})) *mtsSender {
return &mtsSender{
dst: dst,
repairer: mts.NewDiscontinuityRepairer(),
}
}
// Write implements io.Writer.
func (s *mtsSender) Write(d []byte) (int, error) {
if s.next != nil {
s.buf = append(s.buf, s.next...)
}
bytes := make([]byte, len(d))
copy(bytes, d)
s.next = bytes
copy(s.pkt[:], bytes)
s.curPid = s.pkt.PID()
if s.curPid == mts.PatPid && len(s.buf) > 0 {
err := s.repairer.Repair(s.buf)
if err == nil {
_, err = s.dst.Write(s.buf)
if err == nil {
goto done
}
}
s.repairer.Failed()
done:
s.buf = s.buf[:0]
}
return len(d), nil
}
// rtmpSender implements loadSender for a native RTMP destination. // rtmpSender implements loadSender for a native RTMP destination.
type rtmpSender struct { type rtmpSender struct {
@ -354,8 +202,6 @@ type rtmpSender struct {
data []byte data []byte
} }
var _ restarter = (*rtmpSender)(nil)
func newRtmpSender(url string, timeout uint, retries int, log func(lvl int8, msg string, args ...interface{})) (*rtmpSender, error) { func newRtmpSender(url string, timeout uint, retries int, log func(lvl int8, msg string, args ...interface{})) (*rtmpSender, error) {
var conn *rtmp.Conn var conn *rtmp.Conn
var err error var err error
@ -381,27 +227,16 @@ func newRtmpSender(url string, timeout uint, retries int, log func(lvl int8, msg
// Write implements io.Writer. // Write implements io.Writer.
func (s *rtmpSender) Write(d []byte) (int, error) { func (s *rtmpSender) Write(d []byte) (int, error) {
return write(s, d)
}
func (s *rtmpSender) load(d []byte) error {
s.data = d
return nil
}
func (s *rtmpSender) send() error {
if s.conn == nil { if s.conn == nil {
return errors.New("no rtmp connection, cannot write") return 0, errors.New("no rtmp connection, cannot write")
} }
_, err := s.conn.Write(s.data) _, err := s.conn.Write(d)
if err != nil { if err != nil {
err = s.restart() err = s.restart()
} }
return err return len(d), err
} }
func (s *rtmpSender) release() {}
func (s *rtmpSender) restart() error { func (s *rtmpSender) restart() error {
s.close() s.close()
var err error var err error
@ -433,11 +268,6 @@ type rtpSender struct {
data []byte data []byte
} }
// Write implements io.Writer.
func (s *rtpSender) Write(d []byte) (int, error) {
return write(s, d)
}
func newRtpSender(addr string, log func(lvl int8, msg string, args ...interface{}), fps uint) (*rtpSender, error) { func newRtpSender(addr string, log func(lvl int8, msg string, args ...interface{}), fps uint) (*rtpSender, error) {
conn, err := net.Dial("udp", addr) conn, err := net.Dial("udp", addr)
if err != nil { if err != nil {
@ -450,30 +280,7 @@ func newRtpSender(addr string, log func(lvl int8, msg string, args ...interface{
return s, nil return s, nil
} }
func (s *rtpSender) load(d []byte) error { // Write implements io.Writer.
s.data = make([]byte, len(d)) func (s *rtpSender) Write(d []byte) (int, error) {
copy(s.data, d) return s.encoder.Write(s.data)
return nil
}
func (s *rtpSender) close() error { return nil }
func (s *rtpSender) release() {}
func (s *rtpSender) send() error {
_, err := s.encoder.Write(s.data)
return err
}
// write wraps the load and send method for loadSenders.
func write(s loadSender, d []byte) (int, error) {
err := s.load(d)
if err != nil {
return 0, err
}
err = s.send()
if err != nil {
return len(d), err
}
return len(d), nil
} }

View File

@ -31,8 +31,6 @@ package revid
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"sync"
"testing" "testing"
"time" "time"
@ -59,25 +57,26 @@ var (
// sender simulates sending of video data, creating discontinuities if // sender simulates sending of video data, creating discontinuities if
// testDiscontinuities is set to true. // testDiscontinuities is set to true.
type sender struct { type destination struct {
buf [][]byte buf [][]byte
testDiscontinuities bool testDiscontinuities bool
discontinuityAt int discontinuityAt int
currentPkt int currentPkt int
} }
// send takes d and neglects if testDiscontinuities is true, returning an error, // Write implements io.Writer.
// Write takes d and neglects if testDiscontinuities is true, returning an error,
// otherwise d is appended to senders buf. // otherwise d is appended to senders buf.
func (ts *sender) send(d []byte) error { func (ts *destination) Write(d []byte) (int, error) {
if ts.testDiscontinuities && ts.currentPkt == ts.discontinuityAt { if ts.testDiscontinuities && ts.currentPkt == ts.discontinuityAt {
ts.currentPkt++ ts.currentPkt++
return errSendFailed return 0, errSendFailed
} }
cpy := make([]byte, len(d)) cpy := make([]byte, len(d))
copy(cpy, d) copy(cpy, d)
ts.buf = append(ts.buf, cpy) ts.buf = append(ts.buf, cpy)
ts.currentPkt++ ts.currentPkt++
return nil return len(d), nil
} }
// log implements the required logging func for some of the structs in use // log implements the required logging func for some of the structs in use
@ -109,8 +108,8 @@ func TestMtsSenderSegment(t *testing.T) {
mts.Meta = meta.New() mts.Meta = meta.New()
// Create ringBuffer, sender, loadsender and the MPEGTS encoder. // Create ringBuffer, sender, loadsender and the MPEGTS encoder.
tstSender := &sender{} tstDst := &destination{}
loadSender := newMtsSender(tstSender, log) loadSender := newMtsSender(tstDst, log)
rb := ring.NewBuffer(rbSize, rbElementSize, wTimeout) rb := ring.NewBuffer(rbSize, rbElementSize, wTimeout)
encoder := mts.NewEncoder((*buffer)(rb), 25) encoder := mts.NewEncoder((*buffer)(rb), 25)
@ -131,22 +130,16 @@ func TestMtsSenderSegment(t *testing.T) {
break break
} }
err = loadSender.load(next.Bytes()) _, err = loadSender.Write(next.Bytes())
if err != nil { if err != nil {
t.Fatalf("Unexpected err: %v\n", err) t.Fatalf("Unexpected err: %v\n", err)
} }
err = loadSender.send()
if err != nil {
t.Fatalf("Unexpected err: %v\n", err)
}
loadSender.release()
next.Close() next.Close()
next = nil next = nil
} }
} }
result := tstSender.buf result := tstDst.buf
expectData := 0 expectData := 0
for clipNo, clip := range result { for clipNo, clip := range result {
t.Logf("Checking clip: %v\n", clipNo) t.Logf("Checking clip: %v\n", clipNo)
@ -199,8 +192,8 @@ func TestMtsSenderDiscontinuity(t *testing.T) {
// Create ringBuffer sender, loadSender and the MPEGTS encoder. // Create ringBuffer sender, loadSender and the MPEGTS encoder.
const clipWithDiscontinuity = 3 const clipWithDiscontinuity = 3
tstSender := &sender{testDiscontinuities: true, discontinuityAt: clipWithDiscontinuity} tstDst := &destination{testDiscontinuities: true, discontinuityAt: clipWithDiscontinuity}
loadSender := newMtsSender(tstSender, log) loadSender := newMtsSender(tstDst, log)
rb := ring.NewBuffer(rbSize, rbElementSize, wTimeout) rb := ring.NewBuffer(rbSize, rbElementSize, wTimeout)
encoder := mts.NewEncoder((*buffer)(rb), 25) encoder := mts.NewEncoder((*buffer)(rb), 25)
@ -220,19 +213,16 @@ func TestMtsSenderDiscontinuity(t *testing.T) {
break break
} }
err = loadSender.load(next.Bytes()) _, err = loadSender.Write(next.Bytes())
if err != nil { if err != nil {
t.Fatalf("Unexpected err: %v\n", err) t.Fatalf("Unexpected err: %v\n", err)
} }
loadSender.send()
loadSender.release()
next.Close() next.Close()
next = nil next = nil
} }
} }
result := tstSender.buf result := tstDst.buf
// First check that we have less clips as expected. // First check that we have less clips as expected.
expectedLen := (((noOfPacketsToWrite/psiSendCount)*2 + noOfPacketsToWrite) / psiSendCount) - 1 expectedLen := (((noOfPacketsToWrite/psiSendCount)*2 + noOfPacketsToWrite) / psiSendCount) - 1
@ -255,91 +245,3 @@ func TestMtsSenderDiscontinuity(t *testing.T) {
t.Fatalf("Did not get discontinuity indicator for PAT") t.Fatalf("Did not get discontinuity indicator for PAT")
} }
} }
// dummyLoadSender is a loadSender implementation that allows us to simulate
// the behaviour of a loadSender and check that it performas as expected.
type dummyLoadSender struct {
data []byte
buf [][]byte
failOnSend bool
failHandled bool
retry bool
mu sync.Mutex
}
// newDummyLoadSender returns a pointer to a new dummyLoadSender.
func newDummyLoadSender(fail bool, retry bool) *dummyLoadSender {
return &dummyLoadSender{failOnSend: fail, failHandled: true, retry: retry}
}
func (s *dummyLoadSender) Write(d []byte) (int, error) {
return write(s, d)
}
// load takes a byte slice and assigns it to the dummyLoadSenders data slice.
func (s *dummyLoadSender) load(d []byte) error {
s.data = d
return nil
}
// send will append to dummyLoadSender's buf slice, only if failOnSend is false.
// If failOnSend is set to true, we expect that data sent won't be written to
// the buf simulating a failed send.
func (s *dummyLoadSender) send() error {
if !s.getFailOnSend() {
s.buf = append(s.buf, s.data)
return nil
}
s.failHandled = false
return errSendFailed
}
func (s *dummyLoadSender) getFailOnSend() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.failOnSend
}
// release sets dummyLoadSender's data slice to nil. data can be checked to see
// if release has been called at the right time.
func (s *dummyLoadSender) release() {
s.data = nil
}
func (s *dummyLoadSender) close() error { return nil }
// handleSendFail simply sets the failHandled flag to true. This can be checked
// to see if handleSendFail has been called by the multiSender at the right time.
func (s *dummyLoadSender) handleSendFail(err error) error {
s.failHandled = true
return nil
}
func (s *dummyLoadSender) retrySend() bool { return s.retry }
// TestMultiSenderWrite checks that we can do basic writing to multiple senders
// using the multiSender.
func TestMultiSenderWrite(t *testing.T) {
senders := []io.Writer{
newDummyLoadSender(false, false),
newDummyLoadSender(false, false),
newDummyLoadSender(false, false),
}
ms := newMultiSender(senders, log)
// Perform some multiSender writes.
const noOfWrites = 5
for i := byte(0); i < noOfWrites; i++ {
ms.Write([]byte{i})
}
// Check that the senders got the data correctly from the writes.
for i := byte(0); i < noOfWrites; i++ {
for j, dest := range ms.dst {
got := dest.(*dummyLoadSender).buf[i][0]
if got != i {
t.Errorf("Did not get expected result for sender: %v. \nGot: %v\nWant: %v\n", j, got, i)
}
}
}
}