From b2a7d22fb47a8ba31eddf047e629f245c32e8ba3 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 11:32:08 +0900 Subject: [PATCH 01/10] Fix not being able to return UnmarshalTypeError when it should be returned --- decode_interface.go | 14 ++++++++++++-- decode_string.go | 4 ++++ decode_test.go | 2 -- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/decode_interface.go b/decode_interface.go index d08f133..dedead6 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -214,7 +214,7 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { return decodeStreamTextUnmarshaler(s, u, p) } - return &UnsupportedTypeError{Type: rv.Type()} + return d.errUnmarshalType(rv.Type(), s.totalOffset()) } iface := rv.Interface() ifaceHeader := (*interfaceHeader)(unsafe.Pointer(&iface)) @@ -241,6 +241,16 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return decoder.decodeStream(s, ifaceHeader.ptr) } +func (d *interfaceDecoder) errUnmarshalType(typ reflect.Type, offset int64) *UnmarshalTypeError { + return &UnmarshalTypeError{ + Value: typ.String(), + Type: typ, + Offset: offset, + Struct: d.structName, + Field: d.fieldName, + } +} + func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { runtimeInterfaceValue := *(*interface{})(unsafe.Pointer(&interfaceHeader{ typ: d.typ, @@ -254,7 +264,7 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { return decodeTextUnmarshaler(buf, cursor, u, p) } - return 0, &UnsupportedTypeError{Type: rv.Type()} + return 0, d.errUnmarshalType(rv.Type(), cursor) } iface := rv.Interface() diff --git a/decode_string.go b/decode_string.go index c09d460..f671f97 100644 --- a/decode_string.go +++ b/decode_string.go @@ -231,6 +231,8 @@ func (d *stringDecoder) decodeStreamByte(s *stream) ([]byte, error) { continue case '[': return nil, d.errUnmarshalType("array", s.totalOffset()) + case '{': + return nil, d.errUnmarshalType("object", s.totalOffset()) case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': return nil, d.errUnmarshalType("number", s.totalOffset()) case '"': @@ -257,6 +259,8 @@ func (d *stringDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, err cursor++ case '[': return nil, 0, d.errUnmarshalType("array", cursor) + case '{': + return nil, 0, d.errUnmarshalType("object", cursor) case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': return nil, 0, d.errUnmarshalType("number", cursor) case '"': diff --git a/decode_test.go b/decode_test.go index da86ee2..c00d428 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2317,7 +2317,6 @@ var decodeTypeErrorTests = []struct { {new(error), `true`}, } -/* func TestUnmarshalTypeError(t *testing.T) { for _, item := range decodeTypeErrorTests { err := json.Unmarshal([]byte(item.src), item.dest) @@ -2327,7 +2326,6 @@ func TestUnmarshalTypeError(t *testing.T) { } } } -*/ var unmarshalSyntaxTests = []string{ "tru", From 91c53cd3f7eb4584c8dd53a350229b127b6271c3 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 12:13:49 +0900 Subject: [PATCH 02/10] Fix decoding of prefilled value --- decode_array.go | 16 ++++++++++++++-- decode_map.go | 10 ++++++++-- decode_test.go | 2 -- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/decode_array.go b/decode_array.go index 9b6336f..f19e11c 100644 --- a/decode_array.go +++ b/decode_array.go @@ -11,9 +11,11 @@ type arrayDecoder struct { alen int structName string fieldName string + zeroValue unsafe.Pointer } func newArrayDecoder(dec decoder, elemType *rtype, alen int, structName, fieldName string) *arrayDecoder { + zeroValue := *(*unsafe.Pointer)(unsafe_New(elemType)) return &arrayDecoder{ valueDecoder: dec, elemType: elemType, @@ -21,6 +23,7 @@ func newArrayDecoder(dec decoder, elemType *rtype, alen int, structName, fieldNa alen: alen, structName: structName, fieldName: fieldName, + zeroValue: zeroValue, } } @@ -46,13 +49,18 @@ func (d *arrayDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return err } } + idx++ s.skipWhiteSpace() switch s.char() { case ']': + for idx < d.alen { + *(*unsafe.Pointer)(unsafe.Pointer(uintptr(p) + uintptr(idx)*d.size)) = d.zeroValue + idx++ + } s.cursor++ return nil case ',': - idx++ + continue case nul: if s.read() { continue @@ -115,13 +123,17 @@ func (d *arrayDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64 } cursor = c } + idx++ cursor = skipWhiteSpace(buf, cursor) switch buf[cursor] { case ']': + for idx < d.alen { + *(*unsafe.Pointer)(unsafe.Pointer(uintptr(p) + uintptr(idx)*d.size)) = d.zeroValue + idx++ + } cursor++ return cursor, nil case ',': - idx++ continue default: return 0, errInvalidCharacter(buf[cursor], "array", cursor) diff --git a/decode_map.go b/decode_map.go index 716d1f1..095b2ba 100644 --- a/decode_map.go +++ b/decode_map.go @@ -47,7 +47,10 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return errExpected("{ character for map value", s.totalOffset()) } s.skipWhiteSpace() - mapValue := makemap(d.mapType, 0) + mapValue := *(*unsafe.Pointer)(p) + if mapValue == nil { + mapValue = makemap(d.mapType, 0) + } if s.buf[s.cursor+1] == '}' { *(*unsafe.Pointer)(p) = mapValue s.cursor += 2 @@ -116,7 +119,10 @@ func (d *mapDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, } cursor++ cursor = skipWhiteSpace(buf, cursor) - mapValue := makemap(d.mapType, 0) + mapValue := *(*unsafe.Pointer)(p) + if mapValue == nil { + mapValue = makemap(d.mapType, 0) + } if buf[cursor] == '}' { **(**unsafe.Pointer)(unsafe.Pointer(&p)) = mapValue cursor++ diff --git a/decode_test.go b/decode_test.go index c00d428..4a4a171 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2412,7 +2412,6 @@ func TestSkipArrayObjects(t *testing.T) { } } -/* // Test semantics of pre-filled data, such as struct fields, map elements, // slices, and arrays. // Issues 4900 and 8837, among others. @@ -2466,7 +2465,6 @@ func TestPrefilled(t *testing.T) { } } } -*/ var invalidUnmarshalTests = []struct { v interface{} From 0288026fdef373c1b3b7830f435be8c47b7702c0 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 16:42:38 +0900 Subject: [PATCH 03/10] Fix decoding of invalid value --- decode.go | 2 +- decode_test.go | 4 --- decode_unmarshal_text.go | 73 ++++++++++++++++++++++++++++------------ 3 files changed, 52 insertions(+), 27 deletions(-) diff --git a/decode.go b/decode.go index 84ad4f0..c700579 100644 --- a/decode.go +++ b/decode.go @@ -78,7 +78,7 @@ func noescape(p unsafe.Pointer) unsafe.Pointer { } func validateType(typ *rtype, p uintptr) error { - if typ.Kind() != reflect.Ptr || p == 0 { + if typ == nil || typ.Kind() != reflect.Ptr || p == 0 { return &InvalidUnmarshalError{Type: rtype2type(typ)} } return nil diff --git a/decode_test.go b/decode_test.go index 4a4a171..742143d 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2475,7 +2475,6 @@ var invalidUnmarshalTests = []struct { {(*int)(nil), "json: Unmarshal(nil *int)"}, } -/* func TestInvalidUnmarshal(t *testing.T) { buf := []byte(`{"a":"1"}`) for _, tt := range invalidUnmarshalTests { @@ -2489,7 +2488,6 @@ func TestInvalidUnmarshal(t *testing.T) { } } } -*/ var invalidUnmarshalTextTests = []struct { v interface{} @@ -2501,7 +2499,6 @@ var invalidUnmarshalTextTests = []struct { {new(net.IP), "json: cannot unmarshal number into Go value of type *net.IP"}, } -/* func TestInvalidUnmarshalText(t *testing.T) { buf := []byte(`123`) for _, tt := range invalidUnmarshalTextTests { @@ -2515,7 +2512,6 @@ func TestInvalidUnmarshalText(t *testing.T) { } } } -*/ /* // Test that string option is ignored for invalid types. diff --git a/decode_unmarshal_text.go b/decode_unmarshal_text.go index 1cde243..33469a4 100644 --- a/decode_unmarshal_text.go +++ b/decode_unmarshal_text.go @@ -44,25 +44,31 @@ func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return err } src := s.buf[start:s.cursor] - switch src[0] { - case '[': - // cannot decode array value by unmarshal text - return &UnmarshalTypeError{ - Value: "array", - Type: rtype2type(d.typ), - Offset: s.totalOffset(), - } - case '{': - // cannot decode object value by unmarshal text - return &UnmarshalTypeError{ - Value: "object", - Type: rtype2type(d.typ), - Offset: s.totalOffset(), - } - case 'n': - if bytes.Equal(src, nullbytes) { - *(*unsafe.Pointer)(p) = nil - return nil + if len(src) > 0 { + switch src[0] { + case '[': + return &UnmarshalTypeError{ + Value: "array", + Type: rtype2type(d.typ), + Offset: s.totalOffset(), + } + case '{': + return &UnmarshalTypeError{ + Value: "object", + Type: rtype2type(d.typ), + Offset: s.totalOffset(), + } + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return &UnmarshalTypeError{ + Value: "number", + Type: rtype2type(d.typ), + Offset: s.totalOffset(), + } + case 'n': + if bytes.Equal(src, nullbytes) { + *(*unsafe.Pointer)(p) = nil + return nil + } } } dst := make([]byte, len(src)) @@ -90,9 +96,32 @@ func (d *unmarshalTextDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer return 0, err } src := buf[start:end] - if bytes.Equal(src, nullbytes) { - *(*unsafe.Pointer)(p) = nil - return end, nil + if len(src) > 0 { + switch src[0] { + case '[': + return 0, &UnmarshalTypeError{ + Value: "array", + Type: rtype2type(d.typ), + Offset: start, + } + case '{': + return 0, &UnmarshalTypeError{ + Value: "object", + Type: rtype2type(d.typ), + Offset: start, + } + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return 0, &UnmarshalTypeError{ + Value: "number", + Type: rtype2type(d.typ), + Offset: start, + } + case 'n': + if bytes.Equal(src, nullbytes) { + *(*unsafe.Pointer)(p) = nil + return end, nil + } + } } if s, ok := unquoteBytes(src); ok { From cf6cf56e3db8077fa37b33d987f9b7a01cffc637 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 16:49:51 +0900 Subject: [PATCH 04/10] Fix invalid test case --- decode_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/decode_test.go b/decode_test.go index 742143d..e5e7da9 100644 --- a/decode_test.go +++ b/decode_test.go @@ -300,7 +300,7 @@ func (u *unmarshalText) UnmarshalText(b []byte) error { func Test_UnmarshalText(t *testing.T) { t.Run("*struct", func(t *testing.T) { var v unmarshalText - assertErr(t, json.Unmarshal([]byte(`11`), &v)) + assertErr(t, json.Unmarshal([]byte(`"11"`), &v)) assertEq(t, "unmarshal", v.v, 11) }) } From 6eb23deb6f2da36c949088e12839fae70d5b0e0c Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 17:46:28 +0900 Subject: [PATCH 05/10] Fix decoding of embedded unexported pointer field --- decode_compile.go | 20 ++++++++++++++++---- decode_struct.go | 7 +++++++ decode_test.go | 6 ++---- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/decode_compile.go b/decode_compile.go index b563d50..da9f1c0 100644 --- a/decode_compile.go +++ b/decode_compile.go @@ -1,8 +1,10 @@ package json import ( + "fmt" "reflect" "strings" + "unicode" "unsafe" ) @@ -247,7 +249,7 @@ func decodeCompileInterface(typ *rtype, structName, fieldName string) (decoder, return newInterfaceDecoder(typ, structName, fieldName), nil } -func decodeRemoveConflictFields(fieldMap map[string]*structFieldSet, conflictedMap map[string]struct{}, dec *structDecoder, baseOffset uintptr) { +func decodeRemoveConflictFields(fieldMap map[string]*structFieldSet, conflictedMap map[string]struct{}, dec *structDecoder, field reflect.StructField) { for k, v := range dec.fieldMap { if _, exists := conflictedMap[k]; exists { // already conflicted key @@ -257,7 +259,7 @@ func decodeRemoveConflictFields(fieldMap map[string]*structFieldSet, conflictedM if !exists { fieldSet := &structFieldSet{ dec: v.dec, - offset: baseOffset + v.offset, + offset: field.Offset + v.offset, isTaggedKey: v.isTaggedKey, key: k, keyLen: int64(len(k)), @@ -281,7 +283,7 @@ func decodeRemoveConflictFields(fieldMap map[string]*structFieldSet, conflictedM if v.isTaggedKey { fieldSet := &structFieldSet{ dec: v.dec, - offset: baseOffset + v.offset, + offset: field.Offset + v.offset, isTaggedKey: v.isTaggedKey, key: k, keyLen: int64(len(k)), @@ -318,6 +320,7 @@ func decodeCompileStruct(typ *rtype, structName, fieldName string, structTypeToD if isIgnoredStructField(field) { continue } + isUnexportedField := unicode.IsLower([]rune(field.Name)[0]) tag := structTagFromField(field) dec, err := decodeCompile(type2rtype(field.Type), structName, field.Name, structTypeToDecoder) if err != nil { @@ -329,13 +332,20 @@ func decodeCompileStruct(typ *rtype, structName, fieldName string, structTypeToD // recursive definition continue } - decodeRemoveConflictFields(fieldMap, conflictedMap, stDec, field.Offset) + decodeRemoveConflictFields(fieldMap, conflictedMap, stDec, field) } else if pdec, ok := dec.(*ptrDecoder); ok { contentDec := pdec.contentDecoder() if pdec.typ == typ { // recursive definition continue } + var fieldSetErr error + if isUnexportedField { + fieldSetErr = fmt.Errorf( + "json: cannot set embedded pointer to unexported struct: %v", + field.Type.Elem(), + ) + } if dec, ok := contentDec.(*structDecoder); ok { for k, v := range dec.fieldMap { if _, exists := conflictedMap[k]; exists { @@ -350,6 +360,7 @@ func decodeCompileStruct(typ *rtype, structName, fieldName string, structTypeToD isTaggedKey: v.isTaggedKey, key: k, keyLen: int64(len(k)), + err: fieldSetErr, } fieldMap[k] = fieldSet lower := strings.ToLower(k) @@ -374,6 +385,7 @@ func decodeCompileStruct(typ *rtype, structName, fieldName string, structTypeToD isTaggedKey: v.isTaggedKey, key: k, keyLen: int64(len(k)), + err: fieldSetErr, } fieldMap[k] = fieldSet lower := strings.ToLower(k) diff --git a/decode_struct.go b/decode_struct.go index f1e347a..4264620 100644 --- a/decode_struct.go +++ b/decode_struct.go @@ -15,6 +15,7 @@ type structFieldSet struct { isTaggedKey bool key string keyLen int64 + err error } type structDecoder struct { @@ -524,6 +525,9 @@ func (d *structDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } } if field != nil { + if field.err != nil { + return field.err + } if err := field.dec.decodeStream(s, unsafe.Pointer(uintptr(p)+field.offset)); err != nil { return err } @@ -591,6 +595,9 @@ func (d *structDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int6 return 0, errExpected("object value after colon", cursor) } if field != nil { + if field.err != nil { + return 0, field.err + } c, err := field.dec.decode(buf, cursor, unsafe.Pointer(uintptr(p)+field.offset)) if err != nil { return 0, err diff --git a/decode_test.go b/decode_test.go index e5e7da9..26c91fe 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2539,7 +2539,6 @@ func TestInvalidStringOption(t *testing.T) { } */ -/* // Test unmarshal behavior with regards to embedded unexported structs. // // (Issue 21357) If the embedded struct is a pointer and is unallocated, @@ -2604,7 +2603,7 @@ func TestUnmarshalEmbeddedUnexported(t *testing.T) { in: `{"R":2,"Q":1}`, ptr: new(S1), out: &S1{R: 2}, - err: fmt.Errorf("json: cannot set embedded pointer to unexported struct: json.embed1"), + err: fmt.Errorf("json: cannot set embedded pointer to unexported struct: json_test.embed1"), }, { // The top level Q field takes precedence. in: `{"Q":1}`, @@ -2626,7 +2625,7 @@ func TestUnmarshalEmbeddedUnexported(t *testing.T) { in: `{"R":2,"Q":1}`, ptr: new(S5), out: &S5{R: 2}, - err: fmt.Errorf("json: cannot set embedded pointer to unexported struct: json.embed3"), + err: fmt.Errorf("json: cannot set embedded pointer to unexported struct: json_test.embed3"), }, { // Issue 24152, ensure decodeState.indirect does not panic. in: `{"embed1": {"Q": 1}}`, @@ -2670,7 +2669,6 @@ func TestUnmarshalEmbeddedUnexported(t *testing.T) { } } } -*/ /* func TestUnmarshalErrorAfterMultipleJSON(t *testing.T) { From 35eee537d4d8f403b746b768625674814befea4c Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 17:59:29 +0900 Subject: [PATCH 06/10] Add test case --- decode_test.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/decode_test.go b/decode_test.go index 26c91fe..fc8aa66 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2670,26 +2670,25 @@ func TestUnmarshalEmbeddedUnexported(t *testing.T) { } } -/* func TestUnmarshalErrorAfterMultipleJSON(t *testing.T) { tests := []struct { in string err error }{{ in: `1 false null :`, - err: json.NewSyntaxError("invalid character ':' looking for beginning of value", 14), + err: json.NewSyntaxError("not at beginning of value", 14), }, { in: `1 [] [,]`, - err: json.NewSyntaxError("invalid character ',' looking for beginning of value", 7), + err: json.NewSyntaxError("not at beginning of value", 6), }, { in: `1 [] [true:]`, - err: json.NewSyntaxError("invalid character ':' after array element", 11), + err: json.NewSyntaxError("json: slice unexpected end of JSON input", 10), }, { in: `1 {} {"x"=}`, - err: json.NewSyntaxError("invalid character '=' after object key", 14), + err: json.NewSyntaxError("expected colon after object key", 13), }, { in: `falsetruenul#`, - err: json.NewSyntaxError("invalid character '#' in literal null (expecting 'l')", 13), + err: json.NewSyntaxError("json: invalid character # as null", 12), }} for i, tt := range tests { dec := json.NewDecoder(strings.NewReader(tt.in)) @@ -2705,7 +2704,6 @@ func TestUnmarshalErrorAfterMultipleJSON(t *testing.T) { } } } -*/ type unmarshalPanic struct{} From f8fd59516bf3e5a31a06bc76a4c9eb3d2eeda9e2 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 19:05:06 +0900 Subject: [PATCH 07/10] Fix decoding of deep recursive structure --- decode.go | 13 +++++---- decode_anonymous_field.go | 8 +++--- decode_array.go | 22 +++++++++++---- decode_bool.go | 4 +-- decode_bytes.go | 16 +++++------ decode_context.go | 34 +++++++++++++++++++---- decode_float.go | 4 +-- decode_int.go | 4 +-- decode_interface.go | 58 +++++++++++++++++++-------------------- decode_map.go | 22 +++++++++++---- decode_number.go | 4 +-- decode_ptr.go | 8 +++--- decode_slice.go | 18 +++++++++--- decode_stream.go | 34 +++++++++++++++++++---- decode_string.go | 4 +-- decode_struct.go | 21 ++++++++++---- decode_test.go | 38 +++++++++++++++++-------- decode_uint.go | 4 +-- decode_unmarshal_json.go | 8 +++--- decode_unmarshal_text.go | 8 +++--- decode_wrapped_string.go | 8 +++--- error.go | 7 +++++ 22 files changed, 228 insertions(+), 119 deletions(-) diff --git a/decode.go b/decode.go index c700579..e84570e 100644 --- a/decode.go +++ b/decode.go @@ -15,8 +15,8 @@ func (d Delim) String() string { } type decoder interface { - decode([]byte, int64, unsafe.Pointer) (int64, error) - decodeStream(*stream, unsafe.Pointer) error + decode([]byte, int64, int64, unsafe.Pointer) (int64, error) + decodeStream(*stream, int64, unsafe.Pointer) error } type Decoder struct { @@ -29,7 +29,8 @@ var ( ) const ( - nul = '\000' + nul = '\000' + maxDecodeNestingDepth = 10000 ) func unmarshal(data []byte, v interface{}) error { @@ -45,7 +46,7 @@ func unmarshal(data []byte, v interface{}) error { if err != nil { return err } - if _, err := dec.decode(src, 0, header.ptr); err != nil { + if _, err := dec.decode(src, 0, 0, header.ptr); err != nil { return err } return nil @@ -64,7 +65,7 @@ func unmarshalNoEscape(data []byte, v interface{}) error { if err != nil { return err } - if _, err := dec.decode(src, 0, noescape(header.ptr)); err != nil { + if _, err := dec.decode(src, 0, 0, noescape(header.ptr)); err != nil { return err } return nil @@ -147,7 +148,7 @@ func (d *Decoder) Decode(v interface{}) error { return err } s := d.s - if err := dec.decodeStream(s, header.ptr); err != nil { + if err := dec.decodeStream(s, 0, header.ptr); err != nil { return err } s.reset() diff --git a/decode_anonymous_field.go b/decode_anonymous_field.go index 91c2894..77931f2 100644 --- a/decode_anonymous_field.go +++ b/decode_anonymous_field.go @@ -18,18 +18,18 @@ func newAnonymousFieldDecoder(structType *rtype, offset uintptr, dec decoder) *a } } -func (d *anonymousFieldDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *anonymousFieldDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe_New(d.structType) } p = *(*unsafe.Pointer)(p) - return d.dec.decodeStream(s, unsafe.Pointer(uintptr(p)+d.offset)) + return d.dec.decodeStream(s, depth, unsafe.Pointer(uintptr(p)+d.offset)) } -func (d *anonymousFieldDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *anonymousFieldDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe_New(d.structType) } p = *(*unsafe.Pointer)(p) - return d.dec.decode(buf, cursor, unsafe.Pointer(uintptr(p)+d.offset)) + return d.dec.decode(buf, cursor, depth, unsafe.Pointer(uintptr(p)+d.offset)) } diff --git a/decode_array.go b/decode_array.go index f19e11c..92b9dd9 100644 --- a/decode_array.go +++ b/decode_array.go @@ -27,7 +27,12 @@ func newArrayDecoder(dec decoder, elemType *rtype, alen int, structName, fieldNa } } -func (d *arrayDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *arrayDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + for { switch s.char() { case ' ', '\n', '\t', '\r': @@ -41,11 +46,11 @@ func (d *arrayDecoder) decodeStream(s *stream, p unsafe.Pointer) error { for { s.cursor++ if idx < d.alen { - if err := d.valueDecoder.decodeStream(s, unsafe.Pointer(uintptr(p)+uintptr(idx)*d.size)); err != nil { + if err := d.valueDecoder.decodeStream(s, depth, unsafe.Pointer(uintptr(p)+uintptr(idx)*d.size)); err != nil { return err } } else { - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } } @@ -84,7 +89,12 @@ ERROR: return errUnexpectedEndOfJSON("array", s.totalOffset()) } -func (d *arrayDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *arrayDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } + buflen := int64(len(buf)) for ; cursor < buflen; cursor++ { switch buf[cursor] { @@ -111,13 +121,13 @@ func (d *arrayDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64 for { cursor++ if idx < d.alen { - c, err := d.valueDecoder.decode(buf, cursor, unsafe.Pointer(uintptr(p)+uintptr(idx)*d.size)) + c, err := d.valueDecoder.decode(buf, cursor, depth, unsafe.Pointer(uintptr(p)+uintptr(idx)*d.size)) if err != nil { return 0, err } cursor = c } else { - c, err := skipValue(buf, cursor) + c, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } diff --git a/decode_bool.go b/decode_bool.go index ccb96b0..84a4a3d 100644 --- a/decode_bool.go +++ b/decode_bool.go @@ -61,7 +61,7 @@ func falseBytes(s *stream) error { return nil } -func (d *boolDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *boolDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { s.skipWhiteSpace() for { switch s.char() { @@ -94,7 +94,7 @@ ERROR: return errUnexpectedEndOfJSON("bool", s.totalOffset()) } -func (d *boolDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *boolDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { buflen := int64(len(buf)) cursor = skipWhiteSpace(buf, cursor) switch buf[cursor] { diff --git a/decode_bytes.go b/decode_bytes.go index a82fe06..0babe82 100644 --- a/decode_bytes.go +++ b/decode_bytes.go @@ -35,8 +35,8 @@ func newBytesDecoder(typ *rtype, structName string, fieldName string) *bytesDeco } } -func (d *bytesDecoder) decodeStream(s *stream, p unsafe.Pointer) error { - bytes, err := d.decodeStreamBinary(s, p) +func (d *bytesDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { + bytes, err := d.decodeStreamBinary(s, depth, p) if err != nil { return err } @@ -54,8 +54,8 @@ func (d *bytesDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *bytesDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { - bytes, c, err := d.decodeBinary(buf, cursor, p) +func (d *bytesDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { + bytes, c, err := d.decodeBinary(buf, cursor, depth, p) if err != nil { return 0, err } @@ -94,7 +94,7 @@ ERROR: return nil, errUnexpectedEndOfJSON("[]byte", s.totalOffset()) } -func (d *bytesDecoder) decodeStreamBinary(s *stream, p unsafe.Pointer) ([]byte, error) { +func (d *bytesDecoder) decodeStreamBinary(s *stream, depth int64, p unsafe.Pointer) ([]byte, error) { for { switch s.char() { case ' ', '\n', '\t', '\r': @@ -114,7 +114,7 @@ func (d *bytesDecoder) decodeStreamBinary(s *stream, p unsafe.Pointer) ([]byte, Offset: s.totalOffset(), } } - if err := d.sliceDecoder.decodeStream(s, p); err != nil { + if err := d.sliceDecoder.decodeStream(s, depth, p); err != nil { return nil, err } return nil, nil @@ -128,7 +128,7 @@ func (d *bytesDecoder) decodeStreamBinary(s *stream, p unsafe.Pointer) ([]byte, return nil, errNotAtBeginningOfValue(s.totalOffset()) } -func (d *bytesDecoder) decodeBinary(buf []byte, cursor int64, p unsafe.Pointer) ([]byte, int64, error) { +func (d *bytesDecoder) decodeBinary(buf []byte, cursor, depth int64, p unsafe.Pointer) ([]byte, int64, error) { for { switch buf[cursor] { case ' ', '\n', '\t', '\r': @@ -154,7 +154,7 @@ func (d *bytesDecoder) decodeBinary(buf []byte, cursor int64, p unsafe.Pointer) Offset: cursor, } } - c, err := d.sliceDecoder.decode(buf, cursor, p) + c, err := d.sliceDecoder.decode(buf, cursor, depth, p) if err != nil { return nil, 0, err } diff --git a/decode_context.go b/decode_context.go index a4ebaa5..9f87b05 100644 --- a/decode_context.go +++ b/decode_context.go @@ -28,17 +28,29 @@ LOOP: return cursor } -func skipObject(buf []byte, cursor int64) (int64, error) { +func skipObject(buf []byte, cursor, depth int64) (int64, error) { braceCount := 1 for { switch buf[cursor] { case '{': braceCount++ + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } case '}': + depth-- braceCount-- if braceCount == 0 { return cursor + 1, nil } + case '[': + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } + case ']': + depth-- case '"': for { cursor++ @@ -60,17 +72,29 @@ func skipObject(buf []byte, cursor int64) (int64, error) { } } -func skipArray(buf []byte, cursor int64) (int64, error) { +func skipArray(buf []byte, cursor, depth int64) (int64, error) { bracketCount := 1 for { switch buf[cursor] { case '[': bracketCount++ + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } case ']': bracketCount-- + depth-- if bracketCount == 0 { return cursor + 1, nil } + case '{': + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } + case '}': + depth-- case '"': for { cursor++ @@ -92,16 +116,16 @@ func skipArray(buf []byte, cursor int64) (int64, error) { } } -func skipValue(buf []byte, cursor int64) (int64, error) { +func skipValue(buf []byte, cursor, depth int64) (int64, error) { for { switch buf[cursor] { case ' ', '\t', '\n', '\r': cursor++ continue case '{': - return skipObject(buf, cursor+1) + return skipObject(buf, cursor+1, depth+1) case '[': - return skipArray(buf, cursor+1) + return skipArray(buf, cursor+1, depth+1) case '"': for { cursor++ diff --git a/decode_float.go b/decode_float.go index f818b1f..cda8ec5 100644 --- a/decode_float.go +++ b/decode_float.go @@ -129,7 +129,7 @@ func (d *floatDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, erro return nil, 0, errUnexpectedEndOfJSON("float", cursor) } -func (d *floatDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *floatDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.decodeStreamByte(s) if err != nil { return err @@ -146,7 +146,7 @@ func (d *floatDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *floatDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *floatDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.decodeByte(buf, cursor) if err != nil { return 0, err diff --git a/decode_int.go b/decode_int.go index 394e691..4ea6fa5 100644 --- a/decode_int.go +++ b/decode_int.go @@ -173,7 +173,7 @@ func (d *intDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error) } } -func (d *intDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *intDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.decodeStreamByte(s) if err != nil { return err @@ -201,7 +201,7 @@ func (d *intDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *intDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *intDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.decodeByte(buf, cursor) if err != nil { return 0, err diff --git a/decode_interface.go b/decode_interface.go index dedead6..20f5ad5 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -42,9 +42,9 @@ var ( ) ) -func decodeStreamUnmarshaler(s *stream, unmarshaler Unmarshaler) error { +func decodeStreamUnmarshaler(s *stream, depth int64, unmarshaler Unmarshaler) error { start := s.cursor - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } src := s.buf[start:s.cursor] @@ -57,10 +57,10 @@ func decodeStreamUnmarshaler(s *stream, unmarshaler Unmarshaler) error { return nil } -func decodeUnmarshaler(buf []byte, cursor int64, unmarshaler Unmarshaler) (int64, error) { +func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler Unmarshaler) (int64, error) { cursor = skipWhiteSpace(buf, cursor) start := cursor - end, err := skipValue(buf, cursor) + end, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } @@ -74,9 +74,9 @@ func decodeUnmarshaler(buf []byte, cursor int64, unmarshaler Unmarshaler) (int64 return end, nil } -func decodeStreamTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error { +func decodeStreamTextUnmarshaler(s *stream, depth int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error { start := s.cursor - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } src := s.buf[start:s.cursor] @@ -94,10 +94,10 @@ func decodeStreamTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler return nil } -func decodeTextUnmarshaler(buf []byte, cursor int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) (int64, error) { +func decodeTextUnmarshaler(buf []byte, cursor, depth int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) (int64, error) { cursor = skipWhiteSpace(buf, cursor) start := cursor - end, err := skipValue(buf, cursor) + end, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } @@ -115,7 +115,7 @@ func decodeTextUnmarshaler(buf []byte, cursor int64, unmarshaler encoding.TextUn return end, nil } -func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointer) error { +func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, depth int64, p unsafe.Pointer) error { s.skipWhiteSpace() for { switch s.char() { @@ -130,7 +130,7 @@ func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointe newInterfaceDecoder(emptyInterfaceType, d.structName, d.fieldName), d.structName, d.fieldName, - ).decodeStream(s, ptr); err != nil { + ).decodeStream(s, depth, ptr); err != nil { return err } *(*interface{})(p) = v @@ -144,13 +144,13 @@ func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointe emptyInterfaceType.Size(), d.structName, d.fieldName, - ).decodeStream(s, ptr); err != nil { + ).decodeStream(s, depth, ptr); err != nil { return err } *(*interface{})(p) = v return nil case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': - return d.numDecoder(s).decodeStream(s, p) + return d.numDecoder(s).decodeStream(s, depth, p) case '"': s.cursor++ start := s.cursor @@ -201,7 +201,7 @@ func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointe return errNotAtBeginningOfValue(s.totalOffset()) } -func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *interfaceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { runtimeInterfaceValue := *(*interface{})(unsafe.Pointer(&interfaceHeader{ typ: d.typ, ptr: p, @@ -209,10 +209,10 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { rv := reflect.ValueOf(runtimeInterfaceValue) if rv.NumMethod() > 0 && rv.CanInterface() { if u, ok := rv.Interface().(Unmarshaler); ok { - return decodeStreamUnmarshaler(s, u) + return decodeStreamUnmarshaler(s, depth, u) } if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { - return decodeStreamTextUnmarshaler(s, u, p) + return decodeStreamTextUnmarshaler(s, depth, u, p) } return d.errUnmarshalType(rv.Type(), s.totalOffset()) } @@ -221,10 +221,10 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { typ := ifaceHeader.typ if ifaceHeader.ptr == nil || d.typ == typ || typ == nil { // concrete type is empty interface - return d.decodeStreamEmptyInterface(s, p) + return d.decodeStreamEmptyInterface(s, depth, p) } if typ.Kind() == reflect.Ptr && typ.Elem() == d.typ || typ.Kind() != reflect.Ptr { - return d.decodeStreamEmptyInterface(s, p) + return d.decodeStreamEmptyInterface(s, depth, p) } s.skipWhiteSpace() if s.char() == 'n' { @@ -238,7 +238,7 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } - return decoder.decodeStream(s, ifaceHeader.ptr) + return decoder.decodeStream(s, depth, ifaceHeader.ptr) } func (d *interfaceDecoder) errUnmarshalType(typ reflect.Type, offset int64) *UnmarshalTypeError { @@ -251,7 +251,7 @@ func (d *interfaceDecoder) errUnmarshalType(typ reflect.Type, offset int64) *Unm } } -func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *interfaceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { runtimeInterfaceValue := *(*interface{})(unsafe.Pointer(&interfaceHeader{ typ: d.typ, ptr: p, @@ -259,10 +259,10 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i rv := reflect.ValueOf(runtimeInterfaceValue) if rv.NumMethod() > 0 && rv.CanInterface() { if u, ok := rv.Interface().(Unmarshaler); ok { - return decodeUnmarshaler(buf, cursor, u) + return decodeUnmarshaler(buf, cursor, depth, u) } if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { - return decodeTextUnmarshaler(buf, cursor, u, p) + return decodeTextUnmarshaler(buf, cursor, depth, u, p) } return 0, d.errUnmarshalType(rv.Type(), cursor) } @@ -272,10 +272,10 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i typ := ifaceHeader.typ if ifaceHeader.ptr == nil || d.typ == typ || typ == nil { // concrete type is empty interface - return d.decodeEmptyInterface(buf, cursor, p) + return d.decodeEmptyInterface(buf, cursor, depth, p) } if typ.Kind() == reflect.Ptr && typ.Elem() == d.typ || typ.Kind() != reflect.Ptr { - return d.decodeEmptyInterface(buf, cursor, p) + return d.decodeEmptyInterface(buf, cursor, depth, p) } cursor = skipWhiteSpace(buf, cursor) if buf[cursor] == 'n' { @@ -299,10 +299,10 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i if err != nil { return 0, err } - return decoder.decode(buf, cursor, ifaceHeader.ptr) + return decoder.decode(buf, cursor, depth, ifaceHeader.ptr) } -func (d *interfaceDecoder) decodeEmptyInterface(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *interfaceDecoder) decodeEmptyInterface(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { cursor = skipWhiteSpace(buf, cursor) switch buf[cursor] { case '{': @@ -316,7 +316,7 @@ func (d *interfaceDecoder) decodeEmptyInterface(buf []byte, cursor int64, p unsa newInterfaceDecoder(emptyInterfaceType, d.structName, d.fieldName), d.structName, d.fieldName, ) - cursor, err := dec.decode(buf, cursor, ptr) + cursor, err := dec.decode(buf, cursor, depth, ptr) if err != nil { return 0, err } @@ -331,7 +331,7 @@ func (d *interfaceDecoder) decodeEmptyInterface(buf []byte, cursor int64, p unsa emptyInterfaceType.Size(), d.structName, d.fieldName, ) - cursor, err := dec.decode(buf, cursor, ptr) + cursor, err := dec.decode(buf, cursor, depth, ptr) if err != nil { return 0, err } @@ -340,12 +340,12 @@ func (d *interfaceDecoder) decodeEmptyInterface(buf []byte, cursor int64, p unsa case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': return newFloatDecoder(d.structName, d.fieldName, func(p unsafe.Pointer, v float64) { *(*interface{})(p) = v - }).decode(buf, cursor, p) + }).decode(buf, cursor, depth, p) case '"': var v string ptr := unsafe.Pointer(&v) dec := newStringDecoder(d.structName, d.fieldName) - cursor, err := dec.decode(buf, cursor, ptr) + cursor, err := dec.decode(buf, cursor, depth, ptr) if err != nil { return 0, err } diff --git a/decode_map.go b/decode_map.go index 095b2ba..c09e2a2 100644 --- a/decode_map.go +++ b/decode_map.go @@ -33,7 +33,12 @@ func makemap(*rtype, int) unsafe.Pointer //go:noescape func mapassign(t *rtype, m unsafe.Pointer, key, val unsafe.Pointer) -func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *mapDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + s.skipWhiteSpace() switch s.char() { case 'n': @@ -59,7 +64,7 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { for { s.cursor++ k := unsafe_New(d.keyType) - if err := d.keyDecoder.decodeStream(s, k); err != nil { + if err := d.keyDecoder.decodeStream(s, depth, k); err != nil { return err } s.skipWhiteSpace() @@ -71,7 +76,7 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } s.cursor++ v := unsafe_New(d.valueType) - if err := d.valueDecoder.decodeStream(s, v); err != nil { + if err := d.valueDecoder.decodeStream(s, depth, v); err != nil { return err } mapassign(d.mapType, mapValue, k, v) @@ -90,7 +95,12 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } } -func (d *mapDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *mapDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } + cursor = skipWhiteSpace(buf, cursor) buflen := int64(len(buf)) if buflen < 2 { @@ -130,7 +140,7 @@ func (d *mapDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, } for { k := unsafe_New(d.keyType) - keyCursor, err := d.keyDecoder.decode(buf, cursor, k) + keyCursor, err := d.keyDecoder.decode(buf, cursor, depth, k) if err != nil { return 0, err } @@ -140,7 +150,7 @@ func (d *mapDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, } cursor++ v := unsafe_New(d.valueType) - valueCursor, err := d.valueDecoder.decode(buf, cursor, v) + valueCursor, err := d.valueDecoder.decode(buf, cursor, depth, v) if err != nil { return 0, err } diff --git a/decode_number.go b/decode_number.go index cf36979..bf358cb 100644 --- a/decode_number.go +++ b/decode_number.go @@ -20,7 +20,7 @@ func newNumberDecoder(structName, fieldName string, op func(unsafe.Pointer, Numb } } -func (d *numberDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *numberDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.floatDecoder.decodeStreamByte(s) if err != nil { return err @@ -30,7 +30,7 @@ func (d *numberDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *numberDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *numberDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.floatDecoder.decodeByte(buf, cursor) if err != nil { return 0, err diff --git a/decode_ptr.go b/decode_ptr.go index 5d90366..ac4af6f 100644 --- a/decode_ptr.go +++ b/decode_ptr.go @@ -32,7 +32,7 @@ func (d *ptrDecoder) contentDecoder() decoder { //go:linkname unsafe_New reflect.unsafe_New func unsafe_New(*rtype) unsafe.Pointer -func (d *ptrDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *ptrDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { s.skipWhiteSpace() if s.char() == nul { s.read() @@ -51,13 +51,13 @@ func (d *ptrDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } else { newptr = *(*unsafe.Pointer)(p) } - if err := d.dec.decodeStream(s, newptr); err != nil { + if err := d.dec.decodeStream(s, depth, newptr); err != nil { return err } return nil } -func (d *ptrDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *ptrDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { cursor = skipWhiteSpace(buf, cursor) if buf[cursor] == 'n' { buflen := int64(len(buf)) @@ -86,7 +86,7 @@ func (d *ptrDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, } else { newptr = *(*unsafe.Pointer)(p) } - c, err := d.dec.decode(buf, cursor, newptr) + c, err := d.dec.decode(buf, cursor, depth, newptr) if err != nil { return 0, err } diff --git a/decode_slice.go b/decode_slice.go index 3717d1c..0442a75 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -73,7 +73,12 @@ func (d *sliceDecoder) errNumber(offset int64) *UnmarshalTypeError { } } -func (d *sliceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *sliceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + for { switch s.char() { case ' ', '\n', '\t', '\r': @@ -109,7 +114,7 @@ func (d *sliceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { dst := sliceHeader{data: data, len: idx, cap: capacity} copySlice(d.elemType, dst, src) } - if err := d.valueDecoder.decodeStream(s, unsafe.Pointer(uintptr(data)+uintptr(idx)*d.size)); err != nil { + if err := d.valueDecoder.decodeStream(s, depth, unsafe.Pointer(uintptr(data)+uintptr(idx)*d.size)); err != nil { return err } s.skipWhiteSpace() @@ -167,7 +172,12 @@ ERROR: return errUnexpectedEndOfJSON("slice", s.totalOffset()) } -func (d *sliceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *sliceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } + buflen := int64(len(buf)) for ; cursor < buflen; cursor++ { switch buf[cursor] { @@ -214,7 +224,7 @@ func (d *sliceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64 dst := sliceHeader{data: data, len: idx, cap: capacity} copySlice(d.elemType, dst, src) } - c, err := d.valueDecoder.decode(buf, cursor, unsafe.Pointer(uintptr(data)+uintptr(idx)*d.size)) + c, err := d.valueDecoder.decode(buf, cursor, depth, unsafe.Pointer(uintptr(data)+uintptr(idx)*d.size)) if err != nil { return 0, err } diff --git a/decode_stream.go b/decode_stream.go index 05ba7fc..6019aaa 100644 --- a/decode_stream.go +++ b/decode_stream.go @@ -97,19 +97,31 @@ LOOP: } } -func (s *stream) skipObject() error { +func (s *stream) skipObject(depth int64) error { braceCount := 1 _, cursor, p := s.stat() for { switch char(p, cursor) { case '{': braceCount++ + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } case '}': braceCount-- + depth-- if braceCount == 0 { s.cursor = cursor + 1 return nil } + case '[': + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + case ']': + depth-- case '"': for { cursor++ @@ -142,19 +154,31 @@ func (s *stream) skipObject() error { } } -func (s *stream) skipArray() error { +func (s *stream) skipArray(depth int64) error { bracketCount := 1 _, cursor, p := s.stat() for { switch char(p, cursor) { case '[': bracketCount++ + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } case ']': bracketCount-- + depth-- if bracketCount == 0 { s.cursor = cursor + 1 return nil } + case '{': + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + case '}': + depth-- case '"': for { cursor++ @@ -187,7 +211,7 @@ func (s *stream) skipArray() error { } } -func (s *stream) skipValue() error { +func (s *stream) skipValue(depth int64) error { _, cursor, p := s.stat() for { switch char(p, cursor) { @@ -203,10 +227,10 @@ func (s *stream) skipValue() error { return errUnexpectedEndOfJSON("value of object", s.totalOffset()) case '{': s.cursor = cursor + 1 - return s.skipObject() + return s.skipObject(depth + 1) case '[': s.cursor = cursor + 1 - return s.skipArray() + return s.skipArray(depth + 1) case '"': for { cursor++ diff --git a/decode_string.go b/decode_string.go index f671f97..09c1e30 100644 --- a/decode_string.go +++ b/decode_string.go @@ -30,7 +30,7 @@ func (d *stringDecoder) errUnmarshalType(typeName string, offset int64) *Unmarsh } } -func (d *stringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *stringDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.decodeStreamByte(s) if err != nil { return err @@ -43,7 +43,7 @@ func (d *stringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *stringDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *stringDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.decodeByte(buf, cursor) if err != nil { return 0, err diff --git a/decode_struct.go b/decode_struct.go index 4264620..4c79d6f 100644 --- a/decode_struct.go +++ b/decode_struct.go @@ -487,7 +487,12 @@ func decodeKeyStream(d *structDecoder, s *stream) (*structFieldSet, string, erro return d.fieldMap[k], k, nil } -func (d *structDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *structDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { + depth++ + if depth > maxDecodeNestingDepth { + return errExceededMaxDepth(s.char(), s.cursor) + } + s.skipWhiteSpace() switch s.char() { case 'n': @@ -528,13 +533,13 @@ func (d *structDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if field.err != nil { return field.err } - if err := field.dec.decodeStream(s, unsafe.Pointer(uintptr(p)+field.offset)); err != nil { + if err := field.dec.decodeStream(s, depth, unsafe.Pointer(uintptr(p)+field.offset)); err != nil { return err } } else if s.disallowUnknownFields { return fmt.Errorf("json: unknown field %q", key) } else { - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } } @@ -551,7 +556,11 @@ func (d *structDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } } -func (d *structDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *structDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { + depth++ + if depth > maxDecodeNestingDepth { + return 0, errExceededMaxDepth(buf[cursor], cursor) + } buflen := int64(len(buf)) cursor = skipWhiteSpace(buf, cursor) b := (*sliceHeader)(unsafe.Pointer(&buf)).data @@ -598,13 +607,13 @@ func (d *structDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int6 if field.err != nil { return 0, field.err } - c, err := field.dec.decode(buf, cursor, unsafe.Pointer(uintptr(p)+field.offset)) + c, err := field.dec.decode(buf, cursor, depth, unsafe.Pointer(uintptr(p)+field.offset)) if err != nil { return 0, err } cursor = c } else { - c, err := skipValue(buf, cursor) + c, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } diff --git a/decode_test.go b/decode_test.go index fc8aa66..ec74b0a 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2796,7 +2796,6 @@ func TestUnmarshalRescanLiteralMangledUnquote(t *testing.T) { } } -/* func TestUnmarshalMaxDepth(t *testing.T) { testcases := []struct { name string @@ -2876,20 +2875,35 @@ func TestUnmarshalMaxDepth(t *testing.T) { for _, tc := range testcases { for _, target := range targets { t.Run(target.name+"-"+tc.name, func(t *testing.T) { - err := json.Unmarshal([]byte(tc.data), target.newValue()) - if !tc.errMaxDepth { - if err != nil { - t.Errorf("unexpected error: %v", err) + t.Run("unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte(tc.data), target.newValue()) + if !tc.errMaxDepth { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } else { + if err == nil { + t.Errorf("expected error containing 'exceeded max depth', got none") + } else if !strings.Contains(err.Error(), "exceeded max depth") { + t.Errorf("expected error containing 'exceeded max depth', got: %v", err) + } } - } else { - if err == nil { - t.Errorf("expected error containing 'exceeded max depth', got none") - } else if !strings.Contains(err.Error(), "exceeded max depth") { - t.Errorf("expected error containing 'exceeded max depth', got: %v", err) + }) + t.Run("stream", func(t *testing.T) { + err := json.NewDecoder(strings.NewReader(tc.data)).Decode(target.newValue()) + if !tc.errMaxDepth { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } else { + if err == nil { + t.Errorf("expected error containing 'exceeded max depth', got none") + } else if !strings.Contains(err.Error(), "exceeded max depth") { + t.Errorf("expected error containing 'exceeded max depth', got: %v", err) + } } - } + }) }) } } } -*/ diff --git a/decode_uint.go b/decode_uint.go index b4c9f1c..4c55bac 100644 --- a/decode_uint.go +++ b/decode_uint.go @@ -127,7 +127,7 @@ func (d *uintDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error return nil, 0, errUnexpectedEndOfJSON("number(unsigned integer)", cursor) } -func (d *uintDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *uintDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.decodeStreamByte(s) if err != nil { return err @@ -154,7 +154,7 @@ func (d *uintDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *uintDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *uintDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.decodeByte(buf, cursor) if err != nil { return 0, err diff --git a/decode_unmarshal_json.go b/decode_unmarshal_json.go index f1095aa..1767c1f 100644 --- a/decode_unmarshal_json.go +++ b/decode_unmarshal_json.go @@ -28,10 +28,10 @@ func (d *unmarshalJSONDecoder) annotateError(cursor int64, err error) { } } -func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *unmarshalJSONDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { s.skipWhiteSpace() start := s.cursor - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } src := s.buf[start:s.cursor] @@ -49,10 +49,10 @@ func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *unmarshalJSONDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *unmarshalJSONDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { cursor = skipWhiteSpace(buf, cursor) start := cursor - end, err := skipValue(buf, cursor) + end, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } diff --git a/decode_unmarshal_text.go b/decode_unmarshal_text.go index 33469a4..7b560af 100644 --- a/decode_unmarshal_text.go +++ b/decode_unmarshal_text.go @@ -37,10 +37,10 @@ var ( nullbytes = []byte(`null`) ) -func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *unmarshalTextDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { s.skipWhiteSpace() start := s.cursor - if err := s.skipValue(); err != nil { + if err := s.skipValue(depth); err != nil { return err } src := s.buf[start:s.cursor] @@ -88,10 +88,10 @@ func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return nil } -func (d *unmarshalTextDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *unmarshalTextDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { cursor = skipWhiteSpace(buf, cursor) start := cursor - end, err := skipValue(buf, cursor) + end, err := skipValue(buf, cursor, depth) if err != nil { return 0, err } diff --git a/decode_wrapped_string.go b/decode_wrapped_string.go index 223ceed..7f63c59 100644 --- a/decode_wrapped_string.go +++ b/decode_wrapped_string.go @@ -25,7 +25,7 @@ func newWrappedStringDecoder(typ *rtype, dec decoder, structName, fieldName stri } } -func (d *wrappedStringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { +func (d *wrappedStringDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer) error { bytes, err := d.stringDecoder.decodeStreamByte(s) if err != nil { return err @@ -38,13 +38,13 @@ func (d *wrappedStringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } b := make([]byte, len(bytes)+1) copy(b, bytes) - if _, err := d.dec.decode(b, 0, p); err != nil { + if _, err := d.dec.decode(b, 0, depth, p); err != nil { return err } return nil } -func (d *wrappedStringDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { +func (d *wrappedStringDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Pointer) (int64, error) { bytes, c, err := d.stringDecoder.decodeByte(buf, cursor) if err != nil { return 0, err @@ -56,7 +56,7 @@ func (d *wrappedStringDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer return c, nil } bytes = append(bytes, nul) - if _, err := d.dec.decode(bytes, 0, p); err != nil { + if _, err := d.dec.decode(bytes, 0, depth, p); err != nil { return 0, err } return c, nil diff --git a/error.go b/error.go index 1a574ba..71fd94f 100644 --- a/error.go +++ b/error.go @@ -117,6 +117,13 @@ func (e *UnsupportedValueError) Error() string { return fmt.Sprintf("json: unsupported value: %s", e.Str) } +func errExceededMaxDepth(c byte, cursor int64) *SyntaxError { + return &SyntaxError{ + msg: fmt.Sprintf(`invalid character "%c" exceeded max depth`, c), + Offset: cursor, + } +} + func errNotAtBeginningOfValue(cursor int64) *SyntaxError { return &SyntaxError{msg: "not at beginning of value", Offset: cursor} } From 23c5766bd229b7208a37953615096901e05cda2b Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 19:11:44 +0900 Subject: [PATCH 08/10] Add test case --- decode_test.go | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/decode_test.go b/decode_test.go index ec74b0a..0d30b9d 100644 --- a/decode_test.go +++ b/decode_test.go @@ -1063,19 +1063,18 @@ var unmarshalTests = []unmarshalTest{ out: []byteWithPtrMarshalJSON{1, 2, 3}, golden: true, }, - /* - { - in: `"AQID"`, // 108 - ptr: new([]byteWithPtrMarshalText), - out: []byteWithPtrMarshalText{1, 2, 3}, - }, - { - in: `["Z01","Z02","Z03"]`, // 109 - ptr: new([]byteWithPtrMarshalText), - out: []byteWithPtrMarshalText{1, 2, 3}, - golden: true, - }, - */ + { + in: `"AQID"`, // 108 + ptr: new([]byteWithPtrMarshalText), + out: []byteWithPtrMarshalText{1, 2, 3}, + }, + { + in: `["Z01","Z02","Z03"]`, // 109 + ptr: new([]byteWithPtrMarshalText), + out: []byteWithPtrMarshalText{1, 2, 3}, + golden: true, + }, + // ints work with the marshaler but not the base64 []byte case { in: `["Z01","Z02","Z03"]`, // 110 From 24cc1b77b22b7c937101f88d08eabdc893c9c5b3 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 19:20:07 +0900 Subject: [PATCH 09/10] Fix checkptr error --- encode_vm_escaped.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/encode_vm_escaped.go b/encode_vm_escaped.go index 1188b57..b0efc65 100644 --- a/encode_vm_escaped.go +++ b/encode_vm_escaped.go @@ -242,7 +242,7 @@ func encodeRunEscaped(ctx *encodeRuntimeContext, b []byte, codeSet *opcodeSet, o ptr := load(ctxptr, code.idx) isPtr := code.typ.Kind() == reflect.Ptr p := ptrToUnsafePtr(ptr) - if p == nil || isPtr && *(*unsafe.Pointer)(p) == nil { + if p == nil || isPtr && **(**unsafe.Pointer)(unsafe.Pointer(&p)) == nil { b = append(b, '"', '"', ',') } else { v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ From aa0422c239d1b64d28524e6f7cf5beb5569cd8d2 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Feb 2021 19:31:35 +0900 Subject: [PATCH 10/10] Modify section of comment out --- stream_test.go | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/stream_test.go b/stream_test.go index 43588e4..054747d 100644 --- a/stream_test.go +++ b/stream_test.go @@ -376,25 +376,26 @@ var tokenStreamCases = []tokenStreamCase{ map[string]interface{}{"a": float64(1)}, }}, json.Delim('}')}}, - {json: ` [{"a": 1} {"a": 2}] `, expTokens: []interface{}{ - json.Delim('['), - decodeThis{map[string]interface{}{"a": float64(1)}}, - decodeThis{json.NewSyntaxError("expected comma after array element", 11)}, - }}, - {json: `{ "` + strings.Repeat("a", 513) + `" 1 }`, expTokens: []interface{}{ - json.Delim('{'), strings.Repeat("a", 513), - decodeThis{json.NewSyntaxError("expected colon after object key", 518)}, - }}, - {json: `{ "\a" }`, expTokens: []interface{}{ - json.Delim('{'), - json.NewSyntaxError("invalid character 'a' in string escape code", 3), - }}, - {json: ` \a`, expTokens: []interface{}{ - json.NewSyntaxError("invalid character '\\\\' looking for beginning of value", 1), - }}, + /* + {json: ` [{"a": 1} {"a": 2}] `, expTokens: []interface{}{ + json.Delim('['), + decodeThis{map[string]interface{}{"a": float64(1)}}, + decodeThis{json.NewSyntaxError("expected comma after array element", 11)}, + }}, + {json: `{ "` + strings.Repeat("a", 513) + `" 1 }`, expTokens: []interface{}{ + json.Delim('{'), strings.Repeat("a", 513), + decodeThis{json.NewSyntaxError("expected colon after object key", 518)}, + }}, + {json: `{ "\a" }`, expTokens: []interface{}{ + json.Delim('{'), + json.NewSyntaxError("invalid character 'a' in string escape code", 3), + }}, + {json: ` \a`, expTokens: []interface{}{ + json.NewSyntaxError("invalid character '\\\\' looking for beginning of value", 1), + }}, + */ } -/* func TestDecodeInStream(t *testing.T) { for ci, tcase := range tokenStreamCases { @@ -429,7 +430,6 @@ func TestDecodeInStream(t *testing.T) { } } } -*/ // Test from golang.org/issue/11893 func TestHTTPDecoding(t *testing.T) {