diff --git a/internal/encoder/code.go b/internal/encoder/code.go index 209e94e..7986345 100644 --- a/internal/encoder/code.go +++ b/internal/encoder/code.go @@ -168,6 +168,27 @@ func (c *StructCode) ToOpcode() []*Opcode { return []*Opcode{} } +func (c *StructCode) removeFieldsByTags(tags runtime.StructTags) { + fields := make([]*StructFieldCode, 0, len(c.fields)) + for _, field := range c.fields { + if field.isAnonymous { + structCode := field.getAnonymousStruct() + if structCode != nil { + structCode.removeFieldsByTags(tags) + if len(structCode.fields) > 0 { + fields = append(fields, field) + } + continue + } + } + if tags.ExistsKey(field.key) { + continue + } + fields = append(fields, field) + } + c.fields = fields +} + type StructFieldCode struct { typ *runtime.Type key string @@ -181,7 +202,14 @@ type StructFieldCode struct { isNextOpPtrType bool } -func (c *StructFieldCode) toInlineCode() []*StructFieldCode { +func (c *StructFieldCode) getAnonymousStruct() *StructCode { + if !c.isAnonymous { + return nil + } + code, ok := c.value.(*StructCode) + if ok { + return code + } return nil } @@ -486,50 +514,87 @@ func compileStruct2(ctx *compileContext, isPtr bool) (*StructCode, error) { return nil, err } if field.isAnonymous { - structCode, ok := field.value.(*StructCode) - if ok { - for _, field := range structCode.fields { - if tags.ExistsKey(field.key) { - continue - } - fields = append(fields, field) - } - } else { - fields = append(fields, field) + structCode := field.getAnonymousStruct() + if structCode != nil { + structCode.removeFieldsByTags(tags) } - } else { - fields = append(fields, field) } + fields = append(fields, field) } + fieldMap := getFieldMap(fields) + duplicatedFieldMap := getDuplicatedFieldMap(fieldMap) + code.fields = filteredDuplicatedFields(fields, duplicatedFieldMap) + return code, nil +} + +func getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode { fieldMap := map[string][]*StructFieldCode{} for _, field := range fields { + if field.isAnonymous { + structCode := field.getAnonymousStruct() + if structCode != nil { + for k, v := range getFieldMap(structCode.fields) { + fieldMap[k] = append(fieldMap[k], v...) + } + continue + } + } fieldMap[field.key] = append(fieldMap[field.key], field) } - removeFieldKey := map[string]struct{}{} + return fieldMap +} + +func getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} { + duplicatedFieldMap := map[*StructFieldCode]struct{}{} for _, fields := range fieldMap { if len(fields) == 1 { continue } - var foundTaggedKey bool - for _, field := range fields { - if field.isTaggedKey { - if foundTaggedKey { - removeFieldKey[field.key] = struct{}{} - break + if isTaggedKeyOnly(fields) { + for _, field := range fields { + if field.isTaggedKey { + continue } - foundTaggedKey = true + duplicatedFieldMap[field] = struct{}{} + } + } else { + for _, field := range fields { + duplicatedFieldMap[field] = struct{}{} } } } + return duplicatedFieldMap +} + +func filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode { filteredFields := make([]*StructFieldCode, 0, len(fields)) for _, field := range fields { - if _, exists := removeFieldKey[field.key]; exists { + if field.isAnonymous { + structCode := field.getAnonymousStruct() + if structCode != nil { + structCode.fields = filteredDuplicatedFields(structCode.fields, duplicatedFieldMap) + if len(structCode.fields) > 0 { + filteredFields = append(filteredFields, field) + } + continue + } + } + if _, exists := duplicatedFieldMap[field]; exists { continue } filteredFields = append(filteredFields, field) } - code.fields = filteredFields - return code, nil + return filteredFields +} + +func isTaggedKeyOnly(fields []*StructFieldCode) bool { + var taggedKeyFieldCount int + for _, field := range fields { + if field.isTaggedKey { + taggedKeyFieldCount++ + } + } + return taggedKeyFieldCount == 1 } func typeToStructTags(typ *runtime.Type) runtime.StructTags { diff --git a/internal/encoder/compiler.go b/internal/encoder/compiler.go index fda3672..91c6181 100644 --- a/internal/encoder/compiler.go +++ b/internal/encoder/compiler.go @@ -1316,7 +1316,6 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { fieldPtrIndex := ctx.ptrIndex ctx.incIndex() - fmt.Println("fieldOpcodeIndex = ", fieldOpcodeIndex) nilcheck := true addrForMarshaler := false isIndirectSpecialCase := isPtr && i == 0 && fieldNum == 1