Add reflect.rtype trick

This commit is contained in:
Masaaki Goshima 2020-04-24 20:23:26 +09:00
parent cfde002d29
commit 015eb040ee
5 changed files with 268 additions and 41 deletions

View File

@ -54,8 +54,8 @@ func (d *Decoder) Buffered() io.Reader {
}
func (d *Decoder) decodeForUnmarshal(src []byte, v interface{}) error {
rv := reflect.ValueOf(v)
typ := rv.Type()
header := (*interfaceHeader)(unsafe.Pointer(&v))
typ := header.typ
if typ.Kind() != reflect.Ptr {
return ErrDecodePointer
}
@ -71,7 +71,7 @@ func (d *Decoder) decodeForUnmarshal(src []byte, v interface{}) error {
}
dec = compiledDec
}
ptr := rv.Pointer()
ptr := uintptr(header.ptr)
ctx := ctxPool.Get().(*context)
ctx.setBuf(src)
if err := dec.decode(ctx, ptr); err != nil {
@ -83,8 +83,8 @@ func (d *Decoder) decodeForUnmarshal(src []byte, v interface{}) error {
}
func (d *Decoder) Decode(v interface{}) error {
rv := reflect.ValueOf(v)
typ := rv.Type()
header := (*interfaceHeader)(unsafe.Pointer(&v))
typ := header.typ
if typ.Kind() != reflect.Ptr {
return ErrDecodePointer
}
@ -100,7 +100,7 @@ func (d *Decoder) Decode(v interface{}) error {
}
dec = compiledDec
}
ptr := rv.Pointer()
ptr := uintptr(header.ptr)
ctx := ctxPool.Get().(*context)
defer ctxPool.Put(ctx)
for {
@ -120,7 +120,7 @@ func (d *Decoder) Decode(v interface{}) error {
return nil
}
func (d *Decoder) compile(typ reflect.Type) (decoder, error) {
func (d *Decoder) compile(typ *rtype) (decoder, error) {
switch typ.Kind() {
case reflect.Ptr:
return d.compilePtr(typ)
@ -158,7 +158,7 @@ func (d *Decoder) compile(typ reflect.Type) (decoder, error) {
return nil, nil
}
func (d *Decoder) compilePtr(typ reflect.Type) (decoder, error) {
func (d *Decoder) compilePtr(typ *rtype) (decoder, error) {
dec, err := d.compile(typ.Elem())
if err != nil {
return nil, err
@ -262,7 +262,7 @@ func (d *Decoder) isIgnoredStructField(field reflect.StructField) bool {
return false
}
func (d *Decoder) compileStruct(typ reflect.Type) (decoder, error) {
func (d *Decoder) compileStruct(typ *rtype) (decoder, error) {
fieldNum := typ.NumField()
fieldMap := map[string]*structFieldSet{}
for i := 0; i < fieldNum; i++ {
@ -278,7 +278,7 @@ func (d *Decoder) compileStruct(typ reflect.Type) (decoder, error) {
keyName = opts[0]
}
}
dec, err := d.compile(field.Type)
dec, err := d.compile(type2rtype(field.Type))
if err != nil {
return nil, err
}

View File

@ -1,21 +1,23 @@
package json
import (
"reflect"
"unsafe"
)
type ptrDecoder struct {
dec decoder
typ reflect.Type
typ *rtype
}
func newPtrDecoder(dec decoder, typ reflect.Type) *ptrDecoder {
func newPtrDecoder(dec decoder, typ *rtype) *ptrDecoder {
return &ptrDecoder{dec: dec, typ: typ}
}
//go:linkname unsafe_New reflect.unsafe_New
func unsafe_New(*rtype) uintptr
func (d *ptrDecoder) decode(ctx *context, p uintptr) error {
newptr := uintptr(reflect.New(d.typ).Pointer())
newptr := unsafe_New(d.typ)
if err := d.dec.decode(ctx, newptr); err != nil {
return err
}

View File

@ -54,8 +54,7 @@ func NewEncoder(w io.Writer) *Encoder {
//
// See the documentation for Marshal for details about the conversion of Go values to JSON.
func (e *Encoder) Encode(v interface{}) error {
header := (*interfaceHeader)(unsafe.Pointer(&v))
if err := e.encode(reflect.ValueOf(v), header.ptr); err != nil {
if err := e.encode(v); err != nil {
return err
}
if _, err := e.w.Write(e.buf); err != nil {
@ -65,8 +64,7 @@ func (e *Encoder) Encode(v interface{}) error {
}
func (e *Encoder) encodeForMarshal(v interface{}) ([]byte, error) {
header := (*interfaceHeader)(unsafe.Pointer(&v))
if err := e.encode(reflect.ValueOf(v), header.ptr); err != nil {
if err := e.encode(v); err != nil {
return nil, err
}
copied := make([]byte, len(e.buf))
@ -158,18 +156,12 @@ func (e *Encoder) encodeByte(b byte) {
e.buf = append(e.buf, b)
}
type rtype struct{}
type interfaceHeader struct {
typ *rtype
ptr unsafe.Pointer
}
func (e *Encoder) encode(v reflect.Value, ptr unsafe.Pointer) error {
typ := v.Type()
func (e *Encoder) encode(v interface{}) error {
header := (*interfaceHeader)(unsafe.Pointer(&v))
typ := header.typ
name := typ.String()
if op, exists := cachedEncodeOp[name]; exists {
op(e, uintptr(ptr))
op(e, uintptr(header.ptr))
return nil
}
if typ.Kind() == reflect.Ptr {
@ -182,11 +174,11 @@ func (e *Encoder) encode(v reflect.Value, ptr unsafe.Pointer) error {
if name != "" {
cachedEncodeOp[name] = op
}
op(e, uintptr(ptr))
op(e, uintptr(header.ptr))
return nil
}
func (e *Encoder) compile(typ reflect.Type) (EncodeOp, error) {
func (e *Encoder) compile(typ *rtype) (EncodeOp, error) {
switch typ.Kind() {
case reflect.Ptr:
return e.compilePtr(typ)
@ -231,10 +223,10 @@ func (e *Encoder) compile(typ reflect.Type) (EncodeOp, error) {
case reflect.Interface:
return nil, ErrCompileSlowPath
}
return nil, xerrors.Errorf("failed to encode type %s: %w", typ, ErrUnsupportedType)
return nil, xerrors.Errorf("failed to encode type %s: %w", typ.String(), ErrUnsupportedType)
}
func (e *Encoder) compilePtr(typ reflect.Type) (EncodeOp, error) {
func (e *Encoder) compilePtr(typ *rtype) (EncodeOp, error) {
op, err := e.compile(typ.Elem())
if err != nil {
return nil, err
@ -300,7 +292,7 @@ func (e *Encoder) compileBool() (EncodeOp, error) {
return func(enc *Encoder, p uintptr) { enc.encodeBool(e.ptrToBool(p)) }, nil
}
func (e *Encoder) compileSlice(typ reflect.Type) (EncodeOp, error) {
func (e *Encoder) compileSlice(typ *rtype) (EncodeOp, error) {
size := typ.Elem().Size()
op, err := e.compile(typ.Elem())
if err != nil {
@ -324,7 +316,7 @@ func (e *Encoder) compileSlice(typ reflect.Type) (EncodeOp, error) {
}, nil
}
func (e *Encoder) compileArray(typ reflect.Type) (EncodeOp, error) {
func (e *Encoder) compileArray(typ *rtype) (EncodeOp, error) {
alen := typ.Len()
size := typ.Elem().Size()
op, err := e.compile(typ.Elem())
@ -363,7 +355,7 @@ func (e *Encoder) isIgnoredStructField(field reflect.StructField) bool {
return false
}
func (e *Encoder) compileStruct(typ reflect.Type) (EncodeOp, error) {
func (e *Encoder) compileStruct(typ *rtype) (EncodeOp, error) {
fieldNum := typ.NumField()
opQueue := make([]EncodeOp, 0, fieldNum)
for i := 0; i < fieldNum; i++ {
@ -379,7 +371,8 @@ func (e *Encoder) compileStruct(typ reflect.Type) (EncodeOp, error) {
keyName = opts[0]
}
}
op, err := e.compile(field.Type)
fieldType := type2rtype(field.Type)
op, err := e.compile(fieldType)
if err != nil {
return nil, err
}
@ -426,9 +419,8 @@ type valueType struct {
ptr unsafe.Pointer
}
func (e *Encoder) compileMap(typ reflect.Type) (EncodeOp, error) {
v := reflect.New(typ).Elem()
mapType := (*valueType)(unsafe.Pointer(&v)).typ
func (e *Encoder) compileMap(typ *rtype) (EncodeOp, error) {
mapType := unsafe.Pointer(typ)
keyOp, err := e.compile(typ.Key())
if err != nil {
return nil, err

View File

@ -5,8 +5,13 @@ import "bytes"
func Marshal(v interface{}) ([]byte, error) {
var b *bytes.Buffer
enc := NewEncoder(b)
defer enc.release()
return enc.encodeForMarshal(v)
bytes, err := enc.encodeForMarshal(v)
if err != nil {
enc.release()
return nil, err
}
enc.release()
return bytes, nil
}
func Unmarshal(data []byte, v interface{}) error {

228
rtype.go Normal file
View File

@ -0,0 +1,228 @@
package json
import (
"reflect"
"unsafe"
)
// rtype representing reflect.rtype for noescape trick
type rtype struct{}
//go:linkname rtype_Align reflect.(*rtype).Align
func rtype_Align(*rtype) int
func (t *rtype) Align() int {
return rtype_Align(t)
}
//go:linkname rtype_FieldAlign reflect.(*rtype).FieldAlign
func rtype_FieldAlign(*rtype) int
func (t *rtype) FieldAlign() int {
return rtype_FieldAlign(t)
}
//go:linkname rtype_Method reflect.(*rtype).Method
func rtype_Method(*rtype, int) reflect.Method
func (t *rtype) Method(a0 int) reflect.Method {
return rtype_Method(t, a0)
}
//go:linkname rtype_MethodByName reflect.(*rtype).MethodByName
func rtype_MethodByName(*rtype, string) (reflect.Method, bool)
func (t *rtype) MethodByName(a0 string) (reflect.Method, bool) {
return rtype_MethodByName(t, a0)
}
//go:linkname rtype_NumMethod reflect.(*rtype).NumMethod
func rtype_NumMethod(*rtype) int
func (t *rtype) NumMethod() int {
return rtype_NumMethod(t)
}
//go:linkname rtype_Name reflect.(*rtype).Name
//go:noescape
func rtype_Name(*rtype) string
func (t *rtype) Name() string {
return rtype_Name(t)
}
//go:linkname rtype_PkgPath reflect.(*rtype).PkgPath
func rtype_PkgPath(*rtype) string
func (t *rtype) PkgPath() string {
return rtype_PkgPath(t)
}
//go:linkname rtype_Size reflect.(*rtype).Size
//go:noescape
func rtype_Size(*rtype) uintptr
func (t *rtype) Size() uintptr {
return rtype_Size(t)
}
//go:linkname rtype_String reflect.(*rtype).String
//go:noescape
func rtype_String(*rtype) string
func (t *rtype) String() string {
return rtype_String(t)
}
//go:linkname rtype_Kind reflect.(*rtype).Kind
//go:noescape
func rtype_Kind(*rtype) reflect.Kind
func (t *rtype) Kind() reflect.Kind {
return rtype_Kind(t)
}
//go:linkname rtype_Implements reflect.(*rtype).Implements
func rtype_Implements(*rtype, reflect.Type) bool
func (t *rtype) Implements(u reflect.Type) bool {
return rtype_Implements(t, u)
}
//go:linkname rtype_AssignableTo reflect.(*rtype).AssignableTo
func rtype_AssignableTo(*rtype, reflect.Type) bool
func (t *rtype) AssignableTo(u reflect.Type) bool {
return rtype_AssignableTo(t, u)
}
//go:linkname rtype_ConvertibleTo reflect.(*rtype).ConvertibleTo
func rtype_ConvertibleTo(*rtype, reflect.Type) bool
func (t *rtype) ConvertibleTo(u reflect.Type) bool {
return rtype_ConvertibleTo(t, u)
}
//go:linkname rtype_Comparable reflect.(*rtype).Comparable
func rtype_Comparable(*rtype) bool
func (t *rtype) Comparable() bool {
return rtype_Comparable(t)
}
//go:linkname rtype_Bits reflect.(*rtype).Bits
func rtype_Bits(*rtype) int
func (t *rtype) Bits() int {
return rtype_Bits(t)
}
//go:linkname rtype_ChanDir reflect.(*rtype).ChanDir
func rtype_ChanDir(*rtype) reflect.ChanDir
func (t *rtype) ChanDir() reflect.ChanDir {
return rtype_ChanDir(t)
}
//go:linkname rtype_IsVariadic reflect.(*rtype).IsVariadic
func rtype_IsVariadic(*rtype) bool
func (t *rtype) IsVariadic() bool {
return rtype_IsVariadic(t)
}
//go:linkname rtype_Elem reflect.(*rtype).Elem
//go:noescape
func rtype_Elem(*rtype) reflect.Type
func (t *rtype) Elem() *rtype {
return type2rtype(rtype_Elem(t))
}
//go:linkname rtype_Field reflect.(*rtype).Field
//go:noescape
func rtype_Field(*rtype, int) reflect.StructField
func (t *rtype) Field(i int) reflect.StructField {
return rtype_Field(t, i)
}
//go:linkname rtype_FieldByIndex reflect.(*rtype).FieldByIndex
func rtype_FieldByIndex(*rtype, []int) reflect.StructField
func (t *rtype) FieldByIndex(index []int) reflect.StructField {
return rtype_FieldByIndex(t, index)
}
//go:linkname rtype_FieldByName reflect.(*rtype).FieldByName
func rtype_FieldByName(*rtype, string) (reflect.StructField, bool)
func (t *rtype) FieldByName(name string) (reflect.StructField, bool) {
return rtype_FieldByName(t, name)
}
//go:linkname rtype_FieldByNameFunc reflect.(*rtype).FieldByNameFunc
func rtype_FieldByNameFunc(*rtype, func(string) bool) (reflect.StructField, bool)
func (t *rtype) FieldByNameFunc(match func(string) bool) (reflect.StructField, bool) {
return rtype_FieldByNameFunc(t, match)
}
//go:linkname rtype_In reflect.(*rtype).In
func rtype_In(*rtype, int) reflect.Type
func (t *rtype) In(i int) reflect.Type {
return rtype_In(t, i)
}
//go:linkname rtype_Key reflect.(*rtype).Key
func rtype_Key(*rtype) reflect.Type
func (t *rtype) Key() *rtype {
return type2rtype(rtype_Key(t))
}
//go:linkname rtype_Len reflect.(*rtype).Len
//go:noescape
func rtype_Len(*rtype) int
func (t *rtype) Len() int {
return rtype_Len(t)
}
//go:linkname rtype_NumField reflect.(*rtype).NumField
func rtype_NumField(*rtype) int
func (t *rtype) NumField() int {
return rtype_NumField(t)
}
//go:linkname rtype_NumIn reflect.(*rtype).NumIn
func rtype_NumIn(*rtype) int
func (t *rtype) NumIn() int {
return rtype_NumIn(t)
}
//go:linkname rtype_NumOut reflect.(*rtype).NumOut
func rtype_NumOut(*rtype) int
func (t *rtype) NumOut() int {
return rtype_NumOut(t)
}
//go:linkname rtype_Out reflect.(*rtype).Out
func rtype_Out(*rtype, int) reflect.Type
func (t *rtype) Out(i int) reflect.Type {
return rtype_Out(t, i)
}
type interfaceHeader struct {
typ *rtype
ptr unsafe.Pointer
}
func type2rtype(t reflect.Type) *rtype {
return (*rtype)(((*interfaceHeader)(unsafe.Pointer(&t))).ptr)
}