Fix seenPtr

This commit is contained in:
Masaaki Goshima 2020-08-21 01:01:24 +09:00
parent 23dbdf7fbd
commit d7518e3151
3 changed files with 8 additions and 8 deletions

View File

@ -22,7 +22,6 @@ type Encoder struct {
indent int
structTypeToCompiledCode map[uintptr]*compiledCode
structTypeToCompiledIndentCode map[uintptr]*compiledCode
seenPtr map[uintptr]struct{}
}
type compiledCode struct {
@ -68,7 +67,6 @@ func init() {
buf: make([]byte, 0, bufSize),
structTypeToCompiledCode: map[uintptr]*compiledCode{},
structTypeToCompiledIndentCode: map[uintptr]*compiledCode{},
seenPtr: map[uintptr]struct{}{},
}
},
}

View File

@ -478,7 +478,8 @@ func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode {
type recursiveCode struct {
*opcodeHeader
jmp *compiledCode
jmp *compiledCode
seenPtr uintptr
}
func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode {
@ -489,7 +490,7 @@ func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode {
if code, exists := codeMap[addr]; exists {
return code
}
recur := &recursiveCode{}
recur := &recursiveCode{seenPtr: c.seenPtr}
code := (*opcode)(unsafe.Pointer(recur))
codeMap[addr] = code

View File

@ -12,6 +12,7 @@ import (
)
func (e *Encoder) run(code *opcode) error {
seenPtr := map[uintptr]struct{}{}
for {
switch code.op {
case opPtr:
@ -80,13 +81,13 @@ func (e *Encoder) run(code *opcode) error {
typ: ifaceCode.typ,
ptr: unsafe.Pointer(ptr),
}))
if _, exists := e.seenPtr[ptr]; exists {
if _, exists := seenPtr[ptr]; exists {
return &UnsupportedValueError{
Value: reflect.ValueOf(v),
Str: fmt.Sprintf("encountered a cycle via %s", code.typ),
}
}
e.seenPtr[ptr] = struct{}{}
seenPtr[ptr] = struct{}{}
rv := reflect.ValueOf(v)
if rv.IsNil() {
e.encodeNull()
@ -505,7 +506,7 @@ func (e *Encoder) run(code *opcode) error {
code = c.next
case opStructFieldRecursive:
recursive := code.toRecursiveCode()
if _, exists := e.seenPtr[recursive.ptr]; exists {
if recursive.seenPtr != 0 && recursive.seenPtr == recursive.ptr {
v := *(*interface{})(unsafe.Pointer(&interfaceHeader{
typ: code.typ,
ptr: unsafe.Pointer(recursive.ptr),
@ -515,7 +516,7 @@ func (e *Encoder) run(code *opcode) error {
Str: fmt.Sprintf("encountered a cycle via %s", code.typ),
}
}
e.seenPtr[recursive.ptr] = struct{}{}
recursive.seenPtr = recursive.ptr
if err := e.run(newRecursiveCode(recursive)); err != nil {
return err
}