Fix decoding of binary type with escaped char

This commit is contained in:
Masaaki Goshima 2021-08-12 13:51:24 +09:00
parent a1780c18a6
commit 75a6ad40b9
2 changed files with 64 additions and 94 deletions

View File

@ -3699,3 +3699,36 @@ func TestIssue251(t *testing.T) {
}
t.Log(array)
}
func TestDecodeBinaryTypeWithEscapedChar(t *testing.T) {
type T struct {
Msg []byte `json:"msg"`
}
content := []byte(`{"msg":"aGVsbG8K\n"}`)
t.Run("unmarshal", func(t *testing.T) {
var expected T
if err := stdjson.Unmarshal(content, &expected); err != nil {
t.Fatal(err)
}
var got T
if err := json.Unmarshal(content, &got); err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected.Msg, got.Msg) {
t.Fatalf("failed to decode binary type with escaped char. expected %q but got %q", expected.Msg, got.Msg)
}
})
t.Run("stream", func(t *testing.T) {
var expected T
if err := stdjson.NewDecoder(bytes.NewBuffer(content)).Decode(&expected); err != nil {
t.Fatal(err)
}
var got T
if err := json.NewDecoder(bytes.NewBuffer(content)).Decode(&got); err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected.Msg, got.Msg) {
t.Fatalf("failed to decode binary type with escaped char. expected %q but got %q", expected.Msg, got.Msg)
}
})
}

View File

@ -9,10 +9,11 @@ import (
)
type bytesDecoder struct {
typ *runtime.Type
sliceDecoder Decoder
structName string
fieldName string
typ *runtime.Type
sliceDecoder Decoder
stringDecoder *stringDecoder
structName string
fieldName string
}
func byteUnmarshalerSliceDecoder(typ *runtime.Type, structName string, fieldName string) Decoder {
@ -31,10 +32,11 @@ func byteUnmarshalerSliceDecoder(typ *runtime.Type, structName string, fieldName
func newBytesDecoder(typ *runtime.Type, structName string, fieldName string) *bytesDecoder {
return &bytesDecoder{
typ: typ,
sliceDecoder: byteUnmarshalerSliceDecoder(typ, structName, fieldName),
structName: structName,
fieldName: fieldName,
typ: typ,
sliceDecoder: byteUnmarshalerSliceDecoder(typ, structName, fieldName),
stringDecoder: newStringDecoder(structName, fieldName),
structName: structName,
fieldName: fieldName,
}
}
@ -77,101 +79,36 @@ func (d *bytesDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
return cursor, nil
}
func binaryBytes(s *Stream) ([]byte, error) {
s.cursor++
start := s.cursor
for {
switch s.char() {
case '"':
literal := s.buf[start:s.cursor]
s.cursor++
return literal, nil
case nul:
if s.read() {
continue
}
goto ERROR
}
s.cursor++
}
ERROR:
return nil, errors.ErrUnexpectedEndOfJSON("[]byte", s.totalOffset())
}
func (d *bytesDecoder) decodeStreamBinary(s *Stream, depth int64, p unsafe.Pointer) ([]byte, error) {
for {
switch s.char() {
case ' ', '\n', '\t', '\r':
s.cursor++
continue
case '"':
return binaryBytes(s)
case 'n':
if err := nullBytes(s); err != nil {
return nil, err
}
return nil, nil
case '[':
if d.sliceDecoder == nil {
return nil, &errors.UnmarshalTypeError{
Type: runtime.RType2Type(d.typ),
Offset: s.totalOffset(),
}
}
if err := d.sliceDecoder.DecodeStream(s, depth, p); err != nil {
return nil, err
}
return nil, nil
case nul:
if s.read() {
continue
c := s.skipWhiteSpace()
if c == '[' {
if d.sliceDecoder == nil {
return nil, &errors.UnmarshalTypeError{
Type: runtime.RType2Type(d.typ),
Offset: s.totalOffset(),
}
}
break
err := d.sliceDecoder.DecodeStream(s, depth, p)
return nil, err
}
return nil, errors.ErrNotAtBeginningOfValue(s.totalOffset())
return d.stringDecoder.decodeStreamByte(s)
}
func (d *bytesDecoder) decodeBinary(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) ([]byte, int64, error) {
buf := ctx.Buf
for {
switch buf[cursor] {
case ' ', '\n', '\t', '\r':
cursor++
case '"':
cursor++
start := cursor
for {
switch buf[cursor] {
case '"':
literal := buf[start:cursor]
cursor++
return literal, cursor, nil
case nul:
return nil, 0, errors.ErrUnexpectedEndOfJSON("[]byte", cursor)
}
cursor++
cursor = skipWhiteSpace(buf, cursor)
if buf[cursor] == '[' {
if d.sliceDecoder == nil {
return nil, 0, &errors.UnmarshalTypeError{
Type: runtime.RType2Type(d.typ),
Offset: cursor,
}
case '[':
if d.sliceDecoder == nil {
return nil, 0, &errors.UnmarshalTypeError{
Type: runtime.RType2Type(d.typ),
Offset: cursor,
}
}
c, err := d.sliceDecoder.Decode(ctx, cursor, depth, p)
if err != nil {
return nil, 0, err
}
return nil, c, nil
case 'n':
if err := validateNull(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return nil, cursor, nil
default:
return nil, 0, errors.ErrNotAtBeginningOfValue(cursor)
}
c, err := d.sliceDecoder.Decode(ctx, cursor, depth, p)
if err != nil {
return nil, 0, err
}
return nil, c, nil
}
return d.stringDecoder.decodeByte(buf, cursor)
}