From 47b7f4a5a286e431e8e0a5cfe41fedf1e5f317da Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Tue, 24 Nov 2020 20:15:11 +0900 Subject: [PATCH] Fix decoding for UnmarshalJSON / UnmarshalText --- decode_compile.go | 16 ++++++++-------- decode_struct.go | 3 +++ decode_unmarshal_json.go | 22 ++++++++++++++++++++-- decode_unmarshal_text.go | 22 +++++++++++++++++++--- 4 files changed, 50 insertions(+), 13 deletions(-) diff --git a/decode_compile.go b/decode_compile.go index be51732..8d92f5d 100644 --- a/decode_compile.go +++ b/decode_compile.go @@ -9,13 +9,13 @@ import ( func (d *Decoder) compileHead(typ *rtype) (decoder, error) { switch { case typ.Implements(unmarshalJSONType): - return newUnmarshalJSONDecoder(typ), nil + return newUnmarshalJSONDecoder(typ, "", ""), nil case rtype_ptrTo(typ).Implements(unmarshalJSONType): - return newUnmarshalJSONDecoder(rtype_ptrTo(typ)), nil + return newUnmarshalJSONDecoder(rtype_ptrTo(typ), "", ""), nil case typ.Implements(unmarshalTextType): - return newUnmarshalTextDecoder(typ), nil + return newUnmarshalTextDecoder(typ, "", ""), nil case rtype_ptrTo(typ).Implements(unmarshalTextType): - return newUnmarshalTextDecoder(rtype_ptrTo(typ)), nil + return newUnmarshalTextDecoder(rtype_ptrTo(typ), "", ""), nil } return d.compile(typ.Elem(), "", "") } @@ -23,13 +23,13 @@ func (d *Decoder) compileHead(typ *rtype) (decoder, error) { func (d *Decoder) compile(typ *rtype, structName, fieldName string) (decoder, error) { switch { case typ.Implements(unmarshalJSONType): - return newUnmarshalJSONDecoder(typ), nil + return newUnmarshalJSONDecoder(typ, structName, fieldName), nil case rtype_ptrTo(typ).Implements(unmarshalJSONType): - return newUnmarshalJSONDecoder(rtype_ptrTo(typ)), nil + return newUnmarshalJSONDecoder(rtype_ptrTo(typ), structName, fieldName), nil case typ.Implements(unmarshalTextType): - return newUnmarshalTextDecoder(typ), nil + return newUnmarshalTextDecoder(typ, structName, fieldName), nil case rtype_ptrTo(typ).Implements(unmarshalTextType): - return newUnmarshalTextDecoder(rtype_ptrTo(typ)), nil + return newUnmarshalTextDecoder(rtype_ptrTo(typ), structName, fieldName), nil } switch typ.Kind() { diff --git a/decode_struct.go b/decode_struct.go index 34854c1..10f0330 100644 --- a/decode_struct.go +++ b/decode_struct.go @@ -35,6 +35,9 @@ func (d *structDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return errNotAtBeginningOfValue(s.totalOffset()) } s.cursor++ + if s.char() == '}' { + return nil + } for { s.reset() key, err := d.keyDecoder.decodeStreamByte(s) diff --git a/decode_unmarshal_json.go b/decode_unmarshal_json.go index c3274e6..595b70f 100644 --- a/decode_unmarshal_json.go +++ b/decode_unmarshal_json.go @@ -7,10 +7,24 @@ import ( type unmarshalJSONDecoder struct { typ *rtype isDoublePointer bool + structName string + fieldName string } -func newUnmarshalJSONDecoder(typ *rtype) *unmarshalJSONDecoder { - return &unmarshalJSONDecoder{typ: typ} +func newUnmarshalJSONDecoder(typ *rtype, structName, fieldName string) *unmarshalJSONDecoder { + return &unmarshalJSONDecoder{ + typ: typ, + structName: structName, + fieldName: fieldName, + } +} + +func (d *unmarshalJSONDecoder) annotateError(err error) { + ut, ok := err.(*UnmarshalTypeError) + if ok { + ut.Struct = d.structName + ut.Field = d.fieldName + } } func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { @@ -27,6 +41,7 @@ func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { ptr: newptr, })) if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + d.annotateError(err) return err } *(*unsafe.Pointer)(p) = newptr @@ -36,6 +51,7 @@ func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { ptr: p, })) if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + d.annotateError(err) return err } } @@ -57,6 +73,7 @@ func (d *unmarshalJSONDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer ptr: newptr, })) if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + d.annotateError(err) return 0, err } *(*unsafe.Pointer)(p) = newptr @@ -66,6 +83,7 @@ func (d *unmarshalJSONDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer ptr: p, })) if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { + d.annotateError(err) return 0, err } } diff --git a/decode_unmarshal_text.go b/decode_unmarshal_text.go index 5b25834..d60b521 100644 --- a/decode_unmarshal_text.go +++ b/decode_unmarshal_text.go @@ -9,11 +9,25 @@ import ( ) type unmarshalTextDecoder struct { - typ *rtype + typ *rtype + structName string + fieldName string } -func newUnmarshalTextDecoder(typ *rtype) *unmarshalTextDecoder { - return &unmarshalTextDecoder{typ: typ} +func newUnmarshalTextDecoder(typ *rtype, structName, fieldName string) *unmarshalTextDecoder { + return &unmarshalTextDecoder{ + typ: typ, + structName: structName, + fieldName: fieldName, + } +} + +func (d *unmarshalTextDecoder) annotateError(err error) { + ut, ok := err.(*UnmarshalTypeError) + if ok { + ut.Struct = d.structName + ut.Field = d.fieldName + } } func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { @@ -32,6 +46,7 @@ func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { ptr: newptr, })) if err := v.(encoding.TextUnmarshaler).UnmarshalText(src); err != nil { + d.annotateError(err) return err } *(*unsafe.Pointer)(p) = newptr @@ -54,6 +69,7 @@ func (d *unmarshalTextDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer ptr: *(*unsafe.Pointer)(unsafe.Pointer(&p)), })) if err := v.(encoding.TextUnmarshaler).UnmarshalText(src); err != nil { + d.annotateError(err) return 0, err } return end, nil