Fix encoding of recursive slice/map

This commit is contained in:
Masaaki Goshima 2021-04-02 14:03:00 +09:00
parent 7d4316b94a
commit 7007d6ee41
5 changed files with 95 additions and 5 deletions

View File

@ -7,6 +7,11 @@ import (
"github.com/goccy/go-json" "github.com/goccy/go-json"
) )
type recursiveMap struct {
A int
B map[string]*recursiveMap
}
func TestCoverMap(t *testing.T) { func TestCoverMap(t *testing.T) {
type structMap struct { type structMap struct {
A map[string]int `json:"a"` A map[string]int `json:"a"`
@ -54,6 +59,38 @@ func TestCoverMap(t *testing.T) {
name: "NestedMap", name: "NestedMap",
data: map[string]map[string]int{"a": {"b": 1}}, data: map[string]map[string]int{"a": {"b": 1}},
}, },
{
name: "RecursiveMap",
data: map[string]*recursiveMap{
"keyA": {
A: 1,
B: map[string]*recursiveMap{
"keyB": {
A: 2,
B: map[string]*recursiveMap{
"keyC": {
A: 3,
},
},
},
},
},
"keyD": {
A: 4,
B: map[string]*recursiveMap{
"keyE": {
A: 5,
B: map[string]*recursiveMap{
"keyF": {
A: 6,
},
},
},
},
},
},
},
// HeadMapZero // HeadMapZero
{ {
name: "HeadMapZero", name: "HeadMapZero",

View File

@ -23,6 +23,11 @@ func (coverSliceMarshalText) MarshalText() ([]byte, error) {
return []byte(`"hello"`), nil return []byte(`"hello"`), nil
} }
type recursiveSlice struct {
A int
B []*recursiveSlice
}
func TestCoverSlice(t *testing.T) { func TestCoverSlice(t *testing.T) {
type structSlice struct { type structSlice struct {
A []int `json:"a"` A []int `json:"a"`
@ -226,6 +231,33 @@ func TestCoverSlice(t *testing.T) {
name: "SliceStructPtr", name: "SliceStructPtr",
data: []*struct{ A int }{&struct{ A int }{A: 1}, &struct{ A int }{A: 2}}, data: []*struct{ A int }{&struct{ A int }{A: 1}, &struct{ A int }{A: 2}},
}, },
{
name: "RecursiveSlice",
data: []*recursiveSlice{
{
A: 1, B: []*recursiveSlice{
{
A: 2, B: []*recursiveSlice{
{
A: 3,
},
},
},
},
},
{
A: 4, B: []*recursiveSlice{
{
A: 5, B: []*recursiveSlice{
{
A: 6,
},
},
},
},
},
},
},
// HeadSliceZero // HeadSliceZero
{ {

View File

@ -107,12 +107,24 @@ func compileHead(ctx *compileContext) (*Opcode, error) {
return compileBytes(ctx) return compileBytes(ctx)
} }
} }
return compileSlice(ctx) code, err := compileSlice(ctx)
if err != nil {
return nil, err
}
optimizeStructEnd(code)
linkRecursiveCode(code)
return code, nil
case reflect.Map: case reflect.Map:
if isPtr { if isPtr {
return compilePtr(ctx.withType(runtime.PtrTo(typ))) return compilePtr(ctx.withType(runtime.PtrTo(typ)))
} }
return compileMap(ctx.withType(typ)) code, err := compileMap(ctx.withType(typ))
if err != nil {
return nil, err
}
optimizeStructEnd(code)
linkRecursiveCode(code)
return code, nil
case reflect.Struct: case reflect.Struct:
code, err := compileStruct(ctx.withType(typ), isPtr) code, err := compileStruct(ctx.withType(typ), isPtr)
if err != nil { if err != nil {
@ -243,10 +255,11 @@ func linkRecursiveCode(c *Opcode) {
lastCode.Idx = beforeLastCode.Idx + uintptrSize lastCode.Idx = beforeLastCode.Idx + uintptrSize
lastCode.ElemIdx = lastCode.Idx + uintptrSize lastCode.ElemIdx = lastCode.Idx + uintptrSize
lastCode.Length = lastCode.Idx + 2*uintptrSize
// extend length to alloc slot for elemIdx // extend length to alloc slot for elemIdx + length
totalLength := uintptr(code.TotalLength() + 1) totalLength := uintptr(code.TotalLength() + 2)
nextTotalLength := uintptr(c.TotalLength() + 1) nextTotalLength := uintptr(c.TotalLength() + 2)
c.End.Next.Op = OpRecursiveEnd c.End.Next.Op = OpRecursiveEnd

View File

@ -563,6 +563,8 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
offsetNum := ptrOffset / uintptrSize offsetNum := ptrOffset / uintptrSize
oldOffset := ptrOffset oldOffset := ptrOffset
ptrOffset += code.Jmp.CurLen * uintptrSize ptrOffset += code.Jmp.CurLen * uintptrSize
oldBaseIndent := ctx.BaseIndent
ctx.BaseIndent += code.Indent - 1
newLen := offsetNum + code.Jmp.CurLen + code.Jmp.NextLen newLen := offsetNum + code.Jmp.CurLen + code.Jmp.NextLen
if curlen < newLen { if curlen < newLen {
@ -573,12 +575,14 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
store(ctxptr, c.Idx, ptr) store(ctxptr, c.Idx, ptr)
store(ctxptr, c.End.Next.Idx, oldOffset) store(ctxptr, c.End.Next.Idx, oldOffset)
store(ctxptr, c.End.Next.ElemIdx, uintptr(unsafe.Pointer(code.Next))) store(ctxptr, c.End.Next.ElemIdx, uintptr(unsafe.Pointer(code.Next)))
store(ctxptr, c.End.Next.Length, uintptr(oldBaseIndent))
code = c code = c
recursiveLevel++ recursiveLevel++
case encoder.OpRecursiveEnd: case encoder.OpRecursiveEnd:
recursiveLevel-- recursiveLevel--
// restore ctxptr // restore ctxptr
ctx.BaseIndent = int(load(ctxptr, code.Length))
offset := load(ctxptr, code.Idx) offset := load(ctxptr, code.Idx)
ctx.SeenPtr = ctx.SeenPtr[:len(ctx.SeenPtr)-1] ctx.SeenPtr = ctx.SeenPtr[:len(ctx.SeenPtr)-1]

View File

@ -563,6 +563,8 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
offsetNum := ptrOffset / uintptrSize offsetNum := ptrOffset / uintptrSize
oldOffset := ptrOffset oldOffset := ptrOffset
ptrOffset += code.Jmp.CurLen * uintptrSize ptrOffset += code.Jmp.CurLen * uintptrSize
oldBaseIndent := ctx.BaseIndent
ctx.BaseIndent += code.Indent - 1
newLen := offsetNum + code.Jmp.CurLen + code.Jmp.NextLen newLen := offsetNum + code.Jmp.CurLen + code.Jmp.NextLen
if curlen < newLen { if curlen < newLen {
@ -573,12 +575,14 @@ func Run(ctx *encoder.RuntimeContext, b []byte, codeSet *encoder.OpcodeSet, opt
store(ctxptr, c.Idx, ptr) store(ctxptr, c.Idx, ptr)
store(ctxptr, c.End.Next.Idx, oldOffset) store(ctxptr, c.End.Next.Idx, oldOffset)
store(ctxptr, c.End.Next.ElemIdx, uintptr(unsafe.Pointer(code.Next))) store(ctxptr, c.End.Next.ElemIdx, uintptr(unsafe.Pointer(code.Next)))
store(ctxptr, c.End.Next.Length, uintptr(oldBaseIndent))
code = c code = c
recursiveLevel++ recursiveLevel++
case encoder.OpRecursiveEnd: case encoder.OpRecursiveEnd:
recursiveLevel-- recursiveLevel--
// restore ctxptr // restore ctxptr
ctx.BaseIndent = int(load(ctxptr, code.Length))
offset := load(ctxptr, code.Idx) offset := load(ctxptr, code.Idx)
ctx.SeenPtr = ctx.SeenPtr[:len(ctx.SeenPtr)-1] ctx.SeenPtr = ctx.SeenPtr[:len(ctx.SeenPtr)-1]