diff --git a/encode.go b/encode.go index fa4b588..30bf31f 100644 --- a/encode.go +++ b/encode.go @@ -123,10 +123,11 @@ type interfaceHeader struct { func (e *Encoder) Encode(v interface{}) ([]byte, error) { header := (*interfaceHeader)(unsafe.Pointer(&v)) - return e.encode(reflect.TypeOf(v), header.ptr) + return e.encode(reflect.ValueOf(v), header.ptr) } -func (e *Encoder) encode(typ reflect.Type, ptr unsafe.Pointer) ([]byte, error) { +func (e *Encoder) encode(v reflect.Value, ptr unsafe.Pointer) ([]byte, error) { + typ := v.Type() name := typ.String() if op, exists := cachedEncodeOp[name]; exists { op(e, uintptr(ptr)) @@ -135,9 +136,9 @@ func (e *Encoder) encode(typ reflect.Type, ptr unsafe.Pointer) ([]byte, error) { return copied, nil } if typ.Kind() == reflect.Ptr { - typ = typ.Elem() + v = v.Elem() } - op, err := e.compile(typ) + op, err := e.compile(v) if err != nil { return nil, err } @@ -150,14 +151,16 @@ func (e *Encoder) encode(typ reflect.Type, ptr unsafe.Pointer) ([]byte, error) { return copied, nil } -func (e *Encoder) compile(typ reflect.Type) (EncodeOp, error) { - switch typ.Kind() { +func (e *Encoder) compile(v reflect.Value) (EncodeOp, error) { + switch v.Type().Kind() { case reflect.Ptr: - return e.compilePtr(typ) + return e.compilePtr(v) case reflect.Slice: - return e.compileSlice(typ) + return e.compileSlice(v) case reflect.Struct: - return e.compileStruct(typ) + return e.compileStruct(v) + case reflect.Map: + return e.compileMap(v) case reflect.Int: return e.compileInt() case reflect.Int8: @@ -187,12 +190,11 @@ func (e *Encoder) compile(typ reflect.Type) (EncodeOp, error) { case reflect.Bool: return e.compileBool() } - return nil, xerrors.Errorf("failed to compile %s: %w", typ, ErrUnknownType) + return nil, xerrors.Errorf("failed to compile %s: %w", v.Type(), ErrUnknownType) } -func (e *Encoder) compilePtr(typ reflect.Type) (EncodeOp, error) { - elem := typ.Elem() - op, err := e.compile(elem) +func (e *Encoder) compilePtr(v reflect.Value) (EncodeOp, error) { + op, err := e.compile(v.Elem()) if err != nil { return nil, err } @@ -257,9 +259,14 @@ func (e *Encoder) compileBool() (EncodeOp, error) { return func(enc *Encoder, p uintptr) { enc.EncodeBool(e.ptrToBool(p)) }, nil } -func (e *Encoder) compileSlice(typ reflect.Type) (EncodeOp, error) { +func (e *Encoder) zeroValue(typ reflect.Type) reflect.Value { + return reflect.New(typ).Elem() +} + +func (e *Encoder) compileSlice(v reflect.Value) (EncodeOp, error) { + typ := v.Type() size := typ.Elem().Size() - op, err := e.compile(typ.Elem()) + op, err := e.compile(e.zeroValue(typ.Elem())) if err != nil { return nil, err } @@ -282,7 +289,8 @@ func (e *Encoder) compileSlice(typ reflect.Type) (EncodeOp, error) { } -func (e *Encoder) compileStruct(typ reflect.Type) (EncodeOp, error) { +func (e *Encoder) compileStruct(v reflect.Value) (EncodeOp, error) { + typ := v.Type() fieldNum := typ.NumField() opQueue := make([]EncodeOp, 0, fieldNum) for i := 0; i < fieldNum; i++ { @@ -295,7 +303,7 @@ func (e *Encoder) compileStruct(typ reflect.Type) (EncodeOp, error) { keyName = opts[0] } } - op, err := e.compile(typ.Field(i).Type) + op, err := e.compile(v.Field(i)) if err != nil { return nil, err } @@ -322,6 +330,59 @@ func (e *Encoder) compileStruct(typ reflect.Type) (EncodeOp, error) { }, nil } +//go:linkname mapiterinit reflect.mapiterinit +func mapiterinit(mapType unsafe.Pointer, m unsafe.Pointer) unsafe.Pointer + +//go:linkname mapiterkey reflect.mapiterkey +func mapiterkey(it unsafe.Pointer) unsafe.Pointer + +//go:linkname mapitervalue reflect.mapitervalue +func mapitervalue(it unsafe.Pointer) unsafe.Pointer + +//go:linkname mapiternext reflect.mapiternext +func mapiternext(it unsafe.Pointer) + +//go:linkname maplen reflect.maplen +func maplen(m unsafe.Pointer) int + +type valueType struct { + typ unsafe.Pointer + ptr unsafe.Pointer +} + +func (e *Encoder) compileMap(v reflect.Value) (EncodeOp, error) { + mapType := (*valueType)(unsafe.Pointer(&v)).typ + keyOp, err := e.compile(e.zeroValue(v.Type().Key())) + if err != nil { + return nil, err + } + valueOp, err := e.compile(e.zeroValue(v.Type().Elem())) + if err != nil { + return nil, err + } + return func(enc *Encoder, base uintptr) { + if base == 0 { + enc.EncodeString("null") + return + } + enc.EncodeByte('{') + mlen := maplen(unsafe.Pointer(base)) + iter := mapiterinit(mapType, unsafe.Pointer(base)) + for i := 0; i < mlen; i++ { + key := mapiterkey(iter) + if i != 0 { + enc.EncodeByte(',') + } + value := mapitervalue(iter) + keyOp(enc, uintptr(key)) + enc.EncodeByte(':') + valueOp(enc, uintptr(value)) + mapiternext(iter) + } + enc.EncodeByte('}') + }, nil +} + func (e *Encoder) ptrToPtr(p uintptr) uintptr { return *(*uintptr)(unsafe.Pointer(p)) } func (e *Encoder) ptrToInt(p uintptr) int { return *(*int)(unsafe.Pointer(p)) } func (e *Encoder) ptrToInt8(p uintptr) int8 { return *(*int8)(unsafe.Pointer(p)) } diff --git a/encode_test.go b/encode_test.go index 06a27f6..12ddd6b 100644 --- a/encode_test.go +++ b/encode_test.go @@ -159,4 +159,18 @@ func Test_Encoder(t *testing.T) { t.Fatal("failed to encode slice of int") } }) + t.Run("map", func(t *testing.T) { + bytes, err := json.Marshal(map[string]int{ + "a": 1, + "b": 2, + "c": 3, + "d": 4, + }) + if err != nil { + t.Fatalf("%+v", err) + } + if len(string(bytes)) != len(`{"a":1,"b":2,"c":3,"d":4}`) { + t.Fatal("failed to encode map of string/int") + } + }) }