diff --git a/revid/revid_test.go b/revid/revid_test.go index 7ab2f0a8..b2dc87b1 100644 --- a/revid/revid_test.go +++ b/revid/revid_test.go @@ -8,8 +8,6 @@ import ( "runtime" "testing" - "bitbucket.org/ausocean/av/stream/flv" - "bitbucket.org/ausocean/av/stream/mts" "bitbucket.org/ausocean/iot/pi/netsender" ) @@ -71,6 +69,32 @@ func (tl *testLogger) Log(level int8, msg string, params ...interface{}) { } } +// tstMtsEncoder emulates the mts.Encoder to the extent of the dst field. +// This will allow access to the dst to check that it has been set corrctly. +type tstMtsEncoder struct { + dst io.Writer +} + +// newTstMtsEncoder returns a pointer to a newTsMtsEncoder. +func newTstMtsEncoder(dst io.Writer, fps int) io.Writer { + return &tstMtsEncoder{dst: dst} +} + +func (e *tstMtsEncoder) Write(d []byte) (int, error) { return 0, nil } + +// tstFlvEncoder emulates the flv.Encoder to the extent of the dst field. +// This will allow access to the dst to check that it has been set corrctly. +type tstFlvEncoder struct { + dst io.Writer +} + +// newTstFlvEncoder returns a pointer to a new tstFlvEncoder. +func newTstFlvEncoder(dst io.Writer, fps int) (io.Writer, error) { + return &tstFlvEncoder{dst: dst}, nil +} + +func (e *tstFlvEncoder) Write(d []byte) (int, error) { return 0, nil } + // TestResetEncoderSenderSetup checks that revid.reset() correctly sets up the // revid.encoder slice and the senders the encoders write to. func TestResetEncoderSenderSetup(t *testing.T) { @@ -163,10 +187,10 @@ func TestResetEncoderSenderSetup(t *testing.T) { // typeOfEncoder will return the type of encoder implementing stream.Encoder. typeOfEncoder := func(i io.Writer) (string, error) { - if _, ok := i.(*mts.Encoder); ok { + if _, ok := i.(*tstMtsEncoder); ok { return mtsEncoderStr, nil } - if _, ok := i.(*flv.Encoder); ok { + if _, ok := i.(*tstFlvEncoder); ok { return flvEncoderStr, nil } return "", errors.New("unknown Encoder type") @@ -194,9 +218,15 @@ func TestResetEncoderSenderSetup(t *testing.T) { // Go through our test cases. for testNum, test := range tests { // Create a new config and reset revid with it. - const dummyUrl = "rtmp://dummy" - newConfig := Config{Logger: &testLogger{}, Outputs: test.outputs, RtmpUrl: dummyUrl} - err := rv.reset(newConfig) + const dummyURL = "rtmp://dummy" + c := Config{Logger: &testLogger{}, Outputs: test.outputs, RtmpUrl: dummyURL} + err := rv.setConfig(c) + if err != nil { + t.Fatalf("unexpected error: %v for test %v", err, testNum) + } + + // This logic is what we want to check. + err = rv.setupPipeline(newTstMtsEncoder, newTstFlvEncoder) if err != nil { t.Fatalf("unexpected error: %v for test %v", err, testNum) } @@ -224,11 +254,17 @@ func TestResetEncoderSenderSetup(t *testing.T) { } } if idx == -1 { - t.Errorf("encoder %v isn't expected in test %v", encoderType, testNum) + t.Fatalf("encoder %v isn't expected in test %v", encoderType, testNum) } // Now check that this encoder has correct number of destinations (senders). - ms := e.GetDst() + var ms io.Writer + switch encoderType { + case mtsEncoderStr: + ms = e.(*tstMtsEncoder).dst + case flvEncoderStr: + ms = e.(*tstFlvEncoder).dst + } senders := []loadSender(ms.(multiSender)) got = len(senders) want = len(test.encoders[idx].destinations)