Fix decoding of binary type with escaped char

This commit is contained in:
Masaaki Goshima 2021-08-12 13:51:24 +09:00
parent a1780c18a6
commit 75a6ad40b9
2 changed files with 64 additions and 94 deletions

View File

@ -3699,3 +3699,36 @@ func TestIssue251(t *testing.T) {
} }
t.Log(array) t.Log(array)
} }
func TestDecodeBinaryTypeWithEscapedChar(t *testing.T) {
type T struct {
Msg []byte `json:"msg"`
}
content := []byte(`{"msg":"aGVsbG8K\n"}`)
t.Run("unmarshal", func(t *testing.T) {
var expected T
if err := stdjson.Unmarshal(content, &expected); err != nil {
t.Fatal(err)
}
var got T
if err := json.Unmarshal(content, &got); err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected.Msg, got.Msg) {
t.Fatalf("failed to decode binary type with escaped char. expected %q but got %q", expected.Msg, got.Msg)
}
})
t.Run("stream", func(t *testing.T) {
var expected T
if err := stdjson.NewDecoder(bytes.NewBuffer(content)).Decode(&expected); err != nil {
t.Fatal(err)
}
var got T
if err := json.NewDecoder(bytes.NewBuffer(content)).Decode(&got); err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected.Msg, got.Msg) {
t.Fatalf("failed to decode binary type with escaped char. expected %q but got %q", expected.Msg, got.Msg)
}
})
}

View File

@ -11,6 +11,7 @@ import (
type bytesDecoder struct { type bytesDecoder struct {
typ *runtime.Type typ *runtime.Type
sliceDecoder Decoder sliceDecoder Decoder
stringDecoder *stringDecoder
structName string structName string
fieldName string fieldName string
} }
@ -33,6 +34,7 @@ func newBytesDecoder(typ *runtime.Type, structName string, fieldName string) *by
return &bytesDecoder{ return &bytesDecoder{
typ: typ, typ: typ,
sliceDecoder: byteUnmarshalerSliceDecoder(typ, structName, fieldName), sliceDecoder: byteUnmarshalerSliceDecoder(typ, structName, fieldName),
stringDecoder: newStringDecoder(structName, fieldName),
structName: structName, structName: structName,
fieldName: fieldName, fieldName: fieldName,
} }
@ -77,82 +79,25 @@ func (d *bytesDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe
return cursor, nil return cursor, nil
} }
func binaryBytes(s *Stream) ([]byte, error) {
s.cursor++
start := s.cursor
for {
switch s.char() {
case '"':
literal := s.buf[start:s.cursor]
s.cursor++
return literal, nil
case nul:
if s.read() {
continue
}
goto ERROR
}
s.cursor++
}
ERROR:
return nil, errors.ErrUnexpectedEndOfJSON("[]byte", s.totalOffset())
}
func (d *bytesDecoder) decodeStreamBinary(s *Stream, depth int64, p unsafe.Pointer) ([]byte, error) { func (d *bytesDecoder) decodeStreamBinary(s *Stream, depth int64, p unsafe.Pointer) ([]byte, error) {
for { c := s.skipWhiteSpace()
switch s.char() { if c == '[' {
case ' ', '\n', '\t', '\r':
s.cursor++
continue
case '"':
return binaryBytes(s)
case 'n':
if err := nullBytes(s); err != nil {
return nil, err
}
return nil, nil
case '[':
if d.sliceDecoder == nil { if d.sliceDecoder == nil {
return nil, &errors.UnmarshalTypeError{ return nil, &errors.UnmarshalTypeError{
Type: runtime.RType2Type(d.typ), Type: runtime.RType2Type(d.typ),
Offset: s.totalOffset(), Offset: s.totalOffset(),
} }
} }
if err := d.sliceDecoder.DecodeStream(s, depth, p); err != nil { err := d.sliceDecoder.DecodeStream(s, depth, p)
return nil, err return nil, err
} }
return nil, nil return d.stringDecoder.decodeStreamByte(s)
case nul:
if s.read() {
continue
}
}
break
}
return nil, errors.ErrNotAtBeginningOfValue(s.totalOffset())
} }
func (d *bytesDecoder) decodeBinary(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) ([]byte, int64, error) { func (d *bytesDecoder) decodeBinary(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) ([]byte, int64, error) {
buf := ctx.Buf buf := ctx.Buf
for { cursor = skipWhiteSpace(buf, cursor)
switch buf[cursor] { if buf[cursor] == '[' {
case ' ', '\n', '\t', '\r':
cursor++
case '"':
cursor++
start := cursor
for {
switch buf[cursor] {
case '"':
literal := buf[start:cursor]
cursor++
return literal, cursor, nil
case nul:
return nil, 0, errors.ErrUnexpectedEndOfJSON("[]byte", cursor)
}
cursor++
}
case '[':
if d.sliceDecoder == nil { if d.sliceDecoder == nil {
return nil, 0, &errors.UnmarshalTypeError{ return nil, 0, &errors.UnmarshalTypeError{
Type: runtime.RType2Type(d.typ), Type: runtime.RType2Type(d.typ),
@ -164,14 +109,6 @@ func (d *bytesDecoder) decodeBinary(ctx *RuntimeContext, cursor, depth int64, p
return nil, 0, err return nil, 0, err
} }
return nil, c, nil return nil, c, nil
case 'n':
if err := validateNull(buf, cursor); err != nil {
return nil, 0, err
}
cursor += 4
return nil, cursor, nil
default:
return nil, 0, errors.ErrNotAtBeginningOfValue(cursor)
}
} }
return d.stringDecoder.decodeByte(buf, cursor)
} }