Merge pull request #37 from goccy/feature/fix-cycle-pointer

Fix cycle pointer value
This commit is contained in:
Masaaki Goshima 2020-08-21 01:07:41 +09:00 committed by GitHub
commit a257c9b964
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 2 deletions

View File

@ -479,6 +479,7 @@ func (c *interfaceCode) copy(codeMap map[uintptr]*opcode) *opcode {
type recursiveCode struct {
*opcodeHeader
jmp *compiledCode
seenPtr uintptr
}
func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode {
@ -489,7 +490,7 @@ func (c *recursiveCode) copy(codeMap map[uintptr]*opcode) *opcode {
if code, exists := codeMap[addr]; exists {
return code
}
recur := &recursiveCode{}
recur := &recursiveCode{seenPtr: c.seenPtr}
code := (*opcode)(unsafe.Pointer(recur))
codeMap[addr] = code

View File

@ -1158,3 +1158,56 @@ func TestEncodePointerString(t *testing.T) {
t.Fatalf("*N = %d; want 42", *back.N)
}
}
type SamePointerNoCycle struct {
Ptr1, Ptr2 *SamePointerNoCycle
}
var samePointerNoCycle = &SamePointerNoCycle{}
type PointerCycle struct {
Ptr *PointerCycle
}
var pointerCycle = &PointerCycle{}
type PointerCycleIndirect struct {
Ptrs []interface{}
}
var pointerCycleIndirect = &PointerCycleIndirect{}
func init() {
ptr := &SamePointerNoCycle{}
samePointerNoCycle.Ptr1 = ptr
samePointerNoCycle.Ptr2 = ptr
pointerCycle.Ptr = pointerCycle
pointerCycleIndirect.Ptrs = []interface{}{pointerCycleIndirect}
}
func TestSamePointerNoCycle(t *testing.T) {
if _, err := json.Marshal(samePointerNoCycle); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
var unsupportedValues = []interface{}{
math.NaN(),
math.Inf(-1),
math.Inf(1),
pointerCycle,
pointerCycleIndirect,
}
func TestUnsupportedValues(t *testing.T) {
for _, v := range unsupportedValues {
if _, err := json.Marshal(v); err != nil {
if _, ok := err.(*json.UnsupportedValueError); !ok {
t.Errorf("for %v, got %T want UnsupportedValueError", v, err)
}
} else {
t.Errorf("for %v, expected error", v)
}
}
}

View File

@ -12,6 +12,7 @@ import (
)
func (e *Encoder) run(code *opcode) error {
seenPtr := map[uintptr]struct{}{}
for {
switch code.op {
case opPtr:
@ -80,6 +81,13 @@ func (e *Encoder) run(code *opcode) error {
typ: ifaceCode.typ,
ptr: unsafe.Pointer(ptr),
}))
if _, exists := seenPtr[ptr]; exists {
return &UnsupportedValueError{
Value: reflect.ValueOf(v),
Str: fmt.Sprintf("encountered a cycle via %s", code.typ),
}
}
seenPtr[ptr] = struct{}{}
rv := reflect.ValueOf(v)
if rv.IsNil() {
e.encodeNull()
@ -498,6 +506,17 @@ func (e *Encoder) run(code *opcode) error {
code = c.next
case opStructFieldRecursive:
recursive := code.toRecursiveCode()
if recursive.seenPtr != 0 && recursive.seenPtr == recursive.ptr {
v := *(*interface{})(unsafe.Pointer(&interfaceHeader{
typ: code.typ,
ptr: unsafe.Pointer(recursive.ptr),
}))
return &UnsupportedValueError{
Value: reflect.ValueOf(v),
Str: fmt.Sprintf("encountered a cycle via %s", code.typ),
}
}
recursive.seenPtr = recursive.ptr
if err := e.run(newRecursiveCode(recursive)); err != nil {
return err
}