From 86a671f3bbbabb5c6a2c96822a5f38cd2947c40c Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 18 Nov 2021 19:51:29 +0900 Subject: [PATCH] Fix embedded field conflict behavior --- encode_test.go | 76 ++++++++++++++++++++++++- internal/encoder/compiler.go | 105 ++++++++++------------------------- internal/encoder/opcode.go | 16 +++++- 3 files changed, 119 insertions(+), 78 deletions(-) diff --git a/encode_test.go b/encode_test.go index 90cb67b..8bb3598 100644 --- a/encode_test.go +++ b/encode_test.go @@ -2167,4 +2167,78 @@ func TestIssue290(t *testing.T) { if !bytes.Equal(expected, got) { t.Fatalf("failed to encode non empty interface. expected = %q but got %q", expected, got) } -} \ No newline at end of file +} + +func TestIssue299(t *testing.T) { + t.Run("conflict second field", func(t *testing.T) { + type Embedded struct { + ID string `json:"id"` + Name map[string]string `json:"name"` + } + type Container struct { + Embedded + Name string `json:"name"` + } + c := &Container{ + Embedded: Embedded{ + ID: "1", + Name: map[string]string{"en": "Hello", "es": "Hola"}, + }, + Name: "Hi", + } + expected, _ := stdjson.Marshal(c) + got, err := json.Marshal(c) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(expected, got) { + t.Fatalf("expected %q but got %q", expected, got) + } + }) + t.Run("conflict map field", func(t *testing.T) { + type Embedded struct { + Name map[string]string `json:"name"` + } + type Container struct { + Embedded + Name string `json:"name"` + } + c := &Container{ + Embedded: Embedded{ + Name: map[string]string{"en": "Hello", "es": "Hola"}, + }, + Name: "Hi", + } + expected, _ := stdjson.Marshal(c) + got, err := json.Marshal(c) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(expected, got) { + t.Fatalf("expected %q but got %q", expected, got) + } + }) + t.Run("conflict slice field", func(t *testing.T) { + type Embedded struct { + Name []string `json:"name"` + } + type Container struct { + Embedded + Name string `json:"name"` + } + c := &Container{ + Embedded: Embedded{ + Name: []string{"Hello"}, + }, + Name: "Hi", + } + expected, _ := stdjson.Marshal(c) + got, err := json.Marshal(c) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(expected, got) { + t.Fatalf("expected %q but got %q", expected, got) + } + }) +} diff --git a/internal/encoder/compiler.go b/internal/encoder/compiler.go index c627ed3..9fcec21 100644 --- a/internal/encoder/compiler.go +++ b/internal/encoder/compiler.go @@ -1102,61 +1102,6 @@ func structField(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag return fieldCode } -func isNotExistsField(head *Opcode) bool { - if head == nil { - return false - } - if head.Op != OpStructHead { - return false - } - if (head.Flags & AnonymousHeadFlags) == 0 { - return false - } - if head.Next == nil { - return false - } - if head.NextField == nil { - return false - } - if head.NextField.Op != OpStructAnonymousEnd { - return false - } - if head.Next.Op == OpStructAnonymousEnd { - return true - } - if head.Next.Op.CodeType() != CodeStructField { - return false - } - return isNotExistsField(head.Next) -} - -func optimizeAnonymousFields(head *Opcode) { - code := head - var prev *Opcode - removedFields := map[*Opcode]struct{}{} - for { - if code.Op == OpStructEnd { - break - } - if code.Op == OpStructField { - codeType := code.Next.Op.CodeType() - if codeType == CodeStructField { - if isNotExistsField(code.Next) { - code.Next = code.NextField - diff := code.Next.DisplayIdx - code.DisplayIdx - for i := uint32(0); i < diff; i++ { - code.Next.decOpcodeIndex() - } - linkPrevToNextField(code, removedFields) - code = prev - } - } - } - prev = code - code = code.NextField - } -} - type structFieldPair struct { prevField *Opcode curField *Opcode @@ -1164,35 +1109,43 @@ type structFieldPair struct { linked bool } -func anonymousStructFieldPairMap(tags runtime.StructTags, named string, valueCode *Opcode) map[string][]structFieldPair { +func filterAnonymousStructFieldsByTags(value *Opcode, tags runtime.StructTags) *Opcode { + head := value + curField := head + removedFields := map[*Opcode]struct{}{} + for curField != nil { + existsKey := tags.ExistsKey(curField.DisplayKey) + if !existsKey || curField.Next.IsRecursiveOp() { + curField = curField.NextField + continue + } + diff := curField.NextField.DisplayIdx - curField.DisplayIdx + for i := uint32(0); i < diff; i++ { + curField.NextField.decOpcodeIndex() + } + if curField.IsStructHeadOp() || head == curField { + head = curField.NextField + } else { + linkPrevToNextField(curField, removedFields) + } + curField = curField.NextField + } + return head +} + +func anonymousStructFieldPairMap(named string, valueCode *Opcode) map[string][]structFieldPair { anonymousFields := map[string][]structFieldPair{} f := valueCode var prevAnonymousField *Opcode - removedFields := map[*Opcode]struct{}{} for { - existsKey := tags.ExistsKey(f.DisplayKey) isHeadOp := strings.Contains(f.Op.String(), "Head") - if existsKey && f.Next != nil && strings.Contains(f.Next.Op.String(), "Recursive") { - // through - } else if isHeadOp && (f.Flags&AnonymousHeadFlags) == 0 { - if existsKey { - // TODO: need to remove this head - f.Op = OpStructHead - f.Flags |= AnonymousKeyFlags - f.Flags |= AnonymousHeadFlags - } else if named == "" { + if isHeadOp && (f.Flags&AnonymousHeadFlags) == 0 { + if named == "" { f.Flags |= AnonymousHeadFlags } } else if named == "" && f.Op == OpStructEnd { f.Op = OpStructAnonymousEnd - } else if existsKey { - diff := f.NextField.DisplayIdx - f.DisplayIdx - for i := uint32(0); i < diff; i++ { - f.NextField.decOpcodeIndex() - } - linkPrevToNextField(f, removedFields) } - if f.DisplayKey == "" { if f.NextField == nil { break @@ -1422,7 +1375,8 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { if tag.IsTaggedKey { tagKey = tag.Key } - for k, v := range anonymousStructFieldPairMap(tags, tagKey, valueCode) { + valueCode = filterAnonymousStructFieldsByTags(valueCode, tags) + for k, v := range anonymousStructFieldPairMap(tagKey, valueCode) { anonymousFields[k] = append(anonymousFields[k], v...) } @@ -1540,7 +1494,6 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { head.End = structEndCode code.Next = structEndCode optimizeConflictAnonymousFields(anonymousFields) - optimizeAnonymousFields(head) ret := (*Opcode)(unsafe.Pointer(head)) compiled.Code = ret diff --git a/internal/encoder/opcode.go b/internal/encoder/opcode.go index 7c50eef..446baa2 100644 --- a/internal/encoder/opcode.go +++ b/internal/encoder/opcode.go @@ -50,6 +50,20 @@ type Opcode struct { DisplayKey string // key text to display } +func (c *Opcode) IsStructHeadOp() bool { + if c == nil { + return false + } + return strings.Contains(c.Op.String(), "Head") +} + +func (c *Opcode) IsRecursiveOp() bool { + if c == nil { + return false + } + return strings.Contains(c.Op.String(), "Recursive") +} + func (c *Opcode) MaxIdx() uint32 { max := uint32(0) for _, value := range []uint32{ @@ -621,7 +635,7 @@ func linkPrevToNextField(cur *Opcode, removedFields map[*Opcode]struct{}) { nextCode = code.Next } if nextCode == fcode { - code.Next = fcode.Next + code.Next = fcode.NextField break } else if nextCode.Op == OpEnd { break