Fix encoding of recursive slice/map

This commit is contained in:
Masaaki Goshima 2021-04-02 14:03:00 +09:00
parent 7d4316b94a
commit 7007d6ee41
5 changed files with 95 additions and 5 deletions

View File

@ -7,6 +7,11 @@ import (
"github.com/goccy/go-json"
)
type recursiveMap struct {
A int
B map[string]*recursiveMap
}
func TestCoverMap(t *testing.T) {
type structMap struct {
A map[string]int `json:"a"`
@ -54,6 +59,38 @@ func TestCoverMap(t *testing.T) {
name: "NestedMap",
data: map[string]map[string]int{"a": {"b": 1}},
},
{
name: "RecursiveMap",
data: map[string]*recursiveMap{
"keyA": {
A: 1,
B: map[string]*recursiveMap{
"keyB": {
A: 2,
B: map[string]*recursiveMap{
"keyC": {
A: 3,
},
},
},
},
},
"keyD": {
A: 4,
B: map[string]*recursiveMap{
"keyE": {
A: 5,
B: map[string]*recursiveMap{
"keyF": {
A: 6,
},
},
},
},
},
},
},
// HeadMapZero
{
name: "HeadMapZero",

View File

@ -23,6 +23,11 @@ func (coverSliceMarshalText) MarshalText() ([]byte, error) {
return []byte(`"hello"`), nil
}
type recursiveSlice struct {
A int
B []*recursiveSlice
}
func TestCoverSlice(t *testing.T) {
type structSlice struct {
A []int `json:"a"`
@ -226,6 +231,33 @@ func TestCoverSlice(t *testing.T) {
name: "SliceStructPtr",
data: []*struct{ A int }{&struct{ A int }{A: 1}, &struct{ A int }{A: 2}},
},
{
name: "RecursiveSlice",
data: []*recursiveSlice{
{
A: 1, B: []*recursiveSlice{
{
A: 2, B: []*recursiveSlice{
{
A: 3,
},
},
},
},
},
{
A: 4, B: []*recursiveSlice{
{
A: 5, B: []*recursiveSlice{
{
A: 6,
},
},
},
},
},
},
},
// HeadSliceZero
{

View File

@ -107,12 +107,24 @@ func compileHead(ctx *compileContext) (*Opcode, error) {
return compileBytes(ctx)
}
}
return compileSlice(ctx)
code, err := compileSlice(ctx)
if err != nil {
return nil, err
}
optimizeStructEnd(code)
linkRecursiveCode(code)
return code, nil
case reflect.Map:
if isPtr {
return compilePtr(ctx.withType(runtime.PtrTo(typ)))
}
return compileMap(ctx.withType(typ))
code, err := compileMap(ctx.withType(typ))
if err != nil {
return nil, err
}
optimizeStructEnd(code)
linkRecursiveCode(code)
return code, nil
case reflect.Struct:
code, err := compileStruct(ctx.withType(typ), isPtr)
if err != nil {
@ -243,10 +255,11 @@ func linkRecursiveCode(c *Opcode) {
lastCode.Idx = beforeLastCode.Idx + uintptrSize
lastCode.ElemIdx = lastCode.Idx + uintptrSize
lastCode.Length = lastCode.Idx + 2*uintptrSize
// extend length to alloc slot for elemIdx
totalLength := uintptr(code.TotalLength() + 1)
nextTotalLength := uintptr(c.TotalLength() + 1)
// extend length to alloc slot for elemIdx + length
totalLength := uintptr(code.TotalLength() + 2)
nextTotalLength := uintptr(c.TotalLength() + 2)
c.End.Next.Op = OpRecursiveEnd

View File

@ -563,6 +563,8 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
offsetNum := ptrOffset / uintptrSize
oldOffset := ptrOffset
ptrOffset += code.Jmp.CurLen * uintptrSize
oldBaseIndent := ctx.BaseIndent
ctx.BaseIndent += code.Indent - 1
newLen := offsetNum + code.Jmp.CurLen + code.Jmp.NextLen
if curlen < newLen {
@ -573,12 +575,14 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
store(ctxptr, c.Idx, ptr)
store(ctxptr, c.End.Next.Idx, oldOffset)
store(ctxptr, c.End.Next.ElemIdx, uintptr(unsafe.Pointer(code.Next)))
store(ctxptr, c.End.Next.Length, uintptr(oldBaseIndent))
code = c
recursiveLevel++
case encoder.OpRecursiveEnd:
recursiveLevel--
// restore ctxptr
ctx.BaseIndent = int(load(ctxptr, code.Length))
offset := load(ctxptr, code.Idx)
ctx.SeenPtr = ctx.SeenPtr[:len(ctx.SeenPtr)-1]

View File

@ -563,6 +563,8 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
offsetNum := ptrOffset / uintptrSize
oldOffset := ptrOffset
ptrOffset += code.Jmp.CurLen * uintptrSize
oldBaseIndent := ctx.BaseIndent
ctx.BaseIndent += code.Indent - 1
newLen := offsetNum + code.Jmp.CurLen + code.Jmp.NextLen
if curlen < newLen {
@ -573,12 +575,14 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
store(ctxptr, c.Idx, ptr)
store(ctxptr, c.End.Next.Idx, oldOffset)
store(ctxptr, c.End.Next.ElemIdx, uintptr(unsafe.Pointer(code.Next)))
store(ctxptr, c.End.Next.Length, uintptr(oldBaseIndent))
code = c
recursiveLevel++
case encoder.OpRecursiveEnd:
recursiveLevel--
// restore ctxptr
ctx.BaseIndent = int(load(ctxptr, code.Length))
offset := load(ctxptr, code.Idx)
ctx.SeenPtr = ctx.SeenPtr[:len(ctx.SeenPtr)-1]