Fix recursive call

This commit is contained in:
Masaaki Goshima 2021-01-15 16:25:00 +09:00
parent 66ced55701
commit ddfae9189e
5 changed files with 88 additions and 71 deletions

View File

@ -27,7 +27,10 @@ type Encoder struct {
}
type compiledCode struct {
code *opcode
code *opcode
linked bool // whether recursive code already have linked
curLen uintptr
nextLen uintptr
}
const (

View File

@ -34,6 +34,7 @@ func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) {
}
e.convertHeadOnlyCode(code, isPtr)
e.optimizeStructEnd(code)
e.linkRecursiveCode(code)
return code, nil
} else if isPtr && typ.Implements(marshalTextType) {
typ = orgType
@ -46,9 +47,54 @@ func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) {
}
e.convertHeadOnlyCode(code, isPtr)
e.optimizeStructEnd(code)
e.linkRecursiveCode(code)
return code, nil
}
func (e *Encoder) linkRecursiveCode(c *opcode) {
for code := c; code.op != opEnd && code.op != opStructFieldRecursiveEnd; {
switch code.op {
case opStructFieldRecursive,
opStructFieldPtrAnonymousHeadRecursive,
opStructFieldAnonymousHeadRecursive:
if code.jmp.linked {
code = code.next
continue
}
code.jmp.code = copyOpcode(code.jmp.code)
c := code.jmp.code
c.end.next = newEndOp(&encodeCompileContext{})
c.op = c.op.ptrHeadToHead()
beforeLastCode := c.end
lastCode := beforeLastCode.next
lastCode.idx = beforeLastCode.idx + uintptrSize
lastCode.elemIdx = lastCode.idx + uintptrSize
// extend length to alloc slot for elemIdx
totalLength := uintptr(code.totalLength() + 1)
nextTotalLength := uintptr(c.totalLength() + 1)
c.end.next.op = opStructFieldRecursiveEnd
code.jmp.curLen = totalLength
code.jmp.nextLen = nextTotalLength
code.jmp.linked = true
e.linkRecursiveCode(code.jmp.code)
code = code.next
continue
}
switch code.op.codeType() {
case codeArrayElem, codeSliceElem, codeMapKey:
code = code.end
default:
code = code.next
}
}
}
func (e *Encoder) optimizeStructEnd(c *opcode) {
for code := c; code.op != opEnd; {
if code.op == opStructFieldRecursive {

View File

@ -86,6 +86,7 @@ func (c *encodeCompileContext) decPtrIndex() {
type encodeRuntimeContext struct {
ptrs []uintptr
keepRefs []unsafe.Pointer
seenPtr []uintptr
}
func (c *encodeRuntimeContext) init(p uintptr, codelen int) {
@ -94,6 +95,7 @@ func (c *encodeRuntimeContext) init(p uintptr, codelen int) {
}
c.ptrs[0] = p
c.keepRefs = c.keepRefs[:0]
c.seenPtr = c.seenPtr[:0]
}
func (c *encodeRuntimeContext) ptr() uintptr {

View File

@ -50,7 +50,6 @@ func errMarshaler(code *opcode, err error) *MarshalerError {
func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
recursiveLevel := 0
var seenPtr map[uintptr]struct{}
ptrOffset := uintptr(0)
ctxptr := ctx.ptr()
@ -200,13 +199,12 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
code = code.next
break
}
if seenPtr == nil {
seenPtr = map[uintptr]struct{}{}
for _, seen := range ctx.seenPtr {
if ptr == seen {
return nil, errUnsupportedValue(code, ptr)
}
}
if _, exists := seenPtr[ptr]; exists {
return nil, errUnsupportedValue(code, ptr)
}
seenPtr[ptr] = struct{}{}
ctx.seenPtr = append(ctx.seenPtr, ptr)
v := e.ptrToInterface(code, ptr)
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
rv := reflect.ValueOf(v)
@ -540,46 +538,29 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
ptr := load(ctxptr, code.idx)
if ptr != 0 {
if recursiveLevel > startDetectingCyclesAfter {
if _, exists := seenPtr[ptr]; exists {
return nil, errUnsupportedValue(code, ptr)
for _, seen := range ctx.seenPtr {
if ptr == seen {
return nil, errUnsupportedValue(code, ptr)
}
}
}
}
if seenPtr == nil {
seenPtr = map[uintptr]struct{}{}
}
seenPtr[ptr] = struct{}{}
ctx.seenPtr = append(ctx.seenPtr, ptr)
c := code.jmp.code
c.end.next = newEndOp(&encodeCompileContext{})
c.op = c.op.ptrHeadToHead()
beforeLastCode := c.end
lastCode := beforeLastCode.next
lastCode.idx = beforeLastCode.idx + uintptrSize
lastCode.elemIdx = lastCode.idx + uintptrSize
// extend length to alloc slot for elemIdx
totalLength := uintptr(code.totalLength() + 1)
nextTotalLength := uintptr(c.totalLength() + 1)
curlen := uintptr(len(ctx.ptrs))
offsetNum := ptrOffset / uintptrSize
oldOffset := ptrOffset
ptrOffset += totalLength * uintptrSize
ptrOffset += code.jmp.curLen * uintptrSize
newLen := offsetNum + totalLength + nextTotalLength
newLen := offsetNum + code.jmp.curLen + code.jmp.nextLen
if curlen < newLen {
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
}
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
store(ctxptr, c.idx, ptr)
store(ctxptr, lastCode.idx, oldOffset)
store(ctxptr, lastCode.elemIdx, uintptr(unsafe.Pointer(code.next)))
// link lastCode ( opStructFieldRecursiveEnd ) => code.next
lastCode.op = opStructFieldRecursiveEnd
store(ctxptr, c.end.next.idx, oldOffset)
store(ctxptr, c.end.next.elemIdx, uintptr(unsafe.Pointer(code.next)))
code = c
recursiveLevel++
case opStructFieldRecursiveEnd:
@ -587,8 +568,10 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
// restore ctxptr
offset := load(ctxptr, code.idx)
ptr := load(ctxptr, code.elemIdx)
code = (*opcode)(e.ptrToUnsafePtr(ptr))
ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1]
codePtr := load(ctxptr, code.elemIdx)
code = (*opcode)(e.ptrToUnsafePtr(codePtr))
ctxptr = ctx.ptr() + offset
ptrOffset = offset
case opStructFieldPtrHead:

View File

@ -12,7 +12,6 @@ import (
func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
recursiveLevel := 0
var seenPtr map[uintptr]struct{}
ptrOffset := uintptr(0)
ctxptr := ctx.ptr()
@ -162,13 +161,12 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
code = code.next
break
}
if seenPtr == nil {
seenPtr = map[uintptr]struct{}{}
for _, seen := range ctx.seenPtr {
if ptr == seen {
return nil, errUnsupportedValue(code, ptr)
}
}
if _, exists := seenPtr[ptr]; exists {
return nil, errUnsupportedValue(code, ptr)
}
seenPtr[ptr] = struct{}{}
ctx.seenPtr = append(ctx.seenPtr, ptr)
v := e.ptrToInterface(code, ptr)
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
rv := reflect.ValueOf(v)
@ -502,46 +500,29 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
ptr := load(ctxptr, code.idx)
if ptr != 0 {
if recursiveLevel > startDetectingCyclesAfter {
if _, exists := seenPtr[ptr]; exists {
return nil, errUnsupportedValue(code, ptr)
for _, seen := range ctx.seenPtr {
if ptr == seen {
return nil, errUnsupportedValue(code, ptr)
}
}
}
}
if seenPtr == nil {
seenPtr = map[uintptr]struct{}{}
}
seenPtr[ptr] = struct{}{}
ctx.seenPtr = append(ctx.seenPtr, ptr)
c := code.jmp.code
c.end.next = newEndOp(&encodeCompileContext{})
c.op = c.op.ptrHeadToHead()
beforeLastCode := c.end
lastCode := beforeLastCode.next
lastCode.idx = beforeLastCode.idx + uintptrSize
lastCode.elemIdx = lastCode.idx + uintptrSize
// extend length to alloc slot for elemIdx
totalLength := uintptr(code.totalLength() + 1)
nextTotalLength := uintptr(c.totalLength() + 1)
curlen := uintptr(len(ctx.ptrs))
offsetNum := ptrOffset / uintptrSize
oldOffset := ptrOffset
ptrOffset += totalLength * uintptrSize
ptrOffset += code.jmp.curLen * uintptrSize
newLen := offsetNum + totalLength + nextTotalLength
newLen := offsetNum + code.jmp.curLen + code.jmp.nextLen
if curlen < newLen {
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
}
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
store(ctxptr, c.idx, ptr)
store(ctxptr, lastCode.idx, oldOffset)
store(ctxptr, lastCode.elemIdx, uintptr(unsafe.Pointer(code.next)))
// link lastCode ( opStructFieldRecursiveEnd ) => code.next
lastCode.op = opStructFieldRecursiveEnd
store(ctxptr, c.end.next.idx, oldOffset)
store(ctxptr, c.end.next.elemIdx, uintptr(unsafe.Pointer(code.next)))
code = c
recursiveLevel++
case opStructFieldRecursiveEnd:
@ -549,8 +530,10 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
// restore ctxptr
offset := load(ctxptr, code.idx)
ptr := load(ctxptr, code.elemIdx)
code = (*opcode)(e.ptrToUnsafePtr(ptr))
ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1]
codePtr := load(ctxptr, code.elemIdx)
code = (*opcode)(e.ptrToUnsafePtr(codePtr))
ctxptr = ctx.ptr() + offset
ptrOffset = offset
case opStructFieldPtrHead: