From b71f7da8bc4f97d9c246f78f10c1b952ab70447e Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Wed, 12 Aug 2020 18:42:29 +0900 Subject: [PATCH] Fix recursive type definition --- encode.go | 28 ++++++++++------- encode_compile.go | 36 +++++++++++++++++++++- encode_opcode.go | 76 ++++++++++++++++++++++++++++++++++------------- encode_test.go | 22 ++++++++++++++ encode_vm.go | 8 +++++ 5 files changed, 138 insertions(+), 32 deletions(-) diff --git a/encode.go b/encode.go index 58c5b3e..37c8b7d 100644 --- a/encode.go +++ b/encode.go @@ -12,14 +12,20 @@ import ( // An Encoder writes JSON values to an output stream. type Encoder struct { - w io.Writer - buf []byte - pool sync.Pool - enabledIndent bool - enabledHTMLEscape bool - prefix []byte - indentStr []byte - indent int + w io.Writer + buf []byte + pool sync.Pool + enabledIndent bool + enabledHTMLEscape bool + prefix []byte + indentStr []byte + indent int + structTypeToCompiledCode map[uintptr]*compiledCode + structTypeToCompiledIndentCode map[uintptr]*compiledCode +} + +type compiledCode struct { + code *opcode } const ( @@ -58,8 +64,10 @@ func init() { encPool = sync.Pool{ New: func() interface{} { return &Encoder{ - buf: make([]byte, 0, bufSize), - pool: encPool, + buf: make([]byte, 0, bufSize), + pool: encPool, + structTypeToCompiledCode: map[uintptr]*compiledCode{}, + structTypeToCompiledIndentCode: map[uintptr]*compiledCode{}, } }, } diff --git a/encode_compile.go b/encode_compile.go index 1f86e8d..c131624 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -747,6 +747,38 @@ func (e *Encoder) optimizeStructField(op opType, isOmitEmpty, withIndent bool) o } func (e *Encoder) compileStruct(typ *rtype, root, withIndent bool) (*opcode, error) { + typeptr := uintptr(unsafe.Pointer(typ)) + if withIndent { + if compiled, exists := e.structTypeToCompiledCode[typeptr]; exists { + return (*opcode)(unsafe.Pointer(&recursiveCode{ + opcodeHeader: &opcodeHeader{ + op: opStructFieldRecursive, + typ: typ, + indent: e.indent, + next: newEndOp(e.indent), + }, + jmp: compiled, + })), nil + } + } else { + if compiled, exists := e.structTypeToCompiledIndentCode[typeptr]; exists { + return (*opcode)(unsafe.Pointer(&recursiveCode{ + opcodeHeader: &opcodeHeader{ + op: opStructFieldRecursive, + typ: typ, + indent: e.indent, + next: newEndOp(e.indent), + }, + jmp: compiled, + })), nil + } + } + compiled := &compiledCode{} + if withIndent { + e.structTypeToCompiledCode[typeptr] = compiled + } else { + e.structTypeToCompiledIndentCode[typeptr] = compiled + } // header => code => structField => code => end // ^ | // |__________| @@ -851,5 +883,7 @@ func (e *Encoder) compileStruct(typ *rtype, root, withIndent bool) (*opcode, err head.end = structEndCode code.next = structEndCode structEndCode.next = newEndOp(e.indent) - return (*opcode)(unsafe.Pointer(head)), nil + ret := (*opcode)(unsafe.Pointer(head)) + compiled.code = ret + return ret, nil } diff --git a/encode_opcode.go b/encode_opcode.go index 7db664b..9bd8e4b 100644 --- a/encode_opcode.go +++ b/encode_opcode.go @@ -145,6 +145,8 @@ const ( opStructFieldPtrHeadString opStructFieldPtrHeadBool + opStructFieldRecursive + opStructFieldPtrHeadIndent opStructFieldPtrHeadIntIndent opStructFieldPtrHeadInt8Indent @@ -362,6 +364,8 @@ func (t opType) String() string { case opMapEndIndent: return "MAP_END_INDENT" + case opStructFieldRecursive: + return "STRUCT_FIELD_RECURSIVE" case opStructFieldHead: return "STRUCT_FIELD_HEAD" case opStructFieldHeadInt: @@ -831,6 +835,8 @@ func (c *opcode) copy(codeMap map[uintptr]*opcode) *opcode { code = c.toMapKeyCode().copy(codeMap) case opMapValue, opMapValueIndent: code = c.toMapValueCode().copy(codeMap) + case opStructFieldRecursive: + code = c.toRecursiveCode().copy(codeMap) case opStructFieldHead, opStructFieldHeadInt, opStructFieldHeadInt8, @@ -1076,6 +1082,10 @@ func (c *opcode) toInterfaceCode() *interfaceCode { return (*interfaceCode)(unsafe.Pointer(c)) } +func (c *opcode) toRecursiveCode() *recursiveCode { + return (*recursiveCode)(unsafe.Pointer(c)) +} + type sliceHeaderCode struct { *opcodeHeader elem *sliceElemCode @@ -1301,27 +1311,6 @@ func (c *mapKeyCode) set(len int, iter unsafe.Pointer) { c.iter = iter } -type interfaceCode struct { - *opcodeHeader - root bool -} - -func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode { - if c == nil { - return nil - } - addr := uintptr(unsafe.Pointer(c)) - if code, exists := codeMap[addr]; exists { - return code - } - iface := &interfaceCode{} - code := (*opcode)(unsafe.Pointer(iface)) - codeMap[addr] = code - - iface.opcodeHeader = c.opcodeHeader.copy(codeMap) - return code -} - type mapValueCode struct { *opcodeHeader iter unsafe.Pointer @@ -1382,3 +1371,48 @@ func newMapValueCode(indent int) *mapValueCode { }, } } + +type interfaceCode struct { + *opcodeHeader + root bool +} + +func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode { + if c == nil { + return nil + } + addr := uintptr(unsafe.Pointer(c)) + if code, exists := codeMap[addr]; exists { + return code + } + iface := &interfaceCode{} + code := (*opcode)(unsafe.Pointer(iface)) + codeMap[addr] = code + + iface.opcodeHeader = c.opcodeHeader.copy(codeMap) + return code +} + +type recursiveCode struct { + *opcodeHeader + jmp *compiledCode +} + +func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode { + if c == nil { + return nil + } + addr := uintptr(unsafe.Pointer(c)) + if code, exists := codeMap[addr]; exists { + return code + } + recur := &recursiveCode{} + code := (*opcode)(unsafe.Pointer(recur)) + codeMap[addr] = code + + recur.opcodeHeader = c.opcodeHeader.copy(codeMap) + recur.jmp = &compiledCode{ + code: c.jmp.code.copy(codeMap), + } + return code +} diff --git a/encode_test.go b/encode_test.go index a7e518a..8ba12d7 100644 --- a/encode_test.go +++ b/encode_test.go @@ -9,6 +9,15 @@ import ( "github.com/goccy/go-json" ) +type recursiveT struct { + A *recursiveT `json:"a,omitempty"` + B *recursiveU `json:"b,omitempty"` + C string `json:"c,omitempty"` +} +type recursiveU struct { + T *recursiveT `json:"t,omitempty"` +} + func Test_Marshal(t *testing.T) { t.Run("int", func(t *testing.T) { bytes, err := json.Marshal(-10) @@ -103,6 +112,19 @@ func Test_Marshal(t *testing.T) { assertErr(t, err) assertEq(t, "struct", `{"a":null}`, string(bytes)) }) + t.Run("recursive", func(t *testing.T) { + bytes, err := json.Marshal(recursiveT{ + A: &recursiveT{ + B: &recursiveU{ + T: &recursiveT{ + C: "hello", + }, + }, + }, + }) + assertErr(t, err) + assertEq(t, "recursive", `{"a":{"b":{"t":{"c":"hello"}}}}`, string(bytes)) + }) t.Run("omitempty", func(t *testing.T) { type T struct { A int `json:",omitempty"` diff --git a/encode_vm.go b/encode_vm.go index 8fa39ca..d678079 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -458,6 +458,14 @@ func (e *Encoder) run(code *opcode) error { c.next.ptr = uintptr(value) mapiternext(c.iter) code = c.next + case opStructFieldRecursive: + recursive := code.toRecursiveCode() + c := copyOpcode(recursive.jmp.code) + c.ptr = recursive.ptr + if err := e.run(c); err != nil { + return err + } + code = recursive.next case opStructFieldPtrHead: if code.ptr != 0 { code.ptr = e.ptrToPtr(code.ptr)