diff --git a/internal/decoder/map.go b/internal/decoder/map.go index 367a8a1..dd480e1 100644 --- a/internal/decoder/map.go +++ b/internal/decoder/map.go @@ -1,6 +1,7 @@ package decoder import ( + "reflect" "unsafe" "github.com/goccy/go-json/internal/errors" @@ -8,33 +9,49 @@ import ( ) type mapDecoder struct { - mapType *runtime.Type - keyType *runtime.Type - valueType *runtime.Type - keyDecoder Decoder - valueDecoder Decoder - structName string - fieldName string + mapType *runtime.Type + keyType *runtime.Type + valueType *runtime.Type + stringKeyType bool + keyDecoder Decoder + valueDecoder Decoder + structName string + fieldName string } func newMapDecoder(mapType *runtime.Type, keyType *runtime.Type, keyDec Decoder, valueType *runtime.Type, valueDec Decoder, structName, fieldName string) *mapDecoder { return &mapDecoder{ - mapType: mapType, - keyDecoder: keyDec, - keyType: keyType, - valueType: valueType, - valueDecoder: valueDec, - structName: structName, - fieldName: fieldName, + mapType: mapType, + keyDecoder: keyDec, + keyType: keyType, + stringKeyType: keyType.Kind() == reflect.String, + valueType: valueType, + valueDecoder: valueDec, + structName: structName, + fieldName: fieldName, } } //go:linkname makemap reflect.makemap func makemap(*runtime.Type, int) unsafe.Pointer +//nolint:golint +//go:linkname mapassign_faststr runtime.mapassign_faststr +//go:noescape +func mapassign_faststr(t *runtime.Type, m unsafe.Pointer, s string) unsafe.Pointer + //go:linkname mapassign reflect.mapassign //go:noescape -func mapassign(t *runtime.Type, m unsafe.Pointer, key, val unsafe.Pointer) +func mapassign(t *runtime.Type, m unsafe.Pointer, k, v unsafe.Pointer) + +func (d *mapDecoder) mapassign(t *runtime.Type, m, k, v unsafe.Pointer) { + if d.stringKeyType { + mapV := mapassign_faststr(d.mapType, m, *(*string)(k)) + typedmemmove(d.valueType, mapV, v) + } else { + mapassign(t, m, k, v) + } +} func (d *mapDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) error { depth++ @@ -77,7 +94,7 @@ func (d *mapDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) erro if err := d.valueDecoder.DecodeStream(s, depth, v); err != nil { return err } - mapassign(d.mapType, mapValue, k, v) + d.mapassign(d.mapType, mapValue, k, v) s.skipWhiteSpace() if s.equalChar('}') { **(**unsafe.Pointer)(unsafe.Pointer(&p)) = mapValue @@ -141,7 +158,7 @@ func (d *mapDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsafe.P if err != nil { return 0, err } - mapassign(d.mapType, mapValue, k, v) + d.mapassign(d.mapType, mapValue, k, v) cursor = skipWhiteSpace(buf, valueCursor) if buf[cursor] == '}' { **(**unsafe.Pointer)(unsafe.Pointer(&p)) = mapValue