diff --git a/stream/mts/encoder.go b/stream/mts/encoder.go index 57422a3f..2cafc167 100644 --- a/stream/mts/encoder.go +++ b/stream/mts/encoder.go @@ -131,7 +131,7 @@ func (e *Encoder) Encode(nalu []byte) error { Data: nalu, HeaderLength: 5, } - buf := pesPkt.Bytes(e.pesSpace[:pes.MaxPesSize]) + buf := pesPkt.Bytes(e.pesSpace[:pes.MaxPesLen]) pusi := true for len(buf) != 0 { @@ -179,7 +179,7 @@ func (e *Encoder) writePSI() error { PID: patPid, CC: e.ccFor(patPid), AFC: hasPayload, - Payload: addPadding(patTable), + Payload: patTable, } _, err := e.dst.Write(patPkt.Bytes(e.tsSpace[:PktLen])) if err != nil { @@ -202,7 +202,7 @@ func (e *Encoder) writePSI() error { PID: pmtPid, CC: e.ccFor(pmtPid), AFC: hasPayload, - Payload: addPadding(pmtTable), + Payload: pmtTable, } _, err = e.dst.Write(pmtPkt.Bytes(e.tsSpace[:PktLen])) if err != nil { @@ -212,15 +212,6 @@ func (e *Encoder) writePSI() error { return nil } -// addPadding adds an appropriate amount of padding to a pat or pmt table for -// addition to an mpegts packet -func addPadding(d []byte) []byte { - for len(d) < psi.PktLen { - d = append(d, 0xff) - } - return d -} - // tick advances the clock one frame interval. func (e *Encoder) tick() { e.clock += e.frameInterval @@ -239,7 +230,8 @@ func (e *Encoder) pcr() uint64 { // ccFor returns the next continuity counter for pid. func (e *Encoder) ccFor(pid int) byte { cc := e.continuity[pid] - const continuityCounterMask = 0xf - e.continuity[pid] = (cc + 1) & continuityCounterMask + // Continuity counter mask + const ccMask = 0xf + e.continuity[pid] = (cc + 1) & ccMask return cc } diff --git a/stream/mts/psi/op.go b/stream/mts/psi/op.go index e78cdac7..d71dbf61 100644 --- a/stream/mts/psi/op.go +++ b/stream/mts/psi/op.go @@ -143,3 +143,12 @@ func updateCrc(out []byte) []byte { out[len(out)-1] = byte(crc32) return out } + +// addPadding adds an appropriate amount of padding to a pat or pmt table for +// addition to an mpegts packet +func addPadding(d []byte) []byte { + for len(d) < PktLen { + d = append(d, 0xff) + } + return d +} diff --git a/stream/mts/psi/psi.go b/stream/mts/psi/psi.go index 4db81a59..3aa8bcd4 100644 --- a/stream/mts/psi/psi.go +++ b/stream/mts/psi/psi.go @@ -238,6 +238,7 @@ func (p *PSI) Bytes() []byte { out[3] = byte(p.Sl) out = append(out, p.Tss.Bytes()...) out = addCrc(out) + out = addPadding(out) return out } diff --git a/stream/mts/psi/psi_test.go b/stream/mts/psi/psi_test.go index 86c19f60..d6d6d1cc 100644 --- a/stream/mts/psi/psi_test.go +++ b/stream/mts/psi/psi_test.go @@ -91,10 +91,10 @@ var ( } // Bytes representing pmt with time1 and location1 - pmtTimeLocationBytes1 = buildPmtTimeLocationBytes(locationTstStr1) + pmtTimeLocBytes1 = buildPmtTimeLocBytes(locationTstStr1) // bytes representing pmt with with time1 and location 2 - pmtTimeLocationBytes2 = buildPmtTimeLocationBytes(locationTstStr2) + pmtTimeLocBytes2 = buildPmtTimeLocBytes(locationTstStr2) ) // bytesTests contains data for testing the Bytes() funcs for the PSI data struct @@ -189,7 +189,7 @@ var bytesTests = []struct { }, }, }, - want: buildPmtTimeLocationBytes(locationTstStr1), + want: buildPmtTimeLocBytes(locationTstStr1), }, } @@ -198,9 +198,7 @@ var bytesTests = []struct { func TestBytes(t *testing.T) { for _, test := range bytesTests { got := test.input.Bytes() - // Remove crc32s - got = got[:len(got)-4] - if !bytes.Equal(got, test.want) { + if !bytes.Equal(got, addPadding(addCrc(test.want))) { t.Errorf("unexpected error for test %v: got:%v want:%v", test.name, got, test.want) } @@ -259,23 +257,23 @@ func TestLocationGet(t *testing.T) { // TestLocationUpdate checks to see if we can update the location string in a pmt correctly func TestLocationUpdate(t *testing.T) { - cpy := make([]byte, len(pmtTimeLocationBytes1)) - copy(cpy, pmtTimeLocationBytes1) + cpy := make([]byte, len(pmtTimeLocBytes1)) + copy(cpy, pmtTimeLocBytes1) cpy = addCrc(cpy) err := UpdateLocation(cpy, locationTstStr2) cpy = cpy[:len(cpy)-4] if err != nil { t.Errorf("Update time returned err: %v", err) } - if !bytes.Equal(pmtTimeLocationBytes2, cpy) { - t.Errorf(errCmp, "TestLocationUpdate", pmtTimeLocationBytes2, cpy) + if !bytes.Equal(pmtTimeLocBytes2, cpy) { + t.Errorf(errCmp, "TestLocationUpdate", pmtTimeLocBytes2, cpy) } } // buildPmtTimeLocationBytes is a helper function to help construct the byte slices // for pmts with time and location, as the location data field is 32 bytes, i.e. quite large // to type out -func buildPmtTimeLocationBytes(tstStr string) []byte { +func buildPmtTimeLocBytes(tstStr string) []byte { return append(append(append(make([]byte, 0), pmtTimeLocationBytesPart1...), LocationStrBytes(tstStr)...), pmtTimeLocationBytesPart2...) }