mirror of https://github.com/goccy/go-json.git
Fix recursive call
This commit is contained in:
parent
66ced55701
commit
ddfae9189e
|
@ -28,6 +28,9 @@ type Encoder struct {
|
|||
|
||||
type compiledCode struct {
|
||||
code *opcode
|
||||
linked bool // whether recursive code already have linked
|
||||
curLen uintptr
|
||||
nextLen uintptr
|
||||
}
|
||||
|
||||
const (
|
||||
|
|
|
@ -34,6 +34,7 @@ func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) {
|
|||
}
|
||||
e.convertHeadOnlyCode(code, isPtr)
|
||||
e.optimizeStructEnd(code)
|
||||
e.linkRecursiveCode(code)
|
||||
return code, nil
|
||||
} else if isPtr && typ.Implements(marshalTextType) {
|
||||
typ = orgType
|
||||
|
@ -46,9 +47,54 @@ func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) {
|
|||
}
|
||||
e.convertHeadOnlyCode(code, isPtr)
|
||||
e.optimizeStructEnd(code)
|
||||
e.linkRecursiveCode(code)
|
||||
return code, nil
|
||||
}
|
||||
|
||||
func (e *Encoder) linkRecursiveCode(c *opcode) {
|
||||
for code := c; code.op != opEnd && code.op != opStructFieldRecursiveEnd; {
|
||||
switch code.op {
|
||||
case opStructFieldRecursive,
|
||||
opStructFieldPtrAnonymousHeadRecursive,
|
||||
opStructFieldAnonymousHeadRecursive:
|
||||
if code.jmp.linked {
|
||||
code = code.next
|
||||
continue
|
||||
}
|
||||
code.jmp.code = copyOpcode(code.jmp.code)
|
||||
c := code.jmp.code
|
||||
c.end.next = newEndOp(&encodeCompileContext{})
|
||||
c.op = c.op.ptrHeadToHead()
|
||||
|
||||
beforeLastCode := c.end
|
||||
lastCode := beforeLastCode.next
|
||||
|
||||
lastCode.idx = beforeLastCode.idx + uintptrSize
|
||||
lastCode.elemIdx = lastCode.idx + uintptrSize
|
||||
|
||||
// extend length to alloc slot for elemIdx
|
||||
totalLength := uintptr(code.totalLength() + 1)
|
||||
nextTotalLength := uintptr(c.totalLength() + 1)
|
||||
|
||||
c.end.next.op = opStructFieldRecursiveEnd
|
||||
|
||||
code.jmp.curLen = totalLength
|
||||
code.jmp.nextLen = nextTotalLength
|
||||
code.jmp.linked = true
|
||||
|
||||
e.linkRecursiveCode(code.jmp.code)
|
||||
code = code.next
|
||||
continue
|
||||
}
|
||||
switch code.op.codeType() {
|
||||
case codeArrayElem, codeSliceElem, codeMapKey:
|
||||
code = code.end
|
||||
default:
|
||||
code = code.next
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Encoder) optimizeStructEnd(c *opcode) {
|
||||
for code := c; code.op != opEnd; {
|
||||
if code.op == opStructFieldRecursive {
|
||||
|
|
|
@ -86,6 +86,7 @@ func (c *encodeCompileContext) decPtrIndex() {
|
|||
type encodeRuntimeContext struct {
|
||||
ptrs []uintptr
|
||||
keepRefs []unsafe.Pointer
|
||||
seenPtr []uintptr
|
||||
}
|
||||
|
||||
func (c *encodeRuntimeContext) init(p uintptr, codelen int) {
|
||||
|
@ -94,6 +95,7 @@ func (c *encodeRuntimeContext) init(p uintptr, codelen int) {
|
|||
}
|
||||
c.ptrs[0] = p
|
||||
c.keepRefs = c.keepRefs[:0]
|
||||
c.seenPtr = c.seenPtr[:0]
|
||||
}
|
||||
|
||||
func (c *encodeRuntimeContext) ptr() uintptr {
|
||||
|
|
47
encode_vm.go
47
encode_vm.go
|
@ -50,7 +50,6 @@ func errMarshaler(code *opcode, err error) *MarshalerError {
|
|||
|
||||
func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
|
||||
recursiveLevel := 0
|
||||
var seenPtr map[uintptr]struct{}
|
||||
ptrOffset := uintptr(0)
|
||||
ctxptr := ctx.ptr()
|
||||
|
||||
|
@ -200,13 +199,12 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
|
|||
code = code.next
|
||||
break
|
||||
}
|
||||
if seenPtr == nil {
|
||||
seenPtr = map[uintptr]struct{}{}
|
||||
}
|
||||
if _, exists := seenPtr[ptr]; exists {
|
||||
for _, seen := range ctx.seenPtr {
|
||||
if ptr == seen {
|
||||
return nil, errUnsupportedValue(code, ptr)
|
||||
}
|
||||
seenPtr[ptr] = struct{}{}
|
||||
}
|
||||
ctx.seenPtr = append(ctx.seenPtr, ptr)
|
||||
v := e.ptrToInterface(code, ptr)
|
||||
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
|
||||
rv := reflect.ValueOf(v)
|
||||
|
@ -540,46 +538,29 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
|
|||
ptr := load(ctxptr, code.idx)
|
||||
if ptr != 0 {
|
||||
if recursiveLevel > startDetectingCyclesAfter {
|
||||
if _, exists := seenPtr[ptr]; exists {
|
||||
for _, seen := range ctx.seenPtr {
|
||||
if ptr == seen {
|
||||
return nil, errUnsupportedValue(code, ptr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if seenPtr == nil {
|
||||
seenPtr = map[uintptr]struct{}{}
|
||||
}
|
||||
seenPtr[ptr] = struct{}{}
|
||||
ctx.seenPtr = append(ctx.seenPtr, ptr)
|
||||
c := code.jmp.code
|
||||
c.end.next = newEndOp(&encodeCompileContext{})
|
||||
c.op = c.op.ptrHeadToHead()
|
||||
|
||||
beforeLastCode := c.end
|
||||
lastCode := beforeLastCode.next
|
||||
|
||||
lastCode.idx = beforeLastCode.idx + uintptrSize
|
||||
lastCode.elemIdx = lastCode.idx + uintptrSize
|
||||
|
||||
// extend length to alloc slot for elemIdx
|
||||
totalLength := uintptr(code.totalLength() + 1)
|
||||
nextTotalLength := uintptr(c.totalLength() + 1)
|
||||
|
||||
curlen := uintptr(len(ctx.ptrs))
|
||||
offsetNum := ptrOffset / uintptrSize
|
||||
oldOffset := ptrOffset
|
||||
ptrOffset += totalLength * uintptrSize
|
||||
ptrOffset += code.jmp.curLen * uintptrSize
|
||||
|
||||
newLen := offsetNum + totalLength + nextTotalLength
|
||||
newLen := offsetNum + code.jmp.curLen + code.jmp.nextLen
|
||||
if curlen < newLen {
|
||||
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
|
||||
}
|
||||
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
|
||||
|
||||
store(ctxptr, c.idx, ptr)
|
||||
store(ctxptr, lastCode.idx, oldOffset)
|
||||
store(ctxptr, lastCode.elemIdx, uintptr(unsafe.Pointer(code.next)))
|
||||
|
||||
// link lastCode ( opStructFieldRecursiveEnd ) => code.next
|
||||
lastCode.op = opStructFieldRecursiveEnd
|
||||
store(ctxptr, c.end.next.idx, oldOffset)
|
||||
store(ctxptr, c.end.next.elemIdx, uintptr(unsafe.Pointer(code.next)))
|
||||
code = c
|
||||
recursiveLevel++
|
||||
case opStructFieldRecursiveEnd:
|
||||
|
@ -587,8 +568,10 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
|
|||
|
||||
// restore ctxptr
|
||||
offset := load(ctxptr, code.idx)
|
||||
ptr := load(ctxptr, code.elemIdx)
|
||||
code = (*opcode)(e.ptrToUnsafePtr(ptr))
|
||||
ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1]
|
||||
|
||||
codePtr := load(ctxptr, code.elemIdx)
|
||||
code = (*opcode)(e.ptrToUnsafePtr(codePtr))
|
||||
ctxptr = ctx.ptr() + offset
|
||||
ptrOffset = offset
|
||||
case opStructFieldPtrHead:
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
|
||||
func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
|
||||
recursiveLevel := 0
|
||||
var seenPtr map[uintptr]struct{}
|
||||
ptrOffset := uintptr(0)
|
||||
ctxptr := ctx.ptr()
|
||||
|
||||
|
@ -162,13 +161,12 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
|
|||
code = code.next
|
||||
break
|
||||
}
|
||||
if seenPtr == nil {
|
||||
seenPtr = map[uintptr]struct{}{}
|
||||
}
|
||||
if _, exists := seenPtr[ptr]; exists {
|
||||
for _, seen := range ctx.seenPtr {
|
||||
if ptr == seen {
|
||||
return nil, errUnsupportedValue(code, ptr)
|
||||
}
|
||||
seenPtr[ptr] = struct{}{}
|
||||
}
|
||||
ctx.seenPtr = append(ctx.seenPtr, ptr)
|
||||
v := e.ptrToInterface(code, ptr)
|
||||
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
|
||||
rv := reflect.ValueOf(v)
|
||||
|
@ -502,46 +500,29 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
|
|||
ptr := load(ctxptr, code.idx)
|
||||
if ptr != 0 {
|
||||
if recursiveLevel > startDetectingCyclesAfter {
|
||||
if _, exists := seenPtr[ptr]; exists {
|
||||
for _, seen := range ctx.seenPtr {
|
||||
if ptr == seen {
|
||||
return nil, errUnsupportedValue(code, ptr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if seenPtr == nil {
|
||||
seenPtr = map[uintptr]struct{}{}
|
||||
}
|
||||
seenPtr[ptr] = struct{}{}
|
||||
ctx.seenPtr = append(ctx.seenPtr, ptr)
|
||||
c := code.jmp.code
|
||||
c.end.next = newEndOp(&encodeCompileContext{})
|
||||
c.op = c.op.ptrHeadToHead()
|
||||
|
||||
beforeLastCode := c.end
|
||||
lastCode := beforeLastCode.next
|
||||
|
||||
lastCode.idx = beforeLastCode.idx + uintptrSize
|
||||
lastCode.elemIdx = lastCode.idx + uintptrSize
|
||||
|
||||
// extend length to alloc slot for elemIdx
|
||||
totalLength := uintptr(code.totalLength() + 1)
|
||||
nextTotalLength := uintptr(c.totalLength() + 1)
|
||||
|
||||
curlen := uintptr(len(ctx.ptrs))
|
||||
offsetNum := ptrOffset / uintptrSize
|
||||
oldOffset := ptrOffset
|
||||
ptrOffset += totalLength * uintptrSize
|
||||
ptrOffset += code.jmp.curLen * uintptrSize
|
||||
|
||||
newLen := offsetNum + totalLength + nextTotalLength
|
||||
newLen := offsetNum + code.jmp.curLen + code.jmp.nextLen
|
||||
if curlen < newLen {
|
||||
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
|
||||
}
|
||||
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
|
||||
|
||||
store(ctxptr, c.idx, ptr)
|
||||
store(ctxptr, lastCode.idx, oldOffset)
|
||||
store(ctxptr, lastCode.elemIdx, uintptr(unsafe.Pointer(code.next)))
|
||||
|
||||
// link lastCode ( opStructFieldRecursiveEnd ) => code.next
|
||||
lastCode.op = opStructFieldRecursiveEnd
|
||||
store(ctxptr, c.end.next.idx, oldOffset)
|
||||
store(ctxptr, c.end.next.elemIdx, uintptr(unsafe.Pointer(code.next)))
|
||||
code = c
|
||||
recursiveLevel++
|
||||
case opStructFieldRecursiveEnd:
|
||||
|
@ -549,8 +530,10 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
|
|||
|
||||
// restore ctxptr
|
||||
offset := load(ctxptr, code.idx)
|
||||
ptr := load(ctxptr, code.elemIdx)
|
||||
code = (*opcode)(e.ptrToUnsafePtr(ptr))
|
||||
ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1]
|
||||
|
||||
codePtr := load(ctxptr, code.elemIdx)
|
||||
code = (*opcode)(e.ptrToUnsafePtr(codePtr))
|
||||
ctxptr = ctx.ptr() + offset
|
||||
ptrOffset = offset
|
||||
case opStructFieldPtrHead:
|
||||
|
|
Loading…
Reference in New Issue