diff --git a/codec/h264/h264dec/nalunit.go b/codec/h264/h264dec/nalunit.go index 2e076567..2e754f12 100644 --- a/codec/h264/h264dec/nalunit.go +++ b/codec/h264/h264dec/nalunit.go @@ -10,6 +10,7 @@ AUTHORS package h264dec import ( + "fmt" "io" "github.com/pkg/errors" @@ -262,18 +263,18 @@ func NewSVCExtension(br *bits.BitReader) (*SVCExtension, error) { // Field semantics are defined in section 7.4.1. type NALUnit struct { // forbidden_zero_bit, always 0. - ForbiddenZeroBit int + ForbiddenZeroBit uint8 // nal_ref_idc, if not 0 indicates content of NAL contains a sequence parameter // set, a sequence parameter set extension, a subset sequence parameter set, // a picture parameter set, a slice of a reference picture, a slice data // partition of a reference picture, or a prefix NAL preceding a slice of // a reference picture. - RefIdc int + RefIdc uint8 // nal_unit_type, specifies the type of RBSP data contained in the NAL as // defined in Table 7-1. - Type int + Type uint8 // svc_extension_flag, indicates whether a nal_unit_header_svc_extension() // (G.7.3.1.1) or nal_unit_header_mvc_extension() (H.7.3.1.1) will follow next @@ -306,28 +307,19 @@ type NALUnit struct { // syntax structure specified in section 7.3.1, and returns as a new NALUnit. func NewNALUnit(br *bits.BitReader) (*NALUnit, error) { n := &NALUnit{} + r := newFieldReader(br) - err := readFields(br, []field{ - {&n.ForbiddenZeroBit, "ForbiddenZeroBit", 1}, - {&n.RefIdc, "NalRefIdc", 2}, - {&n.Type, "NalUnitType", 5}, - }) - if err != nil { - return nil, err - } + n.ForbiddenZeroBit = uint8(r.readBits(1)) + n.RefIdc = uint8(r.readBits(2)) + n.Type = uint8(r.readBits(5)) // TODO: use consts for the NAL types here + var err error if n.Type == 14 || n.Type == 20 || n.Type == 21 { if n.Type != 21 { - n.SVCExtensionFlag, err = br.ReadBool() - if err != nil { - return nil, errors.Wrap(err, "could not read SVCExtensionFlag") - } + n.SVCExtensionFlag = r.readBits(1) == 1 } else { - n.AVC3DExtensionFlag, err = br.ReadBool() - if err != nil { - return nil, errors.Wrap(err, "could not read AVC3DExtensionFlag") - } + n.AVC3DExtensionFlag = r.readBits(1) == 1 } if n.SVCExtensionFlag { n.SVCExtension, err = NewSVCExtension(br) @@ -360,27 +352,20 @@ func NewNALUnit(br *bits.BitReader) (*NALUnit, error) { if next3Bytes == 0x000003 { for j := 0; j < 2; j++ { - rbspByte, err := br.ReadBits(8) - if err != nil { - return nil, errors.Wrap(err, "could not read rbspByte") - } + rbspByte := byte(r.readBits(8)) n.RBSP = append(n.RBSP, byte(rbspByte)) } // Read Emulation prevention three byte. - eptByte, err := br.ReadBits(8) - if err != nil { - return nil, errors.Wrap(err, "could not read eptByte") - } - n.EmulationPreventionThreeByte = byte(eptByte) + n.EmulationPreventionThreeByte = byte(r.readBits(8)) } else { - b, err := br.ReadBits(8) - if err != nil { - return nil, errors.Wrap(err, "could not read RBSP byte") - } - n.RBSP = append(n.RBSP, byte(b)) + n.RBSP = append(n.RBSP, byte(r.readBits(8))) } } + if r.err() != nil { + return nil, fmt.Errorf("fieldReader error: %v", r.err()) + } + return n, nil } diff --git a/codec/h264/h264dec/slice.go b/codec/h264/h264dec/slice.go index 360282a6..77e369d9 100644 --- a/codec/h264/h264dec/slice.go +++ b/codec/h264/h264dec/slice.go @@ -949,7 +949,7 @@ func NewSliceContext(videoStream *VideoStream, nalUnit *NALUnit, rbsp []byte, sh var err error sps := videoStream.SPS pps := videoStream.PPS - logger.Printf("debug: %s RBSP %d bytes %d bits == \n", NALUnitType[nalUnit.Type], len(rbsp), len(rbsp)*8) + logger.Printf("debug: %s RBSP %d bytes %d bits == \n", NALUnitType[int(nalUnit.Type)], len(rbsp), len(rbsp)*8) logger.Printf("debug: \t%#v\n", rbsp[0:8]) var idrPic bool if nalUnit.Type == 5 { @@ -974,7 +974,7 @@ func NewSliceContext(videoStream *VideoStream, nalUnit *NALUnit, rbsp []byte, sh } sliceType := sliceTypeMap[header.SliceType] - logger.Printf("debug: %s (%s) slice of %d bytes\n", NALUnitType[nalUnit.Type], sliceType, len(rbsp)) + logger.Printf("debug: %s (%s) slice of %d bytes\n", NALUnitType[int(nalUnit.Type)], sliceType, len(rbsp)) header.PPSID, err = readUe(br) if err != nil { return nil, errors.Wrap(err, "could not parse PPSID")