Merge pull request #19 from goccy/feature/fix-recursive-struct

Fix recursive type definition
This commit is contained in:
Masaaki Goshima 2020-08-12 18:44:21 +09:00 committed by GitHub
commit dad89cea0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 138 additions and 32 deletions

View File

@ -12,14 +12,20 @@ import (
// An Encoder writes JSON values to an output stream.
type Encoder struct {
w io.Writer
buf []byte
pool sync.Pool
enabledIndent bool
enabledHTMLEscape bool
prefix []byte
indentStr []byte
indent int
w io.Writer
buf []byte
pool sync.Pool
enabledIndent bool
enabledHTMLEscape bool
prefix []byte
indentStr []byte
indent int
structTypeToCompiledCode map[uintptr]*compiledCode
structTypeToCompiledIndentCode map[uintptr]*compiledCode
}
type compiledCode struct {
code *opcode
}
const (
@ -58,8 +64,10 @@ func init() {
encPool = sync.Pool{
New: func() interface{} {
return &Encoder{
buf: make([]byte, 0, bufSize),
pool: encPool,
buf: make([]byte, 0, bufSize),
pool: encPool,
structTypeToCompiledCode: map[uintptr]*compiledCode{},
structTypeToCompiledIndentCode: map[uintptr]*compiledCode{},
}
},
}

View File

@ -747,6 +747,38 @@ func (e *Encoder) optimizeStructField(op opType, isOmitEmpty, withIndent bool) o
}
func (e *Encoder) compileStruct(typ *rtype, root, withIndent bool) (*opcode, error) {
typeptr := uintptr(unsafe.Pointer(typ))
if withIndent {
if compiled, exists := e.structTypeToCompiledCode[typeptr]; exists {
return (*opcode)(unsafe.Pointer(&recursiveCode{
opcodeHeader: &opcodeHeader{
op: opStructFieldRecursive,
typ: typ,
indent: e.indent,
next: newEndOp(e.indent),
},
jmp: compiled,
})), nil
}
} else {
if compiled, exists := e.structTypeToCompiledIndentCode[typeptr]; exists {
return (*opcode)(unsafe.Pointer(&recursiveCode{
opcodeHeader: &opcodeHeader{
op: opStructFieldRecursive,
typ: typ,
indent: e.indent,
next: newEndOp(e.indent),
},
jmp: compiled,
})), nil
}
}
compiled := &compiledCode{}
if withIndent {
e.structTypeToCompiledCode[typeptr] = compiled
} else {
e.structTypeToCompiledIndentCode[typeptr] = compiled
}
// header => code => structField => code => end
// ^ |
// |__________|
@ -851,5 +883,7 @@ func (e *Encoder) compileStruct(typ *rtype, root, withIndent bool) (*opcode, err
head.end = structEndCode
code.next = structEndCode
structEndCode.next = newEndOp(e.indent)
return (*opcode)(unsafe.Pointer(head)), nil
ret := (*opcode)(unsafe.Pointer(head))
compiled.code = ret
return ret, nil
}

View File

@ -145,6 +145,8 @@ const (
opStructFieldPtrHeadString
opStructFieldPtrHeadBool
opStructFieldRecursive
opStructFieldPtrHeadIndent
opStructFieldPtrHeadIntIndent
opStructFieldPtrHeadInt8Indent
@ -362,6 +364,8 @@ func (t opType) String() string {
case opMapEndIndent:
return "MAP_END_INDENT"
case opStructFieldRecursive:
return "STRUCT_FIELD_RECURSIVE"
case opStructFieldHead:
return "STRUCT_FIELD_HEAD"
case opStructFieldHeadInt:
@ -831,6 +835,8 @@ func (c *opcode) copy(codeMap map[uintptr]*opcode) *opcode {
code = c.toMapKeyCode().copy(codeMap)
case opMapValue, opMapValueIndent:
code = c.toMapValueCode().copy(codeMap)
case opStructFieldRecursive:
code = c.toRecursiveCode().copy(codeMap)
case opStructFieldHead,
opStructFieldHeadInt,
opStructFieldHeadInt8,
@ -1076,6 +1082,10 @@ func (c *opcode) toInterfaceCode() *interfaceCode {
return (*interfaceCode)(unsafe.Pointer(c))
}
func (c *opcode) toRecursiveCode() *recursiveCode {
return (*recursiveCode)(unsafe.Pointer(c))
}
type sliceHeaderCode struct {
*opcodeHeader
elem *sliceElemCode
@ -1301,27 +1311,6 @@ func (c *mapKeyCode) set(len int, iter unsafe.Pointer) {
c.iter = iter
}
type interfaceCode struct {
*opcodeHeader
root bool
}
func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode {
if c == nil {
return nil
}
addr := uintptr(unsafe.Pointer(c))
if code, exists := codeMap[addr]; exists {
return code
}
iface := &interfaceCode{}
code := (*opcode)(unsafe.Pointer(iface))
codeMap[addr] = code
iface.opcodeHeader = c.opcodeHeader.copy(codeMap)
return code
}
type mapValueCode struct {
*opcodeHeader
iter unsafe.Pointer
@ -1382,3 +1371,48 @@ func newMapValueCode(indent int) *mapValueCode {
},
}
}
type interfaceCode struct {
*opcodeHeader
root bool
}
func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode {
if c == nil {
return nil
}
addr := uintptr(unsafe.Pointer(c))
if code, exists := codeMap[addr]; exists {
return code
}
iface := &interfaceCode{}
code := (*opcode)(unsafe.Pointer(iface))
codeMap[addr] = code
iface.opcodeHeader = c.opcodeHeader.copy(codeMap)
return code
}
type recursiveCode struct {
*opcodeHeader
jmp *compiledCode
}
func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode {
if c == nil {
return nil
}
addr := uintptr(unsafe.Pointer(c))
if code, exists := codeMap[addr]; exists {
return code
}
recur := &recursiveCode{}
code := (*opcode)(unsafe.Pointer(recur))
codeMap[addr] = code
recur.opcodeHeader = c.opcodeHeader.copy(codeMap)
recur.jmp = &compiledCode{
code: c.jmp.code.copy(codeMap),
}
return code
}

View File

@ -9,6 +9,15 @@ import (
"github.com/goccy/go-json"
)
type recursiveT struct {
A *recursiveT `json:"a,omitempty"`
B *recursiveU `json:"b,omitempty"`
C string `json:"c,omitempty"`
}
type recursiveU struct {
T *recursiveT `json:"t,omitempty"`
}
func Test_Marshal(t *testing.T) {
t.Run("int", func(t *testing.T) {
bytes, err := json.Marshal(-10)
@ -103,6 +112,19 @@ func Test_Marshal(t *testing.T) {
assertErr(t, err)
assertEq(t, "struct", `{"a":null}`, string(bytes))
})
t.Run("recursive", func(t *testing.T) {
bytes, err := json.Marshal(recursiveT{
A: &recursiveT{
B: &recursiveU{
T: &recursiveT{
C: "hello",
},
},
},
})
assertErr(t, err)
assertEq(t, "recursive", `{"a":{"b":{"t":{"c":"hello"}}}}`, string(bytes))
})
t.Run("omitempty", func(t *testing.T) {
type T struct {
A int `json:",omitempty"`

View File

@ -458,6 +458,14 @@ func (e *Encoder) run(code *opcode) error {
c.next.ptr = uintptr(value)
mapiternext(c.iter)
code = c.next
case opStructFieldRecursive:
recursive := code.toRecursiveCode()
c := copyOpcode(recursive.jmp.code)
c.ptr = recursive.ptr
if err := e.run(c); err != nil {
return err
}
code = recursive.next
case opStructFieldPtrHead:
if code.ptr != 0 {
code.ptr = e.ptrToPtr(code.ptr)