Fix decoding of embeded struct

This commit is contained in:
Masaaki Goshima 2020-12-22 21:55:59 +09:00
parent d7b9036e88
commit 6506007b6c
6 changed files with 178 additions and 25 deletions

37
decode_anonymous_field.go Normal file
View File

@ -0,0 +1,37 @@
package json
import (
"fmt"
"unsafe"
)
type anonymousFieldDecoder struct {
structType *rtype
offset uintptr
dec decoder
}
func newAnonymousFieldDecoder(structType *rtype, offset uintptr, dec decoder) *anonymousFieldDecoder {
return &anonymousFieldDecoder{
structType: structType,
offset: offset,
dec: dec,
}
}
func (d *anonymousFieldDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
fmt.Println("called anonymous field decoder", *(*unsafe.Pointer)(p))
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe_New(d.structType)
}
p = *(*unsafe.Pointer)(p)
return d.dec.decodeStream(s, unsafe.Pointer(uintptr(p)+d.offset))
}
func (d *anonymousFieldDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, error) {
if *(*unsafe.Pointer)(p) == nil {
*(*unsafe.Pointer)(p) = unsafe_New(d.structType)
}
p = *(*unsafe.Pointer)(p)
return d.dec.decode(buf, cursor, unsafe.Pointer(uintptr(p)+d.offset))
}

View File

@ -215,15 +215,59 @@ func (d *Decoder) compileMap(typ *rtype, structName, fieldName string) (decoder,
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newMapDecoder(typ, keyDec, valueDec, structName, fieldName), nil return newMapDecoder(typ, typ.Key(), keyDec, typ.Elem(), valueDec, structName, fieldName), nil
} }
func (d *Decoder) compileInterface(typ *rtype, structName, fieldName string) (decoder, error) { func (d *Decoder) compileInterface(typ *rtype, structName, fieldName string) (decoder, error) {
return newInterfaceDecoder(typ, structName, fieldName), nil return newInterfaceDecoder(typ, structName, fieldName), nil
} }
func (d *Decoder) removeConflictFields(fieldMap map[string]*structFieldSet, conflictedMap map[string]struct{}, dec *structDecoder, baseOffset uintptr) {
for k, v := range dec.fieldMap {
if _, exists := conflictedMap[k]; exists {
// already conflicted key
continue
}
set, exists := fieldMap[k]
if !exists {
fieldSet := &structFieldSet{
dec: v.dec,
offset: baseOffset + v.offset,
isTaggedKey: v.isTaggedKey,
}
fieldMap[k] = fieldSet
fieldMap[strings.ToLower(k)] = fieldSet
continue
}
if set.isTaggedKey {
if v.isTaggedKey {
// conflict tag key
delete(fieldMap, k)
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
} else {
if v.isTaggedKey {
fieldSet := &structFieldSet{
dec: v.dec,
offset: baseOffset + v.offset,
isTaggedKey: v.isTaggedKey,
}
fieldMap[k] = fieldSet
fieldMap[strings.ToLower(k)] = fieldSet
} else {
// conflict tag key
delete(fieldMap, k)
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
}
}
}
func (d *Decoder) compileStruct(typ *rtype, structName, fieldName string) (decoder, error) { func (d *Decoder) compileStruct(typ *rtype, structName, fieldName string) (decoder, error) {
fieldNum := typ.NumField() fieldNum := typ.NumField()
conflictedMap := map[string]struct{}{}
fieldMap := map[string]*structFieldSet{} fieldMap := map[string]*structFieldSet{}
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
if dec, exists := d.structTypeToDecoder[typeptr]; exists { if dec, exists := d.structTypeToDecoder[typeptr]; exists {
@ -242,13 +286,75 @@ func (d *Decoder) compileStruct(typ *rtype, structName, fieldName string) (decod
if err != nil { if err != nil {
return nil, err return nil, err
} }
if field.Anonymous && !tag.isTaggedKey {
if stDec, ok := dec.(*structDecoder); ok {
if type2rtype(field.Type) == typ {
// recursive definition
continue
}
d.removeConflictFields(fieldMap, conflictedMap, stDec, uintptr(field.Offset))
} else if pdec, ok := dec.(*ptrDecoder); ok {
contentDec := pdec.contentDecoder()
if pdec.typ == typ {
// recursive definition
continue
}
if dec, ok := contentDec.(*structDecoder); ok {
for k, v := range dec.fieldMap {
if _, exists := conflictedMap[k]; exists {
// already conflicted key
continue
}
set, exists := fieldMap[k]
if !exists {
fieldSet := &structFieldSet{
dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec),
offset: uintptr(field.Offset),
isTaggedKey: v.isTaggedKey,
}
fieldMap[k] = fieldSet
fieldMap[strings.ToLower(k)] = fieldSet
continue
}
if set.isTaggedKey {
if v.isTaggedKey {
// conflict tag key
delete(fieldMap, k)
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
} else {
if v.isTaggedKey {
fieldSet := &structFieldSet{
dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec),
offset: uintptr(field.Offset),
isTaggedKey: v.isTaggedKey,
}
fieldMap[k] = fieldSet
fieldMap[strings.ToLower(k)] = fieldSet
} else {
// conflict tag key
delete(fieldMap, k)
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
}
}
}
}
} else {
if tag.isString { if tag.isString {
dec = newWrappedStringDecoder(dec, structName, field.Name) dec = newWrappedStringDecoder(dec, structName, field.Name)
} }
fieldSet := &structFieldSet{dec: dec, offset: field.Offset} fieldSet := &structFieldSet{dec: dec, offset: field.Offset, isTaggedKey: tag.isTaggedKey}
fieldMap[field.Name] = fieldSet if tag.key != "" {
fieldMap[tag.key] = fieldSet fieldMap[tag.key] = fieldSet
fieldMap[strings.ToLower(tag.key)] = fieldSet fieldMap[strings.ToLower(tag.key)] = fieldSet
} else {
fieldMap[field.Name] = fieldSet
fieldMap[strings.ToLower(field.Name)] = fieldSet
}
}
} }
delete(d.structTypeToDecoder, typeptr) delete(d.structTypeToDecoder, typeptr)
return structDec, nil return structDec, nil

View File

@ -34,6 +34,9 @@ var (
interfaceMapType = type2rtype( interfaceMapType = type2rtype(
reflect.TypeOf((*map[string]interface{})(nil)).Elem(), reflect.TypeOf((*map[string]interface{})(nil)).Elem(),
) )
stringType = type2rtype(
reflect.TypeOf(""),
)
) )
func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
@ -45,7 +48,9 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
ptr := unsafe.Pointer(&v) ptr := unsafe.Pointer(&v)
if err := newMapDecoder( if err := newMapDecoder(
interfaceMapType, interfaceMapType,
stringType,
newStringDecoder(d.structName, d.fieldName), newStringDecoder(d.structName, d.fieldName),
interfaceMapType.Elem(),
newInterfaceDecoder(d.typ, d.structName, d.fieldName), newInterfaceDecoder(d.typ, d.structName, d.fieldName),
d.structName, d.structName,
d.fieldName, d.fieldName,
@ -129,7 +134,9 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i
ptr := unsafe.Pointer(&v) ptr := unsafe.Pointer(&v)
dec := newMapDecoder( dec := newMapDecoder(
interfaceMapType, interfaceMapType,
stringType,
newStringDecoder(d.structName, d.fieldName), newStringDecoder(d.structName, d.fieldName),
interfaceMapType.Elem(),
newInterfaceDecoder(d.typ, d.structName, d.fieldName), newInterfaceDecoder(d.typ, d.structName, d.fieldName),
d.structName, d.fieldName, d.structName, d.fieldName,
) )

View File

@ -6,16 +6,20 @@ import (
type mapDecoder struct { type mapDecoder struct {
mapType *rtype mapType *rtype
keyType *rtype
valueType *rtype
keyDecoder decoder keyDecoder decoder
valueDecoder decoder valueDecoder decoder
structName string structName string
fieldName string fieldName string
} }
func newMapDecoder(mapType *rtype, keyDec decoder, valueDec decoder, structName, fieldName string) *mapDecoder { func newMapDecoder(mapType *rtype, keyType *rtype, keyDec decoder, valueType *rtype, valueDec decoder, structName, fieldName string) *mapDecoder {
return &mapDecoder{ return &mapDecoder{
mapType: mapType, mapType: mapType,
keyDecoder: keyDec, keyDecoder: keyDec,
keyType: keyType,
valueType: valueType,
valueDecoder: valueDec, valueDecoder: valueDec,
structName: structName, structName: structName,
fieldName: fieldName, fieldName: fieldName,
@ -39,16 +43,6 @@ func (d *mapDecoder) setValue(buf []byte, cursor int64, key interface{}) (int64,
return d.valueDecoder.decode(buf, cursor, header.ptr) return d.valueDecoder.decode(buf, cursor, header.ptr)
} }
func (d *mapDecoder) setKeyStream(s *stream, key interface{}) error {
header := (*interfaceHeader)(unsafe.Pointer(&key))
return d.keyDecoder.decodeStream(s, header.ptr)
}
func (d *mapDecoder) setValueStream(s *stream, key interface{}) error {
header := (*interfaceHeader)(unsafe.Pointer(&key))
return d.valueDecoder.decodeStream(s, header.ptr)
}
func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
s.skipWhiteSpace() s.skipWhiteSpace()
switch s.char() { switch s.char() {
@ -70,8 +64,8 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
} }
for { for {
s.cursor++ s.cursor++
var key interface{} k := unsafe_New(d.keyType)
if err := d.setKeyStream(s, &key); err != nil { if err := d.keyDecoder.decodeStream(s, k); err != nil {
return err return err
} }
s.skipWhiteSpace() s.skipWhiteSpace()
@ -82,11 +76,11 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
return errExpected("colon after object key", s.totalOffset()) return errExpected("colon after object key", s.totalOffset())
} }
s.cursor++ s.cursor++
var value interface{} v := unsafe_New(d.valueType)
if err := d.setValueStream(s, &value); err != nil { if err := d.valueDecoder.decodeStream(s, v); err != nil {
return err return err
} }
mapassign(d.mapType, mapValue, unsafe.Pointer(&key), unsafe.Pointer(&value)) mapassign(d.mapType, mapValue, k, v)
s.skipWhiteSpace() s.skipWhiteSpace()
if s.char() == nul { if s.char() == nul {
s.read() s.read()

View File

@ -20,6 +20,14 @@ func newPtrDecoder(dec decoder, typ *rtype, structName, fieldName string) *ptrDe
} }
} }
func (d *ptrDecoder) contentDecoder() decoder {
dec, ok := d.dec.(*ptrDecoder)
if !ok {
return d.dec
}
return dec.contentDecoder()
}
//go:linkname unsafe_New reflect.unsafe_New //go:linkname unsafe_New reflect.unsafe_New
func unsafe_New(*rtype) unsafe.Pointer func unsafe_New(*rtype) unsafe.Pointer

View File

@ -8,6 +8,7 @@ import (
type structFieldSet struct { type structFieldSet struct {
dec decoder dec decoder
offset uintptr offset uintptr
isTaggedKey bool
} }
type structDecoder struct { type structDecoder struct {