mts: removed repeated use of addPadding func

This commit is contained in:
saxon 2018-12-27 14:11:23 +10:30
parent b28861d690
commit c739b10f86
4 changed files with 25 additions and 25 deletions

View File

@ -131,7 +131,7 @@ func (e *Encoder) Encode(nalu []byte) error {
Data: nalu, Data: nalu,
HeaderLength: 5, HeaderLength: 5,
} }
buf := pesPkt.Bytes(e.pesSpace[:pes.MaxPesSize]) buf := pesPkt.Bytes(e.pesSpace[:pes.MaxPesLen])
pusi := true pusi := true
for len(buf) != 0 { for len(buf) != 0 {
@ -179,7 +179,7 @@ func (e *Encoder) writePSI() error {
PID: patPid, PID: patPid,
CC: e.ccFor(patPid), CC: e.ccFor(patPid),
AFC: hasPayload, AFC: hasPayload,
Payload: addPadding(patTable), Payload: patTable,
} }
_, err := e.dst.Write(patPkt.Bytes(e.tsSpace[:PktLen])) _, err := e.dst.Write(patPkt.Bytes(e.tsSpace[:PktLen]))
if err != nil { if err != nil {
@ -202,7 +202,7 @@ func (e *Encoder) writePSI() error {
PID: pmtPid, PID: pmtPid,
CC: e.ccFor(pmtPid), CC: e.ccFor(pmtPid),
AFC: hasPayload, AFC: hasPayload,
Payload: addPadding(pmtTable), Payload: pmtTable,
} }
_, err = e.dst.Write(pmtPkt.Bytes(e.tsSpace[:PktLen])) _, err = e.dst.Write(pmtPkt.Bytes(e.tsSpace[:PktLen]))
if err != nil { if err != nil {
@ -212,15 +212,6 @@ func (e *Encoder) writePSI() error {
return nil return nil
} }
// addPadding adds an appropriate amount of padding to a pat or pmt table for
// addition to an mpegts packet
func addPadding(d []byte) []byte {
for len(d) < psi.PktLen {
d = append(d, 0xff)
}
return d
}
// tick advances the clock one frame interval. // tick advances the clock one frame interval.
func (e *Encoder) tick() { func (e *Encoder) tick() {
e.clock += e.frameInterval e.clock += e.frameInterval
@ -239,7 +230,8 @@ func (e *Encoder) pcr() uint64 {
// ccFor returns the next continuity counter for pid. // ccFor returns the next continuity counter for pid.
func (e *Encoder) ccFor(pid int) byte { func (e *Encoder) ccFor(pid int) byte {
cc := e.continuity[pid] cc := e.continuity[pid]
const continuityCounterMask = 0xf // Continuity counter mask
e.continuity[pid] = (cc + 1) & continuityCounterMask const ccMask = 0xf
e.continuity[pid] = (cc + 1) & ccMask
return cc return cc
} }

View File

@ -143,3 +143,12 @@ func updateCrc(out []byte) []byte {
out[len(out)-1] = byte(crc32) out[len(out)-1] = byte(crc32)
return out return out
} }
// addPadding adds an appropriate amount of padding to a pat or pmt table for
// addition to an mpegts packet
func addPadding(d []byte) []byte {
for len(d) < PktLen {
d = append(d, 0xff)
}
return d
}

View File

@ -238,6 +238,7 @@ func (p *PSI) Bytes() []byte {
out[3] = byte(p.Sl) out[3] = byte(p.Sl)
out = append(out, p.Tss.Bytes()...) out = append(out, p.Tss.Bytes()...)
out = addCrc(out) out = addCrc(out)
out = addPadding(out)
return out return out
} }

View File

@ -91,10 +91,10 @@ var (
} }
// Bytes representing pmt with time1 and location1 // Bytes representing pmt with time1 and location1
pmtTimeLocationBytes1 = buildPmtTimeLocationBytes(locationTstStr1) pmtTimeLocBytes1 = buildPmtTimeLocBytes(locationTstStr1)
// bytes representing pmt with with time1 and location 2 // bytes representing pmt with with time1 and location 2
pmtTimeLocationBytes2 = buildPmtTimeLocationBytes(locationTstStr2) pmtTimeLocBytes2 = buildPmtTimeLocBytes(locationTstStr2)
) )
// bytesTests contains data for testing the Bytes() funcs for the PSI data struct // bytesTests contains data for testing the Bytes() funcs for the PSI data struct
@ -189,7 +189,7 @@ var bytesTests = []struct {
}, },
}, },
}, },
want: buildPmtTimeLocationBytes(locationTstStr1), want: buildPmtTimeLocBytes(locationTstStr1),
}, },
} }
@ -198,9 +198,7 @@ var bytesTests = []struct {
func TestBytes(t *testing.T) { func TestBytes(t *testing.T) {
for _, test := range bytesTests { for _, test := range bytesTests {
got := test.input.Bytes() got := test.input.Bytes()
// Remove crc32s if !bytes.Equal(got, addPadding(addCrc(test.want))) {
got = got[:len(got)-4]
if !bytes.Equal(got, test.want) {
t.Errorf("unexpected error for test %v: got:%v want:%v", test.name, got, t.Errorf("unexpected error for test %v: got:%v want:%v", test.name, got,
test.want) test.want)
} }
@ -259,23 +257,23 @@ func TestLocationGet(t *testing.T) {
// TestLocationUpdate checks to see if we can update the location string in a pmt correctly // TestLocationUpdate checks to see if we can update the location string in a pmt correctly
func TestLocationUpdate(t *testing.T) { func TestLocationUpdate(t *testing.T) {
cpy := make([]byte, len(pmtTimeLocationBytes1)) cpy := make([]byte, len(pmtTimeLocBytes1))
copy(cpy, pmtTimeLocationBytes1) copy(cpy, pmtTimeLocBytes1)
cpy = addCrc(cpy) cpy = addCrc(cpy)
err := UpdateLocation(cpy, locationTstStr2) err := UpdateLocation(cpy, locationTstStr2)
cpy = cpy[:len(cpy)-4] cpy = cpy[:len(cpy)-4]
if err != nil { if err != nil {
t.Errorf("Update time returned err: %v", err) t.Errorf("Update time returned err: %v", err)
} }
if !bytes.Equal(pmtTimeLocationBytes2, cpy) { if !bytes.Equal(pmtTimeLocBytes2, cpy) {
t.Errorf(errCmp, "TestLocationUpdate", pmtTimeLocationBytes2, cpy) t.Errorf(errCmp, "TestLocationUpdate", pmtTimeLocBytes2, cpy)
} }
} }
// buildPmtTimeLocationBytes is a helper function to help construct the byte slices // buildPmtTimeLocationBytes is a helper function to help construct the byte slices
// for pmts with time and location, as the location data field is 32 bytes, i.e. quite large // for pmts with time and location, as the location data field is 32 bytes, i.e. quite large
// to type out // to type out
func buildPmtTimeLocationBytes(tstStr string) []byte { func buildPmtTimeLocBytes(tstStr string) []byte {
return append(append(append(make([]byte, 0), pmtTimeLocationBytesPart1...), return append(append(append(make([]byte, 0), pmtTimeLocationBytesPart1...),
LocationStrBytes(tstStr)...), pmtTimeLocationBytesPart2...) LocationStrBytes(tstStr)...), pmtTimeLocationBytesPart2...)
} }