Fix embedded field conflict behavior

This commit is contained in:
Masaaki Goshima 2021-11-18 19:51:29 +09:00
parent a89c9e30df
commit 86a671f3bb
No known key found for this signature in database
GPG Key ID: 6A53785055537153
3 changed files with 119 additions and 78 deletions

View File

@ -2167,4 +2167,78 @@ func TestIssue290(t *testing.T) {
if !bytes.Equal(expected, got) {
t.Fatalf("failed to encode non empty interface. expected = %q but got %q", expected, got)
}
}
}
func TestIssue299(t *testing.T) {
t.Run("conflict second field", func(t *testing.T) {
type Embedded struct {
ID string `json:"id"`
Name map[string]string `json:"name"`
}
type Container struct {
Embedded
Name string `json:"name"`
}
c := &Container{
Embedded: Embedded{
ID: "1",
Name: map[string]string{"en": "Hello", "es": "Hola"},
},
Name: "Hi",
}
expected, _ := stdjson.Marshal(c)
got, err := json.Marshal(c)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected, got) {
t.Fatalf("expected %q but got %q", expected, got)
}
})
t.Run("conflict map field", func(t *testing.T) {
type Embedded struct {
Name map[string]string `json:"name"`
}
type Container struct {
Embedded
Name string `json:"name"`
}
c := &Container{
Embedded: Embedded{
Name: map[string]string{"en": "Hello", "es": "Hola"},
},
Name: "Hi",
}
expected, _ := stdjson.Marshal(c)
got, err := json.Marshal(c)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected, got) {
t.Fatalf("expected %q but got %q", expected, got)
}
})
t.Run("conflict slice field", func(t *testing.T) {
type Embedded struct {
Name []string `json:"name"`
}
type Container struct {
Embedded
Name string `json:"name"`
}
c := &Container{
Embedded: Embedded{
Name: []string{"Hello"},
},
Name: "Hi",
}
expected, _ := stdjson.Marshal(c)
got, err := json.Marshal(c)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(expected, got) {
t.Fatalf("expected %q but got %q", expected, got)
}
})
}

View File

@ -1102,61 +1102,6 @@ func structField(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag
return fieldCode
}
func isNotExistsField(head *Opcode) bool {
if head == nil {
return false
}
if head.Op != OpStructHead {
return false
}
if (head.Flags & AnonymousHeadFlags) == 0 {
return false
}
if head.Next == nil {
return false
}
if head.NextField == nil {
return false
}
if head.NextField.Op != OpStructAnonymousEnd {
return false
}
if head.Next.Op == OpStructAnonymousEnd {
return true
}
if head.Next.Op.CodeType() != CodeStructField {
return false
}
return isNotExistsField(head.Next)
}
func optimizeAnonymousFields(head *Opcode) {
code := head
var prev *Opcode
removedFields := map[*Opcode]struct{}{}
for {
if code.Op == OpStructEnd {
break
}
if code.Op == OpStructField {
codeType := code.Next.Op.CodeType()
if codeType == CodeStructField {
if isNotExistsField(code.Next) {
code.Next = code.NextField
diff := code.Next.DisplayIdx - code.DisplayIdx
for i := uint32(0); i < diff; i++ {
code.Next.decOpcodeIndex()
}
linkPrevToNextField(code, removedFields)
code = prev
}
}
}
prev = code
code = code.NextField
}
}
type structFieldPair struct {
prevField *Opcode
curField *Opcode
@ -1164,35 +1109,43 @@ type structFieldPair struct {
linked bool
}
func anonymousStructFieldPairMap(tags runtime.StructTags, named string, valueCode *Opcode) map[string][]structFieldPair {
func filterAnonymousStructFieldsByTags(value *Opcode, tags runtime.StructTags) *Opcode {
head := value
curField := head
removedFields := map[*Opcode]struct{}{}
for curField != nil {
existsKey := tags.ExistsKey(curField.DisplayKey)
if !existsKey || curField.Next.IsRecursiveOp() {
curField = curField.NextField
continue
}
diff := curField.NextField.DisplayIdx - curField.DisplayIdx
for i := uint32(0); i < diff; i++ {
curField.NextField.decOpcodeIndex()
}
if curField.IsStructHeadOp() || head == curField {
head = curField.NextField
} else {
linkPrevToNextField(curField, removedFields)
}
curField = curField.NextField
}
return head
}
func anonymousStructFieldPairMap(named string, valueCode *Opcode) map[string][]structFieldPair {
anonymousFields := map[string][]structFieldPair{}
f := valueCode
var prevAnonymousField *Opcode
removedFields := map[*Opcode]struct{}{}
for {
existsKey := tags.ExistsKey(f.DisplayKey)
isHeadOp := strings.Contains(f.Op.String(), "Head")
if existsKey && f.Next != nil && strings.Contains(f.Next.Op.String(), "Recursive") {
// through
} else if isHeadOp && (f.Flags&AnonymousHeadFlags) == 0 {
if existsKey {
// TODO: need to remove this head
f.Op = OpStructHead
f.Flags |= AnonymousKeyFlags
f.Flags |= AnonymousHeadFlags
} else if named == "" {
if isHeadOp && (f.Flags&AnonymousHeadFlags) == 0 {
if named == "" {
f.Flags |= AnonymousHeadFlags
}
} else if named == "" && f.Op == OpStructEnd {
f.Op = OpStructAnonymousEnd
} else if existsKey {
diff := f.NextField.DisplayIdx - f.DisplayIdx
for i := uint32(0); i < diff; i++ {
f.NextField.decOpcodeIndex()
}
linkPrevToNextField(f, removedFields)
}
if f.DisplayKey == "" {
if f.NextField == nil {
break
@ -1422,7 +1375,8 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
if tag.IsTaggedKey {
tagKey = tag.Key
}
for k, v := range anonymousStructFieldPairMap(tags, tagKey, valueCode) {
valueCode = filterAnonymousStructFieldsByTags(valueCode, tags)
for k, v := range anonymousStructFieldPairMap(tagKey, valueCode) {
anonymousFields[k] = append(anonymousFields[k], v...)
}
@ -1540,7 +1494,6 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) {
head.End = structEndCode
code.Next = structEndCode
optimizeConflictAnonymousFields(anonymousFields)
optimizeAnonymousFields(head)
ret := (*Opcode)(unsafe.Pointer(head))
compiled.Code = ret

View File

@ -50,6 +50,20 @@ type Opcode struct {
DisplayKey string // key text to display
}
func (c *Opcode) IsStructHeadOp() bool {
if c == nil {
return false
}
return strings.Contains(c.Op.String(), "Head")
}
func (c *Opcode) IsRecursiveOp() bool {
if c == nil {
return false
}
return strings.Contains(c.Op.String(), "Recursive")
}
func (c *Opcode) MaxIdx() uint32 {
max := uint32(0)
for _, value := range []uint32{
@ -621,7 +635,7 @@ func linkPrevToNextField(cur *Opcode, removedFields map[*Opcode]struct{}) {
nextCode = code.Next
}
if nextCode == fcode {
code.Next = fcode.Next
code.Next = fcode.NextField
break
} else if nextCode.Op == OpEnd {
break