From fee54d48735c36116c7ef5e567979a0fdf449174 Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Fri, 26 Nov 2021 18:47:45 +0900 Subject: [PATCH] rename compiler API --- internal/encoder/code.go | 610 +---------- internal/encoder/compiler.go | 1861 +++++++++++----------------------- 2 files changed, 580 insertions(+), 1891 deletions(-) diff --git a/internal/encoder/code.go b/internal/encoder/code.go index 19797ef..eb2f7f0 100644 --- a/internal/encoder/code.go +++ b/internal/encoder/code.go @@ -2,10 +2,8 @@ package encoder import ( "fmt" - "reflect" "unsafe" - "github.com/goccy/go-json/internal/errors" "github.com/goccy/go-json/internal/runtime" ) @@ -470,35 +468,6 @@ func (c *StructCode) ToAnonymousOpcode(ctx *compileContext) Opcodes { return codes } -func linkRecursiveCode2(ctx *compileContext) { - for _, recursive := range *ctx.recursiveCodes { - typeptr := uintptr(unsafe.Pointer(recursive.Type)) - codes := ctx.structTypeToCodes[typeptr] - compiled := recursive.Jmp - compiled.Code = copyOpcode(codes.First()) - code := compiled.Code - code.End.Next = newEndOp(&compileContext{}) - code.Op = code.Op.PtrHeadToHead() - - beforeLastCode := code.End - lastCode := beforeLastCode.Next - - totalLength := code.TotalLength() - lastCode.Idx = uint32((totalLength + 1) * uintptrSize) - lastCode.ElemIdx = lastCode.Idx + uintptrSize - lastCode.Length = lastCode.Idx + 2*uintptrSize - code.End.Next.Op = OpRecursiveEnd - - - // extend length to alloc slot for elemIdx + length - curTotalLength := uintptr(recursive.TotalLength()) + 3 - nextTotalLength := uintptr(totalLength) + 3 - compiled.CurLen = curTotalLength - compiled.NextLen = nextTotalLength - compiled.Linked = true - } -} - func (c *StructCode) removeFieldsByTags(tags runtime.StructTags) { fields := make([]*StructFieldCode, 0, len(c.fields)) for _, field := range c.fields { @@ -855,577 +824,6 @@ func (c *PtrCode) ToAnonymousOpcode(ctx *compileContext) Opcodes { return codes } -func type2code(ctx *compileContext) (Code, error) { - typ := ctx.typ - switch { - case implementsMarshalJSON(typ): - return compileMarshalJSON2(ctx) - case implementsMarshalText(typ): - return compileMarshalText2(ctx) - } - - isPtr := false - orgType := typ - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - isPtr = true - } - switch { - case implementsMarshalJSON(typ): - return compileMarshalJSON2(ctx) - case implementsMarshalText(typ): - return compileMarshalText2(ctx) - } - switch typ.Kind() { - case reflect.Slice: - ctx := ctx.withType(typ) - elem := typ.Elem() - if elem.Kind() == reflect.Uint8 { - p := runtime.PtrTo(elem) - if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { - return compileBytes2(ctx, isPtr) - } - } - return compileSlice2(ctx) - case reflect.Map: - if isPtr { - return compilePtr2(ctx.withType(runtime.PtrTo(typ))) - } - return compileMap2(ctx.withType(typ)) - case reflect.Struct: - return compileStruct2(ctx.withType(typ), isPtr) - case reflect.Int: - return compileInt2(ctx.withType(typ), isPtr) - case reflect.Int8: - return compileInt82(ctx.withType(typ), isPtr) - case reflect.Int16: - return compileInt162(ctx.withType(typ), isPtr) - case reflect.Int32: - return compileInt322(ctx.withType(typ), isPtr) - case reflect.Int64: - return compileInt642(ctx.withType(typ), isPtr) - case reflect.Uint, reflect.Uintptr: - return compileUint2(ctx.withType(typ), isPtr) - case reflect.Uint8: - return compileUint82(ctx.withType(typ), isPtr) - case reflect.Uint16: - return compileUint162(ctx.withType(typ), isPtr) - case reflect.Uint32: - return compileUint322(ctx.withType(typ), isPtr) - case reflect.Uint64: - return compileUint642(ctx.withType(typ), isPtr) - case reflect.Float32: - return compileFloat322(ctx.withType(typ), isPtr) - case reflect.Float64: - return compileFloat642(ctx.withType(typ), isPtr) - case reflect.String: - return compileString2(ctx.withType(typ), isPtr) - case reflect.Bool: - return compileBool2(ctx.withType(typ), isPtr) - case reflect.Interface: - return compileInterface2(ctx.withType(typ), isPtr) - default: - if isPtr && typ.Implements(marshalTextType) { - typ = orgType - } - return type2codeWithPtr(ctx.withType(typ), isPtr) - } -} - -func type2codeWithPtr(ctx *compileContext, isPtr bool) (Code, error) { - typ := ctx.typ - switch { - case implementsMarshalJSON(typ): - return compileMarshalJSON2(ctx) - case implementsMarshalText(typ): - return compileMarshalText2(ctx) - } - switch typ.Kind() { - case reflect.Ptr: - return compilePtr2(ctx) - case reflect.Slice: - elem := typ.Elem() - if elem.Kind() == reflect.Uint8 { - p := runtime.PtrTo(elem) - if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { - return compileBytes2(ctx, false) - } - } - return compileSlice2(ctx) - case reflect.Array: - return compileArray2(ctx) - case reflect.Map: - return compileMap2(ctx) - case reflect.Struct: - return compileStruct2(ctx, isPtr) - case reflect.Interface: - return compileInterface2(ctx, false) - case reflect.Int: - return compileInt2(ctx, false) - case reflect.Int8: - return compileInt82(ctx, false) - case reflect.Int16: - return compileInt162(ctx, false) - case reflect.Int32: - return compileInt322(ctx, false) - case reflect.Int64: - return compileInt642(ctx, false) - case reflect.Uint: - return compileUint2(ctx, false) - case reflect.Uint8: - return compileUint82(ctx, false) - case reflect.Uint16: - return compileUint162(ctx, false) - case reflect.Uint32: - return compileUint322(ctx, false) - case reflect.Uint64: - return compileUint642(ctx, false) - case reflect.Uintptr: - return compileUint2(ctx, false) - case reflect.Float32: - return compileFloat322(ctx, false) - case reflect.Float64: - return compileFloat642(ctx, false) - case reflect.String: - return compileString2(ctx, false) - case reflect.Bool: - return compileBool2(ctx, false) - } - return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)} -} - -func compileInt2(ctx *compileContext, isPtr bool) (*IntCode, error) { - return &IntCode{typ: ctx.typ, bitSize: intSize, isPtr: isPtr}, nil -} - -func compileInt82(ctx *compileContext, isPtr bool) (*IntCode, error) { - return &IntCode{typ: ctx.typ, bitSize: 8, isPtr: isPtr}, nil -} - -func compileInt162(ctx *compileContext, isPtr bool) (*IntCode, error) { - return &IntCode{typ: ctx.typ, bitSize: 16, isPtr: isPtr}, nil -} - -func compileInt322(ctx *compileContext, isPtr bool) (*IntCode, error) { - return &IntCode{typ: ctx.typ, bitSize: 32, isPtr: isPtr}, nil -} - -func compileInt642(ctx *compileContext, isPtr bool) (*IntCode, error) { - return &IntCode{typ: ctx.typ, bitSize: 64, isPtr: isPtr}, nil -} - -func compileUint2(ctx *compileContext, isPtr bool) (*UintCode, error) { - return &UintCode{typ: ctx.typ, bitSize: intSize, isPtr: isPtr}, nil -} - -func compileUint82(ctx *compileContext, isPtr bool) (*UintCode, error) { - return &UintCode{typ: ctx.typ, bitSize: 8, isPtr: isPtr}, nil -} - -func compileUint162(ctx *compileContext, isPtr bool) (*UintCode, error) { - return &UintCode{typ: ctx.typ, bitSize: 16, isPtr: isPtr}, nil -} - -func compileUint322(ctx *compileContext, isPtr bool) (*UintCode, error) { - return &UintCode{typ: ctx.typ, bitSize: 32, isPtr: isPtr}, nil -} - -func compileUint642(ctx *compileContext, isPtr bool) (*UintCode, error) { - return &UintCode{typ: ctx.typ, bitSize: 64, isPtr: isPtr}, nil -} - -func compileFloat322(ctx *compileContext, isPtr bool) (*FloatCode, error) { - return &FloatCode{typ: ctx.typ, bitSize: 32, isPtr: isPtr}, nil -} - -func compileFloat642(ctx *compileContext, isPtr bool) (*FloatCode, error) { - return &FloatCode{typ: ctx.typ, bitSize: 64, isPtr: isPtr}, nil -} - -func compileString2(ctx *compileContext, isPtr bool) (*StringCode, error) { - return &StringCode{typ: ctx.typ, isPtr: isPtr}, nil -} - -func compileBool2(ctx *compileContext, isPtr bool) (*BoolCode, error) { - return &BoolCode{typ: ctx.typ, isPtr: isPtr}, nil -} - -func compileIntString2(ctx *compileContext) (*IntCode, error) { - return &IntCode{typ: ctx.typ, bitSize: intSize, isString: true}, nil -} - -func compileInt8String2(ctx *compileContext) (*IntCode, error) { - return &IntCode{typ: ctx.typ, bitSize: 8, isString: true}, nil -} - -func compileInt16String2(ctx *compileContext) (*IntCode, error) { - return &IntCode{typ: ctx.typ, bitSize: 16, isString: true}, nil -} - -func compileInt32String2(ctx *compileContext) (*IntCode, error) { - return &IntCode{typ: ctx.typ, bitSize: 32, isString: true}, nil -} - -func compileInt64String2(ctx *compileContext) (*IntCode, error) { - return &IntCode{typ: ctx.typ, bitSize: 64, isString: true}, nil -} - -func compileUintString2(ctx *compileContext) (*UintCode, error) { - return &UintCode{typ: ctx.typ, bitSize: intSize, isString: true}, nil -} - -func compileUint8String2(ctx *compileContext) (*UintCode, error) { - return &UintCode{typ: ctx.typ, bitSize: 8, isString: true}, nil -} - -func compileUint16String2(ctx *compileContext) (*UintCode, error) { - return &UintCode{typ: ctx.typ, bitSize: 16, isString: true}, nil -} - -func compileUint32String2(ctx *compileContext) (*UintCode, error) { - return &UintCode{typ: ctx.typ, bitSize: 32, isString: true}, nil -} - -func compileUint64String2(ctx *compileContext) (*UintCode, error) { - return &UintCode{typ: ctx.typ, bitSize: 64, isString: true}, nil -} - -func compileSlice2(ctx *compileContext) (*SliceCode, error) { - elem := ctx.typ.Elem() - code, err := compileListElem2(ctx.withType(elem)) - if err != nil { - return nil, err - } - if code.Type() == CodeTypeStruct { - structCode := code.(*StructCode) - structCode.enableIndirect() - } - return &SliceCode{typ: ctx.typ, value: code}, nil -} - -func compileArray2(ctx *compileContext) (*ArrayCode, error) { - typ := ctx.typ - elem := typ.Elem() - code, err := compileListElem2(ctx.withType(elem)) - if err != nil { - return nil, err - } - if code.Type() == CodeTypeStruct { - structCode := code.(*StructCode) - structCode.enableIndirect() - } - return &ArrayCode{typ: ctx.typ, value: code}, nil -} - -func compileMap2(ctx *compileContext) (*MapCode, error) { - typ := ctx.typ - keyCode, err := compileMapKey(ctx.withType(typ.Key())) - if err != nil { - return nil, err - } - valueCode, err := compileMapValue2(ctx.withType(typ.Elem())) - if err != nil { - return nil, err - } - if valueCode.Type() == CodeTypeStruct { - structCode := valueCode.(*StructCode) - structCode.enableIndirect() - } - return &MapCode{typ: ctx.typ, key: keyCode, value: valueCode}, nil -} - -func compileBytes2(ctx *compileContext, isPtr bool) (*BytesCode, error) { - return &BytesCode{typ: ctx.typ, isPtr: isPtr}, nil -} - -func compileInterface2(ctx *compileContext, isPtr bool) (*InterfaceCode, error) { - return &InterfaceCode{typ: ctx.typ, isPtr: isPtr}, nil -} - -func compileMarshalJSON2(ctx *compileContext) (*MarshalJSONCode, error) { - return &MarshalJSONCode{typ: ctx.typ}, nil -} - -func compileMarshalText2(ctx *compileContext) (*MarshalTextCode, error) { - return &MarshalTextCode{typ: ctx.typ}, nil -} - -func compilePtr2(ctx *compileContext) (*PtrCode, error) { - code, err := type2codeWithPtr(ctx.withType(ctx.typ.Elem()), true) - if err != nil { - return nil, err - } - ptr, ok := code.(*PtrCode) - if ok { - return &PtrCode{typ: ctx.typ, value: ptr.value, ptrNum: ptr.ptrNum + 1}, nil - } - return &PtrCode{typ: ctx.typ, value: code, ptrNum: 1}, nil -} - -func compileListElem2(ctx *compileContext) (Code, error) { - typ := ctx.typ - switch { - case isPtrMarshalJSONType(typ): - return compileMarshalJSON2(ctx) - case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType): - return compileMarshalText2(ctx) - case typ.Kind() == reflect.Map: - return compilePtr2(ctx.withType(runtime.PtrTo(typ))) - default: - code, err := type2codeWithPtr(ctx, false) - if err != nil { - return nil, err - } - ptr, ok := code.(*PtrCode) - if ok { - if ptr.value.Type() == CodeTypeMap { - ptr.ptrNum++ - } - } - return code, nil - } -} - -func compileMapKey(ctx *compileContext) (Code, error) { - typ := ctx.typ - switch { - case implementsMarshalJSON(typ): - return compileMarshalJSON2(ctx) - case implementsMarshalText(typ): - return compileMarshalText2(ctx) - } - switch typ.Kind() { - case reflect.Ptr: - return compilePtr2(ctx) - case reflect.String: - return compileString2(ctx, false) - case reflect.Int: - return compileIntString2(ctx) - case reflect.Int8: - return compileInt8String2(ctx) - case reflect.Int16: - return compileInt16String2(ctx) - case reflect.Int32: - return compileInt32String2(ctx) - case reflect.Int64: - return compileInt64String2(ctx) - case reflect.Uint: - return compileUintString2(ctx) - case reflect.Uint8: - return compileUint8String2(ctx) - case reflect.Uint16: - return compileUint16String2(ctx) - case reflect.Uint32: - return compileUint32String2(ctx) - case reflect.Uint64: - return compileUint64String2(ctx) - case reflect.Uintptr: - return compileUintString2(ctx) - } - return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)} -} - -func compileMapValue2(ctx *compileContext) (Code, error) { - switch ctx.typ.Kind() { - case reflect.Map: - return compilePtr2(ctx.withType(runtime.PtrTo(ctx.typ))) - default: - code, err := type2codeWithPtr(ctx, false) - if err != nil { - return nil, err - } - ptr, ok := code.(*PtrCode) - if ok { - if ptr.value.Type() == CodeTypeMap { - ptr.ptrNum++ - } - } - return code, nil - } -} - -func compileStruct2(ctx *compileContext, isPtr bool) (*StructCode, error) { - typ := ctx.typ - typeptr := uintptr(unsafe.Pointer(typ)) - if code, exists := ctx.structTypeToCode[typeptr]; exists { - derefCode := *code - derefCode.isRecursive = true - return &derefCode, nil - } - indirect := runtime.IfaceIndir(typ) - code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect} - ctx.structTypeToCode[typeptr] = code - - fieldNum := typ.NumField() - tags := typeToStructTags(typ) - fields := []*StructFieldCode{} - for i, tag := range tags { - isOnlyOneFirstField := i == 0 && fieldNum == 1 - field, err := code.compileStructField(ctx, tag, isPtr, isOnlyOneFirstField) - if err != nil { - return nil, err - } - if field.isAnonymous { - structCode := field.getAnonymousStruct() - if structCode != nil { - structCode.removeFieldsByTags(tags) - if isAssignableIndirect(field, isPtr) { - if indirect { - structCode.isIndirect = true - } else { - structCode.isIndirect = false - } - } - } - } else { - structCode := field.getStruct() - if structCode != nil { - if indirect { - // if parent is indirect type, set child indirect property to true - structCode.isIndirect = true - } else { - // if parent is not indirect type, set child indirect property to false. - // but if parent's indirect is false and isPtr is true, then indirect must be true. - // Do this only if indirectConversion is enabled at the end of compileStruct. - structCode.isIndirect = false - } - } - } - fields = append(fields, field) - } - fieldMap := getFieldMap(fields) - duplicatedFieldMap := getDuplicatedFieldMap(fieldMap) - code.fields = filteredDuplicatedFields(fields, duplicatedFieldMap) - if !code.disableIndirectConversion && !indirect && isPtr { - code.enableIndirect() - } - delete(ctx.structTypeToCode, typeptr) - return code, nil -} - -func getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode { - fieldMap := map[string][]*StructFieldCode{} - for _, field := range fields { - if field.isAnonymous { - for k, v := range getAnonymousFieldMap(field) { - fieldMap[k] = append(fieldMap[k], v...) - } - continue - } - fieldMap[field.key] = append(fieldMap[field.key], field) - } - return fieldMap -} - -func getAnonymousFieldMap(field *StructFieldCode) map[string][]*StructFieldCode { - fieldMap := map[string][]*StructFieldCode{} - structCode := field.getAnonymousStruct() - if structCode == nil || structCode.isRecursive { - fieldMap[field.key] = append(fieldMap[field.key], field) - return fieldMap - } - for k, v := range getFieldMapFromAnonymousParent(structCode.fields) { - fieldMap[k] = append(fieldMap[k], v...) - } - return fieldMap -} - -func getFieldMapFromAnonymousParent(fields []*StructFieldCode) map[string][]*StructFieldCode { - fieldMap := map[string][]*StructFieldCode{} - for _, field := range fields { - if field.isAnonymous { - for k, v := range getAnonymousFieldMap(field) { - // Do not handle tagged key when embedding more than once - for _, vv := range v { - vv.isTaggedKey = false - } - fieldMap[k] = append(fieldMap[k], v...) - } - continue - } - fieldMap[field.key] = append(fieldMap[field.key], field) - } - return fieldMap -} - -func getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} { - duplicatedFieldMap := map[*StructFieldCode]struct{}{} - for _, fields := range fieldMap { - if len(fields) == 1 { - continue - } - if isTaggedKeyOnly(fields) { - for _, field := range fields { - if field.isTaggedKey { - continue - } - duplicatedFieldMap[field] = struct{}{} - } - } else { - for _, field := range fields { - duplicatedFieldMap[field] = struct{}{} - } - } - } - return duplicatedFieldMap -} - -func filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode { - filteredFields := make([]*StructFieldCode, 0, len(fields)) - for _, field := range fields { - if field.isAnonymous { - structCode := field.getAnonymousStruct() - if structCode != nil && !structCode.isRecursive { - structCode.fields = filteredDuplicatedFields(structCode.fields, duplicatedFieldMap) - if len(structCode.fields) > 0 { - filteredFields = append(filteredFields, field) - } - continue - } - } - if _, exists := duplicatedFieldMap[field]; exists { - continue - } - filteredFields = append(filteredFields, field) - } - return filteredFields -} - -func isTaggedKeyOnly(fields []*StructFieldCode) bool { - var taggedKeyFieldCount int - for _, field := range fields { - if field.isTaggedKey { - taggedKeyFieldCount++ - } - } - return taggedKeyFieldCount == 1 -} - -func typeToStructTags(typ *runtime.Type) runtime.StructTags { - tags := runtime.StructTags{} - fieldNum := typ.NumField() - for i := 0; i < fieldNum; i++ { - field := typ.Field(i) - if runtime.IsIgnoredStructField(field) { - continue - } - tags = append(tags, runtime.StructTagFromField(field)) - } - return tags -} - -// *struct{ field T } => struct { field *T } -// func (*T) MarshalJSON() ([]byte, error) -func isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool { - return isIndirectSpecialCase && !isNilableType(typ) && isPtrMarshalJSONType(typ) -} - -// *struct{ field T } => struct { field *T } -// func (*T) MarshalText() ([]byte, error) -func isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool { - return isIndirectSpecialCase && !isNilableType(typ) && isPtrMarshalTextType(typ) -} - func (c *StructCode) compileStructField(ctx *compileContext, tag *runtime.StructTag, isPtr, isOnlyOneFirstField bool) (*StructFieldCode, error) { field := tag.Field fieldType := runtime.Type2RType(field.Type) @@ -1442,7 +840,7 @@ func (c *StructCode) compileStructField(ctx *compileContext, tag *runtime.Struct } switch { case isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(fieldType, isIndirectSpecialCase): - code, err := compileMarshalJSON2(ctx.withType(fieldType)) + code, err := compileMarshalJSON(ctx.withType(fieldType)) if err != nil { return nil, err } @@ -1452,7 +850,7 @@ func (c *StructCode) compileStructField(ctx *compileContext, tag *runtime.Struct c.isIndirect = false c.disableIndirectConversion = true case isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(fieldType, isIndirectSpecialCase): - code, err := compileMarshalText2(ctx.withType(fieldType)) + code, err := compileMarshalText(ctx.withType(fieldType)) if err != nil { return nil, err } @@ -1464,7 +862,7 @@ func (c *StructCode) compileStructField(ctx *compileContext, tag *runtime.Struct case isPtr && isPtrMarshalJSONType(fieldType): // *struct{ field T } // func (*T) MarshalJSON() ([]byte, error) - code, err := compileMarshalJSON2(ctx.withType(fieldType)) + code, err := compileMarshalJSON(ctx.withType(fieldType)) if err != nil { return nil, err } @@ -1474,7 +872,7 @@ func (c *StructCode) compileStructField(ctx *compileContext, tag *runtime.Struct case isPtr && isPtrMarshalTextType(fieldType): // *struct{ field T } // func (*T) MarshalText() ([]byte, error) - code, err := compileMarshalText2(ctx.withType(fieldType)) + code, err := compileMarshalText(ctx.withType(fieldType)) if err != nil { return nil, err } diff --git a/internal/encoder/compiler.go b/internal/encoder/compiler.go index b0a4461..a71b4f4 100644 --- a/internal/encoder/compiler.go +++ b/internal/encoder/compiler.go @@ -4,9 +4,7 @@ import ( "context" "encoding" "encoding/json" - "fmt" "reflect" - "strings" "sync/atomic" "unsafe" @@ -101,269 +99,12 @@ func compileHead(ctx *compileContext) (*Opcode, error) { if err != nil { return nil, err } - //pp.Println(code) derefctx := *ctx newCtx := &derefctx codes := code.ToOpcode(newCtx) codes.Last().Next = newEndOp(newCtx) - //pp.Println(codes.First()) - linkRecursiveCode2(newCtx) - //fmt.Println(codes.First().Dump()) + linkRecursiveCode(newCtx) return codes.First(), nil - - typ := ctx.typ - switch { - case implementsMarshalJSON(typ): - return compileMarshalJSON(ctx) - case implementsMarshalText(typ): - return compileMarshalText(ctx) - } - - isPtr := false - orgType := typ - if typ.Kind() == reflect.Ptr { - typ = typ.Elem() - isPtr = true - } - switch { - case implementsMarshalJSON(typ): - return compileMarshalJSON(ctx) - case implementsMarshalText(typ): - return compileMarshalText(ctx) - } - switch typ.Kind() { - case reflect.Slice: - ctx := ctx.withType(typ) - elem := typ.Elem() - if elem.Kind() == reflect.Uint8 { - p := runtime.PtrTo(elem) - if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { - if isPtr { - return compileBytesPtr(ctx) - } - return compileBytes(ctx) - } - } - code, err := compileSlice(ctx) - if err != nil { - return nil, err - } - optimizeStructEnd(code) - linkRecursiveCode(code) - return code, nil - case reflect.Map: - if isPtr { - return compilePtr(ctx.withType(runtime.PtrTo(typ))) - } - code, err := compileMap(ctx.withType(typ)) - if err != nil { - return nil, err - } - optimizeStructEnd(code) - linkRecursiveCode(code) - return code, nil - case reflect.Struct: - code, err := compileStruct(ctx.withType(typ), isPtr) - if err != nil { - return nil, err - } - optimizeStructEnd(code) - linkRecursiveCode(code) - return code, nil - case reflect.Int: - ctx := ctx.withType(typ) - if isPtr { - return compileIntPtr(ctx) - } - return compileInt(ctx) - case reflect.Int8: - ctx := ctx.withType(typ) - if isPtr { - return compileInt8Ptr(ctx) - } - return compileInt8(ctx) - case reflect.Int16: - ctx := ctx.withType(typ) - if isPtr { - return compileInt16Ptr(ctx) - } - return compileInt16(ctx) - case reflect.Int32: - ctx := ctx.withType(typ) - if isPtr { - return compileInt32Ptr(ctx) - } - return compileInt32(ctx) - case reflect.Int64: - ctx := ctx.withType(typ) - if isPtr { - return compileInt64Ptr(ctx) - } - return compileInt64(ctx) - case reflect.Uint, reflect.Uintptr: - ctx := ctx.withType(typ) - if isPtr { - return compileUintPtr(ctx) - } - return compileUint(ctx) - case reflect.Uint8: - ctx := ctx.withType(typ) - if isPtr { - return compileUint8Ptr(ctx) - } - return compileUint8(ctx) - case reflect.Uint16: - ctx := ctx.withType(typ) - if isPtr { - return compileUint16Ptr(ctx) - } - return compileUint16(ctx) - case reflect.Uint32: - ctx := ctx.withType(typ) - if isPtr { - return compileUint32Ptr(ctx) - } - return compileUint32(ctx) - case reflect.Uint64: - ctx := ctx.withType(typ) - if isPtr { - return compileUint64Ptr(ctx) - } - return compileUint64(ctx) - case reflect.Float32: - ctx := ctx.withType(typ) - if isPtr { - return compileFloat32Ptr(ctx) - } - return compileFloat32(ctx) - case reflect.Float64: - ctx := ctx.withType(typ) - if isPtr { - return compileFloat64Ptr(ctx) - } - return compileFloat64(ctx) - case reflect.String: - ctx := ctx.withType(typ) - if isPtr { - return compileStringPtr(ctx) - } - return compileString(ctx) - case reflect.Bool: - ctx := ctx.withType(typ) - if isPtr { - return compileBoolPtr(ctx) - } - return compileBool(ctx) - case reflect.Interface: - ctx := ctx.withType(typ) - if isPtr { - return compileInterfacePtr(ctx) - } - return compileInterface(ctx) - default: - if isPtr && typ.Implements(marshalTextType) { - typ = orgType - } - code, err := compile(ctx.withType(typ), isPtr) - if err != nil { - return nil, err - } - optimizeStructEnd(code) - linkRecursiveCode(code) - return code, nil - } -} - -func linkRecursiveCode(c *Opcode) { - for code := c; code.Op != OpEnd && code.Op != OpRecursiveEnd; { - switch code.Op { - case OpRecursive, OpRecursivePtr: - if code.Jmp.Linked { - code = code.Next - continue - } - code.Jmp.Code = copyOpcode(code.Jmp.Code) - - c := code.Jmp.Code - c.End.Next = newEndOp(&compileContext{}) - c.Op = c.Op.PtrHeadToHead() - - beforeLastCode := c.End - lastCode := beforeLastCode.Next - - lastCode.Idx = beforeLastCode.Idx + uintptrSize - lastCode.ElemIdx = lastCode.Idx + uintptrSize - lastCode.Length = lastCode.Idx + 2*uintptrSize - - // extend length to alloc slot for elemIdx + length - totalLength := uintptr(code.TotalLength() + 3) - nextTotalLength := uintptr(c.TotalLength() + 3) - - c.End.Next.Op = OpRecursiveEnd - - code.Jmp.CurLen = totalLength - code.Jmp.NextLen = nextTotalLength - code.Jmp.Linked = true - - linkRecursiveCode(code.Jmp.Code) - - code = code.Next - continue - } - switch code.Op.CodeType() { - case CodeArrayElem, CodeSliceElem, CodeMapKey: - code = code.End - default: - code = code.Next - } - } -} - -func optimizeStructEnd(c *Opcode) { - for code := c; code.Op != OpEnd; { - if code.Op == OpRecursive || code.Op == OpRecursivePtr { - // ignore if exists recursive operation - return - } - switch code.Op.CodeType() { - case CodeArrayElem, CodeSliceElem, CodeMapKey: - code = code.End - default: - code = code.Next - } - } - - for code := c; code.Op != OpEnd; { - switch code.Op.CodeType() { - case CodeArrayElem, CodeSliceElem, CodeMapKey: - code = code.End - case CodeStructEnd: - switch code.Op { - case OpStructEnd: - prev := code.PrevField - prevOp := prev.Op.String() - if strings.Contains(prevOp, "Head") || - strings.Contains(prevOp, "Slice") || - strings.Contains(prevOp, "Array") || - strings.Contains(prevOp, "Map") || - strings.Contains(prevOp, "MarshalJSON") || - strings.Contains(prevOp, "MarshalText") { - // not exists field - code = code.Next - break - } - if prev.Op != prev.Op.FieldToEnd() { - prev.Op = prev.Op.FieldToEnd() - prev.Next = code.Next - } - code = code.Next - default: - code = code.Next - } - default: - code = code.Next - } - } } func implementsMarshalJSON(typ *runtime.Type) bool { @@ -396,68 +137,6 @@ func implementsMarshalText(typ *runtime.Type) bool { return false } -func compile(ctx *compileContext, isPtr bool) (*Opcode, error) { - typ := ctx.typ - switch { - case implementsMarshalJSON(typ): - return compileMarshalJSON(ctx) - case implementsMarshalText(typ): - return compileMarshalText(ctx) - } - switch typ.Kind() { - case reflect.Ptr: - return compilePtr(ctx) - case reflect.Slice: - elem := typ.Elem() - if elem.Kind() == reflect.Uint8 { - p := runtime.PtrTo(elem) - if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { - return compileBytes(ctx) - } - } - return compileSlice(ctx) - case reflect.Array: - return compileArray(ctx) - case reflect.Map: - return compileMap(ctx) - case reflect.Struct: - return compileStruct(ctx, isPtr) - case reflect.Interface: - return compileInterface(ctx) - case reflect.Int: - return compileInt(ctx) - case reflect.Int8: - return compileInt8(ctx) - case reflect.Int16: - return compileInt16(ctx) - case reflect.Int32: - return compileInt32(ctx) - case reflect.Int64: - return compileInt64(ctx) - case reflect.Uint: - return compileUint(ctx) - case reflect.Uint8: - return compileUint8(ctx) - case reflect.Uint16: - return compileUint16(ctx) - case reflect.Uint32: - return compileUint32(ctx) - case reflect.Uint64: - return compileUint64(ctx) - case reflect.Uintptr: - return compileUint(ctx) - case reflect.Float32: - return compileFloat32(ctx) - case reflect.Float64: - return compileFloat64(ctx) - case reflect.String: - return compileString(ctx) - case reflect.Bool: - return compileBool(ctx) - } - return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)} -} - func convertPtrOp(code *Opcode) OpType { ptrHeadOp := code.Op.HeadToPtrHead() if code.Op != ptrHeadOp { @@ -502,7 +181,409 @@ func convertPtrOp(code *Opcode) OpType { return code.Op } -func compileKey(ctx *compileContext) (*Opcode, error) { +const intSize = 32 << (^uint(0) >> 63) + +func optimizeStructHeader(code *Opcode, tag *runtime.StructTag) OpType { + headType := code.ToHeaderType(tag.IsString) + if tag.IsOmitEmpty { + headType = headType.HeadToOmitEmptyHead() + } + return headType +} + +func optimizeStructField(code *Opcode, tag *runtime.StructTag) OpType { + fieldType := code.ToFieldType(tag.IsString) + if tag.IsOmitEmpty { + fieldType = fieldType.FieldToOmitEmptyField() + } + return fieldType +} + +func isNilableType(typ *runtime.Type) bool { + switch typ.Kind() { + case reflect.Ptr: + return true + case reflect.Map: + return true + case reflect.Func: + return true + default: + return false + } +} + +func implementsMarshalJSONType(typ *runtime.Type) bool { + return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType) +} + +func isPtrMarshalJSONType(typ *runtime.Type) bool { + return !implementsMarshalJSONType(typ) && implementsMarshalJSONType(runtime.PtrTo(typ)) +} + +func isPtrMarshalTextType(typ *runtime.Type) bool { + return !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType) +} + +func linkRecursiveCode(ctx *compileContext) { + for _, recursive := range *ctx.recursiveCodes { + typeptr := uintptr(unsafe.Pointer(recursive.Type)) + codes := ctx.structTypeToCodes[typeptr] + compiled := recursive.Jmp + compiled.Code = copyOpcode(codes.First()) + code := compiled.Code + code.End.Next = newEndOp(&compileContext{}) + code.Op = code.Op.PtrHeadToHead() + + beforeLastCode := code.End + lastCode := beforeLastCode.Next + + totalLength := code.TotalLength() + lastCode.Idx = uint32((totalLength + 1) * uintptrSize) + lastCode.ElemIdx = lastCode.Idx + uintptrSize + lastCode.Length = lastCode.Idx + 2*uintptrSize + code.End.Next.Op = OpRecursiveEnd + + // extend length to alloc slot for elemIdx + length + curTotalLength := uintptr(recursive.TotalLength()) + 3 + nextTotalLength := uintptr(totalLength) + 3 + compiled.CurLen = curTotalLength + compiled.NextLen = nextTotalLength + compiled.Linked = true + } +} + +func type2code(ctx *compileContext) (Code, error) { + typ := ctx.typ + switch { + case implementsMarshalJSON(typ): + return compileMarshalJSON(ctx) + case implementsMarshalText(typ): + return compileMarshalText(ctx) + } + + isPtr := false + orgType := typ + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + isPtr = true + } + switch { + case implementsMarshalJSON(typ): + return compileMarshalJSON(ctx) + case implementsMarshalText(typ): + return compileMarshalText(ctx) + } + switch typ.Kind() { + case reflect.Slice: + ctx := ctx.withType(typ) + elem := typ.Elem() + if elem.Kind() == reflect.Uint8 { + p := runtime.PtrTo(elem) + if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { + return compileBytes(ctx, isPtr) + } + } + return compileSlice(ctx) + case reflect.Map: + if isPtr { + return compilePtr(ctx.withType(runtime.PtrTo(typ))) + } + return compileMap(ctx.withType(typ)) + case reflect.Struct: + return compileStruct(ctx.withType(typ), isPtr) + case reflect.Int: + return compileInt(ctx.withType(typ), isPtr) + case reflect.Int8: + return compileInt8(ctx.withType(typ), isPtr) + case reflect.Int16: + return compileInt16(ctx.withType(typ), isPtr) + case reflect.Int32: + return compileInt32(ctx.withType(typ), isPtr) + case reflect.Int64: + return compileInt64(ctx.withType(typ), isPtr) + case reflect.Uint, reflect.Uintptr: + return compileUint(ctx.withType(typ), isPtr) + case reflect.Uint8: + return compileUint8(ctx.withType(typ), isPtr) + case reflect.Uint16: + return compileUint16(ctx.withType(typ), isPtr) + case reflect.Uint32: + return compileUint32(ctx.withType(typ), isPtr) + case reflect.Uint64: + return compileUint64(ctx.withType(typ), isPtr) + case reflect.Float32: + return compileFloat32(ctx.withType(typ), isPtr) + case reflect.Float64: + return compileFloat64(ctx.withType(typ), isPtr) + case reflect.String: + return compileString(ctx.withType(typ), isPtr) + case reflect.Bool: + return compileBool(ctx.withType(typ), isPtr) + case reflect.Interface: + return compileInterface(ctx.withType(typ), isPtr) + default: + if isPtr && typ.Implements(marshalTextType) { + typ = orgType + } + return type2codeWithPtr(ctx.withType(typ), isPtr) + } +} + +func type2codeWithPtr(ctx *compileContext, isPtr bool) (Code, error) { + typ := ctx.typ + switch { + case implementsMarshalJSON(typ): + return compileMarshalJSON(ctx) + case implementsMarshalText(typ): + return compileMarshalText(ctx) + } + switch typ.Kind() { + case reflect.Ptr: + return compilePtr(ctx) + case reflect.Slice: + elem := typ.Elem() + if elem.Kind() == reflect.Uint8 { + p := runtime.PtrTo(elem) + if !implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { + return compileBytes(ctx, false) + } + } + return compileSlice(ctx) + case reflect.Array: + return compileArray(ctx) + case reflect.Map: + return compileMap(ctx) + case reflect.Struct: + return compileStruct(ctx, isPtr) + case reflect.Interface: + return compileInterface(ctx, false) + case reflect.Int: + return compileInt(ctx, false) + case reflect.Int8: + return compileInt8(ctx, false) + case reflect.Int16: + return compileInt16(ctx, false) + case reflect.Int32: + return compileInt32(ctx, false) + case reflect.Int64: + return compileInt64(ctx, false) + case reflect.Uint: + return compileUint(ctx, false) + case reflect.Uint8: + return compileUint8(ctx, false) + case reflect.Uint16: + return compileUint16(ctx, false) + case reflect.Uint32: + return compileUint32(ctx, false) + case reflect.Uint64: + return compileUint64(ctx, false) + case reflect.Uintptr: + return compileUint(ctx, false) + case reflect.Float32: + return compileFloat32(ctx, false) + case reflect.Float64: + return compileFloat64(ctx, false) + case reflect.String: + return compileString(ctx, false) + case reflect.Bool: + return compileBool(ctx, false) + } + return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)} +} + +func compileInt(ctx *compileContext, isPtr bool) (*IntCode, error) { + return &IntCode{typ: ctx.typ, bitSize: intSize, isPtr: isPtr}, nil +} + +func compileInt8(ctx *compileContext, isPtr bool) (*IntCode, error) { + return &IntCode{typ: ctx.typ, bitSize: 8, isPtr: isPtr}, nil +} + +func compileInt16(ctx *compileContext, isPtr bool) (*IntCode, error) { + return &IntCode{typ: ctx.typ, bitSize: 16, isPtr: isPtr}, nil +} + +func compileInt32(ctx *compileContext, isPtr bool) (*IntCode, error) { + return &IntCode{typ: ctx.typ, bitSize: 32, isPtr: isPtr}, nil +} + +func compileInt64(ctx *compileContext, isPtr bool) (*IntCode, error) { + return &IntCode{typ: ctx.typ, bitSize: 64, isPtr: isPtr}, nil +} + +func compileUint(ctx *compileContext, isPtr bool) (*UintCode, error) { + return &UintCode{typ: ctx.typ, bitSize: intSize, isPtr: isPtr}, nil +} + +func compileUint8(ctx *compileContext, isPtr bool) (*UintCode, error) { + return &UintCode{typ: ctx.typ, bitSize: 8, isPtr: isPtr}, nil +} + +func compileUint16(ctx *compileContext, isPtr bool) (*UintCode, error) { + return &UintCode{typ: ctx.typ, bitSize: 16, isPtr: isPtr}, nil +} + +func compileUint32(ctx *compileContext, isPtr bool) (*UintCode, error) { + return &UintCode{typ: ctx.typ, bitSize: 32, isPtr: isPtr}, nil +} + +func compileUint64(ctx *compileContext, isPtr bool) (*UintCode, error) { + return &UintCode{typ: ctx.typ, bitSize: 64, isPtr: isPtr}, nil +} + +func compileFloat32(ctx *compileContext, isPtr bool) (*FloatCode, error) { + return &FloatCode{typ: ctx.typ, bitSize: 32, isPtr: isPtr}, nil +} + +func compileFloat64(ctx *compileContext, isPtr bool) (*FloatCode, error) { + return &FloatCode{typ: ctx.typ, bitSize: 64, isPtr: isPtr}, nil +} + +func compileString(ctx *compileContext, isPtr bool) (*StringCode, error) { + return &StringCode{typ: ctx.typ, isPtr: isPtr}, nil +} + +func compileBool(ctx *compileContext, isPtr bool) (*BoolCode, error) { + return &BoolCode{typ: ctx.typ, isPtr: isPtr}, nil +} + +func compileIntString(ctx *compileContext) (*IntCode, error) { + return &IntCode{typ: ctx.typ, bitSize: intSize, isString: true}, nil +} + +func compileInt8String(ctx *compileContext) (*IntCode, error) { + return &IntCode{typ: ctx.typ, bitSize: 8, isString: true}, nil +} + +func compileInt16String(ctx *compileContext) (*IntCode, error) { + return &IntCode{typ: ctx.typ, bitSize: 16, isString: true}, nil +} + +func compileInt32String(ctx *compileContext) (*IntCode, error) { + return &IntCode{typ: ctx.typ, bitSize: 32, isString: true}, nil +} + +func compileInt64String(ctx *compileContext) (*IntCode, error) { + return &IntCode{typ: ctx.typ, bitSize: 64, isString: true}, nil +} + +func compileUintString(ctx *compileContext) (*UintCode, error) { + return &UintCode{typ: ctx.typ, bitSize: intSize, isString: true}, nil +} + +func compileUint8String(ctx *compileContext) (*UintCode, error) { + return &UintCode{typ: ctx.typ, bitSize: 8, isString: true}, nil +} + +func compileUint16String(ctx *compileContext) (*UintCode, error) { + return &UintCode{typ: ctx.typ, bitSize: 16, isString: true}, nil +} + +func compileUint32String(ctx *compileContext) (*UintCode, error) { + return &UintCode{typ: ctx.typ, bitSize: 32, isString: true}, nil +} + +func compileUint64String(ctx *compileContext) (*UintCode, error) { + return &UintCode{typ: ctx.typ, bitSize: 64, isString: true}, nil +} + +func compileSlice(ctx *compileContext) (*SliceCode, error) { + elem := ctx.typ.Elem() + code, err := compileListElem(ctx.withType(elem)) + if err != nil { + return nil, err + } + if code.Type() == CodeTypeStruct { + structCode := code.(*StructCode) + structCode.enableIndirect() + } + return &SliceCode{typ: ctx.typ, value: code}, nil +} + +func compileArray(ctx *compileContext) (*ArrayCode, error) { + typ := ctx.typ + elem := typ.Elem() + code, err := compileListElem(ctx.withType(elem)) + if err != nil { + return nil, err + } + if code.Type() == CodeTypeStruct { + structCode := code.(*StructCode) + structCode.enableIndirect() + } + return &ArrayCode{typ: ctx.typ, value: code}, nil +} + +func compileMap(ctx *compileContext) (*MapCode, error) { + typ := ctx.typ + keyCode, err := compileMapKey(ctx.withType(typ.Key())) + if err != nil { + return nil, err + } + valueCode, err := compileMapValue(ctx.withType(typ.Elem())) + if err != nil { + return nil, err + } + if valueCode.Type() == CodeTypeStruct { + structCode := valueCode.(*StructCode) + structCode.enableIndirect() + } + return &MapCode{typ: ctx.typ, key: keyCode, value: valueCode}, nil +} + +func compileBytes(ctx *compileContext, isPtr bool) (*BytesCode, error) { + return &BytesCode{typ: ctx.typ, isPtr: isPtr}, nil +} + +func compileInterface(ctx *compileContext, isPtr bool) (*InterfaceCode, error) { + return &InterfaceCode{typ: ctx.typ, isPtr: isPtr}, nil +} + +func compileMarshalJSON(ctx *compileContext) (*MarshalJSONCode, error) { + return &MarshalJSONCode{typ: ctx.typ}, nil +} + +func compileMarshalText(ctx *compileContext) (*MarshalTextCode, error) { + return &MarshalTextCode{typ: ctx.typ}, nil +} + +func compilePtr(ctx *compileContext) (*PtrCode, error) { + code, err := type2codeWithPtr(ctx.withType(ctx.typ.Elem()), true) + if err != nil { + return nil, err + } + ptr, ok := code.(*PtrCode) + if ok { + return &PtrCode{typ: ctx.typ, value: ptr.value, ptrNum: ptr.ptrNum + 1}, nil + } + return &PtrCode{typ: ctx.typ, value: code, ptrNum: 1}, nil +} + +func compileListElem(ctx *compileContext) (Code, error) { + typ := ctx.typ + switch { + case isPtrMarshalJSONType(typ): + return compileMarshalJSON(ctx) + case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType): + return compileMarshalText(ctx) + case typ.Kind() == reflect.Map: + return compilePtr(ctx.withType(runtime.PtrTo(typ))) + default: + code, err := type2codeWithPtr(ctx, false) + if err != nil { + return nil, err + } + ptr, ok := code.(*PtrCode) + if ok { + if ptr.value.Type() == CodeTypeMap { + ptr.ptrNum++ + } + } + return code, nil + } +} + +func compileMapKey(ctx *compileContext) (Code, error) { typ := ctx.typ switch { case implementsMarshalJSON(typ): @@ -514,7 +595,7 @@ func compileKey(ctx *compileContext) (*Opcode, error) { case reflect.Ptr: return compilePtr(ctx) case reflect.String: - return compileString(ctx) + return compileString(ctx, false) case reflect.Int: return compileIntString(ctx) case reflect.Int8: @@ -541,774 +622,185 @@ func compileKey(ctx *compileContext) (*Opcode, error) { return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)} } -func compilePtr(ctx *compileContext) (*Opcode, error) { - code, err := compile(ctx.withType(ctx.typ.Elem()), true) - if err != nil { - return nil, err - } - code.Op = convertPtrOp(code) - code.PtrNum++ - return code, nil -} - -func compileMarshalJSON(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpMarshalJSON) - typ := ctx.typ - if isPtrMarshalJSONType(typ) { - code.Flags |= AddrForMarshalerFlags - } - if typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType) { - code.Flags |= MarshalerContextFlags - } - if isNilableType(typ) { - code.Flags |= IsNilableTypeFlags - } else { - code.Flags &= ^IsNilableTypeFlags - } - ctx.incIndex() - return code, nil -} - -func compileMarshalText(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpMarshalText) - typ := ctx.typ - if !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType) { - code.Flags |= AddrForMarshalerFlags - } - if isNilableType(typ) { - code.Flags |= IsNilableTypeFlags - } else { - code.Flags &= ^IsNilableTypeFlags - } - ctx.incIndex() - return code, nil -} - -const intSize = 32 << (^uint(0) >> 63) - -func compileInt(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpInt) - code.NumBitSize = intSize - ctx.incIndex() - return code, nil -} - -func compileIntPtr(ctx *compileContext) (*Opcode, error) { - code, err := compileInt(ctx) - if err != nil { - return nil, err - } - code.Op = OpIntPtr - return code, nil -} - -func compileInt8(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpInt) - code.NumBitSize = 8 - ctx.incIndex() - return code, nil -} - -func compileInt8Ptr(ctx *compileContext) (*Opcode, error) { - code, err := compileInt8(ctx) - if err != nil { - return nil, err - } - code.Op = OpIntPtr - return code, nil -} - -func compileInt16(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpInt) - code.NumBitSize = 16 - ctx.incIndex() - return code, nil -} - -func compileInt16Ptr(ctx *compileContext) (*Opcode, error) { - code, err := compileInt16(ctx) - if err != nil { - return nil, err - } - code.Op = OpIntPtr - return code, nil -} - -func compileInt32(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpInt) - code.NumBitSize = 32 - ctx.incIndex() - return code, nil -} - -func compileInt32Ptr(ctx *compileContext) (*Opcode, error) { - code, err := compileInt32(ctx) - if err != nil { - return nil, err - } - code.Op = OpIntPtr - return code, nil -} - -func compileInt64(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpInt) - code.NumBitSize = 64 - ctx.incIndex() - return code, nil -} - -func compileInt64Ptr(ctx *compileContext) (*Opcode, error) { - code, err := compileInt64(ctx) - if err != nil { - return nil, err - } - code.Op = OpIntPtr - return code, nil -} - -func compileUint(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpUint) - code.NumBitSize = intSize - ctx.incIndex() - return code, nil -} - -func compileUintPtr(ctx *compileContext) (*Opcode, error) { - code, err := compileUint(ctx) - if err != nil { - return nil, err - } - code.Op = OpUintPtr - return code, nil -} - -func compileUint8(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpUint) - code.NumBitSize = 8 - ctx.incIndex() - return code, nil -} - -func compileUint8Ptr(ctx *compileContext) (*Opcode, error) { - code, err := compileUint8(ctx) - if err != nil { - return nil, err - } - code.Op = OpUintPtr - return code, nil -} - -func compileUint16(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpUint) - code.NumBitSize = 16 - ctx.incIndex() - return code, nil -} - -func compileUint16Ptr(ctx *compileContext) (*Opcode, error) { - code, err := compileUint16(ctx) - if err != nil { - return nil, err - } - code.Op = OpUintPtr - return code, nil -} - -func compileUint32(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpUint) - code.NumBitSize = 32 - ctx.incIndex() - return code, nil -} - -func compileUint32Ptr(ctx *compileContext) (*Opcode, error) { - code, err := compileUint32(ctx) - if err != nil { - return nil, err - } - code.Op = OpUintPtr - return code, nil -} - -func compileUint64(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpUint) - code.NumBitSize = 64 - ctx.incIndex() - return code, nil -} - -func compileUint64Ptr(ctx *compileContext) (*Opcode, error) { - code, err := compileUint64(ctx) - if err != nil { - return nil, err - } - code.Op = OpUintPtr - return code, nil -} - -func compileIntString(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpIntString) - code.NumBitSize = intSize - ctx.incIndex() - return code, nil -} - -func compileInt8String(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpIntString) - code.NumBitSize = 8 - ctx.incIndex() - return code, nil -} - -func compileInt16String(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpIntString) - code.NumBitSize = 16 - ctx.incIndex() - return code, nil -} - -func compileInt32String(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpIntString) - code.NumBitSize = 32 - ctx.incIndex() - return code, nil -} - -func compileInt64String(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpIntString) - code.NumBitSize = 64 - ctx.incIndex() - return code, nil -} - -func compileUintString(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpUintString) - code.NumBitSize = intSize - ctx.incIndex() - return code, nil -} - -func compileUint8String(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpUintString) - code.NumBitSize = 8 - ctx.incIndex() - return code, nil -} - -func compileUint16String(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpUintString) - code.NumBitSize = 16 - ctx.incIndex() - return code, nil -} - -func compileUint32String(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpUintString) - code.NumBitSize = 32 - ctx.incIndex() - return code, nil -} - -func compileUint64String(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpUintString) - code.NumBitSize = 64 - ctx.incIndex() - return code, nil -} - -func compileFloat32(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpFloat32) - ctx.incIndex() - return code, nil -} - -func compileFloat32Ptr(ctx *compileContext) (*Opcode, error) { - code, err := compileFloat32(ctx) - if err != nil { - return nil, err - } - code.Op = OpFloat32Ptr - return code, nil -} - -func compileFloat64(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpFloat64) - ctx.incIndex() - return code, nil -} - -func compileFloat64Ptr(ctx *compileContext) (*Opcode, error) { - code, err := compileFloat64(ctx) - if err != nil { - return nil, err - } - code.Op = OpFloat64Ptr - return code, nil -} - -func compileString(ctx *compileContext) (*Opcode, error) { - var op OpType - if ctx.typ == runtime.Type2RType(jsonNumberType) { - op = OpNumber - } else { - op = OpString - } - code := newOpCode(ctx, op) - ctx.incIndex() - return code, nil -} - -func compileStringPtr(ctx *compileContext) (*Opcode, error) { - code, err := compileString(ctx) - if err != nil { - return nil, err - } - if code.Op == OpNumber { - code.Op = OpNumberPtr - } else { - code.Op = OpStringPtr - } - return code, nil -} - -func compileBool(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpBool) - ctx.incIndex() - return code, nil -} - -func compileBoolPtr(ctx *compileContext) (*Opcode, error) { - code, err := compileBool(ctx) - if err != nil { - return nil, err - } - code.Op = OpBoolPtr - return code, nil -} - -func compileBytes(ctx *compileContext) (*Opcode, error) { - code := newOpCode(ctx, OpBytes) - ctx.incIndex() - return code, nil -} - -func compileBytesPtr(ctx *compileContext) (*Opcode, error) { - code, err := compileBytes(ctx) - if err != nil { - return nil, err - } - code.Op = OpBytesPtr - return code, nil -} - -func compileInterface(ctx *compileContext) (*Opcode, error) { - code := newInterfaceCode(ctx) - ctx.incIndex() - return code, nil -} - -func compileInterfacePtr(ctx *compileContext) (*Opcode, error) { - code, err := compileInterface(ctx) - if err != nil { - return nil, err - } - code.Op = OpInterfacePtr - return code, nil -} - -func compileSlice(ctx *compileContext) (*Opcode, error) { - elem := ctx.typ.Elem() - size := elem.Size() - - header := newSliceHeaderCode(ctx) - ctx.incIndex() - - code, err := compileListElem(ctx.withType(elem).incIndent()) - if err != nil { - return nil, err - } - code.Flags |= IndirectFlags - - // header => opcode => elem => end - // ^ | - // |________| - - elemCode := newSliceElemCode(ctx, header, size) - ctx.incIndex() - - end := newOpCode(ctx, OpSliceEnd) - ctx.incIndex() - - header.End = end - header.Next = code - code.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(elemCode)) - elemCode.Next = code - elemCode.End = end - return (*Opcode)(unsafe.Pointer(header)), nil -} - -func compileListElem(ctx *compileContext) (*Opcode, error) { - typ := ctx.typ - switch { - case isPtrMarshalJSONType(typ): - return compileMarshalJSON(ctx) - case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType): - return compileMarshalText(ctx) - case typ.Kind() == reflect.Map: - return compilePtr(ctx.withType(runtime.PtrTo(typ))) - default: - code, err := compile(ctx, false) - if err != nil { - return nil, err - } - if code.Op == OpMapPtr { - code.PtrNum++ - } - return code, nil - } -} - -func compileArray(ctx *compileContext) (*Opcode, error) { - typ := ctx.typ - elem := typ.Elem() - alen := typ.Len() - size := elem.Size() - - header := newArrayHeaderCode(ctx, alen) - ctx.incIndex() - - code, err := compileListElem(ctx.withType(elem).incIndent()) - if err != nil { - return nil, err - } - code.Flags |= IndirectFlags - // header => opcode => elem => end - // ^ | - // |________| - - elemCode := newArrayElemCode(ctx, header, alen, size) - ctx.incIndex() - - end := newOpCode(ctx, OpArrayEnd) - ctx.incIndex() - - header.End = end - header.Next = code - code.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(elemCode)) - elemCode.Next = code - elemCode.End = end - return (*Opcode)(unsafe.Pointer(header)), nil -} - -func compileMap(ctx *compileContext) (*Opcode, error) { - // header => code => value => code => key => code => value => code => end - // ^ | - // |_______________________| - ctx = ctx.incIndent() - header := newMapHeaderCode(ctx) - ctx.incIndex() - - typ := ctx.typ - keyType := ctx.typ.Key() - keyCode, err := compileKey(ctx.withType(keyType)) - if err != nil { - return nil, err - } - - value := newMapValueCode(ctx, header) - ctx.incIndex() - - valueCode, err := compileMapValue(ctx.withType(typ.Elem())) - if err != nil { - return nil, err - } - valueCode.Flags |= IndirectFlags - - key := newMapKeyCode(ctx, header) - ctx.incIndex() - - ctx = ctx.decIndent() - - end := newMapEndCode(ctx, header) - ctx.incIndex() - - header.Next = keyCode - keyCode.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(value)) - value.Next = valueCode - valueCode.BeforeLastCode().Next = (*Opcode)(unsafe.Pointer(key)) - key.Next = keyCode - - header.End = end - key.End = end - value.End = end - - return (*Opcode)(unsafe.Pointer(header)), nil -} - -func compileMapValue(ctx *compileContext) (*Opcode, error) { +func compileMapValue(ctx *compileContext) (Code, error) { switch ctx.typ.Kind() { case reflect.Map: return compilePtr(ctx.withType(runtime.PtrTo(ctx.typ))) default: - code, err := compile(ctx, false) + code, err := type2codeWithPtr(ctx, false) if err != nil { return nil, err } - if code.Op == OpMapPtr { - code.PtrNum++ + ptr, ok := code.(*PtrCode) + if ok { + if ptr.value.Type() == CodeTypeMap { + ptr.ptrNum++ + } } return code, nil } } -func optimizeStructHeader(code *Opcode, tag *runtime.StructTag) OpType { - headType := code.ToHeaderType(tag.IsString) - if tag.IsOmitEmpty { - headType = headType.HeadToOmitEmptyHead() - } - return headType -} - -func optimizeStructField(code *Opcode, tag *runtime.StructTag) OpType { - fieldType := code.ToFieldType(tag.IsString) - if tag.IsOmitEmpty { - fieldType = fieldType.FieldToOmitEmptyField() - } - return fieldType -} - -func recursiveCode(ctx *compileContext, jmp *CompiledCode) *Opcode { - code := newRecursiveCode(ctx, jmp) - ctx.incIndex() - return code -} - -func compiledCode(ctx *compileContext) *Opcode { +func compileStruct(ctx *compileContext, isPtr bool) (*StructCode, error) { typ := ctx.typ typeptr := uintptr(unsafe.Pointer(typ)) - if cc, exists := ctx.structTypeToCompiledCode[typeptr]; exists { - return recursiveCode(ctx, cc) + if code, exists := ctx.structTypeToCode[typeptr]; exists { + derefCode := *code + derefCode.isRecursive = true + return &derefCode, nil } - return nil -} - -func structHeader(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag *runtime.StructTag) *Opcode { - op := optimizeStructHeader(valueCode, tag) - fieldCode.Op = op - fieldCode.NumBitSize = valueCode.NumBitSize - fieldCode.PtrNum = valueCode.PtrNum - if op.IsMultipleOpHead() { - return valueCode.BeforeLastCode() - } - ctx.decOpcodeIndex() - return fieldCode -} - -func structField(ctx *compileContext, fieldCode *Opcode, valueCode *Opcode, tag *runtime.StructTag) *Opcode { - op := optimizeStructField(valueCode, tag) - fieldCode.Op = op - fieldCode.NumBitSize = valueCode.NumBitSize - fieldCode.PtrNum = valueCode.PtrNum - if op.IsMultipleOpField() { - return valueCode.BeforeLastCode() - } - ctx.decIndex() - return fieldCode -} - -type structFieldPair struct { - prevField *Opcode - curField *Opcode - isTaggedKey bool - linked bool -} - -func filterAnonymousStructFieldsByTags(value *Opcode, tags runtime.StructTags) *Opcode { - head := value - curField := head - removedFields := map[*Opcode]struct{}{} - for curField != nil { - existsKey := tags.ExistsKey(curField.DisplayKey) - if !existsKey || curField.Next.IsRecursiveOp() { - curField = curField.NextField - continue - } - diff := curField.NextField.DisplayIdx - curField.DisplayIdx - for i := uint32(0); i < diff; i++ { - curField.NextField.decOpcodeIndex() - } - if curField.IsStructHeadOp() || head == curField { - head = curField.NextField - } else { - linkPrevToNextField(curField, removedFields) - } - curField = curField.NextField - } - return head -} - -func anonymousStructFieldPairMap(named string, valueCode *Opcode) map[string][]structFieldPair { - anonymousFields := map[string][]structFieldPair{} - f := valueCode - var prevAnonymousField *Opcode - for { - isHeadOp := strings.Contains(f.Op.String(), "Head") - if isHeadOp && (f.Flags&AnonymousHeadFlags) == 0 { - if named == "" { - f.Flags |= AnonymousHeadFlags - } - } else if named == "" && f.Op == OpStructEnd { - f.Op = OpStructAnonymousEnd - } - if f.DisplayKey == "" { - if f.NextField == nil { - break - } - prevAnonymousField = f - f = f.NextField - continue - } - - key := fmt.Sprintf("%s.%s", named, f.DisplayKey) - anonymousFields[key] = append(anonymousFields[key], structFieldPair{ - prevField: prevAnonymousField, - curField: f, - isTaggedKey: (f.Flags & IsTaggedKeyFlags) != 0, - }) - if f.Next != nil && f.NextField != f.Next && f.Next.Op.CodeType() == CodeStructField { - for k, v := range anonymousFieldPairRecursively(named, f.Next) { - anonymousFields[k] = append(anonymousFields[k], v...) - } - } - if f.NextField == nil { - break - } - prevAnonymousField = f - f = f.NextField - } - return anonymousFields -} - -func anonymousFieldPairRecursively(named string, valueCode *Opcode) map[string][]structFieldPair { - anonymousFields := map[string][]structFieldPair{} - f := valueCode - var prevAnonymousField *Opcode - for { - if f.DisplayKey != "" && (f.Flags&AnonymousHeadFlags) != 0 { - key := fmt.Sprintf("%s.%s", named, f.DisplayKey) - anonymousFields[key] = append(anonymousFields[key], structFieldPair{ - prevField: prevAnonymousField, - curField: f, - isTaggedKey: (f.Flags & IsTaggedKeyFlags) != 0, - }) - if f.Next != nil && f.NextField != f.Next && f.Next.Op.CodeType() == CodeStructField { - for k, v := range anonymousFieldPairRecursively(named, f.Next) { - anonymousFields[k] = append(anonymousFields[k], v...) - } - } - } - if f.NextField == nil { - break - } - prevAnonymousField = f - f = f.NextField - } - return anonymousFields -} - -func optimizeConflictAnonymousFields(anonymousFields map[string][]structFieldPair) { - removedFields := map[*Opcode]struct{}{} - for _, fieldPairs := range anonymousFields { - if len(fieldPairs) == 1 { - continue - } - // conflict anonymous fields - taggedPairs := []structFieldPair{} - for _, fieldPair := range fieldPairs { - if fieldPair.isTaggedKey { - taggedPairs = append(taggedPairs, fieldPair) - } else { - if !fieldPair.linked { - if fieldPair.prevField == nil { - // head operation - fieldPair.curField.Op = OpStructHead - fieldPair.curField.Flags |= AnonymousHeadFlags - fieldPair.curField.Flags |= AnonymousKeyFlags - } else { - diff := fieldPair.curField.NextField.DisplayIdx - fieldPair.curField.DisplayIdx - for i := uint32(0); i < diff; i++ { - fieldPair.curField.NextField.decOpcodeIndex() - } - removedFields[fieldPair.curField] = struct{}{} - linkPrevToNextField(fieldPair.curField, removedFields) - } - fieldPair.linked = true - } - } - } - if len(taggedPairs) > 1 { - for _, fieldPair := range taggedPairs { - if !fieldPair.linked { - if fieldPair.prevField == nil { - // head operation - fieldPair.curField.Op = OpStructHead - fieldPair.curField.Flags |= AnonymousHeadFlags - fieldPair.curField.Flags |= AnonymousKeyFlags - } else { - diff := fieldPair.curField.NextField.DisplayIdx - fieldPair.curField.DisplayIdx - removedFields[fieldPair.curField] = struct{}{} - for i := uint32(0); i < diff; i++ { - fieldPair.curField.NextField.decOpcodeIndex() - } - linkPrevToNextField(fieldPair.curField, removedFields) - } - fieldPair.linked = true - } - } - } else { - for _, fieldPair := range taggedPairs { - fieldPair.curField.Flags &= ^IsTaggedKeyFlags - } - } - } -} - -func isNilableType(typ *runtime.Type) bool { - switch typ.Kind() { - case reflect.Ptr: - return true - case reflect.Map: - return true - case reflect.Func: - return true - default: - return false - } -} - -func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { - if code := compiledCode(ctx); code != nil { - return code, nil - } - typ := ctx.typ - typeptr := uintptr(unsafe.Pointer(typ)) - compiled := &CompiledCode{} - ctx.structTypeToCompiledCode[typeptr] = compiled - // header => code => structField => code => end - // ^ | - // |__________| - fieldNum := typ.NumField() indirect := runtime.IfaceIndir(typ) - fieldIdx := 0 - disableIndirectConversion := false - var ( - head *Opcode - code *Opcode - prevField *Opcode - ) - ctx = ctx.incIndent() + code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect} + ctx.structTypeToCode[typeptr] = code + + fieldNum := typ.NumField() + tags := typeToStructTags(typ) + fields := []*StructFieldCode{} + for i, tag := range tags { + isOnlyOneFirstField := i == 0 && fieldNum == 1 + field, err := code.compileStructField(ctx, tag, isPtr, isOnlyOneFirstField) + if err != nil { + return nil, err + } + if field.isAnonymous { + structCode := field.getAnonymousStruct() + if structCode != nil { + structCode.removeFieldsByTags(tags) + if isAssignableIndirect(field, isPtr) { + if indirect { + structCode.isIndirect = true + } else { + structCode.isIndirect = false + } + } + } + } else { + structCode := field.getStruct() + if structCode != nil { + if indirect { + // if parent is indirect type, set child indirect property to true + structCode.isIndirect = true + } else { + // if parent is not indirect type, set child indirect property to false. + // but if parent's indirect is false and isPtr is true, then indirect must be true. + // Do this only if indirectConversion is enabled at the end of compileStruct. + structCode.isIndirect = false + } + } + } + fields = append(fields, field) + } + fieldMap := getFieldMap(fields) + duplicatedFieldMap := getDuplicatedFieldMap(fieldMap) + code.fields = filteredDuplicatedFields(fields, duplicatedFieldMap) + if !code.disableIndirectConversion && !indirect && isPtr { + code.enableIndirect() + } + delete(ctx.structTypeToCode, typeptr) + return code, nil +} + +func getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode { + fieldMap := map[string][]*StructFieldCode{} + for _, field := range fields { + if field.isAnonymous { + for k, v := range getAnonymousFieldMap(field) { + fieldMap[k] = append(fieldMap[k], v...) + } + continue + } + fieldMap[field.key] = append(fieldMap[field.key], field) + } + return fieldMap +} + +func getAnonymousFieldMap(field *StructFieldCode) map[string][]*StructFieldCode { + fieldMap := map[string][]*StructFieldCode{} + structCode := field.getAnonymousStruct() + if structCode == nil || structCode.isRecursive { + fieldMap[field.key] = append(fieldMap[field.key], field) + return fieldMap + } + for k, v := range getFieldMapFromAnonymousParent(structCode.fields) { + fieldMap[k] = append(fieldMap[k], v...) + } + return fieldMap +} + +func getFieldMapFromAnonymousParent(fields []*StructFieldCode) map[string][]*StructFieldCode { + fieldMap := map[string][]*StructFieldCode{} + for _, field := range fields { + if field.isAnonymous { + for k, v := range getAnonymousFieldMap(field) { + // Do not handle tagged key when embedding more than once + for _, vv := range v { + vv.isTaggedKey = false + } + fieldMap[k] = append(fieldMap[k], v...) + } + continue + } + fieldMap[field.key] = append(fieldMap[field.key], field) + } + return fieldMap +} + +func getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} { + duplicatedFieldMap := map[*StructFieldCode]struct{}{} + for _, fields := range fieldMap { + if len(fields) == 1 { + continue + } + if isTaggedKeyOnly(fields) { + for _, field := range fields { + if field.isTaggedKey { + continue + } + duplicatedFieldMap[field] = struct{}{} + } + } else { + for _, field := range fields { + duplicatedFieldMap[field] = struct{}{} + } + } + } + return duplicatedFieldMap +} + +func filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode { + filteredFields := make([]*StructFieldCode, 0, len(fields)) + for _, field := range fields { + if field.isAnonymous { + structCode := field.getAnonymousStruct() + if structCode != nil && !structCode.isRecursive { + structCode.fields = filteredDuplicatedFields(structCode.fields, duplicatedFieldMap) + if len(structCode.fields) > 0 { + filteredFields = append(filteredFields, field) + } + continue + } + } + if _, exists := duplicatedFieldMap[field]; exists { + continue + } + filteredFields = append(filteredFields, field) + } + return filteredFields +} + +func isTaggedKeyOnly(fields []*StructFieldCode) bool { + var taggedKeyFieldCount int + for _, field := range fields { + if field.isTaggedKey { + taggedKeyFieldCount++ + } + } + return taggedKeyFieldCount == 1 +} + +func typeToStructTags(typ *runtime.Type) runtime.StructTags { tags := runtime.StructTags{} - anonymousFields := map[string][]structFieldPair{} + fieldNum := typ.NumField() for i := 0; i < fieldNum; i++ { field := typ.Field(i) if runtime.IsIgnoredStructField(field) { @@ -1316,218 +808,17 @@ func compileStruct(ctx *compileContext, isPtr bool) (*Opcode, error) { } tags = append(tags, runtime.StructTagFromField(field)) } - for i, tag := range tags { - field := tag.Field - fieldType := runtime.Type2RType(field.Type) - fieldOpcodeIndex := ctx.opcodeIndex - fieldPtrIndex := ctx.ptrIndex - ctx.incIndex() - - nilcheck := true - addrForMarshaler := false - isIndirectSpecialCase := isPtr && i == 0 && fieldNum == 1 - isNilableType := isNilableType(fieldType) - - var valueCode *Opcode - switch { - case isIndirectSpecialCase && !isNilableType && isPtrMarshalJSONType(fieldType): - // *struct{ field T } => struct { field *T } - // func (*T) MarshalJSON() ([]byte, error) - // move pointer position from head to first field - code, err := compileMarshalJSON(ctx.withType(fieldType)) - if err != nil { - return nil, err - } - addrForMarshaler = true - valueCode = code - nilcheck = false - indirect = false - disableIndirectConversion = true - case isIndirectSpecialCase && !isNilableType && isPtrMarshalTextType(fieldType): - // *struct{ field T } => struct { field *T } - // func (*T) MarshalText() ([]byte, error) - // move pointer position from head to first field - code, err := compileMarshalText(ctx.withType(fieldType)) - if err != nil { - return nil, err - } - addrForMarshaler = true - valueCode = code - nilcheck = false - indirect = false - disableIndirectConversion = true - case isPtr && isPtrMarshalJSONType(fieldType): - // *struct{ field T } - // func (*T) MarshalJSON() ([]byte, error) - code, err := compileMarshalJSON(ctx.withType(fieldType)) - if err != nil { - return nil, err - } - addrForMarshaler = true - nilcheck = false - valueCode = code - case isPtr && isPtrMarshalTextType(fieldType): - // *struct{ field T } - // func (*T) MarshalText() ([]byte, error) - code, err := compileMarshalText(ctx.withType(fieldType)) - if err != nil { - return nil, err - } - addrForMarshaler = true - nilcheck = false - valueCode = code - default: - code, err := compile(ctx.withType(fieldType), isPtr) - if err != nil { - return nil, err - } - valueCode = code - } - - if field.Anonymous && !tag.IsTaggedKey { - valueCode = filterAnonymousStructFieldsByTags(valueCode, tags) - for k, v := range anonymousStructFieldPairMap("", valueCode) { - anonymousFields[k] = append(anonymousFields[k], v...) - } - - valueCode.decIndent() - - // fix issue144 - if !(isPtr && strings.Contains(valueCode.Op.String(), "Marshal")) { - if indirect { - valueCode.Flags |= IndirectFlags - } else { - valueCode.Flags &= ^IndirectFlags - } - } - } else { - if indirect { - // if parent is indirect type, set child indirect property to true - valueCode.Flags |= IndirectFlags - } else { - // if parent is not indirect type, set child indirect property to false. - // but if parent's indirect is false and isPtr is true, then indirect must be true. - // Do this only if indirectConversion is enabled at the end of compileStruct. - if i == 0 { - valueCode.Flags &= ^IndirectFlags - } - } - } - var flags OpFlags - if indirect { - flags |= IndirectFlags - } - if field.Anonymous { - flags |= AnonymousKeyFlags - } - if tag.IsTaggedKey { - flags |= IsTaggedKeyFlags - } - if nilcheck { - flags |= NilCheckFlags - } - if addrForMarshaler { - flags |= AddrForMarshalerFlags - } - if strings.Contains(valueCode.Op.String(), "Ptr") || valueCode.Op == OpInterface { - flags |= IsNextOpPtrTypeFlags - } - if isNilableType { - flags |= IsNilableTypeFlags - } - var key string - if ctx.escapeKey { - rctx := &RuntimeContext{Option: &Option{Flag: HTMLEscapeOption}} - key = fmt.Sprintf(`%s:`, string(AppendString(rctx, []byte{}, tag.Key))) - } else { - key = fmt.Sprintf(`"%s":`, tag.Key) - } - fieldCode := &Opcode{ - Idx: opcodeOffset(fieldPtrIndex), - Next: valueCode, - Flags: flags, - Key: key, - Offset: uint32(field.Offset), - Type: valueCode.Type, - DisplayIdx: fieldOpcodeIndex, - Indent: ctx.indent, - DisplayKey: tag.Key, - } - if fieldIdx == 0 { - code = structHeader(ctx, fieldCode, valueCode, tag) - head = fieldCode - prevField = fieldCode - } else { - fieldCode.Idx = head.Idx - code.Next = fieldCode - code = structField(ctx, fieldCode, valueCode, tag) - prevField.NextField = fieldCode - fieldCode.PrevField = prevField - prevField = fieldCode - } - fieldIdx++ - } - - structEndCode := &Opcode{ - Op: OpStructEnd, - Type: nil, - Indent: ctx.indent, - } - - ctx = ctx.decIndent() - - // no struct field - if head == nil { - head = &Opcode{ - Op: OpStructHead, - Idx: opcodeOffset(ctx.ptrIndex), - NextField: structEndCode, - Type: typ, - DisplayIdx: ctx.opcodeIndex, - Indent: ctx.indent, - } - structEndCode.PrevField = head - ctx.incIndex() - code = head - } - - structEndCode.DisplayIdx = ctx.opcodeIndex - structEndCode.Idx = opcodeOffset(ctx.ptrIndex) - ctx.incIndex() - structEndCode.Next = newEndOp(ctx) - - if prevField != nil && prevField.NextField == nil { - prevField.NextField = structEndCode - structEndCode.PrevField = prevField - } - - head.End = structEndCode - code.Next = structEndCode - optimizeConflictAnonymousFields(anonymousFields) - ret := (*Opcode)(unsafe.Pointer(head)) - compiled.Code = ret - - delete(ctx.structTypeToCompiledCode, typeptr) - - if !disableIndirectConversion && (head.Flags&IndirectFlags == 0) && isPtr { - headCode := head - for strings.Contains(headCode.Op.String(), "Head") { - headCode.Flags |= IndirectFlags - headCode = headCode.Next - } - } - - return ret, nil + return tags } -func implementsMarshalJSONType(typ *runtime.Type) bool { - return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType) +// *struct{ field T } => struct { field *T } +// func (*T) MarshalJSON() ([]byte, error) +func isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool { + return isIndirectSpecialCase && !isNilableType(typ) && isPtrMarshalJSONType(typ) } -func isPtrMarshalJSONType(typ *runtime.Type) bool { - return !implementsMarshalJSONType(typ) && implementsMarshalJSONType(runtime.PtrTo(typ)) -} - -func isPtrMarshalTextType(typ *runtime.Type) bool { - return !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType) +// *struct{ field T } => struct { field *T } +// func (*T) MarshalText() ([]byte, error) +func isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool { + return isIndirectSpecialCase && !isNilableType(typ) && isPtrMarshalTextType(typ) }