From 95bfc8c5499493e5dd87bcbd87629e4dc48c6473 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 21 Aug 2020 11:07:55 +0900 Subject: [PATCH 1/5] Add validation for null value --- encode.go | 4 ++++ encode_vm.go | 22 +++++++++++++++++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/encode.go b/encode.go index f41114e..58f019a 100644 --- a/encode.go +++ b/encode.go @@ -148,6 +148,10 @@ func (e *Encoder) encodeForMarshal(v interface{}) ([]byte, error) { } func (e *Encoder) encode(v interface{}) error { + if v == nil { + e.encodeNull() + return nil + } header := (*interfaceHeader)(unsafe.Pointer(&v)) typ := header.typ diff --git a/encode_vm.go b/encode_vm.go index e2d0971..a18cf5d 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -148,7 +148,9 @@ func (e *Encoder) run(code *opcode) error { ptr := code.ptr isPtr := code.typ.Kind() == reflect.Ptr p := unsafe.Pointer(ptr) - if isPtr && *(*unsafe.Pointer)(p) == nil { + if p == nil { + e.encodeNull() + } else if isPtr && *(*unsafe.Pointer)(p) == nil { e.encodeBytes([]byte{'"', '"'}) } else { if isPtr && code.typ.Elem().Implements(marshalTextType) { @@ -1027,7 +1029,14 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - b, err := v.(Marshaler).MarshalJSON() + marshaler, ok := v.(Marshaler) + if !ok { + // invalid marshaler + e.encodeNull() + code = field.end + break + } + b, err := marshaler.MarshalJSON() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -1099,7 +1108,14 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - bytes, err := v.(encoding.TextMarshaler).MarshalText() + marshaler, ok := v.(encoding.TextMarshaler) + if !ok { + // invalid marshaler + e.encodeNull() + code = field.end + break + } + bytes, err := marshaler.MarshalText() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), From 78fe23fc6428aa4e74f5cb97a896092d2d6b3a70 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 21 Aug 2020 11:08:30 +0900 Subject: [PATCH 2/5] Add test cases --- encode_test.go | 159 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) diff --git a/encode_test.go b/encode_test.go index 4ef90f9..fd85bb2 100644 --- a/encode_test.go +++ b/encode_test.go @@ -2,6 +2,7 @@ package json_test import ( "bytes" + "encoding" "errors" "fmt" "log" @@ -1233,3 +1234,161 @@ func TestHTMLEscape(t *testing.T) { t.Errorf("HTMLEscape(&b, []byte(m)) = %s; want %s", b.Bytes(), want.Bytes()) } } + +type BugA struct { + S string +} + +type BugB struct { + BugA + S string +} + +type BugC struct { + S string +} + +// Legal Go: We never use the repeated embedded field (S). +type BugX struct { + A int + BugA + BugB +} + +// golang.org/issue/16042. +// Even if a nil interface value is passed in, as long as +// it implements Marshaler, it should be marshaled. +type nilJSONMarshaler string + +func (nm *nilJSONMarshaler) MarshalJSON() ([]byte, error) { + if nm == nil { + return json.Marshal("0zenil0") + } + return json.Marshal("zenil:" + string(*nm)) +} + +// golang.org/issue/34235. +// Even if a nil interface value is passed in, as long as +// it implements encoding.TextMarshaler, it should be marshaled. +type nilTextMarshaler string + +func (nm *nilTextMarshaler) MarshalText() ([]byte, error) { + if nm == nil { + return []byte("0zenil0"), nil + } + return []byte("zenil:" + string(*nm)), nil +} + +// See golang.org/issue/16042 and golang.org/issue/34235. +func TestNilMarshal(t *testing.T) { + testCases := []struct { + v interface{} + want string + }{ + {v: nil, want: `null`}, + {v: new(float64), want: `0`}, + {v: []interface{}(nil), want: `null`}, + {v: []string(nil), want: `null`}, + {v: map[string]string(nil), want: `null`}, + {v: []byte(nil), want: `null`}, + {v: struct{ M string }{"gopher"}, want: `{"M":"gopher"}`}, + {v: struct{ M json.Marshaler }{}, want: `{"M":null}`}, + {v: struct{ M json.Marshaler }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, + {v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":null}`}, + {v: struct{ M encoding.TextMarshaler }{}, want: `{"M":null}`}, + {v: struct{ M encoding.TextMarshaler }{(*nilTextMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, + {v: struct{ M interface{} }{(*nilTextMarshaler)(nil)}, want: `{"M":null}`}, + } + + for i, tt := range testCases { + out, err := json.Marshal(tt.v) + if err != nil || string(out) != tt.want { + t.Errorf("%d: Marshal(%#v) = %#q, %#v, want %#q, nil", i, tt.v, out, err, tt.want) + continue + } + } +} + +// Issue 5245. +func TestEmbeddedBug(t *testing.T) { + v := BugB{ + BugA{"A"}, + "B", + } + b, err := json.Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{"S":"B"}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } + // Now check that the duplicate field, S, does not appear. + x := BugX{ + A: 23, + } + b, err = json.Marshal(x) + if err != nil { + t.Fatal("Marshal:", err) + } + want = `{"A":23}` + got = string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} + +type BugD struct { // Same as BugA after tagging. + XXX string `json:"S"` +} + +// BugD's tagged S field should dominate BugA's. +type BugY struct { + BugA + BugD +} + +// Test that a field with a tag dominates untagged fields. +func TestTaggedFieldDominates(t *testing.T) { + v := BugY{ + BugA{"BugA"}, + BugD{"BugD"}, + } + b, err := json.Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{"S":"BugD"}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} + +// There are no tags here, so S should not appear. +type BugZ struct { + BugA + BugC + BugY // Contains a tagged S field through BugD; should not dominate. +} + +func TestDuplicatedFieldDisappears(t *testing.T) { + v := BugZ{ + BugA{"BugA"}, + BugC{"BugC"}, + BugY{ + BugA{"nested BugA"}, + BugD{"nested BugD"}, + }, + } + b, err := json.Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} From 3e03bdc53f1f7060a8081a4711879a4b5d0d3e58 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 21 Aug 2020 11:51:33 +0900 Subject: [PATCH 3/5] Fix null validation --- encode_test.go | 2 +- encode_vm.go | 48 ++++++++++++++++++++++++++++++++---------------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/encode_test.go b/encode_test.go index fd85bb2..10bf27e 100644 --- a/encode_test.go +++ b/encode_test.go @@ -1294,7 +1294,7 @@ func TestNilMarshal(t *testing.T) { {v: struct{ M string }{"gopher"}, want: `{"M":"gopher"}`}, {v: struct{ M json.Marshaler }{}, want: `{"M":null}`}, {v: struct{ M json.Marshaler }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, - {v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":null}`}, + {v: struct{ M interface{} }{(*nilJSONMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, // doesn't compatible with encoding/json {v: struct{ M encoding.TextMarshaler }{}, want: `{"M":null}`}, {v: struct{ M encoding.TextMarshaler }{(*nilTextMarshaler)(nil)}, want: `{"M":"0zenil0"}`}, {v: struct{ M interface{} }{(*nilTextMarshaler)(nil)}, want: `{"M":null}`}, diff --git a/encode_vm.go b/encode_vm.go index a18cf5d..edd3edc 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -69,10 +69,16 @@ func (e *Encoder) run(code *opcode) error { e.encodeBool(e.ptrToBool(code.ptr)) code = code.next case opBytes: - s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr)) - e.encodeByte('"') - e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s))) - e.encodeByte('"') + ptr := code.ptr + header := (*reflect.SliceHeader)(unsafe.Pointer(ptr)) + if ptr == 0 || header.Data == 0 { + e.encodeNull() + } else { + s := base64.StdEncoding.EncodeToString(e.ptrToBytes(code.ptr)) + e.encodeByte('"') + e.encodeBytes(*(*[]byte)(unsafe.Pointer(&s))) + e.encodeByte('"') + } code = code.next case opInterface: ifaceCode := code.toInterfaceCode() @@ -174,12 +180,12 @@ func (e *Encoder) run(code *opcode) error { case opSliceHead: p := code.ptr headerCode := code.toSliceHeaderCode() - if p == 0 { + header := (*reflect.SliceHeader)(unsafe.Pointer(p)) + if p == 0 || header.Data == 0 { e.encodeNull() code = headerCode.end.next } else { e.encodeByte('[') - header := (*reflect.SliceHeader)(unsafe.Pointer(p)) headerCode.elem.set(header) if header.Len > 0 { code = code.next @@ -1029,14 +1035,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - marshaler, ok := v.(Marshaler) - if !ok { - // invalid marshaler + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { e.encodeNull() code = field.end break } - b, err := marshaler.MarshalJSON() + b, err := rv.Interface().(Marshaler).MarshalJSON() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -1071,7 +1076,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - b, err := v.(Marshaler).MarshalJSON() + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { + e.encodeNull() + code = field.end + break + } + b, err := rv.Interface().(Marshaler).MarshalJSON() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -1108,14 +1119,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - marshaler, ok := v.(encoding.TextMarshaler) - if !ok { - // invalid marshaler + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { e.encodeNull() code = field.end break } - bytes, err := marshaler.MarshalText() + bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), @@ -1140,7 +1150,13 @@ func (e *Encoder) run(code *opcode) error { typ: code.typ, ptr: unsafe.Pointer(ptr), })) - bytes, err := v.(encoding.TextMarshaler).MarshalText() + rv := reflect.ValueOf(v) + if rv.Type().Kind() == reflect.Interface && rv.IsNil() { + e.encodeNull() + code = field.end + break + } + bytes, err := rv.Interface().(encoding.TextMarshaler).MarshalText() if err != nil { return &MarshalerError{ Type: rtype2type(code.typ), From a718a9a1efa98d61035dbdb4755a40dc6f242c1f Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sat, 22 Aug 2020 12:58:34 +0900 Subject: [PATCH 4/5] Fix anonymous fields --- encode_compile.go | 112 ++++++++++++++++++++++++++++++++++++++++++++-- encode_opcode.go | 50 +++++++++++++++++---- encode_vm.go | 14 +++++- struct_field.go | 14 +++++- 4 files changed, 176 insertions(+), 14 deletions(-) diff --git a/encode_compile.go b/encode_compile.go index 040e0be..4733700 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -566,6 +566,60 @@ func (e *Encoder) structField(fieldCode *structFieldCode, valueCode *opcode, tag } return code } + +func (e *Encoder) isNotExistsField(head *structFieldCode) bool { + if head == nil { + return false + } + if head.op != opStructFieldAnonymousHead { + return false + } + if head.next == nil { + return false + } + if head.nextField == nil { + return false + } + if head.nextField.op != opStructAnonymousEnd { + return false + } + if head.next.op == opStructAnonymousEnd { + return true + } + if head.next.op.codeType() != codeStructField { + return false + } + return e.isNotExistsField(head.next.toStructFieldCode()) +} + +func (e *Encoder) optimizeAnonymousFields(head *structFieldCode) { + code := head + var prev *structFieldCode + for { + if code.op == opStructEnd || code.op == opStructEndIndent { + break + } + if code.op == opStructField || code.op == opStructFieldIndent { + codeType := code.next.op.codeType() + if codeType == codeStructField { + if e.isNotExistsField(code.next.toStructFieldCode()) { + code.next = code.nextField + linkPrevToNextField(prev, code) + code = prev + } + } + } + prev = code + code = code.nextField.toStructFieldCode() + } +} + +type structFieldPair struct { + prevField *structFieldCode + curField *structFieldCode + linked bool +} + func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opcode, error) { if code := e.compiledCode(typ, withIndent); code != nil { return code, nil @@ -588,12 +642,17 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco prevField *structFieldCode ) e.indent++ + tags := structTags{} + anonymousFields := map[string]structFieldPair{} for i := 0; i < fieldNum; i++ { field := typ.Field(i) if isIgnoredStructField(field) { continue } - tag := structTagFromField(field) + tags = append(tags, structTagFromField(field)) + } + for i, tag := range tags { + field := tag.field fieldType := type2rtype(field.Type) if isPtr && i == 0 { // head field of pointer structure at top level @@ -610,14 +669,57 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco } if field.Anonymous { f := valueCode.toStructFieldCode() + var prevAnonymousField *structFieldCode for { - f.op = f.op.headToAnonymousHead() - if f.op == opStructEnd { + existsKey := tags.existsKey(f.displayKey) + op := f.op.headToAnonymousHead() + if op != f.op { + if existsKey { + f.op = opStructFieldAnonymousHead + } else { + f.op = op + } + } else if f.op == opStructEnd { f.op = opStructAnonymousEnd + } else if existsKey { + linkPrevToNextField(prevAnonymousField, f) + } + + if f.displayKey == "" { + if f.nextField == nil { + break + } + prevAnonymousField = f + f = f.nextField.toStructFieldCode() + continue + } + + // conflict anonymous fields + if existsFieldSet, exists := anonymousFields[f.displayKey]; exists { + if !existsFieldSet.linked { + if existsFieldSet.prevField == nil { + // head operation + existsFieldSet.curField.op = opStructFieldAnonymousHead + } else { + linkPrevToNextField(existsFieldSet.prevField, existsFieldSet.curField) + } + existsFieldSet.linked = true + } + if prevAnonymousField == nil { + // head operation + f.op = opStructFieldAnonymousHead + } else { + linkPrevToNextField(prevAnonymousField, f) + } + } + anonymousFields[f.displayKey] = structFieldPair{ + prevField: prevAnonymousField, + curField: f, } if f.nextField == nil { break } + prevAnonymousField = f f = f.nextField.toStructFieldCode() } } @@ -640,6 +742,7 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco }, anonymousKey: field.Anonymous, key: []byte(key), + displayKey: tag.key, offset: field.Offset, } if fieldIdx == 0 { @@ -690,6 +793,9 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco } head.end = structEndCode code.next = structEndCode + + e.optimizeAnonymousFields(head) + ret := (*opcode)(unsafe.Pointer(head)) compiled.code = ret diff --git a/encode_opcode.go b/encode_opcode.go index 9fef8bf..59deec9 100644 --- a/encode_opcode.go +++ b/encode_opcode.go @@ -53,12 +53,12 @@ func (c *opcode) beforeLastCode() *opcode { code := c for { var nextCode *opcode - switch code.op { - case opArrayElem, opArrayElemIndent: + switch code.op.codeType() { + case codeArrayElem: nextCode = code.toArrayElemCode().end - case opSliceElem, opSliceElemIndent, opRootSliceElemIndent: + case codeSliceElem: nextCode = code.toSliceElemCode().end - case opMapKey, opMapKeyIndent, opRootMapKeyIndent: + case codeMapKey: nextCode = code.toMapKeyCode().end default: nextCode = code.next @@ -112,13 +112,13 @@ func (c *opcode) dump() string { codes := []string{} for code := c; code.op != opEnd; { indent := strings.Repeat(" ", code.indent) - codes = append(codes, fmt.Sprintf("%s%s", indent, code.op)) - switch code.op { - case opArrayElem, opArrayElemIndent: + codes = append(codes, fmt.Sprintf("%s%s ( %p )", indent, code.op, unsafe.Pointer(code))) + switch code.op.codeType() { + case codeArrayElem: code = code.toArrayElemCode().end - case opSliceElem, opSliceElemIndent, opRootSliceElemIndent: + case codeSliceElem: code = code.toSliceElemCode().end - case opMapKey, opMapKeyIndent, opRootMapKeyIndent: + case codeMapKey: code = code.toMapKeyCode().end default: code = code.next @@ -305,12 +305,43 @@ func (c *arrayElemCode) copy(codeMap map[uintptr]*opcode) *opcode { type structFieldCode struct { *opcodeHeader key []byte + displayKey string offset uintptr anonymousKey bool nextField *opcode end *opcode } +func linkPrevToNextField(prev, cur *structFieldCode) { + prev.nextField = cur.nextField + code := prev.toOpcode() + fcode := cur.toOpcode() + for { + var nextCode *opcode + switch code.op.codeType() { + case codeArrayElem: + nextCode = code.toArrayElemCode().end + case codeSliceElem: + nextCode = code.toSliceElemCode().end + case codeMapKey: + nextCode = code.toMapKeyCode().end + default: + nextCode = code.next + } + if nextCode == fcode { + code.next = fcode.next + break + } else if nextCode.op == opEnd { + break + } + code = nextCode + } +} + +func (c *structFieldCode) toOpcode() *opcode { + return (*opcode)(unsafe.Pointer(c)) +} + func (c *structFieldCode) copy(codeMap map[uintptr]*opcode) *opcode { if c == nil { return nil @@ -321,6 +352,7 @@ func (c *structFieldCode) copy(codeMap map[uintptr]*opcode) *opcode { } field := &structFieldCode{ key: c.key, + displayKey: c.displayKey, anonymousKey: c.anonymousKey, offset: c.offset, } diff --git a/encode_vm.go b/encode_vm.go index edd3edc..8265cde 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -549,6 +549,16 @@ func (e *Encoder) run(code *opcode) error { code.ptr = ptr field.nextField.ptr = ptr } + case opStructFieldAnonymousHead: + field := code.toStructFieldCode() + ptr := field.ptr + if ptr == 0 { + code = field.end.next + } else { + code = field.next + code.ptr = ptr + field.nextField.ptr = ptr + } case opStructFieldPtrHeadInt: code.ptr = e.ptrToPtr(code.ptr) fallthrough @@ -3872,7 +3882,9 @@ func (e *Encoder) run(code *opcode) error { e.encodeByte(',') } c := code.toStructFieldCode() - e.encodeBytes(c.key) + if !c.anonymousKey { + e.encodeBytes(c.key) + } code = code.next code.ptr = c.ptr + c.offset c.nextField.ptr = c.ptr diff --git a/struct_field.go b/struct_field.go index 3352da1..95de4bf 100644 --- a/struct_field.go +++ b/struct_field.go @@ -25,6 +25,18 @@ type structTag struct { key string isOmitEmpty bool isString bool + field reflect.StructField +} + +type structTags []*structTag + +func (t structTags) existsKey(key string) bool { + for _, tt := range t { + if tt.key == key { + return true + } + } + return false } func structTagFromField(field reflect.StructField) *structTag { @@ -36,7 +48,7 @@ func structTagFromField(field reflect.StructField) *structTag { keyName = opts[0] } } - st := &structTag{key: keyName} + st := &structTag{key: keyName, field: field} if len(opts) > 1 { st.isOmitEmpty = opts[1] == "omitempty" st.isString = opts[1] == "string" From 7ada1b2467b51d4ce27f869da00943b434ceb97f Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sat, 22 Aug 2020 15:40:18 +0900 Subject: [PATCH 5/5] Fix conflicted anonymous fields --- encode_compile.go | 156 +++++++++++++++++++++++++++++----------------- encode_opcode.go | 2 + struct_field.go | 5 +- 3 files changed, 105 insertions(+), 58 deletions(-) diff --git a/encode_compile.go b/encode_compile.go index 4733700..8d81da9 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -615,9 +615,100 @@ func (e *Encoder) optimizeAnonymousFields(head *structFieldCode) { } type structFieldPair struct { - prevField *structFieldCode - curField *structFieldCode - linked bool + prevField *structFieldCode + curField *structFieldCode + isTaggedKey bool + linked bool +} + +func (e *Encoder) anonymousStructFieldPairMap(typ *rtype, tags structTags, valueCode *structFieldCode) map[string][]structFieldPair { + //fmt.Println("type = ", typ, "valueCode = ", valueCode.dump()) + anonymousFields := map[string][]structFieldPair{} + f := valueCode + var prevAnonymousField *structFieldCode + for { + existsKey := tags.existsKey(f.displayKey) + op := f.op.headToAnonymousHead() + if op != f.op { + if existsKey { + f.op = opStructFieldAnonymousHead + } else { + f.op = op + } + } else if f.op == opStructEnd { + f.op = opStructAnonymousEnd + } else if existsKey { + linkPrevToNextField(prevAnonymousField, f) + } + + if f.displayKey == "" { + if f.nextField == nil { + break + } + prevAnonymousField = f + f = f.nextField.toStructFieldCode() + continue + } + + anonymousFields[f.displayKey] = append(anonymousFields[f.displayKey], structFieldPair{ + prevField: prevAnonymousField, + curField: f, + isTaggedKey: f.isTaggedKey, + }) + if f.next != nil && f.nextField != f.next && f.next.op.codeType() == codeStructField { + for k, v := range e.anonymousStructFieldPairMap(typ, tags, f.next.toStructFieldCode()) { + anonymousFields[k] = append(anonymousFields[k], v...) + } + } + if f.nextField == nil { + break + } + prevAnonymousField = f + f = f.nextField.toStructFieldCode() + } + return anonymousFields +} + +func (e *Encoder) optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPair) { + for _, fieldPairs := range anonymousFields { + if len(fieldPairs) == 1 { + continue + } + // conflict anonymous fields + taggedPairs := []structFieldPair{} + for _, fieldPair := range fieldPairs { + if fieldPair.isTaggedKey { + taggedPairs = append(taggedPairs, fieldPair) + } else { + if !fieldPair.linked { + if fieldPair.prevField == nil { + // head operation + fieldPair.curField.op = opStructFieldAnonymousHead + } else { + linkPrevToNextField(fieldPair.prevField, fieldPair.curField) + } + fieldPair.linked = true + } + } + } + if len(taggedPairs) > 1 { + for _, fieldPair := range taggedPairs { + if !fieldPair.linked { + if fieldPair.prevField == nil { + // head operation + fieldPair.curField.op = opStructFieldAnonymousHead + } else { + linkPrevToNextField(fieldPair.prevField, fieldPair.curField) + } + fieldPair.linked = true + } + } + } else { + for _, fieldPair := range taggedPairs { + fieldPair.curField.isTaggedKey = false + } + } + } } func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opcode, error) { @@ -643,7 +734,7 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco ) e.indent++ tags := structTags{} - anonymousFields := map[string]structFieldPair{} + anonymousFields := map[string][]structFieldPair{} for i := 0; i < fieldNum; i++ { field := typ.Field(i) if isIgnoredStructField(field) { @@ -668,59 +759,8 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco return nil, err } if field.Anonymous { - f := valueCode.toStructFieldCode() - var prevAnonymousField *structFieldCode - for { - existsKey := tags.existsKey(f.displayKey) - op := f.op.headToAnonymousHead() - if op != f.op { - if existsKey { - f.op = opStructFieldAnonymousHead - } else { - f.op = op - } - } else if f.op == opStructEnd { - f.op = opStructAnonymousEnd - } else if existsKey { - linkPrevToNextField(prevAnonymousField, f) - } - - if f.displayKey == "" { - if f.nextField == nil { - break - } - prevAnonymousField = f - f = f.nextField.toStructFieldCode() - continue - } - - // conflict anonymous fields - if existsFieldSet, exists := anonymousFields[f.displayKey]; exists { - if !existsFieldSet.linked { - if existsFieldSet.prevField == nil { - // head operation - existsFieldSet.curField.op = opStructFieldAnonymousHead - } else { - linkPrevToNextField(existsFieldSet.prevField, existsFieldSet.curField) - } - existsFieldSet.linked = true - } - if prevAnonymousField == nil { - // head operation - f.op = opStructFieldAnonymousHead - } else { - linkPrevToNextField(prevAnonymousField, f) - } - } - anonymousFields[f.displayKey] = structFieldPair{ - prevField: prevAnonymousField, - curField: f, - } - if f.nextField == nil { - break - } - prevAnonymousField = f - f = f.nextField.toStructFieldCode() + for k, v := range e.anonymousStructFieldPairMap(typ, tags, valueCode.toStructFieldCode()) { + anonymousFields[k] = append(anonymousFields[k], v...) } } if fieldNum == 1 && valueCode.op == opPtr { @@ -742,6 +782,7 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco }, anonymousKey: field.Anonymous, key: []byte(key), + isTaggedKey: tag.isTaggedKey, displayKey: tag.key, offset: field.Offset, } @@ -794,6 +835,7 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco head.end = structEndCode code.next = structEndCode + e.optimizeConflictAnonymousFields(anonymousFields) e.optimizeAnonymousFields(head) ret := (*opcode)(unsafe.Pointer(head)) diff --git a/encode_opcode.go b/encode_opcode.go index 59deec9..b43dc41 100644 --- a/encode_opcode.go +++ b/encode_opcode.go @@ -306,6 +306,7 @@ type structFieldCode struct { *opcodeHeader key []byte displayKey string + isTaggedKey bool offset uintptr anonymousKey bool nextField *opcode @@ -352,6 +353,7 @@ func (c *structFieldCode) copy(codeMap map[uintptr]*opcode) *opcode { } field := &structFieldCode{ key: c.key, + isTaggedKey: c.isTaggedKey, displayKey: c.displayKey, anonymousKey: c.anonymousKey, offset: c.offset, diff --git a/struct_field.go b/struct_field.go index 95de4bf..b2a6f50 100644 --- a/struct_field.go +++ b/struct_field.go @@ -23,6 +23,7 @@ func isIgnoredStructField(field reflect.StructField) bool { type structTag struct { key string + isTaggedKey bool isOmitEmpty bool isString bool field reflect.StructField @@ -42,13 +43,15 @@ func (t structTags) existsKey(key string) bool { func structTagFromField(field reflect.StructField) *structTag { keyName := field.Name tag := getTag(field) + st := &structTag{field: field} opts := strings.Split(tag, ",") if len(opts) > 0 { if opts[0] != "" { keyName = opts[0] + st.isTaggedKey = true } } - st := &structTag{key: keyName, field: field} + st.key = keyName if len(opts) > 1 { st.isOmitEmpty = opts[1] == "omitempty" st.isString = opts[1] == "string"