Fix decoding of binary type with escaped char

This commit is contained in:
Masaaki Goshima 2021-08-12 13:51:24 +09:00
parent a1780c18a6
commit 75a6ad40b9
2 changed files with 64 additions and 94 deletions

View File

@ -3699,3 +3699,36 @@ func TestIssue251(t *testing.T) {
} }
t.Log(array) t.Log(array)
} }
func TestDecodeBinaryTypeWithEscapedChar(t *testing.T) {
type T struct {
Msg []byte `json:"msg"`
}
content := []byte(`{"msg":"aGVsbG8K\n"}`)
t.Run("unmarshal", func(t *testing.T) {
var expected T
if err := stdjson.Unmarshal(content, &expected); err != nil {
t.Fatal(err)
}
var got T
if err := json.Unmarshal(content, &got); err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected.Msg, got.Msg) {
t.Fatalf("failed to decode binary type with escaped char. expected %q but got %q", expected.Msg, got.Msg)
}
})
t.Run("stream", func(t *testing.T) {
var expected T
if err := stdjson.NewDecoder(bytes.NewBuffer(content)).Decode(&expected); err != nil {
t.Fatal(err)
}
var got T
if err := json.NewDecoder(bytes.NewBuffer(content)).Decode(&got); err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected.Msg, got.Msg) {
t.Fatalf("failed to decode binary type with escaped char. expected %q but got %q", expected.Msg, got.Msg)
}
})
}

View File

@ -9,10 +9,11 @@ import (
) )
type bytesDecoder struct { type bytesDecoder struct {
typ *runtime.Type typ *runtime.Type
sliceDecoder Decoder sliceDecoder Decoder
structName string stringDecoder *stringDecoder
fieldName string structName string
fieldName string
} }
func byteUnmarshalerSliceDecoder(typ *runtime.Type, structName string, fieldName string) Decoder { func byteUnmarshalerSliceDecoder(typ *runtime.Type, structName string, fieldName string) Decoder {
@ -31,10 +32,11 @@ func byteUnmarshalerSliceDecoder(typ *runtime.Type, structName string, fieldName
func newBytesDecoder(typ *runtime.Type, structName string, fieldName string) *bytesDecoder { func newBytesDecoder(typ *runtime.Type, structName string, fieldName string) *bytesDecoder {
return &bytesDecoder{ return &bytesDecoder{
typ: typ, typ: typ,
sliceDecoder: byteUnmarshalerSliceDecoder(typ, structName, fieldName), sliceDecoder: byteUnmarshalerSliceDecoder(typ, structName, fieldName),
structName: structName, stringDecoder: newStringDecoder(structName, fieldName),
fieldName: fieldName, structName: structName,
fieldName: fieldName,
} }
} }
@ -77,101 +79,36 @@ func (d *bytesDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
return cursor, nil return cursor, nil
} }
func binaryBytes(s *Stream) ([]byte, error) {
s.cursor++
start := s.cursor
for {
switch s.char() {
case '"':
literal := s.buf[start:s.cursor]
s.cursor++
return literal, nil
case nul:
if s.read() {
continue
}
goto ERROR
}
s.cursor++
}
ERROR:
return nil, errors.ErrUnexpectedEndOfJSON("[]byte", s.totalOffset())
}
func (d *bytesDecoder) decodeStreamBinary(s *Stream, depth int64, p unsafe.Pointer) ([]byte, error) { func (d *bytesDecoder) decodeStreamBinary(s *Stream, depth int64, p unsafe.Pointer) ([]byte, error) {
for { c := s.skipWhiteSpace()
switch s.char() { if c == '[' {
case ' ', '\n', '\t', '\r': if d.sliceDecoder == nil {
s.cursor++ return nil, &errors.UnmarshalTypeError{
continue Type: runtime.RType2Type(d.typ),
case '"': Offset: s.totalOffset(),
return binaryBytes(s)
case 'n':
if err := nullBytes(s); err != nil {
return nil, err
}
return nil, nil
case '[':
if d.sliceDecoder == nil {
return nil, &errors.UnmarshalTypeError{
Type: runtime.RType2Type(d.typ),
Offset: s.totalOffset(),
}
}
if err := d.sliceDecoder.DecodeStream(s, depth, p); err != nil {
return nil, err
}
return nil, nil
case nul:
if s.read() {
continue
} }
} }
break err := d.sliceDecoder.DecodeStream(s, depth, p)
return nil, err
} }
return nil, errors.ErrNotAtBeginningOfValue(s.totalOffset()) return d.stringDecoder.decodeStreamByte(s)
} }
func (d *bytesDecoder) decodeBinary(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) ([]byte, int64, error) { func (d *bytesDecoder) decodeBinary(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) ([]byte, int64, error) {
buf := ctx.Buf buf := ctx.Buf
for { cursor = skipWhiteSpace(buf, cursor)
switch buf[cursor] { if buf[cursor] == '[' {
case ' ', '\n', '\t', '\r': if d.sliceDecoder == nil {
cursor++ return nil, 0, &errors.UnmarshalTypeError{
case '"': Type: runtime.RType2Type(d.typ),
cursor++ Offset: cursor,
start := cursor
for {
switch buf[cursor] {
case '"':
literal := buf[start:cursor]
cursor++
return literal, cursor, nil
case nul:
return nil, 0, errors.ErrUnexpectedEndOfJSON("[]byte", cursor)
}
cursor++
} }
case '[':
if d.sliceDecoder == nil {
return nil, 0, &errors.UnmarshalTypeError{
Type: runtime.RType2Type(d.typ),
Offset: cursor,
}
}
c, err := d.sliceDecoder.Decode(ctx, cursor, depth, p)
if err != nil {
return nil, 0, err
}
return nil, c, nil
case 'n':
if err := validateNull(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return nil, cursor, nil
default:
return nil, 0, errors.ErrNotAtBeginningOfValue(cursor)
} }
c, err := d.sliceDecoder.Decode(ctx, cursor, depth, p)
if err != nil {
return nil, 0, err
}
return nil, c, nil
} }
return d.stringDecoder.decodeByte(buf, cursor)
} }