diff --git a/revid/revid.go b/revid/revid.go index 5506f9a2..2994831b 100644 --- a/revid/revid.go +++ b/revid/revid.go @@ -177,7 +177,8 @@ func (r *Revid) setConfig(config Config) error { return nil } -func (r *Revid) setupPipeline(mtsEnc func(io.Writer, int) io.Writer, flvEnc func(io.Writer, int) (io.Writer, error)) error { +func (r *Revid) setupPipeline(mtsEnc func(io.Writer, int) io.Writer, + flvEnc func(io.Writer, int) (io.Writer, error), multiWriter func(...io.Writer) io.Writer) error { r.buffer = (*buffer)(ring.NewBuffer(ringBufferSize, ringBufferElementSize, writeTimeout)) r.encoder = make([]io.Writer, 0) @@ -220,8 +221,8 @@ func (r *Revid) setupPipeline(mtsEnc func(io.Writer, int) io.Writer, flvEnc func // encoder to revid's encoder slice, and give this encoder the mtsSenders // as a destination. if len(mtsSenders) != 0 { - ms := newMultiSender(mtsSenders, r.config.Logger.Log) - e := mtsEnc(ms, int(r.config.FrameRate)) + mw := multiWriter(mtsSenders...) + e := mtsEnc(mw, int(r.config.FrameRate)) r.encoder = append(r.encoder, e) } @@ -229,8 +230,8 @@ func (r *Revid) setupPipeline(mtsEnc func(io.Writer, int) io.Writer, flvEnc func // encoder to revid's encoder slice, and give this encoder the flvSenders // as a destination. if len(flvSenders) != 0 { - ms := newMultiSender(flvSenders, r.config.Logger.Log) - e, err := flvEnc(ms, int(r.config.FrameRate)) + mw := multiWriter(flvSenders...) + e, err := flvEnc(mw, int(r.config.FrameRate)) if err != nil { return err } @@ -278,7 +279,7 @@ func (r *Revid) reset(config Config) error { return err } - err = r.setupPipeline(newMtsEncoder, newFlvEncoder) + err = r.setupPipeline(newMtsEncoder, newFlvEncoder, io.MultiWriter) if err != nil { return err } diff --git a/revid/revid_test.go b/revid/revid_test.go index fbd13e72..ff54d128 100644 --- a/revid/revid_test.go +++ b/revid/revid_test.go @@ -94,6 +94,20 @@ func newTstFlvEncoder(dst io.Writer, fps int) (io.Writer, error) { func (e *tstFlvEncoder) Write(d []byte) (int, error) { return 0, nil } +// dummyMultiWriter emulates the MultiWriter provided by std lib, so that we +// can access the destinations. +type dummyMultiWriter struct { + dst []io.Writer +} + +func newDummyMultiWriter(dst ...io.Writer) io.Writer { + return &dummyMultiWriter{ + dst: dst, + } +} + +func (w *dummyMultiWriter) Write(d []byte) (int, error) { return 0, nil } + // TestResetEncoderSenderSetup checks that revid.reset() correctly sets up the // revid.encoder slice and the senders the encoders write to. func TestResetEncoderSenderSetup(t *testing.T) { @@ -200,7 +214,7 @@ func TestResetEncoderSenderSetup(t *testing.T) { } // This logic is what we want to check. - err = rv.setupPipeline(newTstMtsEncoder, newTstFlvEncoder) + err = rv.setupPipeline(newTstMtsEncoder, newTstFlvEncoder, newDummyMultiWriter) if err != nil { t.Fatalf("unexpected error: %v for test %v", err, testNum) } @@ -237,7 +251,7 @@ func TestResetEncoderSenderSetup(t *testing.T) { ms = e.(*tstFlvEncoder).dst } - senders := ms.(*multiSender).dst + senders := ms.(*dummyMultiWriter).dst got = len(senders) want = len(test.encoders[idx].destinations) if got != want { diff --git a/revid/senders.go b/revid/senders.go index 831ebf17..7e4415b4 100644 --- a/revid/senders.go +++ b/revid/senders.go @@ -48,40 +48,6 @@ import ( // Log is used by the multiSender. type Log func(level int8, message string, params ...interface{}) -// multiSender implements io.Writer. It provides the capacity to send to multiple -// senders from a single Write call. -type multiSender struct { - dst []io.Writer - log Log -} - -// newMultiSender returns a pointer to a new multiSender. -func newMultiSender(senders []io.Writer, log Log) *multiSender { - return &multiSender{ - dst: senders, - log: log, - } -} - -// Write implements io.Writer. This will call load (with the passed slice), send -// and release on all senders of multiSender. -func (s *multiSender) Write(d []byte) (int, error) { - for i, sender := range s.dst { - _, err := sender.Write(d) - if err != nil { - s.log(logger.Warning, pkg+"send failed", "sender", i, "error", err) - } - } - for _, sender := range s.dst { - s, ok := sender.(loadSender) - if !ok { - panic("sender is not a loadSender") - } - s.release() - } - return len(d), nil -} - // minimalHttpSender implements Sender for posting HTTP to netreceiver or vidgrind. type httpSender struct { client *netsender.Sender diff --git a/revid/senders_test.go b/revid/senders_test.go index b990c336..b9073afe 100644 --- a/revid/senders_test.go +++ b/revid/senders_test.go @@ -31,8 +31,6 @@ package revid import ( "errors" "fmt" - "io" - "sync" "testing" "time" @@ -247,78 +245,3 @@ func TestMtsSenderDiscontinuity(t *testing.T) { t.Fatalf("Did not get discontinuity indicator for PAT") } } - -// dummyLoadSender is a loadSender implementation that allows us to simulate -// the behaviour of a loadSender and check that it performas as expected. -type dummyLoadSender struct { - data []byte - buf [][]byte - failOnSend bool - failHandled bool - retry bool - mu sync.Mutex -} - -// newDummyLoadSender returns a pointer to a new dummyLoadSender. -func newDummyLoadSender(fail bool, retry bool) *dummyLoadSender { - return &dummyLoadSender{failOnSend: fail, failHandled: true, retry: retry} -} - -func (s *dummyLoadSender) Write(d []byte) (int, error) { - if !s.getFailOnSend() { - s.buf = append(s.buf, s.data) - return len(d), nil - } - s.failHandled = false - return 0, errSendFailed -} - -func (s *dummyLoadSender) getFailOnSend() bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.failOnSend -} - -// release sets dummyLoadSender's data slice to nil. data can be checked to see -// if release has been called at the right time. -func (s *dummyLoadSender) release() { - s.data = nil -} - -func (s *dummyLoadSender) close() error { return nil } - -// handleSendFail simply sets the failHandled flag to true. This can be checked -// to see if handleSendFail has been called by the multiSender at the right time. -func (s *dummyLoadSender) handleSendFail(err error) error { - s.failHandled = true - return nil -} - -func (s *dummyLoadSender) retrySend() bool { return s.retry } - -// TestMultiSenderWrite checks that we can do basic writing to multiple senders -// using the multiSender. -func TestMultiSenderWrite(t *testing.T) { - senders := []io.Writer{ - newDummyLoadSender(false, false), - newDummyLoadSender(false, false), - newDummyLoadSender(false, false), - } - ms := newMultiSender(senders, log) - - // Perform some multiSender writes. - const noOfWrites = 5 - for i := byte(0); i < noOfWrites; i++ { - ms.Write([]byte{i}) - } - - // Check that the senders got the data correctly from the writes. - for i := byte(0); i < noOfWrites; i++ { - for j, dest := range ms.dst { - got := dest.(*dummyLoadSender).buf[i][0] - if got != i { - t.Errorf("Did not get expected result for sender: %v. \nGot: %v\nWant: %v\n", j, got, i) - } - } - } -}