From 6506007b6cd8c6159eb737602224f738c1361e16 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Tue, 22 Dec 2020 21:55:59 +0900 Subject: [PATCH] Fix decoding of embeded struct --- decode_anonymous_field.go | 37 ++++++++++++ decode_compile.go | 120 +++++++++++++++++++++++++++++++++++--- decode_interface.go | 7 +++ decode_map.go | 26 ++++----- decode_ptr.go | 8 +++ decode_struct.go | 5 +- 6 files changed, 178 insertions(+), 25 deletions(-) create mode 100644 decode_anonymous_field.go diff --git a/decode_anonymous_field.go b/decode_anonymous_field.go new file mode 100644 index 0000000..2ee28f8 --- /dev/null +++ b/decode_anonymous_field.go @@ -0,0 +1,37 @@ +package json + +import ( + "fmt" + "unsafe" +) + +type anonymousFieldDecoder struct { + structType *rtype + offset uintptr + dec decoder +} + +func newAnonymousFieldDecoder(structType *rtype, offset uintptr, dec decoder) *anonymousFieldDecoder { + return &anonymousFieldDecoder{ + structType: structType, + offset: offset, + dec: dec, + } +} + +func (d *anonymousFieldDecoder) decodeStream(s *stream, p unsafe.Pointer) error { + fmt.Println("called anonymous field decoder", *(*unsafe.Pointer)(p)) + if *(*unsafe.Pointer)(p) == nil { + *(*unsafe.Pointer)(p) = unsafe_New(d.structType) + } + p = *(*unsafe.Pointer)(p) + return d.dec.decodeStream(s, unsafe.Pointer(uintptr(p)+d.offset)) +} + +func (d *anonymousFieldDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) { + if *(*unsafe.Pointer)(p) == nil { + *(*unsafe.Pointer)(p) = unsafe_New(d.structType) + } + p = *(*unsafe.Pointer)(p) + return d.dec.decode(buf, cursor, unsafe.Pointer(uintptr(p)+d.offset)) +} diff --git a/decode_compile.go b/decode_compile.go index c23f7b4..c75b8a9 100644 --- a/decode_compile.go +++ b/decode_compile.go @@ -215,15 +215,59 @@ func (d *Decoder) compileMap(typ *rtype, structName, fieldName string) (decoder, if err != nil { return nil, err } - return newMapDecoder(typ, keyDec, valueDec, structName, fieldName), nil + return newMapDecoder(typ, typ.Key(), keyDec, typ.Elem(), valueDec, structName, fieldName), nil } func (d *Decoder) compileInterface(typ *rtype, structName, fieldName string) (decoder, error) { return newInterfaceDecoder(typ, structName, fieldName), nil } +func (d *Decoder) removeConflictFields(fieldMap map[string]*structFieldSet, conflictedMap map[string]struct{}, dec *structDecoder, baseOffset uintptr) { + for k, v := range dec.fieldMap { + if _, exists := conflictedMap[k]; exists { + // already conflicted key + continue + } + set, exists := fieldMap[k] + if !exists { + fieldSet := &structFieldSet{ + dec: v.dec, + offset: baseOffset + v.offset, + isTaggedKey: v.isTaggedKey, + } + fieldMap[k] = fieldSet + fieldMap[strings.ToLower(k)] = fieldSet + continue + } + if set.isTaggedKey { + if v.isTaggedKey { + // conflict tag key + delete(fieldMap, k) + conflictedMap[k] = struct{}{} + conflictedMap[strings.ToLower(k)] = struct{}{} + } + } else { + if v.isTaggedKey { + fieldSet := &structFieldSet{ + dec: v.dec, + offset: baseOffset + v.offset, + isTaggedKey: v.isTaggedKey, + } + fieldMap[k] = fieldSet + fieldMap[strings.ToLower(k)] = fieldSet + } else { + // conflict tag key + delete(fieldMap, k) + conflictedMap[k] = struct{}{} + conflictedMap[strings.ToLower(k)] = struct{}{} + } + } + } +} + func (d *Decoder) compileStruct(typ *rtype, structName, fieldName string) (decoder, error) { fieldNum := typ.NumField() + conflictedMap := map[string]struct{}{} fieldMap := map[string]*structFieldSet{} typeptr := uintptr(unsafe.Pointer(typ)) if dec, exists := d.structTypeToDecoder[typeptr]; exists { @@ -242,13 +286,75 @@ func (d *Decoder) compileStruct(typ *rtype, structName, fieldName string) (decod if err != nil { return nil, err } - if tag.isString { - dec = newWrappedStringDecoder(dec, structName, field.Name) + if field.Anonymous && !tag.isTaggedKey { + if stDec, ok := dec.(*structDecoder); ok { + if type2rtype(field.Type) == typ { + // recursive definition + continue + } + d.removeConflictFields(fieldMap, conflictedMap, stDec, uintptr(field.Offset)) + } else if pdec, ok := dec.(*ptrDecoder); ok { + contentDec := pdec.contentDecoder() + if pdec.typ == typ { + // recursive definition + continue + } + if dec, ok := contentDec.(*structDecoder); ok { + for k, v := range dec.fieldMap { + if _, exists := conflictedMap[k]; exists { + // already conflicted key + continue + } + set, exists := fieldMap[k] + if !exists { + fieldSet := &structFieldSet{ + dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec), + offset: uintptr(field.Offset), + isTaggedKey: v.isTaggedKey, + } + fieldMap[k] = fieldSet + fieldMap[strings.ToLower(k)] = fieldSet + continue + } + if set.isTaggedKey { + if v.isTaggedKey { + // conflict tag key + delete(fieldMap, k) + conflictedMap[k] = struct{}{} + conflictedMap[strings.ToLower(k)] = struct{}{} + } + } else { + if v.isTaggedKey { + fieldSet := &structFieldSet{ + dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec), + offset: uintptr(field.Offset), + isTaggedKey: v.isTaggedKey, + } + fieldMap[k] = fieldSet + fieldMap[strings.ToLower(k)] = fieldSet + } else { + // conflict tag key + delete(fieldMap, k) + conflictedMap[k] = struct{}{} + conflictedMap[strings.ToLower(k)] = struct{}{} + } + } + } + } + } + } else { + if tag.isString { + dec = newWrappedStringDecoder(dec, structName, field.Name) + } + fieldSet := &structFieldSet{dec: dec, offset: field.Offset, isTaggedKey: tag.isTaggedKey} + if tag.key != "" { + fieldMap[tag.key] = fieldSet + fieldMap[strings.ToLower(tag.key)] = fieldSet + } else { + fieldMap[field.Name] = fieldSet + fieldMap[strings.ToLower(field.Name)] = fieldSet + } } - fieldSet := &structFieldSet{dec: dec, offset: field.Offset} - fieldMap[field.Name] = fieldSet - fieldMap[tag.key] = fieldSet - fieldMap[strings.ToLower(tag.key)] = fieldSet } delete(d.structTypeToDecoder, typeptr) return structDec, nil diff --git a/decode_interface.go b/decode_interface.go index 89aa559..a445848 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -34,6 +34,9 @@ var ( interfaceMapType = type2rtype( reflect.TypeOf((*map[string]interface{})(nil)).Elem(), ) + stringType = type2rtype( + reflect.TypeOf(""), + ) ) func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { @@ -45,7 +48,9 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { ptr := unsafe.Pointer(&v) if err := newMapDecoder( interfaceMapType, + stringType, newStringDecoder(d.structName, d.fieldName), + interfaceMapType.Elem(), newInterfaceDecoder(d.typ, d.structName, d.fieldName), d.structName, d.fieldName, @@ -129,7 +134,9 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i ptr := unsafe.Pointer(&v) dec := newMapDecoder( interfaceMapType, + stringType, newStringDecoder(d.structName, d.fieldName), + interfaceMapType.Elem(), newInterfaceDecoder(d.typ, d.structName, d.fieldName), d.structName, d.fieldName, ) diff --git a/decode_map.go b/decode_map.go index 4aa8e5e..e59f0f7 100644 --- a/decode_map.go +++ b/decode_map.go @@ -6,16 +6,20 @@ import ( type mapDecoder struct { mapType *rtype + keyType *rtype + valueType *rtype keyDecoder decoder valueDecoder decoder structName string fieldName string } -func newMapDecoder(mapType *rtype, keyDec decoder, valueDec decoder, structName, fieldName string) *mapDecoder { +func newMapDecoder(mapType *rtype, keyType *rtype, keyDec decoder, valueType *rtype, valueDec decoder, structName, fieldName string) *mapDecoder { return &mapDecoder{ mapType: mapType, keyDecoder: keyDec, + keyType: keyType, + valueType: valueType, valueDecoder: valueDec, structName: structName, fieldName: fieldName, @@ -39,16 +43,6 @@ func (d *mapDecoder) setValue(buf []byte, cursor int64, key interface{}) (int64, return d.valueDecoder.decode(buf, cursor, header.ptr) } -func (d *mapDecoder) setKeyStream(s *stream, key interface{}) error { - header := (*interfaceHeader)(unsafe.Pointer(&key)) - return d.keyDecoder.decodeStream(s, header.ptr) -} - -func (d *mapDecoder) setValueStream(s *stream, key interface{}) error { - header := (*interfaceHeader)(unsafe.Pointer(&key)) - return d.valueDecoder.decodeStream(s, header.ptr) -} - func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { s.skipWhiteSpace() switch s.char() { @@ -70,8 +64,8 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { } for { s.cursor++ - var key interface{} - if err := d.setKeyStream(s, &key); err != nil { + k := unsafe_New(d.keyType) + if err := d.keyDecoder.decodeStream(s, k); err != nil { return err } s.skipWhiteSpace() @@ -82,11 +76,11 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return errExpected("colon after object key", s.totalOffset()) } s.cursor++ - var value interface{} - if err := d.setValueStream(s, &value); err != nil { + v := unsafe_New(d.valueType) + if err := d.valueDecoder.decodeStream(s, v); err != nil { return err } - mapassign(d.mapType, mapValue, unsafe.Pointer(&key), unsafe.Pointer(&value)) + mapassign(d.mapType, mapValue, k, v) s.skipWhiteSpace() if s.char() == nul { s.read() diff --git a/decode_ptr.go b/decode_ptr.go index c25d893..a62a691 100644 --- a/decode_ptr.go +++ b/decode_ptr.go @@ -20,6 +20,14 @@ func newPtrDecoder(dec decoder, typ *rtype, structName, fieldName string) *ptrDe } } +func (d *ptrDecoder) contentDecoder() decoder { + dec, ok := d.dec.(*ptrDecoder) + if !ok { + return d.dec + } + return dec.contentDecoder() +} + //go:linkname unsafe_New reflect.unsafe_New func unsafe_New(*rtype) unsafe.Pointer diff --git a/decode_struct.go b/decode_struct.go index 8eeca73..f4e1847 100644 --- a/decode_struct.go +++ b/decode_struct.go @@ -6,8 +6,9 @@ import ( ) type structFieldSet struct { - dec decoder - offset uintptr + dec decoder + offset uintptr + isTaggedKey bool } type structDecoder struct {