mirror of https://bitbucket.org/ausocean/av.git
271 lines
6.4 KiB
Go
271 lines
6.4 KiB
Go
package h264dec
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
|
|
"bitbucket.org/ausocean/av/codec/h264/h264dec/bits"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type H264Reader struct {
|
|
IsStarted bool
|
|
Stream io.Reader
|
|
NalUnits []*bits.BitReader
|
|
VideoStreams []*VideoStream
|
|
DebugFile *os.File
|
|
bytes []byte
|
|
byteOffset int
|
|
*bits.BitReader
|
|
}
|
|
|
|
func (h *H264Reader) BufferToReader(cntBytes int) error {
|
|
buf := make([]byte, cntBytes)
|
|
if _, err := h.Stream.Read(buf); err != nil {
|
|
logger.Printf("error: while reading %d bytes: %v\n", cntBytes, err)
|
|
return err
|
|
}
|
|
h.bytes = append(h.bytes, buf...)
|
|
if h.DebugFile != nil {
|
|
h.DebugFile.Write(buf)
|
|
}
|
|
h.byteOffset += cntBytes
|
|
return nil
|
|
}
|
|
|
|
func (h *H264Reader) Discard(cntBytes int) error {
|
|
buf := make([]byte, cntBytes)
|
|
if _, err := h.Stream.Read(buf); err != nil {
|
|
logger.Printf("error: while discarding %d bytes: %v\n", cntBytes, err)
|
|
return err
|
|
}
|
|
h.byteOffset += cntBytes
|
|
return nil
|
|
}
|
|
|
|
// TODO: what does this do ?
|
|
func bitVal(bits []int) int {
|
|
t := 0
|
|
for i, b := range bits {
|
|
if b == 1 {
|
|
t += 1 << uint((len(bits)-1)-i)
|
|
}
|
|
}
|
|
// fmt.Printf("\t bitVal: %d\n", t)
|
|
return t
|
|
}
|
|
|
|
func (h *H264Reader) Start() {
|
|
for {
|
|
// TODO: need to handle error from this.
|
|
nalUnit, _, _ := h.readNalUnit()
|
|
switch nalUnit.Type {
|
|
case naluTypeSPS:
|
|
// TODO: handle this error
|
|
sps, _ := NewSPS(nalUnit.rbsp, false)
|
|
h.VideoStreams = append(
|
|
h.VideoStreams,
|
|
&VideoStream{SPS: sps},
|
|
)
|
|
case naluTypePPS:
|
|
videoStream := h.VideoStreams[len(h.VideoStreams)-1]
|
|
// TODO: handle this error
|
|
videoStream.PPS, _ = NewPPS(nil, videoStream.SPS.ChromaFormat)
|
|
case naluTypeSliceIDRPicture:
|
|
fallthrough
|
|
case naluTypeSliceNonIDRPicture:
|
|
videoStream := h.VideoStreams[len(h.VideoStreams)-1]
|
|
logger.Printf("info: frame number %d\n", len(videoStream.Slices))
|
|
// TODO: handle this error
|
|
sliceContext, _ := NewSliceContext(videoStream, nalUnit, nalUnit.RBSP(), true)
|
|
videoStream.Slices = append(videoStream.Slices, sliceContext)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *H264Reader) readNalUnit() (*NalUnit, *bits.BitReader, error) {
|
|
// Read to start of NAL
|
|
logger.Printf("debug: Seeking NAL %d start\n", len(r.NalUnits))
|
|
|
|
// TODO: Fix this.
|
|
for !isStartSequence(nil) {
|
|
if err := r.BufferToReader(1); err != nil {
|
|
// TODO: should this return an error here.
|
|
return nil, nil, nil
|
|
}
|
|
}
|
|
/*
|
|
if !r.IsStarted {
|
|
logger.Printf("debug: skipping initial NAL zero byte spaces\n")
|
|
r.LogStreamPosition()
|
|
// Annex B.2 Step 1
|
|
if err := r.Discard(1); err != nil {
|
|
logger.Printf("error: while discarding empty byte (Annex B.2:1): %v\n", err)
|
|
return nil
|
|
}
|
|
if err := r.Discard(2); err != nil {
|
|
logger.Printf("error: while discarding start code prefix one 3bytes (Annex B.2:2): %v\n", err)
|
|
return nil
|
|
}
|
|
}
|
|
*/
|
|
startOffset := r.BytesRead()
|
|
logger.Printf("debug: Seeking next NAL start\n")
|
|
// Read to start of next NAL
|
|
so := r.BytesRead()
|
|
for so == startOffset || !isStartSequence(nil) {
|
|
so = r.BytesRead()
|
|
if err := r.BufferToReader(1); err != nil {
|
|
// TODO: should this return an error here?
|
|
return nil, nil, nil
|
|
}
|
|
}
|
|
// logger.Printf("debug: PreRewind %#v\n", r.Bytes())
|
|
// Rewind back the length of the start sequence
|
|
// r.RewindBytes(4)
|
|
// logger.Printf("debug: PostRewind %#v\n", r.Bytes())
|
|
endOffset := r.BytesRead()
|
|
logger.Printf("debug: found NAL unit with %d bytes from %d to %d\n", endOffset-startOffset, startOffset, endOffset)
|
|
nalUnitReader := bits.NewBitReader(nil)
|
|
r.NalUnits = append(r.NalUnits, nalUnitReader)
|
|
// TODO: this should really take an io.Reader rather than []byte. Need to fix nil
|
|
// once this is fixed.
|
|
nalUnit, err := NewNalUnit(nil, 0)
|
|
if err != nil {
|
|
return nil, nil, errors.Wrap(err, "cannot create new nal unit")
|
|
}
|
|
return nalUnit, nalUnitReader, nil
|
|
}
|
|
|
|
func isStartSequence(packet []byte) bool {
|
|
if len(packet) < len(InitialNALU) {
|
|
return false
|
|
}
|
|
naluSegment := packet[len(packet)-4:]
|
|
for i := range InitialNALU {
|
|
if naluSegment[i] != InitialNALU[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func isStartCodeOnePrefix(buf []byte) bool {
|
|
for i, b := range buf {
|
|
if i < 2 && b != byte(0) {
|
|
return false
|
|
}
|
|
// byte 3 may be 0 or 1
|
|
if i == 3 && b != byte(0) || b != byte(1) {
|
|
return false
|
|
}
|
|
}
|
|
logger.Printf("debug: found start code one prefix byte\n")
|
|
return true
|
|
}
|
|
|
|
func isEmpty3Byte(buf []byte) bool {
|
|
if len(buf) < 3 {
|
|
return false
|
|
}
|
|
for _, i := range buf[len(buf)-3:] {
|
|
if i != 0 {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func moreRBSPData(br *bits.BitReader) bool {
|
|
// If we get an error then we must at end of NAL unit or end of stream, so
|
|
// return false.
|
|
b, err := br.PeekBits(1)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// If b is not 1, then we don't have a stop bit and therefore there is more
|
|
// data so return true.
|
|
if b == 0 {
|
|
return true
|
|
}
|
|
|
|
// If we have a stop bit and trailing zeros then we're okay, otherwise return
|
|
// now, we haven't found the end.
|
|
b, err = br.PeekBits(8 - br.Off())
|
|
if err != nil {
|
|
return false
|
|
}
|
|
rem := 0x01 << uint(7-br.Off())
|
|
if int(b) != rem {
|
|
return true
|
|
}
|
|
|
|
// If we try to read another bit but get EOF then we must be at the end of the
|
|
// NAL or stream.
|
|
_, err = br.PeekBits(9 - br.Off())
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Do we have some trailing 0 bits, and then a 24-bit start code ? If so, it
|
|
// there must not be any more RBSP data left.
|
|
// If we get an error from the Peek, then there must not be another NAL, and
|
|
// there must be some more RBSP, because trailing bits do not extend past the
|
|
// byte in which the stop bit is found.
|
|
b, err = br.PeekBits(8 - br.Off() + 24)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
rem = (0x01 << uint((7-br.Off())+24)) | 0x01
|
|
if int(b) == rem {
|
|
return false
|
|
}
|
|
|
|
// Similar check to above, but this time checking for 32-bit start code.
|
|
b, err = br.PeekBits(8 - br.Off() + 32)
|
|
if err != nil {
|
|
return true
|
|
}
|
|
rem = (0x01 << uint((7-br.Off())+32)) | 0x01
|
|
if int(b) == rem {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
type field struct {
|
|
loc *int
|
|
name string
|
|
n int
|
|
}
|
|
|
|
func readFields(br *bits.BitReader, fields []field) error {
|
|
for _, f := range fields {
|
|
b, err := br.ReadBits(f.n)
|
|
if err != nil {
|
|
return errors.Wrap(err, fmt.Sprintf("could not read %s", f.name))
|
|
}
|
|
*f.loc = int(b)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type flag struct {
|
|
loc *bool
|
|
name string
|
|
}
|
|
|
|
func readFlags(br *bits.BitReader, flags []flag) error {
|
|
for _, f := range flags {
|
|
b, err := br.ReadBits(1)
|
|
if err != nil {
|
|
return errors.Wrap(err, fmt.Sprintf("could not read %s", f.name))
|
|
}
|
|
*f.loc = b == 1
|
|
}
|
|
return nil
|
|
}
|