From 7ac966b81e7f3560cf8adab61821ed3c4b2be277 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 20 Aug 2020 23:56:12 +0900 Subject: [PATCH 1/3] Fix cycle pointer value --- encode.go | 2 ++ encode_vm.go | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/encode.go b/encode.go index f41114e..d3590d5 100644 --- a/encode.go +++ b/encode.go @@ -22,6 +22,7 @@ type Encoder struct { indent int structTypeToCompiledCode map[uintptr]*compiledCode structTypeToCompiledIndentCode map[uintptr]*compiledCode + seenPtr map[uintptr]struct{} } type compiledCode struct { @@ -67,6 +68,7 @@ func init() { buf: make([]byte, 0, bufSize), structTypeToCompiledCode: map[uintptr]*compiledCode{}, structTypeToCompiledIndentCode: map[uintptr]*compiledCode{}, + seenPtr: map[uintptr]struct{}{}, } }, } diff --git a/encode_vm.go b/encode_vm.go index 71e42ff..ab81e07 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -80,6 +80,13 @@ func (e *Encoder) run(code *opcode) error { typ: ifaceCode.typ, ptr: unsafe.Pointer(ptr), })) + if _, exists := e.seenPtr[ptr]; exists { + return &UnsupportedValueError{ + Value: reflect.ValueOf(v), + Str: fmt.Sprintf("encountered a cycle via %s", code.typ), + } + } + e.seenPtr[ptr] = struct{}{} rv := reflect.ValueOf(v) if rv.IsNil() { e.encodeNull() @@ -498,6 +505,17 @@ func (e *Encoder) run(code *opcode) error { code = c.next case opStructFieldRecursive: recursive := code.toRecursiveCode() + if _, exists := e.seenPtr[recursive.ptr]; exists { + v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ + typ: code.typ, + ptr: unsafe.Pointer(recursive.ptr), + })) + return &UnsupportedValueError{ + Value: reflect.ValueOf(v), + Str: fmt.Sprintf("encountered a cycle via %s", code.typ), + } + } + e.seenPtr[recursive.ptr] = struct{}{} if err := e.run(newRecursiveCode(recursive)); err != nil { return err } From 23dbdf7fbde7711572e19ed1495506620e2081dc Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Thu, 20 Aug 2020 23:56:50 +0900 Subject: [PATCH 2/3] Add test case --- encode_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/encode_test.go b/encode_test.go index 9fd5668..d3f622b 100644 --- a/encode_test.go +++ b/encode_test.go @@ -1158,3 +1158,56 @@ func TestEncodePointerString(t *testing.T) { t.Fatalf("*N = %d; want 42", *back.N) } } + +type SamePointerNoCycle struct { + Ptr1, Ptr2 *SamePointerNoCycle +} + +var samePointerNoCycle = &SamePointerNoCycle{} + +type PointerCycle struct { + Ptr *PointerCycle +} + +var pointerCycle = &PointerCycle{} + +type PointerCycleIndirect struct { + Ptrs []interface{} +} + +var pointerCycleIndirect = &PointerCycleIndirect{} + +func init() { + ptr := &SamePointerNoCycle{} + samePointerNoCycle.Ptr1 = ptr + samePointerNoCycle.Ptr2 = ptr + + pointerCycle.Ptr = pointerCycle + pointerCycleIndirect.Ptrs = []interface{}{pointerCycleIndirect} +} + +func TestSamePointerNoCycle(t *testing.T) { + if _, err := json.Marshal(samePointerNoCycle); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +var unsupportedValues = []interface{}{ + math.NaN(), + math.Inf(-1), + math.Inf(1), + pointerCycle, + pointerCycleIndirect, +} + +func TestUnsupportedValues(t *testing.T) { + for _, v := range unsupportedValues { + if _, err := json.Marshal(v); err != nil { + if _, ok := err.(*json.UnsupportedValueError); !ok { + t.Errorf("for %v, got %T want UnsupportedValueError", v, err) + } + } else { + t.Errorf("for %v, expected error", v) + } + } +} From d7518e31513731a843d78b4d55e2c9a989de7610 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 21 Aug 2020 01:01:24 +0900 Subject: [PATCH 3/3] Fix seenPtr --- encode.go | 2 -- encode_opcode.go | 5 +++-- encode_vm.go | 9 +++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/encode.go b/encode.go index d3590d5..f41114e 100644 --- a/encode.go +++ b/encode.go @@ -22,7 +22,6 @@ type Encoder struct { indent int structTypeToCompiledCode map[uintptr]*compiledCode structTypeToCompiledIndentCode map[uintptr]*compiledCode - seenPtr map[uintptr]struct{} } type compiledCode struct { @@ -68,7 +67,6 @@ func init() { buf: make([]byte, 0, bufSize), structTypeToCompiledCode: map[uintptr]*compiledCode{}, structTypeToCompiledIndentCode: map[uintptr]*compiledCode{}, - seenPtr: map[uintptr]struct{}{}, } }, } diff --git a/encode_opcode.go b/encode_opcode.go index f67065e..9fef8bf 100644 --- a/encode_opcode.go +++ b/encode_opcode.go @@ -478,7 +478,8 @@ func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode { type recursiveCode struct { *opcodeHeader - jmp *compiledCode + jmp *compiledCode + seenPtr uintptr } func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode { @@ -489,7 +490,7 @@ func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode { if code, exists := codeMap[addr]; exists { return code } - recur := &recursiveCode{} + recur := &recursiveCode{seenPtr: c.seenPtr} code := (*opcode)(unsafe.Pointer(recur)) codeMap[addr] = code diff --git a/encode_vm.go b/encode_vm.go index ab81e07..e2d0971 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -12,6 +12,7 @@ import ( ) func (e *Encoder) run(code *opcode) error { + seenPtr := map[uintptr]struct{}{} for { switch code.op { case opPtr: @@ -80,13 +81,13 @@ func (e *Encoder) run(code *opcode) error { typ: ifaceCode.typ, ptr: unsafe.Pointer(ptr), })) - if _, exists := e.seenPtr[ptr]; exists { + if _, exists := seenPtr[ptr]; exists { return &UnsupportedValueError{ Value: reflect.ValueOf(v), Str: fmt.Sprintf("encountered a cycle via %s", code.typ), } } - e.seenPtr[ptr] = struct{}{} + seenPtr[ptr] = struct{}{} rv := reflect.ValueOf(v) if rv.IsNil() { e.encodeNull() @@ -505,7 +506,7 @@ func (e *Encoder) run(code *opcode) error { code = c.next case opStructFieldRecursive: recursive := code.toRecursiveCode() - if _, exists := e.seenPtr[recursive.ptr]; exists { + if recursive.seenPtr != 0 && recursive.seenPtr == recursive.ptr { v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ typ: code.typ, ptr: unsafe.Pointer(recursive.ptr), @@ -515,7 +516,7 @@ func (e *Encoder) run(code *opcode) error { Str: fmt.Sprintf("encountered a cycle via %s", code.typ), } } - e.seenPtr[recursive.ptr] = struct{}{} + recursive.seenPtr = recursive.ptr if err := e.run(newRecursiveCode(recursive)); err != nil { return err }