Fix decoding of invalid top level value

This commit is contained in:
Masaaki Goshima 2021-04-10 18:44:43 +09:00
parent 6fbf3fc7c0
commit 2cb792ca28
2 changed files with 45 additions and 3 deletions

View File

@ -2,6 +2,7 @@ package json
import ( import (
"encoding" "encoding"
"fmt"
"io" "io"
"reflect" "reflect"
"strconv" "strconv"
@ -40,10 +41,11 @@ func unmarshal(data []byte, v interface{}) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := dec.decode(src, 0, 0, header.ptr); err != nil { cursor, err := dec.decode(src, 0, 0, header.ptr)
if err != nil {
return err return err
} }
return nil return validateEndBuf(src, cursor)
} }
func unmarshalNoEscape(data []byte, v interface{}) error { func unmarshalNoEscape(data []byte, v interface{}) error {
@ -59,9 +61,27 @@ func unmarshalNoEscape(data []byte, v interface{}) error {
if err != nil { if err != nil {
return err return err
} }
if _, err := dec.decode(src, 0, 0, noescape(header.ptr)); err != nil { cursor, err := dec.decode(src, 0, 0, noescape(header.ptr))
if err != nil {
return err return err
} }
return validateEndBuf(src, cursor)
}
func validateEndBuf(src []byte, cursor int64) error {
for {
switch src[cursor] {
case ' ', '\t', '\n', '\r':
cursor++
continue
case nul:
return nil
}
return errSyntax(
fmt.Sprintf("invalid character '%c' after top-level value", src[cursor]),
cursor+1,
)
}
return nil return nil
} }

View File

@ -3,6 +3,7 @@ package json_test
import ( import (
"bytes" "bytes"
"encoding" "encoding"
stdjson "encoding/json"
"errors" "errors"
"fmt" "fmt"
"image" "image"
@ -2929,3 +2930,24 @@ func TestDecodeSlice(t *testing.T) {
t.Fatal("invaid address") t.Fatal("invaid address")
} }
} }
func TestInvalidTopLevelValue(t *testing.T) {
t.Run("invalid end of buffer", func(t *testing.T) {
var v struct{}
if err := stdjson.Unmarshal([]byte(`{}0`), &v); err == nil {
t.Fatal("expected error")
}
if err := json.Unmarshal([]byte(`{}0`), &v); err == nil {
t.Fatal("expected error")
}
})
t.Run("invalid object", func(t *testing.T) {
var v interface{}
if err := stdjson.Unmarshal([]byte(`{"a":4}{"a"5}`), &v); err == nil {
t.Fatal("expected error")
}
if err := json.Unmarshal([]byte(`{"a":4}{"a"5}`), &v); err == nil {
t.Fatal("expected error")
}
})
}