From d7518e31513731a843d78b4d55e2c9a989de7610 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 21 Aug 2020 01:01:24 +0900 Subject: [PATCH] Fix seenPtr --- encode.go | 2 -- encode_opcode.go | 5 +++-- encode_vm.go | 9 +++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/encode.go b/encode.go index d3590d5..f41114e 100644 --- a/encode.go +++ b/encode.go @@ -22,7 +22,6 @@ type Encoder struct { indent int structTypeToCompiledCode map[uintptr]*compiledCode structTypeToCompiledIndentCode map[uintptr]*compiledCode - seenPtr map[uintptr]struct{} } type compiledCode struct { @@ -68,7 +67,6 @@ func init() { buf: make([]byte, 0, bufSize), structTypeToCompiledCode: map[uintptr]*compiledCode{}, structTypeToCompiledIndentCode: map[uintptr]*compiledCode{}, - seenPtr: map[uintptr]struct{}{}, } }, } diff --git a/encode_opcode.go b/encode_opcode.go index f67065e..9fef8bf 100644 --- a/encode_opcode.go +++ b/encode_opcode.go @@ -478,7 +478,8 @@ func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode { type recursiveCode struct { *opcodeHeader - jmp *compiledCode + jmp *compiledCode + seenPtr uintptr } func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode { @@ -489,7 +490,7 @@ func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode { if code, exists := codeMap[addr]; exists { return code } - recur := &recursiveCode{} + recur := &recursiveCode{seenPtr: c.seenPtr} code := (*opcode)(unsafe.Pointer(recur)) codeMap[addr] = code diff --git a/encode_vm.go b/encode_vm.go index ab81e07..e2d0971 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -12,6 +12,7 @@ import ( ) func (e *Encoder) run(code *opcode) error { + seenPtr := map[uintptr]struct{}{} for { switch code.op { case opPtr: @@ -80,13 +81,13 @@ func (e *Encoder) run(code *opcode) error { typ: ifaceCode.typ, ptr: unsafe.Pointer(ptr), })) - if _, exists := e.seenPtr[ptr]; exists { + if _, exists := seenPtr[ptr]; exists { return &UnsupportedValueError{ Value: reflect.ValueOf(v), Str: fmt.Sprintf("encountered a cycle via %s", code.typ), } } - e.seenPtr[ptr] = struct{}{} + seenPtr[ptr] = struct{}{} rv := reflect.ValueOf(v) if rv.IsNil() { e.encodeNull() @@ -505,7 +506,7 @@ func (e *Encoder) run(code *opcode) error { code = c.next case opStructFieldRecursive: recursive := code.toRecursiveCode() - if _, exists := e.seenPtr[recursive.ptr]; exists { + if recursive.seenPtr != 0 && recursive.seenPtr == recursive.ptr { v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ typ: code.typ, ptr: unsafe.Pointer(recursive.ptr), @@ -515,7 +516,7 @@ func (e *Encoder) run(code *opcode) error { Str: fmt.Sprintf("encountered a cycle via %s", code.typ), } } - e.seenPtr[recursive.ptr] = struct{}{} + recursive.seenPtr = recursive.ptr if err := e.run(newRecursiveCode(recursive)); err != nil { return err }