codec/h264/h264dec: renamed NalUnit function to NewNALUnit and cleaned it up

This commit is contained in:
Saxon 2019-07-19 20:20:33 +09:30
parent 24f6b2d12f
commit f11ba2b433
1 changed files with 65 additions and 69 deletions

View File

@ -1,23 +1,12 @@
package h264 package h264
import ( import (
"bitbucket.org/ausocean/av/codec/h264/decode/bits" "io"
"github.com/pkg/errors"
)
type NalUnit struct { "github.com/pkg/errors"
NumBytes int
ForbiddenZeroBit int "bitbucket.org/ausocean/av/codec/h264/decode/bits"
RefIdc int )
Type int
SvcExtensionFlag int
Avc3dExtensionFlag int
SVCExtension *SVCExtension
ThreeDAVCExtension *ThreeDAVCExtension
MVCExtension *MVCExtension
EmulationPreventionThreeByte byte
rbsp []byte
}
type MVCExtension struct { type MVCExtension struct {
NonIdrFlag bool NonIdrFlag bool
@ -187,70 +176,80 @@ func NewSVCExtension(br *bits.BitReader) (*SVCExtension, error) {
return e, nil return e, nil
} }
func (n *NalUnit) RBSP() []byte { type NALUnit struct {
return n.rbsp ForbiddenZeroBit int
RefIdc int
Type int
SVCExtensionFlag int
AVC3dExtensionFlag int
SVCExtension *SVCExtension
ThreeDAVCExtension *ThreeDAVCExtension
MVCExtension *MVCExtension
EmulationPreventionThreeByte byte
RBSP []byte
} }
func NewNalUnit(frame []byte, numBytesInNal int) (*NalUnit, error) { func NewNALUnit(br *bits.BitReader) (*NALUnit, error) {
logger.Printf("debug: reading %d byte NAL\n", numBytesInNal) n := &NALUnit{}
nalUnit := NalUnit{
NumBytes: numBytesInNal,
HeaderBytes: 1,
}
// TODO: pass in actual io.Reader to NewBitReader
br := bits.NewBitReader(nil)
err := readFields(br, []field{ err := readFields(br, []field{
{&nalUnit.ForbiddenZeroBit, "ForbiddenZeroBit", 1}, {&n.ForbiddenZeroBit, "ForbiddenZeroBit", 1},
{&nalUnit.RefIdc, "NalRefIdc", 2}, {&n.RefIdc, "NalRefIdc", 2},
{&nalUnit.Type, "NalUnitType", 5}, {&n.Type, "NalUnitType", 5},
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
if nalUnit.Type == 14 || nalUnit.Type == 20 || nalUnit.Type == 21 { var headBytes, rbspBytes = 1, 0
if nalUnit.Type != 21 {
b, err := br.ReadBits(1)
if err != nil {
return nil, errors.Wrap(err, "could not read SvcExtensionFlag")
}
nalUnit.SvcExtensionFlag = int(b)
} else {
b, err := br.ReadBits(1)
if err != nil {
return nil, errors.Wrap(err, "could not read Avc3dExtensionFlag")
}
nalUnit.Avc3dExtensionFlag = int(b)
}
if nalUnit.SvcExtensionFlag == 1 {
NalUnitHeaderSvcExtension(&nalUnit, br)
nalUnit.HeaderBytes += 3
} else if nalUnit.Avc3dExtensionFlag == 1 {
NalUnitHeader3davcExtension(&nalUnit, br)
nalUnit.HeaderBytes += 2
} else {
NalUnitHeaderMvcExtension(&nalUnit, br)
nalUnit.HeaderBytes += 3
// TODO: use consts for the NAL types here
if n.Type == 14 || n.Type == 20 || n.Type == 21 {
if n.Type != 21 {
n.SVCExtensionFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read SVCExtensionFlag")
}
} else {
n.AVC3DExtensionFlag, err = br.ReadBool()
if err != nil {
return nil, errors.Wrap(err, "could not read AVC3DExtensionFlag")
}
}
if n.SVCExtensionFlag {
n.SVCExtension, err = NewSVCExtension(br)
if err != nil {
return nil, errors.Wrap(err, "could not parse SVCExtension")
}
n.HeaderBytes += 3
} else if n.AVC3DExtensionFlag {
n.ThreeDAVCExtension, err = NewThreeDAVCExtension(br)
if err != nil {
return nil, errors.Wrap(err, "could not parse ThreeDAVCExtension")
}
n.HeaderBytes += 2
} else {
n.MVCExtension, err = NewMVCExtension(br)
if err != nil {
return nil, errors.Wrap(err, "could not parse MVCExtension")
}
n.HeaderBytes += 3
} }
} }
logger.Printf("debug: found %d byte header. Reading body\n", nalUnit.HeaderBytes) for moreRBSPData(br) {
for i := nalUnit.HeaderBytes; i < nalUnit.NumBytes; i++ {
next3Bytes, err := br.PeekBits(24) next3Bytes, err := br.PeekBits(24)
if err != nil { if err != nil && errors.Cause(err) != io.EOF {
logger.Printf("error: while reading next 3 NAL bytes: %v\n", err) return nil, errors.Wrap("could not Peek next 3 bytes")
break
} }
// Little odd, the err above and the i+2 check might be synonyms
if i+2 < nalUnit.NumBytes && next3Bytes == 0x000003 { if next3Bytes == 0x000003 {
for j := 0; j < 2; j++ { for j := 0; j < 2; j++ {
rbspByte, err := br.ReadBits(8) rbspByte, err := br.ReadBits(8)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not read rbspByte") return nil, errors.Wrap(err, "could not read rbspByte")
} }
nalUnit.rbsp = append(nalUnit.rbsp, byte(rbspByte)) n.rbsp = append(n.rbsp, byte(rbspByte))
} }
i += 2 i += 2
@ -259,18 +258,15 @@ func NewNalUnit(frame []byte, numBytesInNal int) (*NalUnit, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not read eptByte") return nil, errors.Wrap(err, "could not read eptByte")
} }
nalUnit.EmulationPreventionThreeByte = byte(eptByte) n.EmulationPreventionThreeByte = byte(eptByte)
} else { } else {
if b, err := br.ReadBits(8); err == nil { b, err := br.ReadBits(8)
nalUnit.rbsp = append(nalUnit.rbsp, byte(b)) if err != nil {
} else { return nil, errors.Wrap(err, "could not read RBSP byte")
logger.Printf("error: while reading byte %d of %d nal bytes: %v\n", i, nalUnit.NumBytes, err)
break
} }
n.rbsp = append(n.rbsp, byte(b))
} }
} }
// nalUnit.rbsp = frame[nalUnit.HeaderBytes:] return n, nil
logger.Printf("info: decoded %s NAL with %d RBSP bytes\n", NALUnitType[nalUnit.Type], len(nalUnit.rbsp))
return &nalUnit, nil
} }