From 2a997045316637a9676326212fea438692d6fdbf Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 13 Aug 2020 15:26:35 +0900 Subject: [PATCH] Fix recursive definition of struct --- encode_compile.go | 8 +-- encode_opcode.go | 131 ++++++++++++++++++++++++++++++++++++++++++++++ encode_test.go | 13 +++-- encode_vm.go | 4 +- 4 files changed, 146 insertions(+), 10 deletions(-) diff --git a/encode_compile.go b/encode_compile.go index c131624..ea61682 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -749,7 +749,7 @@ func (e *Encoder) optimizeStructField(op opType, isOmitEmpty, withIndent bool) o func (e *Encoder) compileStruct(typ *rtype, root, withIndent bool) (*opcode, error) { typeptr := uintptr(unsafe.Pointer(typ)) if withIndent { - if compiled, exists := e.structTypeToCompiledCode[typeptr]; exists { + if compiled, exists := e.structTypeToCompiledIndentCode[typeptr]; exists { return (*opcode)(unsafe.Pointer(&recursiveCode{ opcodeHeader: &opcodeHeader{ op: opStructFieldRecursive, @@ -761,7 +761,7 @@ func (e *Encoder) compileStruct(typ *rtype, root, withIndent bool) (*opcode, err })), nil } } else { - if compiled, exists := e.structTypeToCompiledIndentCode[typeptr]; exists { + if compiled, exists := e.structTypeToCompiledCode[typeptr]; exists { return (*opcode)(unsafe.Pointer(&recursiveCode{ opcodeHeader: &opcodeHeader{ op: opStructFieldRecursive, @@ -775,9 +775,9 @@ func (e *Encoder) compileStruct(typ *rtype, root, withIndent bool) (*opcode, err } compiled := &compiledCode{} if withIndent { - e.structTypeToCompiledCode[typeptr] = compiled - } else { e.structTypeToCompiledIndentCode[typeptr] = compiled + } else { + e.structTypeToCompiledCode[typeptr] = compiled } // header => code => structField => code => end // ^ | diff --git a/encode_opcode.go b/encode_opcode.go index 9bd8e4b..f254706 100644 --- a/encode_opcode.go +++ b/encode_opcode.go @@ -1416,3 +1416,134 @@ func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode { } return code } + +func newRecursiveCode(recursive *recursiveCode) *opcode { + code := copyOpcode(recursive.jmp.code) + head := (*structFieldCode)(unsafe.Pointer(code)) + head.end.next = newEndOp(0) + code.ptr = recursive.ptr + + switch code.op { + case opStructFieldPtrHead: + code.op = opStructFieldHead + case opStructFieldPtrHeadInt: + code.op = opStructFieldHeadInt + case opStructFieldPtrHeadInt8: + code.op = opStructFieldHeadInt8 + case opStructFieldPtrHeadInt16: + code.op = opStructFieldHeadInt16 + case opStructFieldPtrHeadInt32: + code.op = opStructFieldHeadInt32 + case opStructFieldPtrHeadInt64: + code.op = opStructFieldHeadInt64 + case opStructFieldPtrHeadUint: + code.op = opStructFieldHeadUint + case opStructFieldPtrHeadUint8: + code.op = opStructFieldHeadUint8 + case opStructFieldPtrHeadUint16: + code.op = opStructFieldHeadUint16 + case opStructFieldPtrHeadUint32: + code.op = opStructFieldHeadUint32 + case opStructFieldPtrHeadUint64: + code.op = opStructFieldHeadUint64 + case opStructFieldPtrHeadFloat32: + code.op = opStructFieldHeadFloat32 + case opStructFieldPtrHeadFloat64: + code.op = opStructFieldHeadFloat64 + case opStructFieldPtrHeadString: + code.op = opStructFieldHeadString + case opStructFieldPtrHeadBool: + code.op = opStructFieldHeadBool + case opStructFieldPtrHeadIndent: + code.op = opStructFieldHeadIndent + case opStructFieldPtrHeadIntIndent: + code.op = opStructFieldHeadIntIndent + case opStructFieldPtrHeadInt8Indent: + code.op = opStructFieldHeadInt8Indent + case opStructFieldPtrHeadInt16Indent: + code.op = opStructFieldHeadInt16Indent + case opStructFieldPtrHeadInt32Indent: + code.op = opStructFieldHeadInt32Indent + case opStructFieldPtrHeadInt64Indent: + code.op = opStructFieldHeadInt64Indent + case opStructFieldPtrHeadUintIndent: + code.op = opStructFieldHeadUintIndent + case opStructFieldPtrHeadUint8Indent: + code.op = opStructFieldHeadUint8Indent + case opStructFieldPtrHeadUint16Indent: + code.op = opStructFieldHeadUint16Indent + case opStructFieldPtrHeadUint32Indent: + code.op = opStructFieldHeadUint32Indent + case opStructFieldPtrHeadUint64Indent: + code.op = opStructFieldHeadUint64Indent + case opStructFieldPtrHeadFloat32Indent: + code.op = opStructFieldHeadFloat32Indent + case opStructFieldPtrHeadFloat64Indent: + code.op = opStructFieldHeadFloat64Indent + case opStructFieldPtrHeadStringIndent: + code.op = opStructFieldHeadStringIndent + case opStructFieldPtrHeadBoolIndent: + code.op = opStructFieldHeadBoolIndent + case opStructFieldPtrHeadOmitEmpty: + code.op = opStructFieldHeadOmitEmpty + case opStructFieldPtrHeadIntOmitEmpty: + code.op = opStructFieldHeadIntOmitEmpty + case opStructFieldPtrHeadInt8OmitEmpty: + code.op = opStructFieldHeadInt8OmitEmpty + case opStructFieldPtrHeadInt16OmitEmpty: + code.op = opStructFieldHeadInt16OmitEmpty + case opStructFieldPtrHeadInt32OmitEmpty: + code.op = opStructFieldHeadInt32OmitEmpty + case opStructFieldPtrHeadInt64OmitEmpty: + code.op = opStructFieldHeadInt64OmitEmpty + case opStructFieldPtrHeadUintOmitEmpty: + code.op = opStructFieldHeadUintOmitEmpty + case opStructFieldPtrHeadUint8OmitEmpty: + code.op = opStructFieldHeadUint8OmitEmpty + case opStructFieldPtrHeadUint16OmitEmpty: + code.op = opStructFieldHeadUint16OmitEmpty + case opStructFieldPtrHeadUint32OmitEmpty: + code.op = opStructFieldHeadUint32OmitEmpty + case opStructFieldPtrHeadUint64OmitEmpty: + code.op = opStructFieldHeadUint64OmitEmpty + case opStructFieldPtrHeadFloat32OmitEmpty: + code.op = opStructFieldHeadFloat32OmitEmpty + case opStructFieldPtrHeadFloat64OmitEmpty: + code.op = opStructFieldHeadFloat64OmitEmpty + case opStructFieldPtrHeadStringOmitEmpty: + code.op = opStructFieldHeadStringOmitEmpty + case opStructFieldPtrHeadBoolOmitEmpty: + code.op = opStructFieldHeadBoolOmitEmpty + case opStructFieldPtrHeadOmitEmptyIndent: + code.op = opStructFieldHeadOmitEmptyIndent + case opStructFieldPtrHeadIntOmitEmptyIndent: + code.op = opStructFieldHeadIntOmitEmptyIndent + case opStructFieldPtrHeadInt8OmitEmptyIndent: + code.op = opStructFieldHeadInt8OmitEmptyIndent + case opStructFieldPtrHeadInt16OmitEmptyIndent: + code.op = opStructFieldHeadInt16OmitEmptyIndent + case opStructFieldPtrHeadInt32OmitEmptyIndent: + code.op = opStructFieldHeadInt32OmitEmptyIndent + case opStructFieldPtrHeadInt64OmitEmptyIndent: + code.op = opStructFieldHeadInt64OmitEmptyIndent + case opStructFieldPtrHeadUintOmitEmptyIndent: + code.op = opStructFieldHeadUintOmitEmptyIndent + case opStructFieldPtrHeadUint8OmitEmptyIndent: + code.op = opStructFieldHeadUint8OmitEmptyIndent + case opStructFieldPtrHeadUint16OmitEmptyIndent: + code.op = opStructFieldHeadUint16OmitEmptyIndent + case opStructFieldPtrHeadUint32OmitEmptyIndent: + code.op = opStructFieldHeadUint32OmitEmptyIndent + case opStructFieldPtrHeadUint64OmitEmptyIndent: + code.op = opStructFieldHeadUint64OmitEmptyIndent + case opStructFieldPtrHeadFloat32OmitEmptyIndent: + code.op = opStructFieldHeadFloat32OmitEmptyIndent + case opStructFieldPtrHeadFloat64OmitEmptyIndent: + code.op = opStructFieldHeadFloat64OmitEmptyIndent + case opStructFieldPtrHeadStringOmitEmptyIndent: + code.op = opStructFieldHeadStringOmitEmptyIndent + case opStructFieldPtrHeadBoolOmitEmptyIndent: + code.op = opStructFieldHeadBoolOmitEmptyIndent + } + return code +} diff --git a/encode_test.go b/encode_test.go index 8ba12d7..831bc0d 100644 --- a/encode_test.go +++ b/encode_test.go @@ -12,8 +12,10 @@ import ( type recursiveT struct { A *recursiveT `json:"a,omitempty"` B *recursiveU `json:"b,omitempty"` - C string `json:"c,omitempty"` + C *recursiveU `json:"c,omitempty"` + D string `json:"d,omitempty"` } + type recursiveU struct { T *recursiveT `json:"t,omitempty"` } @@ -117,13 +119,18 @@ func Test_Marshal(t *testing.T) { A: &recursiveT{ B: &recursiveU{ T: &recursiveT{ - C: "hello", + D: "hello", + }, + }, + C: &recursiveU{ + T: &recursiveT{ + D: "world", }, }, }, }) assertErr(t, err) - assertEq(t, "recursive", `{"a":{"b":{"t":{"c":"hello"}}}}`, string(bytes)) + assertEq(t, "recursive", `{"a":{"b":{"t":{"d":"hello"}},"c":{"t":{"d":"world"}}}}`, string(bytes)) }) t.Run("omitempty", func(t *testing.T) { type T struct { diff --git a/encode_vm.go b/encode_vm.go index d678079..a591a8d 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -460,9 +460,7 @@ func (e *Encoder) run(code *opcode) error { code = c.next case opStructFieldRecursive: recursive := code.toRecursiveCode() - c := copyOpcode(recursive.jmp.code) - c.ptr = recursive.ptr - if err := e.run(c); err != nil { + if err := e.run(newRecursiveCode(recursive)); err != nil { return err } code = recursive.next