av/codec/h264/decode/nalunit.go

236 lines
6.0 KiB
Go

package h264
import (
"bitbucket.org/ausocean/av/codec/h264/decode/bits"
"github.com/pkg/errors"
)
type NalUnit struct {
NumBytes int
ForbiddenZeroBit int
RefIdc int
Type int
SvcExtensionFlag int
Avc3dExtensionFlag int
SVCExtension *SVCExtension
ThreeDAVCExtension *ThreeDAVCExtension
EmulationPreventionThreeByte byte
rbsp []byte
}
type ThreeDAVCExtension struct {
ViewIdx int
DepthFlag bool
NonIdrFlag bool
TemporalID int
AnchorPicFlag bool
InterViewFlag bool
}
func NewThreeDAVCExtension(br *bits.BitReader) (*ThreeDAVCExtension, error) {
e := &ThreeDAVCExtension{}
var err error
e.ViewIdx, err = br.ReadBitsInt(8)
if err != nil {
return nil, errors.Wrap(err, "could not read ViewIdx")
}
e.DepthFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read DepthFlag")
}
e.NonIdrFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read NonIdrFlag")
}
e.TemporalID, err = br.ReadBitsInt(3)
if err != nil {
return nil, errors.Wrap(err, "could not read TemporalId")
}
e.AnchorPicFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read AnchorPicFlag")
}
e.InterViewFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read InterViewFlag")
}
return e, nil
}
type SVCExtension struct {
IdrFlag bool
PriorityId int
NoInterLayerPredFlag bool
DependencyId int
QualityId int
TemporalId int
UseRefBasePicFlag bool
DiscardableFlag bool
OutputFlag bool
ReservedThree2Bits int
}
func NewSVCExtension(br *bits.BitReader) (*SVCExtension, error) {
e := &SVCExtension{}
var err error
e.IdrFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read IdrFlag")
}
e.PriorityId, err = br.ReadBitsInt(6)
if err != nil {
return nil, errors.Wrap(err, "could not read PriorityId")
}
e.NoInterLayerPredFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read NoInterLayerPredFlag")
}
e.DependencyId, err = br.ReadBitsInt(3)
if err != nil {
return nil, errors.Wrap(err, "could not read DependencyId")
}
e.QualityId, err = br.ReadBitsInt(4)
if err != nil {
return nil, errors.Wrap(err, "could not read QualityId")
}
e.TemporalId, err = br.ReadBitsInt(3)
if err != nil {
return nil, errors.Wrap(err, "could not read TemporalId")
}
e.UseRefBasePicFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read UseRefBasePicFlag")
}
e.DiscardableFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read DiscardableFlag")
}
e.OutputFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read OutputFlag")
}
e.ReservedThree2Bits, err = br.ReadBitsInt(2)
if err != nil {
return nil, errors.Wrap(err, "could not read ReservedThree2Bits")
}
return e, nil
}
func NalUnitHeaderMvcExtension(nalUnit *NalUnit, br *bits.BitReader) error {
return readFields(br, []field{
{&nalUnit.NonIdrFlag, "NonIdrFlag", 1},
{&nalUnit.PriorityId, "PriorityId", 6},
{&nalUnit.ViewId, "ViewId", 10},
{&nalUnit.TemporalId, "TemporalId", 3},
{&nalUnit.AnchorPicFlag, "AnchorPicFlag", 1},
{&nalUnit.InterViewFlag, "InterViewFlag", 1},
{&nalUnit.ReservedOneBit, "ReservedOneBit", 1},
})
}
func (n *NalUnit) RBSP() []byte {
return n.rbsp
}
func NewNalUnit(frame []byte, numBytesInNal int) (*NalUnit, error) {
logger.Printf("debug: reading %d byte NAL\n", numBytesInNal)
nalUnit := NalUnit{
NumBytes: numBytesInNal,
HeaderBytes: 1,
}
// TODO: pass in actual io.Reader to NewBitReader
br := bits.NewBitReader(nil)
err := readFields(br, []field{
{&nalUnit.ForbiddenZeroBit, "ForbiddenZeroBit", 1},
{&nalUnit.RefIdc, "NalRefIdc", 2},
{&nalUnit.Type, "NalUnitType", 5},
})
if err != nil {
return nil, err
}
if nalUnit.Type == 14 || nalUnit.Type == 20 || nalUnit.Type == 21 {
if nalUnit.Type != 21 {
b, err := br.ReadBits(1)
if err != nil {
return nil, errors.Wrap(err, "could not read SvcExtensionFlag")
}
nalUnit.SvcExtensionFlag = int(b)
} else {
b, err := br.ReadBits(1)
if err != nil {
return nil, errors.Wrap(err, "could not read Avc3dExtensionFlag")
}
nalUnit.Avc3dExtensionFlag = int(b)
}
if nalUnit.SvcExtensionFlag == 1 {
NalUnitHeaderSvcExtension(&nalUnit, br)
nalUnit.HeaderBytes += 3
} else if nalUnit.Avc3dExtensionFlag == 1 {
NalUnitHeader3davcExtension(&nalUnit, br)
nalUnit.HeaderBytes += 2
} else {
NalUnitHeaderMvcExtension(&nalUnit, br)
nalUnit.HeaderBytes += 3
}
}
logger.Printf("debug: found %d byte header. Reading body\n", nalUnit.HeaderBytes)
for i := nalUnit.HeaderBytes; i < nalUnit.NumBytes; i++ {
next3Bytes, err := br.PeekBits(24)
if err != nil {
logger.Printf("error: while reading next 3 NAL bytes: %v\n", err)
break
}
// Little odd, the err above and the i+2 check might be synonyms
if i+2 < nalUnit.NumBytes && next3Bytes == 0x000003 {
for j := 0; j < 2; j++ {
rbspByte, err := br.ReadBits(8)
if err != nil {
return nil, errors.Wrap(err, "could not read rbspByte")
}
nalUnit.rbsp = append(nalUnit.rbsp, byte(rbspByte))
}
i += 2
// Read Emulation prevention three byte.
eptByte, err := br.ReadBits(8)
if err != nil {
return nil, errors.Wrap(err, "could not read eptByte")
}
nalUnit.EmulationPreventionThreeByte = byte(eptByte)
} else {
if b, err := br.ReadBits(8); err == nil {
nalUnit.rbsp = append(nalUnit.rbsp, byte(b))
} else {
logger.Printf("error: while reading byte %d of %d nal bytes: %v\n", i, nalUnit.NumBytes, err)
break
}
}
}
// nalUnit.rbsp = frame[nalUnit.HeaderBytes:]
logger.Printf("info: decoded %s NAL with %d RBSP bytes\n", NALUnitType[nalUnit.Type], len(nalUnit.rbsp))
return &nalUnit, nil
}