diff --git a/decode_test.go b/decode_test.go index 9be2ba0..5df2c35 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3789,3 +3789,42 @@ func TestDecodeStructFieldMap(t *testing.T) { t.Fatalf("failed to decode v.Bar = %+v", v.Bar) } } + +type issue303 struct { + Count int + Type string + Value interface{} +} + +func (t *issue303) UnmarshalJSON(b []byte) error { + type tmpType issue303 + + wrapped := struct { + Value json.RawMessage + tmpType + }{} + if err := json.Unmarshal(b, &wrapped); err != nil { + return err + } + *t = issue303(wrapped.tmpType) + + switch wrapped.Type { + case "string": + var str string + if err := json.Unmarshal(wrapped.Value, &str); err != nil { + return err + } + t.Value = str + } + return nil +} + +func TestIssue303(t *testing.T) { + var v issue303 + if err := json.Unmarshal([]byte(`{"Count":7,"Type":"string","Value":"hello"}`), &v); err != nil { + t.Fatal(err) + } + if v.Count != 7 || v.Type != "string" || v.Value != "hello" { + t.Fatalf("failed to decode. count = %d type = %s value = %v", v.Count, v.Type, v.Value) + } +} diff --git a/internal/decoder/compile.go b/internal/decoder/compile.go index 9cafc06..b2e9f57 100644 --- a/internal/decoder/compile.go +++ b/internal/decoder/compile.go @@ -307,64 +307,21 @@ func compileFunc(typ *runtime.Type, strutName, fieldName string) (Decoder, error return newFuncDecoder(typ, strutName, fieldName), nil } -func removeConflictFields(fieldMap map[string]*structFieldSet, conflictedMap map[string]struct{}, dec *structDecoder, field reflect.StructField) { - for k, v := range dec.fieldMap { - if _, exists := conflictedMap[k]; exists { - // already conflicted key +func typeToStructTags(typ *runtime.Type) runtime.StructTags { + tags := runtime.StructTags{} + fieldNum := typ.NumField() + for i := 0; i < fieldNum; i++ { + field := typ.Field(i) + if runtime.IsIgnoredStructField(field) { continue } - set, exists := fieldMap[k] - if !exists { - fieldSet := &structFieldSet{ - dec: v.dec, - offset: field.Offset + v.offset, - isTaggedKey: v.isTaggedKey, - key: k, - keyLen: int64(len(k)), - } - fieldMap[k] = fieldSet - lower := strings.ToLower(k) - if _, exists := fieldMap[lower]; !exists { - fieldMap[lower] = fieldSet - } - continue - } - if set.isTaggedKey { - if v.isTaggedKey { - // conflict tag key - delete(fieldMap, k) - delete(fieldMap, strings.ToLower(k)) - conflictedMap[k] = struct{}{} - conflictedMap[strings.ToLower(k)] = struct{}{} - } - } else { - if v.isTaggedKey { - fieldSet := &structFieldSet{ - dec: v.dec, - offset: field.Offset + v.offset, - isTaggedKey: v.isTaggedKey, - key: k, - keyLen: int64(len(k)), - } - fieldMap[k] = fieldSet - lower := strings.ToLower(k) - if _, exists := fieldMap[lower]; !exists { - fieldMap[lower] = fieldSet - } - } else { - // conflict tag key - delete(fieldMap, k) - delete(fieldMap, strings.ToLower(k)) - conflictedMap[k] = struct{}{} - conflictedMap[strings.ToLower(k)] = struct{}{} - } - } + tags = append(tags, runtime.StructTagFromField(field)) } + return tags } func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) { fieldNum := typ.NumField() - conflictedMap := map[string]struct{}{} fieldMap := map[string]*structFieldSet{} typeptr := uintptr(unsafe.Pointer(typ)) if dec, exists := structTypeToDecoder[typeptr]; exists { @@ -373,6 +330,8 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo structDec := newStructDecoder(structName, fieldName, fieldMap) structTypeToDecoder[typeptr] = structDec structName = typ.Name() + tags := typeToStructTags(typ) + allFields := []*structFieldSet{} for i := 0; i < fieldNum; i++ { field := typ.Field(i) if runtime.IsIgnoredStructField(field) { @@ -390,7 +349,19 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo // recursive definition continue } - removeConflictFields(fieldMap, conflictedMap, stDec, field) + for k, v := range stDec.fieldMap { + if tags.ExistsKey(k) { + continue + } + fieldSet := &structFieldSet{ + dec: v.dec, + offset: field.Offset + v.offset, + isTaggedKey: v.isTaggedKey, + key: k, + keyLen: int64(len(k)), + } + allFields = append(allFields, fieldSet) + } } else if pdec, ok := dec.(*ptrDecoder); ok { contentDec := pdec.contentDecoder() if pdec.typ == typ { @@ -406,58 +377,18 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo } if dec, ok := contentDec.(*structDecoder); ok { for k, v := range dec.fieldMap { - if _, exists := conflictedMap[k]; exists { - // already conflicted key + if tags.ExistsKey(k) { continue } - set, exists := fieldMap[k] - if !exists { - fieldSet := &structFieldSet{ - dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec), - offset: field.Offset, - isTaggedKey: v.isTaggedKey, - key: k, - keyLen: int64(len(k)), - err: fieldSetErr, - } - fieldMap[k] = fieldSet - lower := strings.ToLower(k) - if _, exists := fieldMap[lower]; !exists { - fieldMap[lower] = fieldSet - } - continue - } - if set.isTaggedKey { - if v.isTaggedKey { - // conflict tag key - delete(fieldMap, k) - delete(fieldMap, strings.ToLower(k)) - conflictedMap[k] = struct{}{} - conflictedMap[strings.ToLower(k)] = struct{}{} - } - } else { - if v.isTaggedKey { - fieldSet := &structFieldSet{ - dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec), - offset: field.Offset, - isTaggedKey: v.isTaggedKey, - key: k, - keyLen: int64(len(k)), - err: fieldSetErr, - } - fieldMap[k] = fieldSet - lower := strings.ToLower(k) - if _, exists := fieldMap[lower]; !exists { - fieldMap[lower] = fieldSet - } - } else { - // conflict tag key - delete(fieldMap, k) - delete(fieldMap, strings.ToLower(k)) - conflictedMap[k] = struct{}{} - conflictedMap[strings.ToLower(k)] = struct{}{} - } + fieldSet := &structFieldSet{ + dec: newAnonymousFieldDecoder(pdec.typ, v.offset, v.dec), + offset: field.Offset, + isTaggedKey: v.isTaggedKey, + key: k, + keyLen: int64(len(k)), + err: fieldSetErr, } + allFields = append(allFields, fieldSet) } } } @@ -478,11 +409,15 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo key: key, keyLen: int64(len(key)), } - fieldMap[key] = fieldSet - lower := strings.ToLower(key) - if _, exists := fieldMap[lower]; !exists { - fieldMap[lower] = fieldSet - } + allFields = append(allFields, fieldSet) + } + } + for _, set := range filterDuplicatedFields(allFields) { + fieldMap[set.key] = set + lower := strings.ToLower(set.key) + if _, exists := fieldMap[lower]; !exists { + // first win + fieldMap[lower] = set } } delete(structTypeToDecoder, typeptr) @@ -490,6 +425,42 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo return structDec, nil } +func filterDuplicatedFields(allFields []*structFieldSet) []*structFieldSet { + fieldMap := map[string][]*structFieldSet{} + for _, field := range allFields { + fieldMap[field.key] = append(fieldMap[field.key], field) + } + duplicatedFieldMap := map[string]struct{}{} + for k, sets := range fieldMap { + sets = filterFieldSets(sets) + if len(sets) != 1 { + duplicatedFieldMap[k] = struct{}{} + } + } + + filtered := make([]*structFieldSet, 0, len(allFields)) + for _, field := range allFields { + if _, exists := duplicatedFieldMap[field.key]; exists { + continue + } + filtered = append(filtered, field) + } + return filtered +} + +func filterFieldSets(sets []*structFieldSet) []*structFieldSet { + if len(sets) == 1 { + return sets + } + filtered := make([]*structFieldSet, 0, len(sets)) + for _, set := range sets { + if set.isTaggedKey { + filtered = append(filtered, set) + } + } + return filtered +} + func implementsUnmarshalJSONType(typ *runtime.Type) bool { return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType) }