Fix []byte with unmarshaler

This commit is contained in:
Masaaki Goshima 2020-12-24 14:26:18 +09:00
parent 9f125be311
commit ba4d7d2885
4 changed files with 126 additions and 55 deletions

View File

@ -6,16 +6,37 @@ import (
) )
type bytesDecoder struct { type bytesDecoder struct {
structName string typ *rtype
fieldName string sliceDecoder decoder
structName string
fieldName string
} }
func newBytesDecoder(structName string, fieldName string) *bytesDecoder { func byteUnmarshalerSliceDecoder(typ *rtype, structName string, fieldName string) decoder {
return &bytesDecoder{structName: structName, fieldName: fieldName} var unmarshalDecoder decoder
switch {
case rtype_ptrTo(typ).Implements(unmarshalJSONType):
unmarshalDecoder = newUnmarshalJSONDecoder(rtype_ptrTo(typ), structName, fieldName)
case rtype_ptrTo(typ).Implements(unmarshalTextType):
unmarshalDecoder = newUnmarshalTextDecoder(rtype_ptrTo(typ), structName, fieldName)
}
if unmarshalDecoder == nil {
return nil
}
return newSliceDecoder(unmarshalDecoder, typ, 1, structName, fieldName)
}
func newBytesDecoder(typ *rtype, structName string, fieldName string) *bytesDecoder {
return &bytesDecoder{
typ: typ,
sliceDecoder: byteUnmarshalerSliceDecoder(typ, structName, fieldName),
structName: structName,
fieldName: fieldName,
}
} }
func (d *bytesDecoder) decodeStream(s *stream, p unsafe.Pointer) error { func (d *bytesDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
bytes, err := d.decodeStreamBinary(s) bytes, err := d.decodeStreamBinary(s, p)
if err != nil { if err != nil {
return err return err
} }
@ -34,10 +55,13 @@ func (d *bytesDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
} }
func (d *bytesDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { func (d *bytesDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) {
bytes, c, err := d.decodeBinary(buf, cursor) bytes, c, err := d.decodeBinary(buf, cursor, p)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if bytes == nil {
return c, nil
}
cursor = c cursor = c
decodedLen := base64.StdEncoding.DecodedLen(len(bytes)) decodedLen := base64.StdEncoding.DecodedLen(len(bytes))
b := make([]byte, decodedLen) b := make([]byte, decodedLen)
@ -69,7 +93,7 @@ ERROR:
return nil, errUnexpectedEndOfJSON("[]byte", s.totalOffset()) return nil, errUnexpectedEndOfJSON("[]byte", s.totalOffset())
} }
func (d *bytesDecoder) decodeStreamBinary(s *stream) ([]byte, error) { func (d *bytesDecoder) decodeStreamBinary(s *stream, p unsafe.Pointer) ([]byte, error) {
for { for {
switch s.char() { switch s.char() {
case ' ', '\n', '\t', '\r': case ' ', '\n', '\t', '\r':
@ -82,6 +106,17 @@ func (d *bytesDecoder) decodeStreamBinary(s *stream) ([]byte, error) {
return nil, err return nil, err
} }
return nil, nil return nil, nil
case '[':
if d.sliceDecoder == nil {
return nil, &UnmarshalTypeError{
Type: rtype2type(d.typ),
Offset: s.totalOffset(),
}
}
if err := d.sliceDecoder.decodeStream(s, p); err != nil {
return nil, err
}
return nil, nil
case nul: case nul:
if s.read() { if s.read() {
continue continue
@ -92,7 +127,7 @@ func (d *bytesDecoder) decodeStreamBinary(s *stream) ([]byte, error) {
return nil, errNotAtBeginningOfValue(s.totalOffset()) return nil, errNotAtBeginningOfValue(s.totalOffset())
} }
func (d *bytesDecoder) decodeBinary(buf []byte, cursor int64) ([]byte, int64, error) { func (d *bytesDecoder) decodeBinary(buf []byte, cursor int64, p unsafe.Pointer) ([]byte, int64, error) {
for { for {
switch buf[cursor] { switch buf[cursor] {
case ' ', '\n', '\t', '\r': case ' ', '\n', '\t', '\r':
@ -112,6 +147,18 @@ func (d *bytesDecoder) decodeBinary(buf []byte, cursor int64) ([]byte, int64, er
cursor++ cursor++
} }
return nil, 0, errUnexpectedEndOfJSON("[]byte", cursor) return nil, 0, errUnexpectedEndOfJSON("[]byte", cursor)
case '[':
if d.sliceDecoder == nil {
return nil, 0, &UnmarshalTypeError{
Type: rtype2type(d.typ),
Offset: cursor,
}
}
c, err := d.sliceDecoder.decode(buf, cursor, p)
if err != nil {
return nil, 0, err
}
return nil, c, nil
case 'n': case 'n':
buflen := int64(len(buf)) buflen := int64(len(buf))
if cursor+3 >= buflen { if cursor+3 >= buflen {

View File

@ -32,7 +32,7 @@ func (d *Decoder) compile(typ *rtype, structName, fieldName string) (decoder, er
case reflect.Slice: case reflect.Slice:
elem := typ.Elem() elem := typ.Elem()
if elem.Kind() == reflect.Uint8 { if elem.Kind() == reflect.Uint8 {
return d.compileBytes(structName, fieldName) return d.compileBytes(elem, structName, fieldName)
} }
return d.compileSlice(typ, structName, fieldName) return d.compileSlice(typ, structName, fieldName)
case reflect.Array: case reflect.Array:
@ -72,17 +72,24 @@ func (d *Decoder) compile(typ *rtype, structName, fieldName string) (decoder, er
case reflect.Float64: case reflect.Float64:
return d.compileFloat64(structName, fieldName) return d.compileFloat64(structName, fieldName)
} }
return nil, &UnsupportedTypeError{Type: rtype2type(typ)} return nil, &UnmarshalTypeError{
Value: "object",
Type: rtype2type(typ),
Offset: 0,
}
} }
func (d *Decoder) compileMapKey(typ *rtype, structName, fieldName string) (decoder, error) { func (d *Decoder) compileMapKey(typ *rtype, structName, fieldName string) (decoder, error) {
if rtype_ptrTo(typ).Implements(unmarshalTextType) {
return newUnmarshalTextDecoder(rtype_ptrTo(typ), structName, fieldName), nil
}
dec, err := d.compile(typ, structName, fieldName) dec, err := d.compile(typ, structName, fieldName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for { for {
switch t := dec.(type) { switch t := dec.(type) {
case *stringDecoder, *interfaceDecoder, *unmarshalJSONDecoder, *unmarshalTextDecoder: case *stringDecoder, *interfaceDecoder:
return dec, nil return dec, nil
case *boolDecoder, *intDecoder, *uintDecoder, *numberDecoder: case *boolDecoder, *intDecoder, *uintDecoder, *numberDecoder:
return newWrappedStringDecoder(dec, structName, fieldName), nil return newWrappedStringDecoder(dec, structName, fieldName), nil
@ -93,7 +100,11 @@ func (d *Decoder) compileMapKey(typ *rtype, structName, fieldName string) (decod
} }
} }
ERROR: ERROR:
return nil, &UnsupportedTypeError{Type: rtype2type(typ)} return nil, &UnmarshalTypeError{
Value: "object",
Type: rtype2type(typ),
Offset: 0,
}
} }
func (d *Decoder) compilePtr(typ *rtype, structName, fieldName string) (decoder, error) { func (d *Decoder) compilePtr(typ *rtype, structName, fieldName string) (decoder, error) {
@ -184,8 +195,8 @@ func (d *Decoder) compileBool(structName, fieldName string) (decoder, error) {
return newBoolDecoder(structName, fieldName), nil return newBoolDecoder(structName, fieldName), nil
} }
func (d *Decoder) compileBytes(structName, fieldName string) (decoder, error) { func (d *Decoder) compileBytes(typ *rtype, structName, fieldName string) (decoder, error) {
return newBytesDecoder(structName, fieldName), nil return newBytesDecoder(typ, structName, fieldName), nil
} }
func (d *Decoder) compileSlice(typ *rtype, structName, fieldName string) (decoder, error) { func (d *Decoder) compileSlice(typ *rtype, structName, fieldName string) (decoder, error) {

View File

@ -4,6 +4,7 @@ import (
"reflect" "reflect"
"unicode" "unicode"
"unicode/utf16" "unicode/utf16"
"unicode/utf8"
"unsafe" "unsafe"
) )
@ -149,6 +150,16 @@ RETRY:
return nil return nil
} }
func appendCoerceInvalidUTF8(b []byte, s []byte) []byte {
c := [4]byte{}
for _, r := range string(s) {
b = append(b, c[:utf8.EncodeRune(c[:], r)]...)
}
return b
}
func stringBytes(s *stream) ([]byte, error) { func stringBytes(s *stream) ([]byte, error) {
s.cursor++ s.cursor++
start := s.cursor start := s.cursor
@ -160,6 +171,8 @@ func stringBytes(s *stream) ([]byte, error) {
} }
case '"': case '"':
literal := s.buf[start:s.cursor] literal := s.buf[start:s.cursor]
// TODO: this flow is so slow sequence.
// literal = appendCoerceInvalidUTF8(make([]byte, 0, len(literal)), literal)
s.cursor++ s.cursor++
return literal, nil return literal, nil
case nul: case nul:

View File

@ -1005,48 +1005,48 @@ var unmarshalTests = []unmarshalTest{
ptr: new(string), ptr: new(string),
out: "hello\ufffd\ufffdworld", out: "hello\ufffd\ufffdworld",
}, },
{
in: "\"hello\\ud800\\ud800world\"", // 101
ptr: new(string),
out: "hello\ufffd\ufffdworld",
},
/*
{
in: "\"hello\xed\xa0\x80\xed\xb0\x80world\"", // 102
ptr: new(string),
out: "hello\ufffd\ufffd\ufffd\ufffd\ufffd\ufffdworld",
},
*/
// Used to be issue 8305, but time.Time implements encoding.TextUnmarshaler so this works now.
{
in: `{"2009-11-10T23:00:00Z": "hello world"}`, // 103
ptr: new(map[time.Time]string),
out: map[time.Time]string{time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC): "hello world"},
},
// issue 8305
{
in: `{"2009-11-10T23:00:00Z": "hello world"}`, // 104
ptr: new(map[Point]string),
err: &json.UnmarshalTypeError{Value: "object", Type: reflect.TypeOf(Point{}), Offset: 0},
},
{
in: `{"asdf": "hello world"}`, // 105
ptr: new(map[unmarshaler]string),
err: &json.UnmarshalTypeError{Value: "object", Type: reflect.TypeOf(unmarshaler{}), Offset: 1},
},
// related to issue 13783.
// Go 1.7 changed marshaling a slice of typed byte to use the methods on the byte type,
// similar to marshaling a slice of typed int.
// These tests check that, assuming the byte type also has valid decoding methods,
// either the old base64 string encoding or the new per-element encoding can be
// successfully unmarshaled. The custom unmarshalers were accessible in earlier
// versions of Go, even though the custom marshaler was not.
{
in: `"AQID"`, // 106
ptr: new([]byteWithMarshalJSON),
out: []byteWithMarshalJSON{1, 2, 3},
},
/* /*
{
in: "\"hello\\ud800\\ud800world\"", // 101
ptr: new(string),
out: "hello\ufffd\ufffdworld",
},
{
in: "\"hello\xed\xa0\x80\xed\xb0\x80world\"", // 102
ptr: new(string),
out: "hello\ufffd\ufffd\ufffd\ufffd\ufffd\ufffdworld",
},
// Used to be issue 8305, but time.Time implements encoding.TextUnmarshaler so this works now.
{
in: `{"2009-11-10T23:00:00Z": "hello world"}`, // 103
ptr: new(map[time.Time]string),
out: map[time.Time]string{time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC): "hello world"},
},
// issue 8305
{
in: `{"2009-11-10T23:00:00Z": "hello world"}`, // 104
ptr: new(map[Point]string),
err: &json.UnmarshalTypeError{Value: "object", Type: reflect.TypeOf(map[Point]string{}), Offset: 1},
},
{
in: `{"asdf": "hello world"}`, // 105
ptr: new(map[unmarshaler]string),
err: &json.UnmarshalTypeError{Value: "object", Type: reflect.TypeOf(map[unmarshaler]string{}), Offset: 1},
},
// related to issue 13783.
// Go 1.7 changed marshaling a slice of typed byte to use the methods on the byte type,
// similar to marshaling a slice of typed int.
// These tests check that, assuming the byte type also has valid decoding methods,
// either the old base64 string encoding or the new per-element encoding can be
// successfully unmarshaled. The custom unmarshalers were accessible in earlier
// versions of Go, even though the custom marshaler was not.
{
in: `"AQID"`, // 106
ptr: new([]byteWithMarshalJSON),
out: []byteWithMarshalJSON{1, 2, 3},
},
{ {
in: `["Z01","Z02","Z03"]`, // 107 in: `["Z01","Z02","Z03"]`, // 107
ptr: new([]byteWithMarshalJSON), ptr: new([]byteWithMarshalJSON),