diff --git a/decode_test.go b/decode_test.go index a0d4eda..8319b0f 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3699,3 +3699,36 @@ func TestIssue251(t *testing.T) { } t.Log(array) } + +func TestDecodeBinaryTypeWithEscapedChar(t *testing.T) { + type T struct { + Msg []byte `json:"msg"` + } + content := []byte(`{"msg":"aGVsbG8K\n"}`) + t.Run("unmarshal", func(t *testing.T) { + var expected T + if err := stdjson.Unmarshal(content, &expected); err != nil { + t.Fatal(err) + } + var got T + if err := json.Unmarshal(content, &got); err != nil { + t.Fatal(err) + } + if !bytes.Equal(expected.Msg, got.Msg) { + t.Fatalf("failed to decode binary type with escaped char. expected %q but got %q", expected.Msg, got.Msg) + } + }) + t.Run("stream", func(t *testing.T) { + var expected T + if err := stdjson.NewDecoder(bytes.NewBuffer(content)).Decode(&expected); err != nil { + t.Fatal(err) + } + var got T + if err := json.NewDecoder(bytes.NewBuffer(content)).Decode(&got); err != nil { + t.Fatal(err) + } + if !bytes.Equal(expected.Msg, got.Msg) { + t.Fatalf("failed to decode binary type with escaped char. expected %q but got %q", expected.Msg, got.Msg) + } + }) +} diff --git a/internal/decoder/bytes.go b/internal/decoder/bytes.go index 0c4681a..01a37fe 100644 --- a/internal/decoder/bytes.go +++ b/internal/decoder/bytes.go @@ -9,10 +9,11 @@ import ( ) type bytesDecoder struct { - typ *runtime.Type - sliceDecoder Decoder - structName string - fieldName string + typ *runtime.Type + sliceDecoder Decoder + stringDecoder *stringDecoder + structName string + fieldName string } func byteUnmarshalerSliceDecoder(typ *runtime.Type, structName string, fieldName string) Decoder { @@ -31,10 +32,11 @@ func byteUnmarshalerSliceDecoder(typ *runtime.Type, structName string, fieldName func newBytesDecoder(typ *runtime.Type, structName string, fieldName string) *bytesDecoder { return &bytesDecoder{ - typ: typ, - sliceDecoder: byteUnmarshalerSliceDecoder(typ, structName, fieldName), - structName: structName, - fieldName: fieldName, + typ: typ, + sliceDecoder: byteUnmarshalerSliceDecoder(typ, structName, fieldName), + stringDecoder: newStringDecoder(structName, fieldName), + structName: structName, + fieldName: fieldName, } } @@ -77,101 +79,36 @@ func (d *bytesDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe return cursor, nil } -func binaryBytes(s *Stream) ([]byte, error) { - s.cursor++ - start := s.cursor - for { - switch s.char() { - case '"': - literal := s.buf[start:s.cursor] - s.cursor++ - return literal, nil - case nul: - if s.read() { - continue - } - goto ERROR - } - s.cursor++ - } -ERROR: - return nil, errors.ErrUnexpectedEndOfJSON("[]byte", s.totalOffset()) -} - func (d *bytesDecoder) decodeStreamBinary(s *Stream, depth int64, p unsafe.Pointer) ([]byte, error) { - for { - switch s.char() { - case ' ', '\n', '\t', '\r': - s.cursor++ - continue - case '"': - return binaryBytes(s) - case 'n': - if err := nullBytes(s); err != nil { - return nil, err - } - return nil, nil - case '[': - if d.sliceDecoder == nil { - return nil, &errors.UnmarshalTypeError{ - Type: runtime.RType2Type(d.typ), - Offset: s.totalOffset(), - } - } - if err := d.sliceDecoder.DecodeStream(s, depth, p); err != nil { - return nil, err - } - return nil, nil - case nul: - if s.read() { - continue + c := s.skipWhiteSpace() + if c == '[' { + if d.sliceDecoder == nil { + return nil, &errors.UnmarshalTypeError{ + Type: runtime.RType2Type(d.typ), + Offset: s.totalOffset(), } } - break + err := d.sliceDecoder.DecodeStream(s, depth, p) + return nil, err } - return nil, errors.ErrNotAtBeginningOfValue(s.totalOffset()) + return d.stringDecoder.decodeStreamByte(s) } func (d *bytesDecoder) decodeBinary(ctx *RuntimeContext, cursor, depth int64, p unsafe.Pointer) ([]byte, int64, error) { buf := ctx.Buf - for { - switch buf[cursor] { - case ' ', '\n', '\t', '\r': - cursor++ - case '"': - cursor++ - start := cursor - for { - switch buf[cursor] { - case '"': - literal := buf[start:cursor] - cursor++ - return literal, cursor, nil - case nul: - return nil, 0, errors.ErrUnexpectedEndOfJSON("[]byte", cursor) - } - cursor++ + cursor = skipWhiteSpace(buf, cursor) + if buf[cursor] == '[' { + if d.sliceDecoder == nil { + return nil, 0, &errors.UnmarshalTypeError{ + Type: runtime.RType2Type(d.typ), + Offset: cursor, } - case '[': - if d.sliceDecoder == nil { - return nil, 0, &errors.UnmarshalTypeError{ - Type: runtime.RType2Type(d.typ), - Offset: cursor, - } - } - c, err := d.sliceDecoder.Decode(ctx, cursor, depth, p) - if err != nil { - return nil, 0, err - } - return nil, c, nil - case 'n': - if err := validateNull(buf, cursor); err != nil { - return nil, 0, err - } - cursor += 4 - return nil, cursor, nil - default: - return nil, 0, errors.ErrNotAtBeginningOfValue(cursor) } + c, err := d.sliceDecoder.Decode(ctx, cursor, depth, p) + if err != nil { + return nil, 0, err + } + return nil, c, nil } + return d.stringDecoder.decodeByte(buf, cursor) }