From f8fd59516bf3e5a31a06bc76a4c9eb3d2eeda9e2 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 19:05:06 +0900 Subject: [PATCH] Fix decoding of deep recursive structure --- decode.go | 13 +++++---- decode_anonymous_field.go | 8 +++--- decode_array.go | 22 +++++++++++---- decode_bool.go | 4 +-- decode_bytes.go | 16 +++++------ decode_context.go | 34 +++++++++++++++++++---- decode_float.go | 4 +-- decode_int.go | 4 +-- decode_interface.go | 58 +++++++++++++++++++-------------------- decode_map.go | 22 +++++++++++---- decode_number.go | 4 +-- decode_ptr.go | 8 +++--- decode_slice.go | 18 +++++++++--- decode_stream.go | 34 +++++++++++++++++++---- decode_string.go | 4 +-- decode_struct.go | 21 ++++++++++---- decode_test.go | 38 +++++++++++++++++-------- decode_uint.go | 4 +-- decode_unmarshal_json.go | 8 +++--- decode_unmarshal_text.go | 8 +++--- decode_wrapped_string.go | 8 +++--- error.go | 7 +++++ 22 files changed, 228 insertions(+), 119 deletions(-) diff --git a/decode.go b/decode.go index c700579..e84570e 100644 --- a/decode.go +++ b/decode.go @@ -15,8 +15,8 @@ func (d Delim) String() string { } type decoder interface { - decode([]byte, int64, unsafe.Pointer) (int64, error) - decodeStream(*stream, unsafe.Pointer) error + decode([]byte, int64, int64, unsafe.Pointer) (int64, error) + decodeStream(*stream, int64, unsafe.Pointer) error } type Decoder struct { @@ -29,7 +29,8 @@ var ( ) const ( - nul = '\000' + nul = '\000' + maxDecodeNestingDepth = 10000 ) func unmarshal(data []byte, v interface{}) error { @@ -45,7 +46,7 @@ func unmarshal(data []byte, v interface{}) error { if err != nil { return err } - if _, err := dec.decode(src, 0, header.ptr); err != nil { + if _, err := dec.decode(src, 0, 0, header.ptr); err != nil { return err } return nil @@ -64,7 +65,7 @@ func unmarshalNoEscape(data []byte, v interface{}) error { if err != nil { return err } - if _, err := dec.decode(src, 0, noescape(header.ptr)); err != nil { + if _, err := dec.decode(src, 0, 0, noescape(header.ptr)); err != nil { return err } return nil @@ -147,7 +148,7 @@ func (d *Decoder) Decode(v interface{}) error { return err } s := d.s - if err := dec.decodeStream(s, header.ptr); err != nil { + if err := dec.decodeStream(s, 0, header.ptr); err != nil { return err } s.reset() diff --git a/decode_anonymous_field.go b/decode_anonymous_field.go index 91c2894..77931f2 100644 --- a/decode_anonymous_field.go +++ b/decode_anonymous_field.go @@ -18,18 +18,18 @@ func newAnonymousFieldDecoder(structType *rtype, offset uintptr, dec decoder) *a } } -func (d *anonymousFieldDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *anonymousFieldDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe_New(d.structType) } p = *(*unsafe.Pointer)(p) - return d.dec.decodeStream(s, unsafe.Pointer(uintptr(p)+d.offset)) + return d.dec.decodeStream(s, depth, unsafe.Pointer(uintptr(p)+d.offset)) } -func (d *anonymousFieldDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *anonymousFieldDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe_New(d.structType) } p = *(*unsafe.Pointer)(p) - return d.dec.decode(buf, cursor, unsafe.Pointer(uintptr(p)+d.offset)) + return d.dec.decode(buf, cursor, depth, unsafe.Pointer(uintptr(p)+d.offset)) } diff --git a/decode_array.go b/decode_array.go index f19e11c..92b9dd9 100644 --- a/decode_array.go +++ b/decode_array.go @@ -27,7 +27,12 @@ func newArrayDecoder(dec decoder, elemType *rtype, alen int, structName, fieldNa } } -func (d *arrayDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *arrayDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + for { switch s.char() { case ' ', '\n', '\t', '\r': @@ -41,11 +46,11 @@ func (d *arrayDecoder) decodeStream(s *stream, p unsafe.Pointer) error { for { s.cursor++ if idx < d.alen { - if err := d.valueDecoder.decodeStream(s, unsafe.Pointer(uintptr(p)+uintptr(idx)*d.size)); err != nil { + if err := d.valueDecoder.decodeStream(s, depth, unsafe.Pointer(uintptr(p)+uintptr(idx)*d.size)); err != nil { return err } } else { - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } } @@ -84,7 +89,12 @@ ERROR: return errUnexpectedEndOfJSON("array", s.totalOffset()) } -func (d *arrayDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *arrayDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } + buflen := int64(len(buf)) for ; cursor < buflen; cursor++ { switch buf[cursor] { @@ -111,13 +121,13 @@ func (d *arrayDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64 for { cursor++ if idx < d.alen { - c, err := d.valueDecoder.decode(buf, cursor, unsafe.Pointer(uintptr(p)+uintptr(idx)*d.size)) + c, err := d.valueDecoder.decode(buf, cursor, depth, unsafe.Pointer(uintptr(p)+uintptr(idx)*d.size)) if err != nil { return 0, err } cursor = c } else { - c, err := skipValue(buf, cursor) + c, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } diff --git a/decode_bool.go b/decode_bool.go index ccb96b0..84a4a3d 100644 --- a/decode_bool.go +++ b/decode_bool.go @@ -61,7 +61,7 @@ func falseBytes(s *stream) error { return nil } -func (d *boolDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *boolDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { s.skipWhiteSpace() for { switch s.char() { @@ -94,7 +94,7 @@ ERROR: return errUnexpectedEndOfJSON("bool", s.totalOffset()) } -func (d *boolDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *boolDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { buflen := int64(len(buf)) cursor = skipWhiteSpace(buf, cursor) switch buf[cursor] { diff --git a/decode_bytes.go b/decode_bytes.go index a82fe06..0babe82 100644 --- a/decode_bytes.go +++ b/decode_bytes.go @@ -35,8 +35,8 @@ func newBytesDecoder(typ *rtype, structName string, fieldName string) *bytesDeco } } -func (d *bytesDecoder) decodeStream(s *stream, p unsafe.Pointer) error { - bytes, err := d.decodeStreamBinary(s, p) +func (d *bytesDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { + bytes, err := d.decodeStreamBinary(s, depth, p) if err != nil { return err } @@ -54,8 +54,8 @@ func (d *bytesDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *bytesDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { - bytes, c, err := d.decodeBinary(buf, cursor, p) +func (d *bytesDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { + bytes, c, err := d.decodeBinary(buf, cursor, depth, p) if err != nil { return 0, err } @@ -94,7 +94,7 @@ ERROR: return nil, errUnexpectedEndOfJSON("[]byte", s.totalOffset()) } -func (d *bytesDecoder) decodeStreamBinary(s *stream, p unsafe.Pointer) ([]byte, error) { +func (d *bytesDecoder) decodeStreamBinary(s *stream, depth int64, p unsafe.Pointer) ([]byte, error) { for { switch s.char() { case ' ', '\n', '\t', '\r': @@ -114,7 +114,7 @@ func (d *bytesDecoder) decodeStreamBinary(s *stream, p unsafe.Pointer) ([]byte, Offset: s.totalOffset(), } } - if err := d.sliceDecoder.decodeStream(s, p); err != nil { + if err := d.sliceDecoder.decodeStream(s, depth, p); err != nil { return nil, err } return nil, nil @@ -128,7 +128,7 @@ func (d *bytesDecoder) decodeStreamBinary(s *stream, p unsafe.Pointer) ([]byte, return nil, errNotAtBeginningOfValue(s.totalOffset()) } -func (d *bytesDecoder) decodeBinary(buf []byte, cursor int64, p unsafe.Pointer) ([]byte, int64, error) { +func (d *bytesDecoder) decodeBinary(buf []byte, cursor, depth int64, p unsafe.Pointer) ([]byte, int64, error) { for { switch buf[cursor] { case ' ', '\n', '\t', '\r': @@ -154,7 +154,7 @@ func (d *bytesDecoder) decodeBinary(buf []byte, cursor int64, p unsafe.Pointer) Offset: cursor, } } - c, err := d.sliceDecoder.decode(buf, cursor, p) + c, err := d.sliceDecoder.decode(buf, cursor, depth, p) if err != nil { return nil, 0, err } diff --git a/decode_context.go b/decode_context.go index a4ebaa5..9f87b05 100644 --- a/decode_context.go +++ b/decode_context.go @@ -28,17 +28,29 @@ LOOP: return cursor } -func skipObject(buf []byte, cursor int64) (int64, error) { +func skipObject(buf []byte, cursor, depth int64) (int64, error) { braceCount := 1 for { switch buf[cursor] { case '{': braceCount++ + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } case '}': + depth-- braceCount-- if braceCount == 0 { return cursor + 1, nil } + case '[': + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } + case ']': + depth-- case '"': for { cursor++ @@ -60,17 +72,29 @@ func skipObject(buf []byte, cursor int64) (int64, error) { } } -func skipArray(buf []byte, cursor int64) (int64, error) { +func skipArray(buf []byte, cursor, depth int64) (int64, error) { bracketCount := 1 for { switch buf[cursor] { case '[': bracketCount++ + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } case ']': bracketCount-- + depth-- if bracketCount == 0 { return cursor + 1, nil } + case '{': + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } + case '}': + depth-- case '"': for { cursor++ @@ -92,16 +116,16 @@ func skipArray(buf []byte, cursor int64) (int64, error) { } } -func skipValue(buf []byte, cursor int64) (int64, error) { +func skipValue(buf []byte, cursor, depth int64) (int64, error) { for { switch buf[cursor] { case ' ', '\t', '\n', '\r': cursor++ continue case '{': - return skipObject(buf, cursor+1) + return skipObject(buf, cursor+1, depth+1) case '[': - return skipArray(buf, cursor+1) + return skipArray(buf, cursor+1, depth+1) case '"': for { cursor++ diff --git a/decode_float.go b/decode_float.go index f818b1f..cda8ec5 100644 --- a/decode_float.go +++ b/decode_float.go @@ -129,7 +129,7 @@ func (d *floatDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, erro return nil, 0, errUnexpectedEndOfJSON("float", cursor) } -func (d *floatDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *floatDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.decodeStreamByte(s) if err != nil { return err @@ -146,7 +146,7 @@ func (d *floatDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *floatDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *floatDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.decodeByte(buf, cursor) if err != nil { return 0, err diff --git a/decode_int.go b/decode_int.go index 394e691..4ea6fa5 100644 --- a/decode_int.go +++ b/decode_int.go @@ -173,7 +173,7 @@ func (d *intDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error) } } -func (d *intDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *intDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.decodeStreamByte(s) if err != nil { return err @@ -201,7 +201,7 @@ func (d *intDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *intDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *intDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.decodeByte(buf, cursor) if err != nil { return 0, err diff --git a/decode_interface.go b/decode_interface.go index dedead6..20f5ad5 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -42,9 +42,9 @@ var ( ) ) -func decodeStreamUnmarshaler(s *stream, unmarshaler Unmarshaler) error { +func decodeStreamUnmarshaler(s *stream, depth int64, unmarshaler Unmarshaler) error { start := s.cursor - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } src := s.buf[start:s.cursor] @@ -57,10 +57,10 @@ func decodeStreamUnmarshaler(s *stream, unmarshaler Unmarshaler) error { return nil } -func decodeUnmarshaler(buf []byte, cursor int64, unmarshaler Unmarshaler) (int64, error) { +func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler Unmarshaler) (int64, error) { cursor = skipWhiteSpace(buf, cursor) start := cursor - end, err := skipValue(buf, cursor) + end, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } @@ -74,9 +74,9 @@ func decodeUnmarshaler(buf []byte, cursor int64, unmarshaler Unmarshaler) (int64 return end, nil } -func decodeStreamTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error { +func decodeStreamTextUnmarshaler(s *stream, depth int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error { start := s.cursor - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } src := s.buf[start:s.cursor] @@ -94,10 +94,10 @@ func decodeStreamTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler return nil } -func decodeTextUnmarshaler(buf []byte, cursor int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) (int64, error) { +func decodeTextUnmarshaler(buf []byte, cursor, depth int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) (int64, error) { cursor = skipWhiteSpace(buf, cursor) start := cursor - end, err := skipValue(buf, cursor) + end, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } @@ -115,7 +115,7 @@ func decodeTextUnmarshaler(buf []byte, cursor int64, unmarshaler encoding.TextUn return end, nil } -func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointer) error { +func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, depth int64, p unsafe.Pointer) error { s.skipWhiteSpace() for { switch s.char() { @@ -130,7 +130,7 @@ func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointe newInterfaceDecoder(emptyInterfaceType, d.structName, d.fieldName), d.structName, d.fieldName, - ).decodeStream(s, ptr); err != nil { + ).decodeStream(s, depth, ptr); err != nil { return err } *(*interface{})(p) = v @@ -144,13 +144,13 @@ func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointe emptyInterfaceType.Size(), d.structName, d.fieldName, - ).decodeStream(s, ptr); err != nil { + ).decodeStream(s, depth, ptr); err != nil { return err } *(*interface{})(p) = v return nil case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': - return d.numDecoder(s).decodeStream(s, p) + return d.numDecoder(s).decodeStream(s, depth, p) case '"': s.cursor++ start := s.cursor @@ -201,7 +201,7 @@ func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointe return errNotAtBeginningOfValue(s.totalOffset()) } -func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *interfaceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { runtimeInterfaceValue := *(*interface{})(unsafe.Pointer(&interfaceHeader{ typ: d.typ, ptr: p, @@ -209,10 +209,10 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { rv := reflect.ValueOf(runtimeInterfaceValue) if rv.NumMethod() > 0 && rv.CanInterface() { if u, ok := rv.Interface().(Unmarshaler); ok { - return decodeStreamUnmarshaler(s, u) + return decodeStreamUnmarshaler(s, depth, u) } if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { - return decodeStreamTextUnmarshaler(s, u, p) + return decodeStreamTextUnmarshaler(s, depth, u, p) } return d.errUnmarshalType(rv.Type(), s.totalOffset()) } @@ -221,10 +221,10 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { typ := ifaceHeader.typ if ifaceHeader.ptr == nil || d.typ == typ || typ == nil { // concrete type is empty interface - return d.decodeStreamEmptyInterface(s, p) + return d.decodeStreamEmptyInterface(s, depth, p) } if typ.Kind() == reflect.Ptr && typ.Elem() == d.typ || typ.Kind() != reflect.Ptr { - return d.decodeStreamEmptyInterface(s, p) + return d.decodeStreamEmptyInterface(s, depth, p) } s.skipWhiteSpace() if s.char() == 'n' { @@ -238,7 +238,7 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } - return decoder.decodeStream(s, ifaceHeader.ptr) + return decoder.decodeStream(s, depth, ifaceHeader.ptr) } func (d *interfaceDecoder) errUnmarshalType(typ reflect.Type, offset int64) *UnmarshalTypeError { @@ -251,7 +251,7 @@ func (d *interfaceDecoder) errUnmarshalType(typ reflect.Type, offset int64) *Unm } } -func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *interfaceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { runtimeInterfaceValue := *(*interface{})(unsafe.Pointer(&interfaceHeader{ typ: d.typ, ptr: p, @@ -259,10 +259,10 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i rv := reflect.ValueOf(runtimeInterfaceValue) if rv.NumMethod() > 0 && rv.CanInterface() { if u, ok := rv.Interface().(Unmarshaler); ok { - return decodeUnmarshaler(buf, cursor, u) + return decodeUnmarshaler(buf, cursor, depth, u) } if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { - return decodeTextUnmarshaler(buf, cursor, u, p) + return decodeTextUnmarshaler(buf, cursor, depth, u, p) } return 0, d.errUnmarshalType(rv.Type(), cursor) } @@ -272,10 +272,10 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i typ := ifaceHeader.typ if ifaceHeader.ptr == nil || d.typ == typ || typ == nil { // concrete type is empty interface - return d.decodeEmptyInterface(buf, cursor, p) + return d.decodeEmptyInterface(buf, cursor, depth, p) } if typ.Kind() == reflect.Ptr && typ.Elem() == d.typ || typ.Kind() != reflect.Ptr { - return d.decodeEmptyInterface(buf, cursor, p) + return d.decodeEmptyInterface(buf, cursor, depth, p) } cursor = skipWhiteSpace(buf, cursor) if buf[cursor] == 'n' { @@ -299,10 +299,10 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i if err != nil { return 0, err } - return decoder.decode(buf, cursor, ifaceHeader.ptr) + return decoder.decode(buf, cursor, depth, ifaceHeader.ptr) } -func (d *interfaceDecoder) decodeEmptyInterface(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *interfaceDecoder) decodeEmptyInterface(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { cursor = skipWhiteSpace(buf, cursor) switch buf[cursor] { case '{': @@ -316,7 +316,7 @@ func (d *interfaceDecoder) decodeEmptyInterface(buf []byte, cursor int64, p unsa newInterfaceDecoder(emptyInterfaceType, d.structName, d.fieldName), d.structName, d.fieldName, ) - cursor, err := dec.decode(buf, cursor, ptr) + cursor, err := dec.decode(buf, cursor, depth, ptr) if err != nil { return 0, err } @@ -331,7 +331,7 @@ func (d *interfaceDecoder) decodeEmptyInterface(buf []byte, cursor int64, p unsa emptyInterfaceType.Size(), d.structName, d.fieldName, ) - cursor, err := dec.decode(buf, cursor, ptr) + cursor, err := dec.decode(buf, cursor, depth, ptr) if err != nil { return 0, err } @@ -340,12 +340,12 @@ func (d *interfaceDecoder) decodeEmptyInterface(buf []byte, cursor int64, p unsa case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': return newFloatDecoder(d.structName, d.fieldName, func(p unsafe.Pointer, v float64) { *(*interface{})(p) = v - }).decode(buf, cursor, p) + }).decode(buf, cursor, depth, p) case '"': var v string ptr := unsafe.Pointer(&v) dec := newStringDecoder(d.structName, d.fieldName) - cursor, err := dec.decode(buf, cursor, ptr) + cursor, err := dec.decode(buf, cursor, depth, ptr) if err != nil { return 0, err } diff --git a/decode_map.go b/decode_map.go index 095b2ba..c09e2a2 100644 --- a/decode_map.go +++ b/decode_map.go @@ -33,7 +33,12 @@ func makemap(*rtype, int) unsafe.Pointer //go:noescape func mapassign(t *rtype, m unsafe.Pointer, key, val unsafe.Pointer) -func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *mapDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + s.skipWhiteSpace() switch s.char() { case 'n': @@ -59,7 +64,7 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { for { s.cursor++ k := unsafe_New(d.keyType) - if err := d.keyDecoder.decodeStream(s, k); err != nil { + if err := d.keyDecoder.decodeStream(s, depth, k); err != nil { return err } s.skipWhiteSpace() @@ -71,7 +76,7 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } s.cursor++ v := unsafe_New(d.valueType) - if err := d.valueDecoder.decodeStream(s, v); err != nil { + if err := d.valueDecoder.decodeStream(s, depth, v); err != nil { return err } mapassign(d.mapType, mapValue, k, v) @@ -90,7 +95,12 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } } -func (d *mapDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *mapDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } + cursor = skipWhiteSpace(buf, cursor) buflen := int64(len(buf)) if buflen < 2 { @@ -130,7 +140,7 @@ func (d *mapDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, } for { k := unsafe_New(d.keyType) - keyCursor, err := d.keyDecoder.decode(buf, cursor, k) + keyCursor, err := d.keyDecoder.decode(buf, cursor, depth, k) if err != nil { return 0, err } @@ -140,7 +150,7 @@ func (d *mapDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, } cursor++ v := unsafe_New(d.valueType) - valueCursor, err := d.valueDecoder.decode(buf, cursor, v) + valueCursor, err := d.valueDecoder.decode(buf, cursor, depth, v) if err != nil { return 0, err } diff --git a/decode_number.go b/decode_number.go index cf36979..bf358cb 100644 --- a/decode_number.go +++ b/decode_number.go @@ -20,7 +20,7 @@ func newNumberDecoder(structName, fieldName string, op func(unsafe.Pointer, Numb } } -func (d *numberDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *numberDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.floatDecoder.decodeStreamByte(s) if err != nil { return err @@ -30,7 +30,7 @@ func (d *numberDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *numberDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *numberDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.floatDecoder.decodeByte(buf, cursor) if err != nil { return 0, err diff --git a/decode_ptr.go b/decode_ptr.go index 5d90366..ac4af6f 100644 --- a/decode_ptr.go +++ b/decode_ptr.go @@ -32,7 +32,7 @@ func (d *ptrDecoder) contentDecoder() decoder { //go:linkname unsafe_New reflect.unsafe_New func unsafe_New(*rtype) unsafe.Pointer -func (d *ptrDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *ptrDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { s.skipWhiteSpace() if s.char() == nul { s.read() @@ -51,13 +51,13 @@ func (d *ptrDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } else { newptr = *(*unsafe.Pointer)(p) } - if err := d.dec.decodeStream(s, newptr); err != nil { + if err := d.dec.decodeStream(s, depth, newptr); err != nil { return err } return nil } -func (d *ptrDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *ptrDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { cursor = skipWhiteSpace(buf, cursor) if buf[cursor] == 'n' { buflen := int64(len(buf)) @@ -86,7 +86,7 @@ func (d *ptrDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, } else { newptr = *(*unsafe.Pointer)(p) } - c, err := d.dec.decode(buf, cursor, newptr) + c, err := d.dec.decode(buf, cursor, depth, newptr) if err != nil { return 0, err } diff --git a/decode_slice.go b/decode_slice.go index 3717d1c..0442a75 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -73,7 +73,12 @@ func (d *sliceDecoder) errNumber(offset int64) *UnmarshalTypeError { } } -func (d *sliceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + for { switch s.char() { case ' ', '\n', '\t', '\r': @@ -109,7 +114,7 @@ func (d *sliceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { dst := sliceHeader{data: data, len: idx, cap: capacity} copySlice(d.elemType, dst, src) } - if err := d.valueDecoder.decodeStream(s, unsafe.Pointer(uintptr(data)+uintptr(idx)*d.size)); err != nil { + if err := d.valueDecoder.decodeStream(s, depth, unsafe.Pointer(uintptr(data)+uintptr(idx)*d.size)); err != nil { return err } s.skipWhiteSpace() @@ -167,7 +172,12 @@ ERROR: return errUnexpectedEndOfJSON("slice", s.totalOffset()) } -func (d *sliceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } + buflen := int64(len(buf)) for ; cursor < buflen; cursor++ { switch buf[cursor] { @@ -214,7 +224,7 @@ func (d *sliceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64 dst := sliceHeader{data: data, len: idx, cap: capacity} copySlice(d.elemType, dst, src) } - c, err := d.valueDecoder.decode(buf, cursor, unsafe.Pointer(uintptr(data)+uintptr(idx)*d.size)) + c, err := d.valueDecoder.decode(buf, cursor, depth, unsafe.Pointer(uintptr(data)+uintptr(idx)*d.size)) if err != nil { return 0, err } diff --git a/decode_stream.go b/decode_stream.go index 05ba7fc..6019aaa 100644 --- a/decode_stream.go +++ b/decode_stream.go @@ -97,19 +97,31 @@ LOOP: } } -func (s *stream) skipObject() error { +func (s *stream) skipObject(depth int64) error { braceCount := 1 _, cursor, p := s.stat() for { switch char(p, cursor) { case '{': braceCount++ + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } case '}': braceCount-- + depth-- if braceCount == 0 { s.cursor = cursor + 1 return nil } + case '[': + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + case ']': + depth-- case '"': for { cursor++ @@ -142,19 +154,31 @@ func (s *stream) skipObject() error { } } -func (s *stream) skipArray() error { +func (s *stream) skipArray(depth int64) error { bracketCount := 1 _, cursor, p := s.stat() for { switch char(p, cursor) { case '[': bracketCount++ + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } case ']': bracketCount-- + depth-- if bracketCount == 0 { s.cursor = cursor + 1 return nil } + case '{': + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + case '}': + depth-- case '"': for { cursor++ @@ -187,7 +211,7 @@ func (s *stream) skipArray() error { } } -func (s *stream) skipValue() error { +func (s *stream) skipValue(depth int64) error { _, cursor, p := s.stat() for { switch char(p, cursor) { @@ -203,10 +227,10 @@ func (s *stream) skipValue() error { return errUnexpectedEndOfJSON("value of object", s.totalOffset()) case '{': s.cursor = cursor + 1 - return s.skipObject() + return s.skipObject(depth + 1) case '[': s.cursor = cursor + 1 - return s.skipArray() + return s.skipArray(depth + 1) case '"': for { cursor++ diff --git a/decode_string.go b/decode_string.go index f671f97..09c1e30 100644 --- a/decode_string.go +++ b/decode_string.go @@ -30,7 +30,7 @@ func (d *stringDecoder) errUnmarshalType(typeName string, offset int64) *Unmarsh } } -func (d *stringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *stringDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.decodeStreamByte(s) if err != nil { return err @@ -43,7 +43,7 @@ func (d *stringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *stringDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *stringDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.decodeByte(buf, cursor) if err != nil { return 0, err diff --git a/decode_struct.go b/decode_struct.go index 4264620..4c79d6f 100644 --- a/decode_struct.go +++ b/decode_struct.go @@ -487,7 +487,12 @@ func decodeKeyStream(d *structDecoder, s *stream) (*structFieldSet, string, erro return d.fieldMap[k], k, nil } -func (d *structDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *structDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + s.skipWhiteSpace() switch s.char() { case 'n': @@ -528,13 +533,13 @@ func (d *structDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if field.err != nil { return field.err } - if err := field.dec.decodeStream(s, unsafe.Pointer(uintptr(p)+field.offset)); err != nil { + if err := field.dec.decodeStream(s, depth, unsafe.Pointer(uintptr(p)+field.offset)); err != nil { return err } } else if s.disallowUnknownFields { return fmt.Errorf("json: unknown field %q", key) } else { - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } } @@ -551,7 +556,11 @@ func (d *structDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } } -func (d *structDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *structDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } buflen := int64(len(buf)) cursor = skipWhiteSpace(buf, cursor) b := (*sliceHeader)(unsafe.Pointer(&buf)).data @@ -598,13 +607,13 @@ func (d *structDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int6 if field.err != nil { return 0, field.err } - c, err := field.dec.decode(buf, cursor, unsafe.Pointer(uintptr(p)+field.offset)) + c, err := field.dec.decode(buf, cursor, depth, unsafe.Pointer(uintptr(p)+field.offset)) if err != nil { return 0, err } cursor = c } else { - c, err := skipValue(buf, cursor) + c, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } diff --git a/decode_test.go b/decode_test.go index fc8aa66..ec74b0a 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2796,7 +2796,6 @@ func TestUnmarshalRescanLiteralMangledUnquote(t *testing.T) { } } -/* func TestUnmarshalMaxDepth(t *testing.T) { testcases := []struct { name string @@ -2876,20 +2875,35 @@ func TestUnmarshalMaxDepth(t *testing.T) { for _, tc := range testcases { for _, target := range targets { t.Run(target.name+"-"+tc.name, func(t *testing.T) { - err := json.Unmarshal([]byte(tc.data), target.newValue()) - if !tc.errMaxDepth { - if err != nil { - t.Errorf("unexpected error: %v", err) + t.Run("unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte(tc.data), target.newValue()) + if !tc.errMaxDepth { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } else { + if err == nil { + t.Errorf("expected error containing 'exceeded max depth', got none") + } else if !strings.Contains(err.Error(), "exceeded max depth") { + t.Errorf("expected error containing 'exceeded max depth', got: %v", err) + } } - } else { - if err == nil { - t.Errorf("expected error containing 'exceeded max depth', got none") - } else if !strings.Contains(err.Error(), "exceeded max depth") { - t.Errorf("expected error containing 'exceeded max depth', got: %v", err) + }) + t.Run("stream", func(t *testing.T) { + err := json.NewDecoder(strings.NewReader(tc.data)).Decode(target.newValue()) + if !tc.errMaxDepth { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } else { + if err == nil { + t.Errorf("expected error containing 'exceeded max depth', got none") + } else if !strings.Contains(err.Error(), "exceeded max depth") { + t.Errorf("expected error containing 'exceeded max depth', got: %v", err) + } } - } + }) }) } } } -*/ diff --git a/decode_uint.go b/decode_uint.go index b4c9f1c..4c55bac 100644 --- a/decode_uint.go +++ b/decode_uint.go @@ -127,7 +127,7 @@ func (d *uintDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error return nil, 0, errUnexpectedEndOfJSON("number(unsigned integer)", cursor) } -func (d *uintDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *uintDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.decodeStreamByte(s) if err != nil { return err @@ -154,7 +154,7 @@ func (d *uintDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *uintDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *uintDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.decodeByte(buf, cursor) if err != nil { return 0, err diff --git a/decode_unmarshal_json.go b/decode_unmarshal_json.go index f1095aa..1767c1f 100644 --- a/decode_unmarshal_json.go +++ b/decode_unmarshal_json.go @@ -28,10 +28,10 @@ func (d *unmarshalJSONDecoder) annotateError(cursor int64, err error) { } } -func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *unmarshalJSONDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { s.skipWhiteSpace() start := s.cursor - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } src := s.buf[start:s.cursor] @@ -49,10 +49,10 @@ func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *unmarshalJSONDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *unmarshalJSONDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { cursor = skipWhiteSpace(buf, cursor) start := cursor - end, err := skipValue(buf, cursor) + end, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } diff --git a/decode_unmarshal_text.go b/decode_unmarshal_text.go index 33469a4..7b560af 100644 --- a/decode_unmarshal_text.go +++ b/decode_unmarshal_text.go @@ -37,10 +37,10 @@ var ( nullbytes = []byte(`null`) ) -func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *unmarshalTextDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { s.skipWhiteSpace() start := s.cursor - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } src := s.buf[start:s.cursor] @@ -88,10 +88,10 @@ func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *unmarshalTextDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *unmarshalTextDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { cursor = skipWhiteSpace(buf, cursor) start := cursor - end, err := skipValue(buf, cursor) + end, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } diff --git a/decode_wrapped_string.go b/decode_wrapped_string.go index 223ceed..7f63c59 100644 --- a/decode_wrapped_string.go +++ b/decode_wrapped_string.go @@ -25,7 +25,7 @@ func newWrappedStringDecoder(typ *rtype, dec decoder, structName, fieldName stri } } -func (d *wrappedStringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *wrappedStringDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.stringDecoder.decodeStreamByte(s) if err != nil { return err @@ -38,13 +38,13 @@ func (d *wrappedStringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } b := make([]byte, len(bytes)+1) copy(b, bytes) - if _, err := d.dec.decode(b, 0, p); err != nil { + if _, err := d.dec.decode(b, 0, depth, p); err != nil { return err } return nil } -func (d *wrappedStringDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *wrappedStringDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.stringDecoder.decodeByte(buf, cursor) if err != nil { return 0, err @@ -56,7 +56,7 @@ func (d *wrappedStringDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer return c, nil } bytes = append(bytes, nul) - if _, err := d.dec.decode(bytes, 0, p); err != nil { + if _, err := d.dec.decode(bytes, 0, depth, p); err != nil { return 0, err } return c, nil diff --git a/error.go b/error.go index 1a574ba..71fd94f 100644 --- a/error.go +++ b/error.go @@ -117,6 +117,13 @@ func (e *UnsupportedValueError) Error() string { return fmt.Sprintf("json: unsupported value: %s", e.Str) } +func errExceededMaxDepth(c byte, cursor int64) *SyntaxError { + return &SyntaxError{ + msg: fmt.Sprintf(`invalid character "%c" exceeded max depth`, c), + Offset: cursor, + } +} + func errNotAtBeginningOfValue(cursor int64) *SyntaxError { return &SyntaxError{msg: "not at beginning of value", Offset: cursor} }