diff --git a/encode_opcode.go b/encode_opcode.go index f67065e..9fef8bf 100644 --- a/encode_opcode.go +++ b/encode_opcode.go @@ -478,7 +478,8 @@ func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode { type recursiveCode struct { *opcodeHeader - jmp *compiledCode + jmp *compiledCode + seenPtr uintptr } func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode { @@ -489,7 +490,7 @@ func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode { if code, exists := codeMap[addr]; exists { return code } - recur := &recursiveCode{} + recur := &recursiveCode{seenPtr: c.seenPtr} code := (*opcode)(unsafe.Pointer(recur)) codeMap[addr] = code diff --git a/encode_test.go b/encode_test.go index 9fd5668..d3f622b 100644 --- a/encode_test.go +++ b/encode_test.go @@ -1158,3 +1158,56 @@ func TestEncodePointerString(t *testing.T) { t.Fatalf("*N = %d; want 42", *back.N) } } + +type SamePointerNoCycle struct { + Ptr1, Ptr2 *SamePointerNoCycle +} + +var samePointerNoCycle = &SamePointerNoCycle{} + +type PointerCycle struct { + Ptr *PointerCycle +} + +var pointerCycle = &PointerCycle{} + +type PointerCycleIndirect struct { + Ptrs []interface{} +} + +var pointerCycleIndirect = &PointerCycleIndirect{} + +func init() { + ptr := &SamePointerNoCycle{} + samePointerNoCycle.Ptr1 = ptr + samePointerNoCycle.Ptr2 = ptr + + pointerCycle.Ptr = pointerCycle + pointerCycleIndirect.Ptrs = []interface{}{pointerCycleIndirect} +} + +func TestSamePointerNoCycle(t *testing.T) { + if _, err := json.Marshal(samePointerNoCycle); err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +var unsupportedValues = []interface{}{ + math.NaN(), + math.Inf(-1), + math.Inf(1), + pointerCycle, + pointerCycleIndirect, +} + +func TestUnsupportedValues(t *testing.T) { + for _, v := range unsupportedValues { + if _, err := json.Marshal(v); err != nil { + if _, ok := err.(*json.UnsupportedValueError); !ok { + t.Errorf("for %v, got %T want UnsupportedValueError", v, err) + } + } else { + t.Errorf("for %v, expected error", v) + } + } +} diff --git a/encode_vm.go b/encode_vm.go index 71e42ff..e2d0971 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -12,6 +12,7 @@ import ( ) func (e *Encoder) run(code *opcode) error { + seenPtr := map[uintptr]struct{}{} for { switch code.op { case opPtr: @@ -80,6 +81,13 @@ func (e *Encoder) run(code *opcode) error { typ: ifaceCode.typ, ptr: unsafe.Pointer(ptr), })) + if _, exists := seenPtr[ptr]; exists { + return &UnsupportedValueError{ + Value: reflect.ValueOf(v), + Str: fmt.Sprintf("encountered a cycle via %s", code.typ), + } + } + seenPtr[ptr] = struct{}{} rv := reflect.ValueOf(v) if rv.IsNil() { e.encodeNull() @@ -498,6 +506,17 @@ func (e *Encoder) run(code *opcode) error { code = c.next case opStructFieldRecursive: recursive := code.toRecursiveCode() + if recursive.seenPtr != 0 && recursive.seenPtr == recursive.ptr { + v := *(*interface{})(unsafe.Pointer(&interfaceHeader{ + typ: code.typ, + ptr: unsafe.Pointer(recursive.ptr), + })) + return &UnsupportedValueError{ + Value: reflect.ValueOf(v), + Str: fmt.Sprintf("encountered a cycle via %s", code.typ), + } + } + recursive.seenPtr = recursive.ptr if err := e.run(newRecursiveCode(recursive)); err != nil { return err }