diff --git a/encode.go b/encode.go index 58f019a..2247322 100644 --- a/encode.go +++ b/encode.go @@ -179,11 +179,19 @@ func (e *Encoder) encode(v interface{}) error { // noescape trick for header.typ ( reflect.*rtype ) copiedType := (*rtype)(unsafe.Pointer(typeptr)) - codeIndent, err := e.compileHead(copiedType, true) + codeIndent, err := e.compileHead(&encodeCompileContext{ + typ: copiedType, + root: true, + withIndent: true, + }) if err != nil { return err } - code, err := e.compileHead(copiedType, false) + code, err := e.compileHead(&encodeCompileContext{ + typ: copiedType, + root: true, + withIndent: false, + }) if err != nil { return err } diff --git a/encode_compile.go b/encode_compile.go index 9dc7ec2..229c485 100644 --- a/encode_compile.go +++ b/encode_compile.go @@ -6,8 +6,8 @@ import ( "unsafe" ) -func (e *Encoder) compileHead(typ *rtype, withIndent bool) (*opcode, error) { - root := true +func (e *Encoder) compileHead(ctx *encodeCompileContext) (*opcode, error) { + typ := ctx.typ switch { case typ.Implements(marshalJSONType): return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil @@ -24,11 +24,11 @@ func (e *Encoder) compileHead(typ *rtype, withIndent bool) (*opcode, error) { isPtr = true } if typ.Kind() == reflect.Map { - return e.compileMap(typ, isPtr, root, withIndent) + return e.compileMap(ctx.withType(typ), isPtr) } else if typ.Kind() == reflect.Struct { - return e.compileStruct(typ, isPtr, root, withIndent) + return e.compileStruct(ctx.withType(typ), isPtr) } - return e.compile(typ, root, withIndent) + return e.compile(ctx.withType(typ)) } func (e *Encoder) implementsMarshaler(typ *rtype) bool { @@ -45,7 +45,8 @@ func (e *Encoder) implementsMarshaler(typ *rtype) bool { return false } -func (e *Encoder) compile(typ *rtype, root, withIndent bool) (*opcode, error) { +func (e *Encoder) compile(ctx *encodeCompileContext) (*opcode, error) { + typ := ctx.typ switch { case typ.Implements(marshalJSONType): return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil @@ -58,56 +59,57 @@ func (e *Encoder) compile(typ *rtype, root, withIndent bool) (*opcode, error) { } switch typ.Kind() { case reflect.Ptr: - return e.compilePtr(typ, root, withIndent) + return e.compilePtr(ctx) case reflect.Slice: elem := typ.Elem() if !e.implementsMarshaler(elem) && elem.Kind() == reflect.Uint8 { - return e.compileBytes(typ) + return e.compileBytes(ctx) } - return e.compileSlice(typ, root, withIndent) + return e.compileSlice(ctx) case reflect.Array: - return e.compileArray(typ, root, withIndent) + return e.compileArray(ctx) case reflect.Map: - return e.compileMap(typ, true, root, withIndent) + return e.compileMap(ctx, true) case reflect.Struct: - return e.compileStruct(typ, false, root, withIndent) + return e.compileStruct(ctx, false) case reflect.Interface: - return e.compileInterface(typ, root) + return e.compileInterface(ctx) case reflect.Int: - return e.compileInt(typ) + return e.compileInt(ctx) case reflect.Int8: - return e.compileInt8(typ) + return e.compileInt8(ctx) case reflect.Int16: - return e.compileInt16(typ) + return e.compileInt16(ctx) case reflect.Int32: - return e.compileInt32(typ) + return e.compileInt32(ctx) case reflect.Int64: - return e.compileInt64(typ) + return e.compileInt64(ctx) case reflect.Uint: - return e.compileUint(typ) + return e.compileUint(ctx) case reflect.Uint8: - return e.compileUint8(typ) + return e.compileUint8(ctx) case reflect.Uint16: - return e.compileUint16(typ) + return e.compileUint16(ctx) case reflect.Uint32: - return e.compileUint32(typ) + return e.compileUint32(ctx) case reflect.Uint64: - return e.compileUint64(typ) + return e.compileUint64(ctx) case reflect.Uintptr: - return e.compileUint(typ) + return e.compileUint(ctx) case reflect.Float32: - return e.compileFloat32(typ) + return e.compileFloat32(ctx) case reflect.Float64: - return e.compileFloat64(typ) + return e.compileFloat64(ctx) case reflect.String: - return e.compileString(typ) + return e.compileString(ctx) case reflect.Bool: - return e.compileBool(typ) + return e.compileBool(ctx) } return nil, &UnsupportedTypeError{Type: rtype2type(typ)} } -func (e *Encoder) compileKey(typ *rtype, root, withIndent bool) (*opcode, error) { +func (e *Encoder) compileKey(ctx *encodeCompileContext) (*opcode, error) { + typ := ctx.typ switch { case typ.Implements(marshalJSONType): return newOpCode(opMarshalJSON, typ, e.indent, newEndOp(e.indent)), nil @@ -120,44 +122,17 @@ func (e *Encoder) compileKey(typ *rtype, root, withIndent bool) (*opcode, error) } switch typ.Kind() { case reflect.Ptr: - return e.compilePtr(typ, root, withIndent) + return e.compilePtr(ctx) case reflect.Interface: - return e.compileInterface(typ, root) - case reflect.Int: - return e.compileInt(typ) - case reflect.Int8: - return e.compileInt8(typ) - case reflect.Int16: - return e.compileInt16(typ) - case reflect.Int32: - return e.compileInt32(typ) - case reflect.Int64: - return e.compileInt64(typ) - case reflect.Uint: - return e.compileUint(typ) - case reflect.Uint8: - return e.compileUint8(typ) - case reflect.Uint16: - return e.compileUint16(typ) - case reflect.Uint32: - return e.compileUint32(typ) - case reflect.Uint64: - return e.compileUint64(typ) - case reflect.Uintptr: - return e.compileUint(typ) - case reflect.Float32: - return e.compileFloat32(typ) - case reflect.Float64: - return e.compileFloat64(typ) + return e.compileInterface(ctx) case reflect.String: - return e.compileString(typ) - case reflect.Bool: - return e.compileBool(typ) + return e.compileString(ctx) } return nil, &UnsupportedTypeError{Type: rtype2type(typ)} } -func (e *Encoder) optimizeStructFieldPtrHead(typ *rtype, code *opcode) *opcode { +func (e *Encoder) optimizeStructFieldPtrHead(ctx *encodeCompileContext, code *opcode) *opcode { + typ := ctx.typ ptrHeadOp := code.op.headToPtrHead() if code.op != ptrHeadOp { code.op = ptrHeadOp @@ -166,92 +141,93 @@ func (e *Encoder) optimizeStructFieldPtrHead(typ *rtype, code *opcode) *opcode { return newOpCode(opPtr, typ, e.indent, code) } -func (e *Encoder) compilePtr(typ *rtype, root, withIndent bool) (*opcode, error) { - code, err := e.compile(typ.Elem(), root, withIndent) +func (e *Encoder) compilePtr(ctx *encodeCompileContext) (*opcode, error) { + code, err := e.compile(ctx.withType(ctx.typ.Elem())) if err != nil { return nil, err } - return e.optimizeStructFieldPtrHead(typ, code), nil + return e.optimizeStructFieldPtrHead(ctx, code), nil } -func (e *Encoder) compileInt(typ *rtype) (*opcode, error) { - return newOpCode(opInt, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileInt(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opInt, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileInt8(typ *rtype) (*opcode, error) { - return newOpCode(opInt8, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileInt8(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opInt8, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileInt16(typ *rtype) (*opcode, error) { - return newOpCode(opInt16, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileInt16(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opInt16, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileInt32(typ *rtype) (*opcode, error) { - return newOpCode(opInt32, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileInt32(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opInt32, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileInt64(typ *rtype) (*opcode, error) { - return newOpCode(opInt64, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileInt64(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opInt64, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileUint(typ *rtype) (*opcode, error) { - return newOpCode(opUint, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileUint(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opUint, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileUint8(typ *rtype) (*opcode, error) { - return newOpCode(opUint8, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileUint8(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opUint8, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileUint16(typ *rtype) (*opcode, error) { - return newOpCode(opUint16, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileUint16(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opUint16, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileUint32(typ *rtype) (*opcode, error) { - return newOpCode(opUint32, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileUint32(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opUint32, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileUint64(typ *rtype) (*opcode, error) { - return newOpCode(opUint64, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileUint64(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opUint64, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileFloat32(typ *rtype) (*opcode, error) { - return newOpCode(opFloat32, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileFloat32(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opFloat32, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileFloat64(typ *rtype) (*opcode, error) { - return newOpCode(opFloat64, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileFloat64(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opFloat64, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileString(typ *rtype) (*opcode, error) { - return newOpCode(opString, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileString(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opString, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileBool(typ *rtype) (*opcode, error) { - return newOpCode(opBool, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileBool(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opBool, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileBytes(typ *rtype) (*opcode, error) { - return newOpCode(opBytes, typ, e.indent, newEndOp(e.indent)), nil +func (e *Encoder) compileBytes(ctx *encodeCompileContext) (*opcode, error) { + return newOpCode(opBytes, ctx.typ, e.indent, newEndOp(e.indent)), nil } -func (e *Encoder) compileInterface(typ *rtype, root bool) (*opcode, error) { +func (e *Encoder) compileInterface(ctx *encodeCompileContext) (*opcode, error) { return (*opcode)(unsafe.Pointer(&interfaceCode{ opcodeHeader: &opcodeHeader{ op: opInterface, - typ: typ, + typ: ctx.typ, indent: e.indent, next: newEndOp(e.indent), }, - root: root, + root: ctx.root, })), nil } -func (e *Encoder) compileSlice(typ *rtype, root, withIndent bool) (*opcode, error) { - elem := typ.Elem() +func (e *Encoder) compileSlice(ctx *encodeCompileContext) (*opcode, error) { + ctx.root = false + elem := ctx.typ.Elem() size := elem.Size() e.indent++ - code, err := e.compile(elem, false, withIndent) + code, err := e.compile(ctx.withType(ctx.typ.Elem())) e.indent-- if err != nil { @@ -271,8 +247,8 @@ func (e *Encoder) compileSlice(typ *rtype, root, withIndent bool) (*opcode, erro size: size, } end := newOpCode(opSliceEnd, nil, e.indent, newEndOp(e.indent)) - if withIndent { - if root { + if ctx.withIndent { + if ctx.root { header.op = opRootSliceHeadIndent elemCode.op = opRootSliceElemIndent } else { @@ -291,13 +267,15 @@ func (e *Encoder) compileSlice(typ *rtype, root, withIndent bool) (*opcode, erro return (*opcode)(unsafe.Pointer(header)), nil } -func (e *Encoder) compileArray(typ *rtype, root, withIndent bool) (*opcode, error) { +func (e *Encoder) compileArray(ctx *encodeCompileContext) (*opcode, error) { + ctx.root = false + typ := ctx.typ elem := typ.Elem() alen := typ.Len() size := elem.Size() e.indent++ - code, err := e.compile(elem, false, withIndent) + code, err := e.compile(ctx.withType(elem)) e.indent-- if err != nil { @@ -317,7 +295,7 @@ func (e *Encoder) compileArray(typ *rtype, root, withIndent bool) (*opcode, erro } end := newOpCode(opArrayEnd, nil, e.indent, newEndOp(e.indent)) - if withIndent { + if ctx.withIndent { header.op = opArrayHeadIndent elemCode.op = opArrayElemIndent end.op = opArrayEndIndent @@ -348,18 +326,19 @@ func mapiternext(it unsafe.Pointer) //go:noescape func maplen(m unsafe.Pointer) int -func (e *Encoder) compileMap(typ *rtype, withLoad, root, withIndent bool) (*opcode, error) { +func (e *Encoder) compileMap(ctx *encodeCompileContext, withLoad bool) (*opcode, error) { // header => code => value => code => key => code => value => code => end // ^ | // |_______________________| e.indent++ - keyType := typ.Key() - keyCode, err := e.compileKey(keyType, false, withIndent) + typ := ctx.typ + keyType := ctx.typ.Key() + keyCode, err := e.compileKey(ctx.withType(keyType)) if err != nil { return nil, err } valueType := typ.Elem() - valueCode, err := e.compile(valueType, false, withIndent) + valueCode, err := e.compile(ctx.withType(valueType)) if err != nil { return nil, err } @@ -374,9 +353,9 @@ func (e *Encoder) compileMap(typ *rtype, withLoad, root, withIndent bool) (*opco header.value = value end := newOpCode(opMapEnd, nil, e.indent, newEndOp(e.indent)) - if withIndent { + if ctx.withIndent { if header.op == opMapHead { - if root { + if ctx.root { header.op = opRootMapHeadIndent } else { header.op = opMapHeadIndent @@ -384,7 +363,7 @@ func (e *Encoder) compileMap(typ *rtype, withLoad, root, withIndent bool) (*opco } else { header.op = opMapHeadLoadIndent } - if root { + if ctx.root { key.op = opRootMapKeyIndent } else { key.op = opMapKeyIndent @@ -549,11 +528,11 @@ func (e *Encoder) optimizeStructField(op opType, tag *structTag, withIndent bool return fieldType } -func (e *Encoder) recursiveCode(typ *rtype, code *compiledCode) *opcode { +func (e *Encoder) recursiveCode(ctx *encodeCompileContext, code *compiledCode) *opcode { return (*opcode)(unsafe.Pointer(&recursiveCode{ opcodeHeader: &opcodeHeader{ op: opStructFieldRecursive, - typ: typ, + typ: ctx.typ, indent: e.indent, next: newEndOp(e.indent), }, @@ -561,15 +540,16 @@ func (e *Encoder) recursiveCode(typ *rtype, code *compiledCode) *opcode { })) } -func (e *Encoder) compiledCode(typ *rtype, withIndent bool) *opcode { +func (e *Encoder) compiledCode(ctx *encodeCompileContext) *opcode { + typ := ctx.typ typeptr := uintptr(unsafe.Pointer(typ)) - if withIndent { + if ctx.withIndent { if compiledCode, exists := e.structTypeToCompiledIndentCode[typeptr]; exists { - return e.recursiveCode(typ, compiledCode) + return e.recursiveCode(ctx, compiledCode) } } else { if compiledCode, exists := e.structTypeToCompiledCode[typeptr]; exists { - return e.recursiveCode(typ, compiledCode) + return e.recursiveCode(ctx, compiledCode) } } return nil @@ -790,13 +770,15 @@ func (e *Encoder) optimizeConflictAnonymousFields(anonymousFields map[string][]s } } -func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opcode, error) { - if code := e.compiledCode(typ, withIndent); code != nil { +func (e *Encoder) compileStruct(ctx *encodeCompileContext, isPtr bool) (*opcode, error) { + ctx.root = false + if code := e.compiledCode(ctx); code != nil { return code, nil } + typ := ctx.typ typeptr := uintptr(unsafe.Pointer(typ)) compiled := &compiledCode{} - if withIndent { + if ctx.withIndent { e.structTypeToCompiledIndentCode[typeptr] = compiled } else { e.structTypeToCompiledCode[typeptr] = compiled @@ -833,7 +815,7 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco fieldType = rtype_ptrTo(fieldType) } } - valueCode, err := e.compile(fieldType, false, withIndent) + valueCode, err := e.compile(ctx.withType(fieldType)) if err != nil { return nil, err } @@ -866,13 +848,13 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco offset: field.Offset, } if fieldIdx == 0 { - code = e.structHeader(fieldCode, valueCode, tag, withIndent) + code = e.structHeader(fieldCode, valueCode, tag, ctx.withIndent) head = fieldCode prevField = fieldCode } else { fcode := (*opcode)(unsafe.Pointer(fieldCode)) code.next = fcode - code = e.structField(fieldCode, valueCode, tag, withIndent) + code = e.structField(fieldCode, valueCode, tag, ctx.withIndent) prevField.nextField = fcode prevField = fieldCode } @@ -888,7 +870,7 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco }, })) structEndCode.next = newEndOp(e.indent) - if withIndent { + if ctx.withIndent { structEndCode.op = opStructEndIndent } @@ -906,7 +888,7 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco }, nextField: structEndCode, } - if withIndent { + if ctx.withIndent { head.op = opStructFieldHeadIndent } code = (*opcode)(unsafe.Pointer(head)) @@ -920,7 +902,7 @@ func (e *Encoder) compileStruct(typ *rtype, isPtr, root, withIndent bool) (*opco ret := (*opcode)(unsafe.Pointer(head)) compiled.code = ret - if withIndent { + if ctx.withIndent { delete(e.structTypeToCompiledIndentCode, typeptr) } else { delete(e.structTypeToCompiledCode, typeptr) diff --git a/encode_vm.go b/encode_vm.go index 7126476..1873cdc 100644 --- a/encode_vm.go +++ b/encode_vm.go @@ -112,13 +112,21 @@ func (e *Encoder) run(code *opcode) error { e.indent = ifaceCode.indent var c *opcode if typ.Kind() == reflect.Map { - code, err := e.compileMap(typ, false, ifaceCode.root, e.enabledIndent) + code, err := e.compileMap(&encodeCompileContext{ + typ: typ, + root: ifaceCode.root, + withIndent: e.enabledIndent, + }, false) if err != nil { return err } c = code } else { - code, err := e.compile(typ, ifaceCode.root, e.enabledIndent) + code, err := e.compile(&encodeCompileContext{ + typ: typ, + root: ifaceCode.root, + withIndent: e.enabledIndent, + }) if err != nil { return err }