diff --git a/container/mts/audio_test.go b/container/mts/audio_test.go index d03bd14e..36df7df8 100644 --- a/container/mts/audio_test.go +++ b/container/mts/audio_test.go @@ -47,8 +47,8 @@ func TestEncodePcm(t *testing.T) { buf := bytes.NewBuffer(b) sampleRate := 48000 sampleSize := 2 - writeSize := 16000 - writeFreq := float64(sampleRate*sampleSize) / float64(writeSize) + blockSize := 16000 + writeFreq := float64(sampleRate*sampleSize) / float64(blockSize) e := NewEncoder(buf, writeFreq, Audio) inPath := "../../../test/test-data/av/input/sweep_400Hz_20000Hz_-3dBFS_5s_48khz.pcm" @@ -58,19 +58,33 @@ func TestEncodePcm(t *testing.T) { } // Encode pcm to mts and get the resulting bytes. - _, err = e.Write(inPcm) - if err != nil { - log.Fatal(err) + for i := 0; i < len(inPcm); i += blockSize { + if len(inPcm)-i < blockSize { + block := inPcm[i:] + _, err = e.Write(block) + if err != nil { + log.Fatal(err) + } + } else { + block := inPcm[i : i+blockSize] + _, err = e.Write(block) + if err != nil { + log.Fatal(err) + } + } } clip := buf.Bytes() // Decode the mts packets to extract the original data var pkt packet.Packet - pesPacket := make([]byte, 0, writeSize) + pesPacket := make([]byte, 0, blockSize) got := make([]byte, 0, len(inPcm)) i := 0 - for i+PacketSize <= len(clip) { + if i+PacketSize <= len(clip) { copy(pkt[:], clip[i:i+PacketSize]) + } + + for i+PacketSize <= len(clip) { if pkt.PID() == audioPid { if pkt.PayloadUnitStartIndicator() { payload, err := pkt.Payload() @@ -78,18 +92,28 @@ func TestEncodePcm(t *testing.T) { t.Fatalf("Unexpected err: %v\n", err) } pesPacket = append(pesPacket, payload...) + i += PacketSize - first := true - for (first || !pkt.PayloadUnitStartIndicator()) && i+PacketSize <= len(clip) { - first = false + if i+PacketSize <= len(clip) { copy(pkt[:], clip[i:i+PacketSize]) - payload, err := pkt.Payload() + } + + for (!pkt.PayloadUnitStartIndicator()) && i+PacketSize <= len(clip) { + payload, err = pkt.Payload() if err != nil { t.Fatalf("Unexpected err: %v\n", err) } - pesPacket = append(pesPacket, payload...) + i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } + } + } else { + i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) } } pesHeader, err := pes.NewPESHeader(pesPacket) @@ -97,8 +121,12 @@ func TestEncodePcm(t *testing.T) { t.Fatalf("Unexpected err: %v\n", err) } got = append(got, pesHeader.Data()...) + pesPacket = pesPacket[:0] } else { i += PacketSize + if i+PacketSize <= len(clip) { + copy(pkt[:], clip[i:i+PacketSize]) + } } } diff --git a/container/mts/encoder.go b/container/mts/encoder.go index 5824ca15..34cfc8de 100644 --- a/container/mts/encoder.go +++ b/container/mts/encoder.go @@ -29,6 +29,7 @@ LICENSE package mts import ( + "fmt" "io" "time" @@ -206,6 +207,9 @@ func (e *Encoder) TimeBasedPsi(b bool, sendCount int) { // Write implements io.Writer. Write takes raw h264 and encodes into mpegts, // then sending it to the encoder's io.Writer destination. func (e *Encoder) Write(data []byte) (int, error) { + if len(data) > pes.MaxPesSize { + return 0, fmt.Errorf("data size too large (Max is %v): %v", pes.MaxPesSize, len(data)) + } now := time.Now() if (e.timeBasedPsi && (now.Sub(e.psiLastTime) > psiInterval)) || (!e.timeBasedPsi && (e.pktCount >= e.psiSendCount)) { e.pktCount = 0