diff --git a/internal/encoder/code.go b/internal/encoder/code.go index 68303ea..3e03f97 100644 --- a/internal/encoder/code.go +++ b/internal/encoder/code.go @@ -14,6 +14,10 @@ type Code interface { ToOpcode(*compileContext) Opcodes } +type AnonymousCode interface { + ToAnonymousOpcode(*compileContext) Opcodes +} + type Opcodes []*Opcode func (o Opcodes) First() *Opcode { @@ -225,7 +229,7 @@ func (c *SliceCode) ToOpcode(ctx *compileContext) Opcodes { header := newSliceHeaderCode(ctx) ctx.incIndex() codes := c.value.ToOpcode(ctx) - elemCode := newSliceElemCode(ctx, header, size) + elemCode := newSliceElemCode(ctx.withType(c.typ.Elem()).incIndent(), header, size) ctx.incIndex() end := newOpCode(ctx, OpSliceEnd) ctx.incIndex() @@ -259,7 +263,7 @@ func (c *ArrayCode) ToOpcode(ctx *compileContext) Opcodes { codes := c.value.ToOpcode(ctx) - elemCode := newArrayElemCode(ctx, header, alen, size) + elemCode := newArrayElemCode(ctx.withType(c.typ.Elem()).incIndent(), header, alen, size) ctx.incIndex() end := newOpCode(ctx, OpArrayEnd) @@ -347,6 +351,7 @@ func (c *StructCode) ToOpcode(ctx *compileContext) Opcodes { } codes := Opcodes{} var prevField *Opcode + ctx = ctx.incIndent() for idx, field := range c.fields { isFirstField := idx == 0 isEndField := idx == len(c.fields)-1 @@ -358,10 +363,69 @@ func (c *StructCode) ToOpcode(ctx *compileContext) Opcodes { } if len(codes) > 0 { codes.Last().Next = fieldCodes.First() + fieldCodes.First().Idx = codes.First().Idx } if prevField != nil { prevField.NextField = fieldCodes.First() } + if isEndField { + if len(codes) > 0 { + codes.First().End = fieldCodes.Last() + } else { + fieldCodes.First().End = fieldCodes.Last() + } + } + prevField = fieldCodes.First() + codes = append(codes, fieldCodes...) + } + ctx = ctx.decIndent() + if c.isRecursive { + c.recursiveCodes = codes + c.linkRecursiveCode(compiled, c.recursiveCodes) + return Opcodes{recursive} + } + return codes +} + +func (c *StructCode) ToAnonymousOpcode(ctx *compileContext) Opcodes { + // header => code => structField => code => end + // ^ | + // |__________| + var recursive *Opcode + compiled := &CompiledCode{} + if c.isRecursive { + recursive = newRecursiveCode(ctx, compiled) + ctx.incIndex() + } + if len(c.recursiveCodes) > 0 { + c.linkRecursiveCode(compiled, c.recursiveCodes) + return Opcodes{recursive} + } + codes := Opcodes{} + var prevField *Opcode + for idx, field := range c.fields { + isFirstField := idx == 0 + isEndField := idx == len(c.fields)-1 + fieldCodes := field.ToAnonymousOpcode(ctx, isFirstField, isEndField) + for _, code := range fieldCodes { + if c.isIndirect { + code.Flags |= IndirectFlags + } + } + if len(codes) > 0 { + codes.Last().Next = fieldCodes.First() + fieldCodes.First().Idx = codes.First().Idx + } + if prevField != nil { + prevField.NextField = fieldCodes.First() + } + if isEndField { + if len(codes) > 0 { + codes.First().End = fieldCodes.Last() + } else { + fieldCodes.First().End = fieldCodes.Last() + } + } prevField = fieldCodes.First() codes = append(codes, fieldCodes...) } @@ -387,7 +451,7 @@ func (c *StructCode) linkRecursiveCode(compiled *CompiledCode, codes Opcodes) { lastCode.Length = lastCode.Idx + 2*uintptrSize // extend length to alloc slot for elemIdx + length - totalLength := uintptr(codes.First().TotalLength() + 3) + totalLength := uintptr((codes.Last().MaxIdx() / uintptrSize) + 3) nextTotalLength := uintptr(code.TotalLength() + 3) code.End.Next.Op = OpRecursiveEnd @@ -505,7 +569,12 @@ func (c *StructFieldCode) ToOpcode(ctx *compileContext, isFirstField, isEndField DisplayKey: c.key, } ctx.incIndex() - codes := c.value.ToOpcode(ctx) + var codes Opcodes + if c.isAnonymous { + codes = c.value.(AnonymousCode).ToAnonymousOpcode(ctx) + } else { + codes = c.value.ToOpcode(ctx) + } if isFirstField { op := optimizeStructHeader(codes.First(), c.tag) field.Op = op @@ -518,7 +587,7 @@ func (c *StructFieldCode) ToOpcode(ctx *compileContext, isFirstField, isEndField } else { ctx.decIndex() } - if isEndField && !c.isAnonymous { + if isEndField { end := &Opcode{ Op: OpStructEnd, Idx: opcodeOffset(ctx.ptrIndex), @@ -526,6 +595,7 @@ func (c *StructFieldCode) ToOpcode(ctx *compileContext, isFirstField, isEndField Indent: ctx.indent, } fieldCodes.Last().Next = end + fieldCodes.First().NextField = end fieldCodes = append(fieldCodes, end) ctx.incIndex() } @@ -555,6 +625,7 @@ func (c *StructFieldCode) ToOpcode(ctx *compileContext, isFirstField, isEndField Indent: ctx.indent, } fieldCodes.Last().Next = end + fieldCodes.First().NextField = end fieldCodes = append(fieldCodes, end) ctx.incIndex() } @@ -562,6 +633,79 @@ func (c *StructFieldCode) ToOpcode(ctx *compileContext, isFirstField, isEndField return fieldCodes } +func (c *StructFieldCode) ToAnonymousOpcode(ctx *compileContext, isFirstField, isEndField bool) Opcodes { + var key string + if ctx.escapeKey { + rctx := &RuntimeContext{Option: &Option{Flag: HTMLEscapeOption}} + key = fmt.Sprintf(`%s:`, string(AppendString(rctx, []byte{}, c.key))) + } else { + key = fmt.Sprintf(`"%s":`, c.key) + } + var flags OpFlags + flags |= AnonymousHeadFlags + flags |= AnonymousKeyFlags + if c.isTaggedKey { + flags |= IsTaggedKeyFlags + } + if c.isNilableType { + flags |= IsNilableTypeFlags + } + if c.isNilCheck { + flags |= NilCheckFlags + } + if c.isAddrForMarshaler { + flags |= AddrForMarshalerFlags + } + if c.isNextOpPtrType { + flags |= IsNextOpPtrTypeFlags + } + field := &Opcode{ + Idx: opcodeOffset(ctx.ptrIndex), + Flags: flags, + Key: key, + Offset: uint32(c.offset), + Type: c.typ, + DisplayIdx: ctx.opcodeIndex, + Indent: ctx.indent, + DisplayKey: c.key, + } + ctx.incIndex() + var codes Opcodes + if c.isAnonymous { + codes = c.value.(AnonymousCode).ToAnonymousOpcode(ctx) + } else { + codes = c.value.ToOpcode(ctx) + } + if isFirstField { + op := optimizeStructHeader(codes.First(), c.tag) + field.Op = op + field.NumBitSize = codes.First().NumBitSize + field.PtrNum = codes.First().PtrNum + fieldCodes := Opcodes{field} + if op.IsMultipleOpHead() { + field.Next = codes.First() + fieldCodes = append(fieldCodes, codes...) + } else { + ctx.decIndex() + } + return fieldCodes + } + op := optimizeStructField(codes.First(), c.tag) + field.Op = op + field.NumBitSize = codes.First().NumBitSize + field.PtrNum = codes.First().PtrNum + + fieldCodes := Opcodes{field} + if op.IsMultipleOpField() { + field.Next = codes.First() + fieldCodes = append(fieldCodes, codes...) + } else { + // optimize codes + ctx.decIndex() + } + return fieldCodes +} + func isEnableStructEndOptimizationType(typ CodeType2) bool { switch typ { case CodeTypeInt, CodeTypeUint, CodeTypeFloat, CodeTypeString, CodeTypeBool: @@ -653,6 +797,14 @@ func (c *PtrCode) Type() CodeType2 { func (c *PtrCode) ToOpcode(ctx *compileContext) Opcodes { codes := c.value.ToOpcode(ctx) + codes.First().Op = convertPtrOp(codes.First()) + codes.First().PtrNum = c.ptrNum + return codes +} + +func (c *PtrCode) ToAnonymousOpcode(ctx *compileContext) Opcodes { + codes := c.value.(AnonymousCode).ToAnonymousOpcode(ctx) + codes.First().Op = convertPtrOp(codes.First()) codes.First().PtrNum = c.ptrNum return codes } @@ -1101,6 +1253,7 @@ func compileStruct2(ctx *compileContext, isPtr bool) (*StructCode, error) { if !code.disableIndirectConversion && !indirect && isPtr { code.enableIndirect() } + delete(ctx.structTypeToCode, typeptr) return code, nil } diff --git a/internal/encoder/opcode.go b/internal/encoder/opcode.go index 171ce7d..618b37d 100644 --- a/internal/encoder/opcode.go +++ b/internal/encoder/opcode.go @@ -83,7 +83,7 @@ func (c *Opcode) IsEnd() bool { if c == nil { return false } - return c.Op == OpEnd || c.Op == OpInterfaceEnd + return c.Op == OpEnd || c.Op == OpInterfaceEnd || c.Op == OpRecursiveEnd } func (c *Opcode) IsStructHeadOp() bool { diff --git a/test/cover/cover_int_test.go b/test/cover/cover_int_test.go index b21a528..c2377d4 100644 --- a/test/cover/cover_int_test.go +++ b/test/cover/cover_int_test.go @@ -2,6 +2,7 @@ package json_test import ( "bytes" + "fmt" "testing" "github.com/goccy/go-json" @@ -2410,19 +2411,21 @@ func TestCoverInt(t *testing.T) { for _, test := range tests { for _, indent := range []bool{true, false} { for _, htmlEscape := range []bool{true, false} { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.SetEscapeHTML(htmlEscape) - if indent { - enc.SetIndent("", " ") - } - if err := enc.Encode(test.data); err != nil { - t.Fatalf("%s(htmlEscape:%v,indent:%v): %+v: %s", test.name, htmlEscape, indent, test.data, err) - } - stdresult := encodeByEncodingJSON(test.data, indent, htmlEscape) - if buf.String() != stdresult { - t.Errorf("%s(htmlEscape:%v,indent:%v): doesn't compatible with encoding/json. expected %q but got %q", test.name, htmlEscape, indent, stdresult, buf.String()) - } + t.Run(fmt.Sprintf("%s_indent_%t_escape_%t", test.name, indent, htmlEscape), func(t *testing.T) { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(htmlEscape) + if indent { + enc.SetIndent("", " ") + } + if err := enc.Encode(test.data); err != nil { + t.Fatalf("%s(htmlEscape:%v,indent:%v): %+v: %s", test.name, htmlEscape, indent, test.data, err) + } + stdresult := encodeByEncodingJSON(test.data, indent, htmlEscape) + if buf.String() != stdresult { + t.Errorf("%s(htmlEscape:%v,indent:%v): doesn't compatible with encoding/json. expected %q but got %q", test.name, htmlEscape, indent, stdresult, buf.String()) + } + }) } } }