diff --git a/decode_interface.go b/decode_interface.go index a445848..2138f68 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -1,6 +1,7 @@ package json import ( + "encoding" "reflect" "unsafe" ) @@ -39,11 +40,56 @@ var ( ) ) +func decodeWithUnmarshaler(s *stream, unmarshaler Unmarshaler) error { + start := s.cursor + if err := s.skipValue(); err != nil { + return err + } + src := s.buf[start:s.cursor] + dst := make([]byte, len(src)) + copy(dst, src) + + if err := unmarshaler.UnmarshalJSON(dst); err != nil { + return err + } + return nil +} + +func decodeWithTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler) error { + start := s.cursor + if err := s.skipValue(); err != nil { + return err + } + src := s.buf[start:s.cursor] + dst := make([]byte, len(src)) + copy(dst, src) + + if err := unmarshaler.UnmarshalText(dst); err != nil { + return err + } + return nil +} + func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { s.skipWhiteSpace() for { switch s.char() { case '{': + runtimeInterfaceValue := *(*interface{})(unsafe.Pointer(&interfaceHeader{ + typ: d.typ, + ptr: p, + })) + rv := reflect.ValueOf(runtimeInterfaceValue) + if rv.NumMethod() > 0 && rv.CanInterface() { + if u, ok := rv.Interface().(Unmarshaler); ok { + return decodeWithUnmarshaler(s, u) + } + if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { + return decodeWithTextUnmarshaler(s, u) + } + return nil + } + // empty interface var v map[string]interface{} ptr := unsafe.Pointer(&v) if err := newMapDecoder(