diff --git a/decode.go b/decode.go index 2b18a56..3459da5 100644 --- a/decode.go +++ b/decode.go @@ -154,6 +154,8 @@ func (d *Decoder) compile(typ *rtype) (decoder, error) { return d.compileSlice(typ) case reflect.Array: return d.compileArray(typ) + case reflect.Map: + return d.compileMap(typ) case reflect.Int: return d.compileInt() case reflect.Int8: @@ -292,6 +294,18 @@ func (d *Decoder) compileArray(typ *rtype) (decoder, error) { return newArrayDecoder(decoder, elem, typ.Len()), nil } +func (d *Decoder) compileMap(typ *rtype) (decoder, error) { + keyDec, err := d.compile(typ.Key()) + if err != nil { + return nil, err + } + valueDec, err := d.compile(typ.Elem()) + if err != nil { + return nil, err + } + return newMapDecoder(typ, keyDec, valueDec), nil +} + func (d *Decoder) getTag(field reflect.StructField) string { return field.Tag.Get("json") } diff --git a/decode_map.go b/decode_map.go new file mode 100644 index 0000000..a4a1d72 --- /dev/null +++ b/decode_map.go @@ -0,0 +1,90 @@ +package json + +import ( + "errors" + "fmt" + "unsafe" +) + +type mapDecoder struct { + mapType *rtype + keyDecoder decoder + valueDecoder decoder +} + +func newMapDecoder(mapType *rtype, keyDec decoder, valueDec decoder) *mapDecoder { + return &mapDecoder{ + mapType: mapType, + keyDecoder: keyDec, + valueDecoder: valueDec, + } +} + +//go:linkname makemap reflect.makemap +func makemap(*rtype, int) unsafe.Pointer + +//go:linkname mapassign reflect.mapassign +//go:noescape +func mapassign(t *rtype, m unsafe.Pointer, key, val unsafe.Pointer) + +func (d *mapDecoder) setKey(ctx *context, key interface{}) error { + header := (*interfaceHeader)(unsafe.Pointer(&key)) + if err := d.keyDecoder.decode(ctx, uintptr(header.ptr)); err != nil { + return err + } + //fmt.Println("key = ", *(*string)(header.ptr)) + //fmt.Println("Key = ", key.(*string)) + //fmt.Println("key = ", *(*string)(unsafe.Pointer(key))) + return nil +} + +func (d *mapDecoder) setValue(ctx *context, key interface{}) error { + header := (*interfaceHeader)(unsafe.Pointer(&key)) + return d.valueDecoder.decode(ctx, uintptr(header.ptr)) +} + +func (d *mapDecoder) decode(ctx *context, p uintptr) error { + ctx.skipWhiteSpace() + buf := ctx.buf + buflen := ctx.buflen + cursor := ctx.cursor + if buflen < 2 { + return errors.New("unexpected error {}") + } + if buf[cursor] != '{' { + return errors.New("unexpected error {") + } + cursor++ + mapValue := makemap(d.mapType, 0) + fmt.Println("mapValue = ", mapValue) + for ; cursor < buflen; cursor++ { + ctx.cursor = cursor + var key interface{} + if err := d.setKey(ctx, &key); err != nil { + return err + } + cursor = ctx.skipWhiteSpace() + if buf[cursor] != ':' { + return errors.New("unexpected error invalid delimiter for object") + } + cursor++ + if cursor >= buflen { + return errors.New("unexpected error missing value") + } + ctx.cursor = cursor + var value interface{} + if err := d.setValue(ctx, &value); err != nil { + return err + } + mapassign(d.mapType, mapValue, unsafe.Pointer(&key), unsafe.Pointer(&value)) + cursor = ctx.skipWhiteSpace() + if buf[cursor] == '}' { + *(*unsafe.Pointer)(unsafe.Pointer(p)) = mapValue + return nil + } + if buf[cursor] != ',' { + return errors.New("unexpected error ,") + } + } + return nil +} diff --git a/decode_test.go b/decode_test.go index 9cac903..e72e0b9 100644 --- a/decode_test.go +++ b/decode_test.go @@ -95,6 +95,14 @@ func Test_Decoder(t *testing.T) { assertErr(t, json.Unmarshal([]byte(` [ 1 , 2 , 3 , 4 ] `), &v)) assertEq(t, "array", fmt.Sprint([4]int{1, 2, 3, 4}), fmt.Sprint(v)) }) + t.Run("map", func(t *testing.T) { + var v map[string]int + assertErr(t, json.Unmarshal([]byte(` { "a": 1, "b": 2, "c": 3, "d": 4 } `), &v)) + assertEq(t, "map.a", v["a"], 1) + assertEq(t, "map.b", v["b"], 2) + assertEq(t, "map.c", v["c"], 3) + assertEq(t, "map.d", v["d"], 4) + }) t.Run("struct", func(t *testing.T) { type T struct { AA int `json:"aa"`