Fix recursive call

This commit is contained in:
Masaaki Goshima 2021-01-15 16:25:00 +09:00
parent 66ced55701
commit ddfae9189e
5 changed files with 88 additions and 71 deletions

View File

@ -27,7 +27,10 @@ type Encoder struct {
} }
type compiledCode struct { type compiledCode struct {
code *opcode code *opcode
linked bool // whether recursive code already have linked
curLen uintptr
nextLen uintptr
} }
const ( const (

View File

@ -34,6 +34,7 @@ func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) {
} }
e.convertHeadOnlyCode(code, isPtr) e.convertHeadOnlyCode(code, isPtr)
e.optimizeStructEnd(code) e.optimizeStructEnd(code)
e.linkRecursiveCode(code)
return code, nil return code, nil
} else if isPtr && typ.Implements(marshalTextType) { } else if isPtr && typ.Implements(marshalTextType) {
typ = orgType typ = orgType
@ -46,9 +47,54 @@ func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) {
} }
e.convertHeadOnlyCode(code, isPtr) e.convertHeadOnlyCode(code, isPtr)
e.optimizeStructEnd(code) e.optimizeStructEnd(code)
e.linkRecursiveCode(code)
return code, nil return code, nil
} }
func (e *Encoder) linkRecursiveCode(c *opcode) {
for code := c; code.op != opEnd && code.op != opStructFieldRecursiveEnd; {
switch code.op {
case opStructFieldRecursive,
opStructFieldPtrAnonymousHeadRecursive,
opStructFieldAnonymousHeadRecursive:
if code.jmp.linked {
code = code.next
continue
}
code.jmp.code = copyOpcode(code.jmp.code)
c := code.jmp.code
c.end.next = newEndOp(&encodeCompileContext{})
c.op = c.op.ptrHeadToHead()
beforeLastCode := c.end
lastCode := beforeLastCode.next
lastCode.idx = beforeLastCode.idx + uintptrSize
lastCode.elemIdx = lastCode.idx + uintptrSize
// extend length to alloc slot for elemIdx
totalLength := uintptr(code.totalLength() + 1)
nextTotalLength := uintptr(c.totalLength() + 1)
c.end.next.op = opStructFieldRecursiveEnd
code.jmp.curLen = totalLength
code.jmp.nextLen = nextTotalLength
code.jmp.linked = true
e.linkRecursiveCode(code.jmp.code)
code = code.next
continue
}
switch code.op.codeType() {
case codeArrayElem, codeSliceElem, codeMapKey:
code = code.end
default:
code = code.next
}
}
}
func (e *Encoder) optimizeStructEnd(c *opcode) { func (e *Encoder) optimizeStructEnd(c *opcode) {
for code := c; code.op != opEnd; { for code := c; code.op != opEnd; {
if code.op == opStructFieldRecursive { if code.op == opStructFieldRecursive {

View File

@ -86,6 +86,7 @@ func (c *encodeCompileContext) decPtrIndex() {
type encodeRuntimeContext struct { type encodeRuntimeContext struct {
ptrs []uintptr ptrs []uintptr
keepRefs []unsafe.Pointer keepRefs []unsafe.Pointer
seenPtr []uintptr
} }
func (c *encodeRuntimeContext) init(p uintptr, codelen int) { func (c *encodeRuntimeContext) init(p uintptr, codelen int) {
@ -94,6 +95,7 @@ func (c *encodeRuntimeContext) init(p uintptr, codelen int) {
} }
c.ptrs[0] = p c.ptrs[0] = p
c.keepRefs = c.keepRefs[:0] c.keepRefs = c.keepRefs[:0]
c.seenPtr = c.seenPtr[:0]
} }
func (c *encodeRuntimeContext) ptr() uintptr { func (c *encodeRuntimeContext) ptr() uintptr {

View File

@ -50,7 +50,6 @@ func errMarshaler(code *opcode, err error) *MarshalerError {
func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) { func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
recursiveLevel := 0 recursiveLevel := 0
var seenPtr map[uintptr]struct{}
ptrOffset := uintptr(0) ptrOffset := uintptr(0)
ctxptr := ctx.ptr() ctxptr := ctx.ptr()
@ -200,13 +199,12 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
code = code.next code = code.next
break break
} }
if seenPtr == nil { for _, seen := range ctx.seenPtr {
seenPtr = map[uintptr]struct{}{} if ptr == seen {
return nil, errUnsupportedValue(code, ptr)
}
} }
if _, exists := seenPtr[ptr]; exists { ctx.seenPtr = append(ctx.seenPtr, ptr)
return nil, errUnsupportedValue(code, ptr)
}
seenPtr[ptr] = struct{}{}
v := e.ptrToInterface(code, ptr) v := e.ptrToInterface(code, ptr)
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v)) ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
rv := reflect.ValueOf(v) rv := reflect.ValueOf(v)
@ -540,46 +538,29 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
ptr := load(ctxptr, code.idx) ptr := load(ctxptr, code.idx)
if ptr != 0 { if ptr != 0 {
if recursiveLevel > startDetectingCyclesAfter { if recursiveLevel > startDetectingCyclesAfter {
if _, exists := seenPtr[ptr]; exists { for _, seen := range ctx.seenPtr {
return nil, errUnsupportedValue(code, ptr) if ptr == seen {
return nil, errUnsupportedValue(code, ptr)
}
} }
} }
} }
if seenPtr == nil { ctx.seenPtr = append(ctx.seenPtr, ptr)
seenPtr = map[uintptr]struct{}{}
}
seenPtr[ptr] = struct{}{}
c := code.jmp.code c := code.jmp.code
c.end.next = newEndOp(&encodeCompileContext{})
c.op = c.op.ptrHeadToHead()
beforeLastCode := c.end
lastCode := beforeLastCode.next
lastCode.idx = beforeLastCode.idx + uintptrSize
lastCode.elemIdx = lastCode.idx + uintptrSize
// extend length to alloc slot for elemIdx
totalLength := uintptr(code.totalLength() + 1)
nextTotalLength := uintptr(c.totalLength() + 1)
curlen := uintptr(len(ctx.ptrs)) curlen := uintptr(len(ctx.ptrs))
offsetNum := ptrOffset / uintptrSize offsetNum := ptrOffset / uintptrSize
oldOffset := ptrOffset oldOffset := ptrOffset
ptrOffset += totalLength * uintptrSize ptrOffset += code.jmp.curLen * uintptrSize
newLen := offsetNum + totalLength + nextTotalLength newLen := offsetNum + code.jmp.curLen + code.jmp.nextLen
if curlen < newLen { if curlen < newLen {
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...) ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
} }
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
store(ctxptr, c.idx, ptr) store(ctxptr, c.idx, ptr)
store(ctxptr, lastCode.idx, oldOffset) store(ctxptr, c.end.next.idx, oldOffset)
store(ctxptr, lastCode.elemIdx, uintptr(unsafe.Pointer(code.next))) store(ctxptr, c.end.next.elemIdx, uintptr(unsafe.Pointer(code.next)))
// link lastCode ( opStructFieldRecursiveEnd ) => code.next
lastCode.op = opStructFieldRecursiveEnd
code = c code = c
recursiveLevel++ recursiveLevel++
case opStructFieldRecursiveEnd: case opStructFieldRecursiveEnd:
@ -587,8 +568,10 @@ func (e *Encoder) run(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte
// restore ctxptr // restore ctxptr
offset := load(ctxptr, code.idx) offset := load(ctxptr, code.idx)
ptr := load(ctxptr, code.elemIdx) ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1]
code = (*opcode)(e.ptrToUnsafePtr(ptr))
codePtr := load(ctxptr, code.elemIdx)
code = (*opcode)(e.ptrToUnsafePtr(codePtr))
ctxptr = ctx.ptr() + offset ctxptr = ctx.ptr() + offset
ptrOffset = offset ptrOffset = offset
case opStructFieldPtrHead: case opStructFieldPtrHead:

View File

@ -12,7 +12,6 @@ import (
func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) { func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode) ([]byte, error) {
recursiveLevel := 0 recursiveLevel := 0
var seenPtr map[uintptr]struct{}
ptrOffset := uintptr(0) ptrOffset := uintptr(0)
ctxptr := ctx.ptr() ctxptr := ctx.ptr()
@ -162,13 +161,12 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
code = code.next code = code.next
break break
} }
if seenPtr == nil { for _, seen := range ctx.seenPtr {
seenPtr = map[uintptr]struct{}{} if ptr == seen {
return nil, errUnsupportedValue(code, ptr)
}
} }
if _, exists := seenPtr[ptr]; exists { ctx.seenPtr = append(ctx.seenPtr, ptr)
return nil, errUnsupportedValue(code, ptr)
}
seenPtr[ptr] = struct{}{}
v := e.ptrToInterface(code, ptr) v := e.ptrToInterface(code, ptr)
ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v)) ctx.keepRefs = append(ctx.keepRefs, unsafe.Pointer(&v))
rv := reflect.ValueOf(v) rv := reflect.ValueOf(v)
@ -502,46 +500,29 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
ptr := load(ctxptr, code.idx) ptr := load(ctxptr, code.idx)
if ptr != 0 { if ptr != 0 {
if recursiveLevel > startDetectingCyclesAfter { if recursiveLevel > startDetectingCyclesAfter {
if _, exists := seenPtr[ptr]; exists { for _, seen := range ctx.seenPtr {
return nil, errUnsupportedValue(code, ptr) if ptr == seen {
return nil, errUnsupportedValue(code, ptr)
}
} }
} }
} }
if seenPtr == nil { ctx.seenPtr = append(ctx.seenPtr, ptr)
seenPtr = map[uintptr]struct{}{}
}
seenPtr[ptr] = struct{}{}
c := code.jmp.code c := code.jmp.code
c.end.next = newEndOp(&encodeCompileContext{})
c.op = c.op.ptrHeadToHead()
beforeLastCode := c.end
lastCode := beforeLastCode.next
lastCode.idx = beforeLastCode.idx + uintptrSize
lastCode.elemIdx = lastCode.idx + uintptrSize
// extend length to alloc slot for elemIdx
totalLength := uintptr(code.totalLength() + 1)
nextTotalLength := uintptr(c.totalLength() + 1)
curlen := uintptr(len(ctx.ptrs)) curlen := uintptr(len(ctx.ptrs))
offsetNum := ptrOffset / uintptrSize offsetNum := ptrOffset / uintptrSize
oldOffset := ptrOffset oldOffset := ptrOffset
ptrOffset += totalLength * uintptrSize ptrOffset += code.jmp.curLen * uintptrSize
newLen := offsetNum + totalLength + nextTotalLength newLen := offsetNum + code.jmp.curLen + code.jmp.nextLen
if curlen < newLen { if curlen < newLen {
ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...) ctx.ptrs = append(ctx.ptrs, make([]uintptr, newLen-curlen)...)
} }
ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr ctxptr = ctx.ptr() + ptrOffset // assign new ctxptr
store(ctxptr, c.idx, ptr) store(ctxptr, c.idx, ptr)
store(ctxptr, lastCode.idx, oldOffset) store(ctxptr, c.end.next.idx, oldOffset)
store(ctxptr, lastCode.elemIdx, uintptr(unsafe.Pointer(code.next))) store(ctxptr, c.end.next.elemIdx, uintptr(unsafe.Pointer(code.next)))
// link lastCode ( opStructFieldRecursiveEnd ) => code.next
lastCode.op = opStructFieldRecursiveEnd
code = c code = c
recursiveLevel++ recursiveLevel++
case opStructFieldRecursiveEnd: case opStructFieldRecursiveEnd:
@ -549,8 +530,10 @@ func (e *Encoder) runEscaped(ctx *encodeRuntimeContext, b []byte, code *opcode)
// restore ctxptr // restore ctxptr
offset := load(ctxptr, code.idx) offset := load(ctxptr, code.idx)
ptr := load(ctxptr, code.elemIdx) ctx.seenPtr = ctx.seenPtr[:len(ctx.seenPtr)-1]
code = (*opcode)(e.ptrToUnsafePtr(ptr))
codePtr := load(ctxptr, code.elemIdx)
code = (*opcode)(e.ptrToUnsafePtr(codePtr))
ctxptr = ctx.ptr() + offset ctxptr = ctx.ptr() + offset
ptrOffset = offset ptrOffset = offset
case opStructFieldPtrHead: case opStructFieldPtrHead: