Fix decoding for UnmarshalJSON / UnmarshalText

This commit is contained in:
Masaaki Goshima 2020-11-24 20:15:11 +09:00
parent 5c82b00ee7
commit 47b7f4a5a2
4 changed files with 50 additions and 13 deletions

View File

@ -9,13 +9,13 @@ import (
func (d *Decoder) compileHead(typ *rtype) (decoder, error) { func (d *Decoder) compileHead(typ *rtype) (decoder, error) {
switch { switch {
case typ.Implements(unmarshalJSONType): case typ.Implements(unmarshalJSONType):
return newUnmarshalJSONDecoder(typ), nil return newUnmarshalJSONDecoder(typ, "", ""), nil
case rtype_ptrTo(typ).Implements(unmarshalJSONType): case rtype_ptrTo(typ).Implements(unmarshalJSONType):
return newUnmarshalJSONDecoder(rtype_ptrTo(typ)), nil return newUnmarshalJSONDecoder(rtype_ptrTo(typ), "", ""), nil
case typ.Implements(unmarshalTextType): case typ.Implements(unmarshalTextType):
return newUnmarshalTextDecoder(typ), nil return newUnmarshalTextDecoder(typ, "", ""), nil
case rtype_ptrTo(typ).Implements(unmarshalTextType): case rtype_ptrTo(typ).Implements(unmarshalTextType):
return newUnmarshalTextDecoder(rtype_ptrTo(typ)), nil return newUnmarshalTextDecoder(rtype_ptrTo(typ), "", ""), nil
} }
return d.compile(typ.Elem(), "", "") return d.compile(typ.Elem(), "", "")
} }
@ -23,13 +23,13 @@ func (d *Decoder) compileHead(typ *rtype) (decoder, error) {
func (d *Decoder) compile(typ *rtype, structName, fieldName string) (decoder, error) { func (d *Decoder) compile(typ *rtype, structName, fieldName string) (decoder, error) {
switch { switch {
case typ.Implements(unmarshalJSONType): case typ.Implements(unmarshalJSONType):
return newUnmarshalJSONDecoder(typ), nil return newUnmarshalJSONDecoder(typ, structName, fieldName), nil
case rtype_ptrTo(typ).Implements(unmarshalJSONType): case rtype_ptrTo(typ).Implements(unmarshalJSONType):
return newUnmarshalJSONDecoder(rtype_ptrTo(typ)), nil return newUnmarshalJSONDecoder(rtype_ptrTo(typ), structName, fieldName), nil
case typ.Implements(unmarshalTextType): case typ.Implements(unmarshalTextType):
return newUnmarshalTextDecoder(typ), nil return newUnmarshalTextDecoder(typ, structName, fieldName), nil
case rtype_ptrTo(typ).Implements(unmarshalTextType): case rtype_ptrTo(typ).Implements(unmarshalTextType):
return newUnmarshalTextDecoder(rtype_ptrTo(typ)), nil return newUnmarshalTextDecoder(rtype_ptrTo(typ), structName, fieldName), nil
} }
switch typ.Kind() { switch typ.Kind() {

View File

@ -35,6 +35,9 @@ func (d *structDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
return errNotAtBeginningOfValue(s.totalOffset()) return errNotAtBeginningOfValue(s.totalOffset())
} }
s.cursor++ s.cursor++
if s.char() == '}' {
return nil
}
for { for {
s.reset() s.reset()
key, err := d.keyDecoder.decodeStreamByte(s) key, err := d.keyDecoder.decodeStreamByte(s)

View File

@ -7,10 +7,24 @@ import (
type unmarshalJSONDecoder struct { type unmarshalJSONDecoder struct {
typ *rtype typ *rtype
isDoublePointer bool isDoublePointer bool
structName string
fieldName string
} }
func newUnmarshalJSONDecoder(typ *rtype) *unmarshalJSONDecoder { func newUnmarshalJSONDecoder(typ *rtype, structName, fieldName string) *unmarshalJSONDecoder {
return &unmarshalJSONDecoder{typ: typ} return &unmarshalJSONDecoder{
typ: typ,
structName: structName,
fieldName: fieldName,
}
}
func (d *unmarshalJSONDecoder) annotateError(err error) {
ut, ok := err.(*UnmarshalTypeError)
if ok {
ut.Struct = d.structName
ut.Field = d.fieldName
}
} }
func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error { func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
@ -27,6 +41,7 @@ func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
ptr: newptr, ptr: newptr,
})) }))
if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil {
d.annotateError(err)
return err return err
} }
*(*unsafe.Pointer)(p) = newptr *(*unsafe.Pointer)(p) = newptr
@ -36,6 +51,7 @@ func (d *unmarshalJSONDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
ptr: p, ptr: p,
})) }))
if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil {
d.annotateError(err)
return err return err
} }
} }
@ -57,6 +73,7 @@ func (d *unmarshalJSONDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer
ptr: newptr, ptr: newptr,
})) }))
if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil {
d.annotateError(err)
return 0, err return 0, err
} }
*(*unsafe.Pointer)(p) = newptr *(*unsafe.Pointer)(p) = newptr
@ -66,6 +83,7 @@ func (d *unmarshalJSONDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer
ptr: p, ptr: p,
})) }))
if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil { if err := v.(Unmarshaler).UnmarshalJSON(src); err != nil {
d.annotateError(err)
return 0, err return 0, err
} }
} }

View File

@ -10,10 +10,24 @@ import (
type unmarshalTextDecoder struct { type unmarshalTextDecoder struct {
typ *rtype typ *rtype
structName string
fieldName string
} }
func newUnmarshalTextDecoder(typ *rtype) *unmarshalTextDecoder { func newUnmarshalTextDecoder(typ *rtype, structName, fieldName string) *unmarshalTextDecoder {
return &unmarshalTextDecoder{typ: typ} return &unmarshalTextDecoder{
typ: typ,
structName: structName,
fieldName: fieldName,
}
}
func (d *unmarshalTextDecoder) annotateError(err error) {
ut, ok := err.(*UnmarshalTypeError)
if ok {
ut.Struct = d.structName
ut.Field = d.fieldName
}
} }
func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
@ -32,6 +46,7 @@ func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
ptr: newptr, ptr: newptr,
})) }))
if err := v.(encoding.TextUnmarshaler).UnmarshalText(src); err != nil { if err := v.(encoding.TextUnmarshaler).UnmarshalText(src); err != nil {
d.annotateError(err)
return err return err
} }
*(*unsafe.Pointer)(p) = newptr *(*unsafe.Pointer)(p) = newptr
@ -54,6 +69,7 @@ func (d *unmarshalTextDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer
ptr: *(*unsafe.Pointer)(unsafe.Pointer(&p)), ptr: *(*unsafe.Pointer)(unsafe.Pointer(&p)),
})) }))
if err := v.(encoding.TextUnmarshaler).UnmarshalText(src); err != nil { if err := v.(encoding.TextUnmarshaler).UnmarshalText(src); err != nil {
d.annotateError(err)
return 0, err return 0, err
} }
return end, nil return end, nil