Fix recursive type definition

This commit is contained in:
Masaaki Goshima 2020-08-12 18:42:29 +09:00
parent 18e30e0106
commit b71f7da8bc
5 changed files with 138 additions and 32 deletions

View File

@ -12,14 +12,20 @@ import (
// An Encoder writes JSON values to an output stream. // An Encoder writes JSON values to an output stream.
type Encoder struct { type Encoder struct {
w io.Writer w io.Writer
buf []byte buf []byte
pool sync.Pool pool sync.Pool
enabledIndent bool enabledIndent bool
enabledHTMLEscape bool enabledHTMLEscape bool
prefix []byte prefix []byte
indentStr []byte indentStr []byte
indent int indent int
structTypeToCompiledCode map[uintptr]*compiledCode
structTypeToCompiledIndentCode map[uintptr]*compiledCode
}
type compiledCode struct {
code *opcode
} }
const ( const (
@ -58,8 +64,10 @@ func init() {
encPool = sync.Pool{ encPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
return &Encoder{ return &Encoder{
buf: make([]byte, 0, bufSize), buf: make([]byte, 0, bufSize),
pool: encPool, pool: encPool,
structTypeToCompiledCode: map[uintptr]*compiledCode{},
structTypeToCompiledIndentCode: map[uintptr]*compiledCode{},
} }
}, },
} }

View File

@ -747,6 +747,38 @@ func (e *Encoder) optimizeStructField(op opType, isOmitEmpty, withIndent bool) o
} }
func (e *Encoder) compileStruct(typ *rtype, root, withIndent bool) (*opcode, error) { func (e *Encoder) compileStruct(typ *rtype, root, withIndent bool) (*opcode, error) {
typeptr := uintptr(unsafe.Pointer(typ))
if withIndent {
if compiled, exists := e.structTypeToCompiledCode[typeptr]; exists {
return (*opcode)(unsafe.Pointer(&recursiveCode{
opcodeHeader: &opcodeHeader{
op: opStructFieldRecursive,
typ: typ,
indent: e.indent,
next: newEndOp(e.indent),
},
jmp: compiled,
})), nil
}
} else {
if compiled, exists := e.structTypeToCompiledIndentCode[typeptr]; exists {
return (*opcode)(unsafe.Pointer(&recursiveCode{
opcodeHeader: &opcodeHeader{
op: opStructFieldRecursive,
typ: typ,
indent: e.indent,
next: newEndOp(e.indent),
},
jmp: compiled,
})), nil
}
}
compiled := &compiledCode{}
if withIndent {
e.structTypeToCompiledCode[typeptr] = compiled
} else {
e.structTypeToCompiledIndentCode[typeptr] = compiled
}
// header => code => structField => code => end // header => code => structField => code => end
// ^ | // ^ |
// |__________| // |__________|
@ -851,5 +883,7 @@ func (e *Encoder) compileStruct(typ *rtype, root, withIndent bool) (*opcode, err
head.end = structEndCode head.end = structEndCode
code.next = structEndCode code.next = structEndCode
structEndCode.next = newEndOp(e.indent) structEndCode.next = newEndOp(e.indent)
return (*opcode)(unsafe.Pointer(head)), nil ret := (*opcode)(unsafe.Pointer(head))
compiled.code = ret
return ret, nil
} }

View File

@ -145,6 +145,8 @@ const (
opStructFieldPtrHeadString opStructFieldPtrHeadString
opStructFieldPtrHeadBool opStructFieldPtrHeadBool
opStructFieldRecursive
opStructFieldPtrHeadIndent opStructFieldPtrHeadIndent
opStructFieldPtrHeadIntIndent opStructFieldPtrHeadIntIndent
opStructFieldPtrHeadInt8Indent opStructFieldPtrHeadInt8Indent
@ -362,6 +364,8 @@ func (t opType) String() string {
case opMapEndIndent: case opMapEndIndent:
return "MAP_END_INDENT" return "MAP_END_INDENT"
case opStructFieldRecursive:
return "STRUCT_FIELD_RECURSIVE"
case opStructFieldHead: case opStructFieldHead:
return "STRUCT_FIELD_HEAD" return "STRUCT_FIELD_HEAD"
case opStructFieldHeadInt: case opStructFieldHeadInt:
@ -831,6 +835,8 @@ func (c *opcode) copy(codeMap map[uintptr]*opcode) *opcode {
code = c.toMapKeyCode().copy(codeMap) code = c.toMapKeyCode().copy(codeMap)
case opMapValue, opMapValueIndent: case opMapValue, opMapValueIndent:
code = c.toMapValueCode().copy(codeMap) code = c.toMapValueCode().copy(codeMap)
case opStructFieldRecursive:
code = c.toRecursiveCode().copy(codeMap)
case opStructFieldHead, case opStructFieldHead,
opStructFieldHeadInt, opStructFieldHeadInt,
opStructFieldHeadInt8, opStructFieldHeadInt8,
@ -1076,6 +1082,10 @@ func (c *opcode) toInterfaceCode() *interfaceCode {
return (*interfaceCode)(unsafe.Pointer(c)) return (*interfaceCode)(unsafe.Pointer(c))
} }
func (c *opcode) toRecursiveCode() *recursiveCode {
return (*recursiveCode)(unsafe.Pointer(c))
}
type sliceHeaderCode struct { type sliceHeaderCode struct {
*opcodeHeader *opcodeHeader
elem *sliceElemCode elem *sliceElemCode
@ -1301,27 +1311,6 @@ func (c *mapKeyCode) set(len int, iter unsafe.Pointer) {
c.iter = iter c.iter = iter
} }
type interfaceCode struct {
*opcodeHeader
root bool
}
func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode {
if c == nil {
return nil
}
addr := uintptr(unsafe.Pointer(c))
if code, exists := codeMap[addr]; exists {
return code
}
iface := &interfaceCode{}
code := (*opcode)(unsafe.Pointer(iface))
codeMap[addr] = code
iface.opcodeHeader = c.opcodeHeader.copy(codeMap)
return code
}
type mapValueCode struct { type mapValueCode struct {
*opcodeHeader *opcodeHeader
iter unsafe.Pointer iter unsafe.Pointer
@ -1382,3 +1371,48 @@ func newMapValueCode(indent int) *mapValueCode {
}, },
} }
} }
type interfaceCode struct {
*opcodeHeader
root bool
}
func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode {
if c == nil {
return nil
}
addr := uintptr(unsafe.Pointer(c))
if code, exists := codeMap[addr]; exists {
return code
}
iface := &interfaceCode{}
code := (*opcode)(unsafe.Pointer(iface))
codeMap[addr] = code
iface.opcodeHeader = c.opcodeHeader.copy(codeMap)
return code
}
type recursiveCode struct {
*opcodeHeader
jmp *compiledCode
}
func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode {
if c == nil {
return nil
}
addr := uintptr(unsafe.Pointer(c))
if code, exists := codeMap[addr]; exists {
return code
}
recur := &recursiveCode{}
code := (*opcode)(unsafe.Pointer(recur))
codeMap[addr] = code
recur.opcodeHeader = c.opcodeHeader.copy(codeMap)
recur.jmp = &compiledCode{
code: c.jmp.code.copy(codeMap),
}
return code
}

View File

@ -9,6 +9,15 @@ import (
"github.com/goccy/go-json" "github.com/goccy/go-json"
) )
type recursiveT struct {
A *recursiveT `json:"a,omitempty"`
B *recursiveU `json:"b,omitempty"`
C string `json:"c,omitempty"`
}
type recursiveU struct {
T *recursiveT `json:"t,omitempty"`
}
func Test_Marshal(t *testing.T) { func Test_Marshal(t *testing.T) {
t.Run("int", func(t *testing.T) { t.Run("int", func(t *testing.T) {
bytes, err := json.Marshal(-10) bytes, err := json.Marshal(-10)
@ -103,6 +112,19 @@ func Test_Marshal(t *testing.T) {
assertErr(t, err) assertErr(t, err)
assertEq(t, "struct", `{"a":null}`, string(bytes)) assertEq(t, "struct", `{"a":null}`, string(bytes))
}) })
t.Run("recursive", func(t *testing.T) {
bytes, err := json.Marshal(recursiveT{
A: &recursiveT{
B: &recursiveU{
T: &recursiveT{
C: "hello",
},
},
},
})
assertErr(t, err)
assertEq(t, "recursive", `{"a":{"b":{"t":{"c":"hello"}}}}`, string(bytes))
})
t.Run("omitempty", func(t *testing.T) { t.Run("omitempty", func(t *testing.T) {
type T struct { type T struct {
A int `json:",omitempty"` A int `json:",omitempty"`

View File

@ -458,6 +458,14 @@ func (e *Encoder) run(code *opcode) error {
c.next.ptr = uintptr(value) c.next.ptr = uintptr(value)
mapiternext(c.iter) mapiternext(c.iter)
code = c.next code = c.next
case opStructFieldRecursive:
recursive := code.toRecursiveCode()
c := copyOpcode(recursive.jmp.code)
c.ptr = recursive.ptr
if err := e.run(c); err != nil {
return err
}
code = recursive.next
case opStructFieldPtrHead: case opStructFieldPtrHead:
if code.ptr != 0 { if code.ptr != 0 {
code.ptr = e.ptrToPtr(code.ptr) code.ptr = e.ptrToPtr(code.ptr)