diff --git a/decode.go b/decode.go index d06136d..14d1a33 100644 --- a/decode.go +++ b/decode.go @@ -2,6 +2,7 @@ package json import ( "encoding" + "fmt" "io" "reflect" "strconv" @@ -40,10 +41,11 @@ func unmarshal(data []byte, v interface{}) error { if err != nil { return err } - if _, err := dec.decode(src, 0, 0, header.ptr); err != nil { + cursor, err := dec.decode(src, 0, 0, header.ptr) + if err != nil { return err } - return nil + return validateEndBuf(src, cursor) } func unmarshalNoEscape(data []byte, v interface{}) error { @@ -59,10 +61,27 @@ func unmarshalNoEscape(data []byte, v interface{}) error { if err != nil { return err } - if _, err := dec.decode(src, 0, 0, noescape(header.ptr)); err != nil { + cursor, err := dec.decode(src, 0, 0, noescape(header.ptr)) + if err != nil { return err } - return nil + return validateEndBuf(src, cursor) +} + +func validateEndBuf(src []byte, cursor int64) error { + for { + switch src[cursor] { + case ' ', '\t', '\n', '\r': + cursor++ + continue + case nul: + return nil + } + return errSyntax( + fmt.Sprintf("invalid character '%c' after top-level value", src[cursor]), + cursor+1, + ) + } } //nolint:staticcheck diff --git a/decode_test.go b/decode_test.go index 6595f4b..356a754 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3,6 +3,7 @@ package json_test import ( "bytes" "encoding" + stdjson "encoding/json" "errors" "fmt" "image" @@ -2929,3 +2930,24 @@ func TestDecodeSlice(t *testing.T) { t.Fatal("invaid address") } } + +func TestInvalidTopLevelValue(t *testing.T) { + t.Run("invalid end of buffer", func(t *testing.T) { + var v struct{} + if err := stdjson.Unmarshal([]byte(`{}0`), &v); err == nil { + t.Fatal("expected error") + } + if err := json.Unmarshal([]byte(`{}0`), &v); err == nil { + t.Fatal("expected error") + } + }) + t.Run("invalid object", func(t *testing.T) { + var v interface{} + if err := stdjson.Unmarshal([]byte(`{"a":4}{"a"5}`), &v); err == nil { + t.Fatal("expected error") + } + if err := json.Unmarshal([]byte(`{"a":4}{"a"5}`), &v); err == nil { + t.Fatal("expected error") + } + }) +}