diff --git a/codec/h264/decode/nalunit.go b/codec/h264/decode/nalunit.go index 9969cf91..2adf3165 100644 --- a/codec/h264/decode/nalunit.go +++ b/codec/h264/decode/nalunit.go @@ -1,23 +1,12 @@ package h264 import ( - "bitbucket.org/ausocean/av/codec/h264/decode/bits" - "github.com/pkg/errors" -) + "io" -type NalUnit struct { - NumBytes int - ForbiddenZeroBit int - RefIdc int - Type int - SvcExtensionFlag int - Avc3dExtensionFlag int - SVCExtension *SVCExtension - ThreeDAVCExtension *ThreeDAVCExtension - MVCExtension *MVCExtension - EmulationPreventionThreeByte byte - rbsp []byte -} + "github.com/pkg/errors" + + "bitbucket.org/ausocean/av/codec/h264/decode/bits" +) type MVCExtension struct { NonIdrFlag bool @@ -187,70 +176,80 @@ func NewSVCExtension(br *bits.BitReader) (*SVCExtension, error) { return e, nil } -func (n *NalUnit) RBSP() []byte { - return n.rbsp +type NALUnit struct { + ForbiddenZeroBit int + RefIdc int + Type int + SVCExtensionFlag int + AVC3dExtensionFlag int + SVCExtension *SVCExtension + ThreeDAVCExtension *ThreeDAVCExtension + MVCExtension *MVCExtension + EmulationPreventionThreeByte byte + RBSP []byte } -func NewNalUnit(frame []byte, numBytesInNal int) (*NalUnit, error) { - logger.Printf("debug: reading %d byte NAL\n", numBytesInNal) - nalUnit := NalUnit{ - NumBytes: numBytesInNal, - HeaderBytes: 1, - } - // TODO: pass in actual io.Reader to NewBitReader - br := bits.NewBitReader(nil) +func NewNALUnit(br *bits.BitReader) (*NALUnit, error) { + n := &NALUnit{} err := readFields(br, []field{ - {&nalUnit.ForbiddenZeroBit, "ForbiddenZeroBit", 1}, - {&nalUnit.RefIdc, "NalRefIdc", 2}, - {&nalUnit.Type, "NalUnitType", 5}, + {&n.ForbiddenZeroBit, "ForbiddenZeroBit", 1}, + {&n.RefIdc, "NalRefIdc", 2}, + {&n.Type, "NalUnitType", 5}, }) if err != nil { return nil, err } - if nalUnit.Type == 14 || nalUnit.Type == 20 || nalUnit.Type == 21 { - if nalUnit.Type != 21 { - b, err := br.ReadBits(1) - if err != nil { - return nil, errors.Wrap(err, "could not read SvcExtensionFlag") - } - nalUnit.SvcExtensionFlag = int(b) - } else { - b, err := br.ReadBits(1) - if err != nil { - return nil, errors.Wrap(err, "could not read Avc3dExtensionFlag") - } - nalUnit.Avc3dExtensionFlag = int(b) - } - if nalUnit.SvcExtensionFlag == 1 { - NalUnitHeaderSvcExtension(&nalUnit, br) - nalUnit.HeaderBytes += 3 - } else if nalUnit.Avc3dExtensionFlag == 1 { - NalUnitHeader3davcExtension(&nalUnit, br) - nalUnit.HeaderBytes += 2 - } else { - NalUnitHeaderMvcExtension(&nalUnit, br) - nalUnit.HeaderBytes += 3 + var headBytes, rbspBytes = 1, 0 + // TODO: use consts for the NAL types here + if n.Type == 14 || n.Type == 20 || n.Type == 21 { + if n.Type != 21 { + n.SVCExtensionFlag, err = br.ReadBool() + if err != nil { + return nil, errors.Wrap(err, "could not read SVCExtensionFlag") + } + } else { + n.AVC3DExtensionFlag, err = br.ReadBool() + if err != nil { + return nil, errors.Wrap(err, "could not read AVC3DExtensionFlag") + } + } + if n.SVCExtensionFlag { + n.SVCExtension, err = NewSVCExtension(br) + if err != nil { + return nil, errors.Wrap(err, "could not parse SVCExtension") + } + n.HeaderBytes += 3 + } else if n.AVC3DExtensionFlag { + n.ThreeDAVCExtension, err = NewThreeDAVCExtension(br) + if err != nil { + return nil, errors.Wrap(err, "could not parse ThreeDAVCExtension") + } + n.HeaderBytes += 2 + } else { + n.MVCExtension, err = NewMVCExtension(br) + if err != nil { + return nil, errors.Wrap(err, "could not parse MVCExtension") + } + n.HeaderBytes += 3 } } - logger.Printf("debug: found %d byte header. Reading body\n", nalUnit.HeaderBytes) - for i := nalUnit.HeaderBytes; i < nalUnit.NumBytes; i++ { + for moreRBSPData(br) { next3Bytes, err := br.PeekBits(24) - if err != nil { - logger.Printf("error: while reading next 3 NAL bytes: %v\n", err) - break + if err != nil && errors.Cause(err) != io.EOF { + return nil, errors.Wrap("could not Peek next 3 bytes") } - // Little odd, the err above and the i+2 check might be synonyms - if i+2 < nalUnit.NumBytes && next3Bytes == 0x000003 { + + if next3Bytes == 0x000003 { for j := 0; j < 2; j++ { rbspByte, err := br.ReadBits(8) if err != nil { return nil, errors.Wrap(err, "could not read rbspByte") } - nalUnit.rbsp = append(nalUnit.rbsp, byte(rbspByte)) + n.rbsp = append(n.rbsp, byte(rbspByte)) } i += 2 @@ -259,18 +258,15 @@ func NewNalUnit(frame []byte, numBytesInNal int) (*NalUnit, error) { if err != nil { return nil, errors.Wrap(err, "could not read eptByte") } - nalUnit.EmulationPreventionThreeByte = byte(eptByte) + n.EmulationPreventionThreeByte = byte(eptByte) } else { - if b, err := br.ReadBits(8); err == nil { - nalUnit.rbsp = append(nalUnit.rbsp, byte(b)) - } else { - logger.Printf("error: while reading byte %d of %d nal bytes: %v\n", i, nalUnit.NumBytes, err) - break + b, err := br.ReadBits(8) + if err != nil { + return nil, errors.Wrap(err, "could not read RBSP byte") } + n.rbsp = append(n.rbsp, byte(b)) } } - // nalUnit.rbsp = frame[nalUnit.HeaderBytes:] - logger.Printf("info: decoded %s NAL with %d RBSP bytes\n", NALUnitType[nalUnit.Type], len(nalUnit.rbsp)) - return &nalUnit, nil + return n, nil }