Fix logic of removing struct field for decoder

This commit is contained in:
Masaaki Goshima 2022-01-14 20:18:18 +09:00
parent d5a9e00a5e
commit 50b494bc5f
No known key found for this signature in database
GPG Key ID: 6A53785055537153
2 changed files with 116 additions and 106 deletions

View File

@ -3789,3 +3789,42 @@ func TestDecodeStructFieldMap(t *testing.T) {
t.Fatalf("failed to decode v.Bar = %+v", v.Bar) t.Fatalf("failed to decode v.Bar = %+v", v.Bar)
} }
} }
type issue303 struct {
Count int
Type string
Value interface{}
}
func (t *issue303) UnmarshalJSON(b []byte) error {
type tmpType issue303
wrapped := struct {
Value json.RawMessage
tmpType
}{}
if err := json.Unmarshal(b, &wrapped); err != nil {
return err
}
*t = issue303(wrapped.tmpType)
switch wrapped.Type {
case "string":
var str string
if err := json.Unmarshal(wrapped.Value, &str); err != nil {
return err
}
t.Value = str
}
return nil
}
func TestIssue303(t *testing.T) {
var v issue303
if err := json.Unmarshal([]byte(`{"Count":7,"Type":"string","Value":"hello"}`), &v); err != nil {
t.Fatal(err)
}
if v.Count != 7 || v.Type != "string" || v.Value != "hello" {
t.Fatalf("failed to decode. count = %d type = %s value = %v", v.Count, v.Type, v.Value)
}
}

View File

@ -307,64 +307,21 @@ func compileFunc(typ *runtime.Type, strutName, fieldName string) (Decoder, error
return newFuncDecoder(typ, strutName, fieldName), nil return newFuncDecoder(typ, strutName, fieldName), nil
} }
func removeConflictFields(fieldMap map[string]*structFieldSet, conflictedMap map[string]struct{}, dec *structDecoder, field reflect.StructField) { func typeToStructTags(typ *runtime.Type) runtime.StructTags {
for k, v := range dec.fieldMap { tags := runtime.StructTags{}
if _, exists := conflictedMap[k]; exists { fieldNum := typ.NumField()
// already conflicted key for i := 0; i < fieldNum; i++ {
field := typ.Field(i)
if runtime.IsIgnoredStructField(field) {
continue continue
} }
set, exists := fieldMap[k] tags = append(tags, runtime.StructTagFromField(field))
if !exists {
fieldSet := &structFieldSet{
dec: v.dec,
offset: field.Offset + v.offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
}
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
continue
}
if set.isTaggedKey {
if v.isTaggedKey {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
} else {
if v.isTaggedKey {
fieldSet := &structFieldSet{
dec: v.dec,
offset: field.Offset + v.offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
}
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
} else {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
}
} }
return tags
} }
func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) { func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
fieldNum := typ.NumField() fieldNum := typ.NumField()
conflictedMap := map[string]struct{}{}
fieldMap := map[string]*structFieldSet{} fieldMap := map[string]*structFieldSet{}
typeptr := uintptr(unsafe.Pointer(typ)) typeptr := uintptr(unsafe.Pointer(typ))
if dec, exists := structTypeToDecoder[typeptr]; exists { if dec, exists := structTypeToDecoder[typeptr]; exists {
@ -373,6 +330,8 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
structDec := newStructDecoder(structName, fieldName, fieldMap) structDec := newStructDecoder(structName, fieldName, fieldMap)
structTypeToDecoder[typeptr] = structDec structTypeToDecoder[typeptr] = structDec
structName = typ.Name() structName = typ.Name()
tags := typeToStructTags(typ)
allFields := []*structFieldSet{}
for i := 0; i < fieldNum; i++ { for i := 0; i < fieldNum; i++ {
field := typ.Field(i) field := typ.Field(i)
if runtime.IsIgnoredStructField(field) { if runtime.IsIgnoredStructField(field) {
@ -390,7 +349,19 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
// recursive definition // recursive definition
continue continue
} }
removeConflictFields(fieldMap, conflictedMap, stDec, field) for k, v := range stDec.fieldMap {
if tags.ExistsKey(k) {
continue
}
fieldSet := &structFieldSet{
dec: v.dec,
offset: field.Offset + v.offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
}
allFields = append(allFields, fieldSet)
}
} else if pdec, ok := dec.(*ptrDecoder); ok { } else if pdec, ok := dec.(*ptrDecoder); ok {
contentDec := pdec.contentDecoder() contentDec := pdec.contentDecoder()
if pdec.typ == typ { if pdec.typ == typ {
@ -406,12 +377,9 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
} }
if dec, ok := contentDec.(*structDecoder); ok { if dec, ok := contentDec.(*structDecoder); ok {
for k, v := range dec.fieldMap { for k, v := range dec.fieldMap {
if _, exists := conflictedMap[k]; exists { if tags.ExistsKey(k) {
// already conflicted key
continue continue
} }
set, exists := fieldMap[k]
if !exists {
fieldSet := &structFieldSet{ fieldSet := &structFieldSet{
dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec), dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec),
offset: field.Offset, offset: field.Offset,
@ -420,44 +388,7 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
keyLen: int64(len(k)), keyLen: int64(len(k)),
err: fieldSetErr, err: fieldSetErr,
} }
fieldMap[k] = fieldSet allFields = append(allFields, fieldSet)
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
continue
}
if set.isTaggedKey {
if v.isTaggedKey {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
} else {
if v.isTaggedKey {
fieldSet := &structFieldSet{
dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec),
offset: field.Offset,
isTaggedKey: v.isTaggedKey,
key: k,
keyLen: int64(len(k)),
err: fieldSetErr,
}
fieldMap[k] = fieldSet
lower := strings.ToLower(k)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
}
} else {
// conflict tag key
delete(fieldMap, k)
delete(fieldMap, strings.ToLower(k))
conflictedMap[k] = struct{}{}
conflictedMap[strings.ToLower(k)] = struct{}{}
}
}
} }
} }
} }
@ -478,18 +409,58 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
key: key, key: key,
keyLen: int64(len(key)), keyLen: int64(len(key)),
} }
fieldMap[key] = fieldSet allFields = append(allFields, fieldSet)
lower := strings.ToLower(key)
if _, exists := fieldMap[lower]; !exists {
fieldMap[lower] = fieldSet
} }
} }
for _, set := range filterDuplicatedFields(allFields) {
fieldMap[set.key] = set
lower := strings.ToLower(set.key)
if _, exists := fieldMap[lower]; !exists {
// first win
fieldMap[lower] = set
}
} }
delete(structTypeToDecoder, typeptr) delete(structTypeToDecoder, typeptr)
structDec.tryOptimize() structDec.tryOptimize()
return structDec, nil return structDec, nil
} }
func filterDuplicatedFields(allFields []*structFieldSet) []*structFieldSet {
fieldMap := map[string][]*structFieldSet{}
for _, field := range allFields {
fieldMap[field.key] = append(fieldMap[field.key], field)
}
duplicatedFieldMap := map[string]struct{}{}
for k, sets := range fieldMap {
sets = filterFieldSets(sets)
if len(sets) != 1 {
duplicatedFieldMap[k] = struct{}{}
}
}
filtered := make([]*structFieldSet, 0, len(allFields))
for _, field := range allFields {
if _, exists := duplicatedFieldMap[field.key]; exists {
continue
}
filtered = append(filtered, field)
}
return filtered
}
func filterFieldSets(sets []*structFieldSet) []*structFieldSet {
if len(sets) == 1 {
return sets
}
filtered := make([]*structFieldSet, 0, len(sets))
for _, set := range sets {
if set.isTaggedKey {
filtered = append(filtered, set)
}
}
return filtered
}
func implementsUnmarshalJSONType(typ *runtime.Type) bool { func implementsUnmarshalJSONType(typ *runtime.Type) bool {
return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType) return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType)
} }