Support encoding for map type

This commit is contained in:
Masaaki Goshima 2020-04-21 00:06:27 +09:00
parent f1ad87fd71
commit ed5ee07fdf
2 changed files with 92 additions and 17 deletions

View File

@ -123,10 +123,11 @@ type interfaceHeader struct {
func (e *Encoder) Encode(v interface{}) ([]byte, error) { func (e *Encoder) Encode(v interface{}) ([]byte, error) {
header := (*interfaceHeader)(unsafe.Pointer(&v)) header := (*interfaceHeader)(unsafe.Pointer(&v))
return e.encode(reflect.TypeOf(v), header.ptr) return e.encode(reflect.ValueOf(v), header.ptr)
} }
func (e *Encoder) encode(typ reflect.Type, ptr unsafe.Pointer) ([]byte, error) { func (e *Encoder) encode(v reflect.Value, ptr unsafe.Pointer) ([]byte, error) {
typ := v.Type()
name := typ.String() name := typ.String()
if op, exists := cachedEncodeOp[name]; exists { if op, exists := cachedEncodeOp[name]; exists {
op(e, uintptr(ptr)) op(e, uintptr(ptr))
@ -135,9 +136,9 @@ func (e *Encoder) encode(typ reflect.Type, ptr unsafe.Pointer) ([]byte, error) {
return copied, nil return copied, nil
} }
if typ.Kind() == reflect.Ptr { if typ.Kind() == reflect.Ptr {
typ = typ.Elem() v = v.Elem()
} }
op, err := e.compile(typ) op, err := e.compile(v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -150,14 +151,16 @@ func (e *Encoder) encode(typ reflect.Type, ptr unsafe.Pointer) ([]byte, error) {
return copied, nil return copied, nil
} }
func (e *Encoder) compile(typ reflect.Type) (EncodeOp, error) { func (e *Encoder) compile(v reflect.Value) (EncodeOp, error) {
switch typ.Kind() { switch v.Type().Kind() {
case reflect.Ptr: case reflect.Ptr:
return e.compilePtr(typ) return e.compilePtr(v)
case reflect.Slice: case reflect.Slice:
return e.compileSlice(typ) return e.compileSlice(v)
case reflect.Struct: case reflect.Struct:
return e.compileStruct(typ) return e.compileStruct(v)
case reflect.Map:
return e.compileMap(v)
case reflect.Int: case reflect.Int:
return e.compileInt() return e.compileInt()
case reflect.Int8: case reflect.Int8:
@ -187,12 +190,11 @@ func (e *Encoder) compile(typ reflect.Type) (EncodeOp, error) {
case reflect.Bool: case reflect.Bool:
return e.compileBool() return e.compileBool()
} }
return nil, xerrors.Errorf("failed to compile %s: %w", typ, ErrUnknownType) return nil, xerrors.Errorf("failed to compile %s: %w", v.Type(), ErrUnknownType)
} }
func (e *Encoder) compilePtr(typ reflect.Type) (EncodeOp, error) { func (e *Encoder) compilePtr(v reflect.Value) (EncodeOp, error) {
elem := typ.Elem() op, err := e.compile(v.Elem())
op, err := e.compile(elem)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -257,9 +259,14 @@ func (e *Encoder) compileBool() (EncodeOp, error) {
return func(enc *Encoder, p uintptr) { enc.EncodeBool(e.ptrToBool(p)) }, nil return func(enc *Encoder, p uintptr) { enc.EncodeBool(e.ptrToBool(p)) }, nil
} }
func (e *Encoder) compileSlice(typ reflect.Type) (EncodeOp, error) { func (e *Encoder) zeroValue(typ reflect.Type) reflect.Value {
return reflect.New(typ).Elem()
}
func (e *Encoder) compileSlice(v reflect.Value) (EncodeOp, error) {
typ := v.Type()
size := typ.Elem().Size() size := typ.Elem().Size()
op, err := e.compile(typ.Elem()) op, err := e.compile(e.zeroValue(typ.Elem()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -282,7 +289,8 @@ func (e *Encoder) compileSlice(typ reflect.Type) (EncodeOp, error) {
} }
func (e *Encoder) compileStruct(typ reflect.Type) (EncodeOp, error) { func (e *Encoder) compileStruct(v reflect.Value) (EncodeOp, error) {
typ := v.Type()
fieldNum := typ.NumField() fieldNum := typ.NumField()
opQueue := make([]EncodeOp, 0, fieldNum) opQueue := make([]EncodeOp, 0, fieldNum)
for i := 0; i < fieldNum; i++ { for i := 0; i < fieldNum; i++ {
@ -295,7 +303,7 @@ func (e *Encoder) compileStruct(typ reflect.Type) (EncodeOp, error) {
keyName = opts[0] keyName = opts[0]
} }
} }
op, err := e.compile(typ.Field(i).Type) op, err := e.compile(v.Field(i))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -322,6 +330,59 @@ func (e *Encoder) compileStruct(typ reflect.Type) (EncodeOp, error) {
}, nil }, nil
} }
//go:linkname mapiterinit reflect.mapiterinit
func mapiterinit(mapType unsafe.Pointer, m unsafe.Pointer) unsafe.Pointer
//go:linkname mapiterkey reflect.mapiterkey
func mapiterkey(it unsafe.Pointer) unsafe.Pointer
//go:linkname mapitervalue reflect.mapitervalue
func mapitervalue(it unsafe.Pointer) unsafe.Pointer
//go:linkname mapiternext reflect.mapiternext
func mapiternext(it unsafe.Pointer)
//go:linkname maplen reflect.maplen
func maplen(m unsafe.Pointer) int
type valueType struct {
typ unsafe.Pointer
ptr unsafe.Pointer
}
func (e *Encoder) compileMap(v reflect.Value) (EncodeOp, error) {
mapType := (*valueType)(unsafe.Pointer(&v)).typ
keyOp, err := e.compile(e.zeroValue(v.Type().Key()))
if err != nil {
return nil, err
}
valueOp, err := e.compile(e.zeroValue(v.Type().Elem()))
if err != nil {
return nil, err
}
return func(enc *Encoder, base uintptr) {
if base == 0 {
enc.EncodeString("null")
return
}
enc.EncodeByte('{')
mlen := maplen(unsafe.Pointer(base))
iter := mapiterinit(mapType, unsafe.Pointer(base))
for i := 0; i < mlen; i++ {
key := mapiterkey(iter)
if i != 0 {
enc.EncodeByte(',')
}
value := mapitervalue(iter)
keyOp(enc, uintptr(key))
enc.EncodeByte(':')
valueOp(enc, uintptr(value))
mapiternext(iter)
}
enc.EncodeByte('}')
}, nil
}
func (e *Encoder) ptrToPtr(p uintptr) uintptr { return *(*uintptr)(unsafe.Pointer(p)) } func (e *Encoder) ptrToPtr(p uintptr) uintptr { return *(*uintptr)(unsafe.Pointer(p)) }
func (e *Encoder) ptrToInt(p uintptr) int { return *(*int)(unsafe.Pointer(p)) } func (e *Encoder) ptrToInt(p uintptr) int { return *(*int)(unsafe.Pointer(p)) }
func (e *Encoder) ptrToInt8(p uintptr) int8 { return *(*int8)(unsafe.Pointer(p)) } func (e *Encoder) ptrToInt8(p uintptr) int8 { return *(*int8)(unsafe.Pointer(p)) }

View File

@ -159,4 +159,18 @@ func Test_Encoder(t *testing.T) {
t.Fatal("failed to encode slice of int") t.Fatal("failed to encode slice of int")
} }
}) })
t.Run("map", func(t *testing.T) {
bytes, err := json.Marshal(map[string]int{
"a": 1,
"b": 2,
"c": 3,
"d": 4,
})
if err != nil {
t.Fatalf("%+v", err)
}
if len(string(bytes)) != len(`{"a":1,"b":2,"c":3,"d":4}`) {
t.Fatal("failed to encode map of string/int")
}
})
} }