diff --git a/decode.go b/decode.go index e9b1c31..668ffe0 100644 --- a/decode.go +++ b/decode.go @@ -258,5 +258,5 @@ func (d *Decoder) InputOffset() int64 { // UseNumber causes the Decoder to unmarshal a number into an interface{} as a // Number instead of as a float64. func (d *Decoder) UseNumber() { - + d.s.useNumber = true } diff --git a/decode_interface.go b/decode_interface.go index e6af61c..7c3066e 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -11,7 +11,20 @@ type interfaceDecoder struct { } func newInterfaceDecoder(typ *rtype) *interfaceDecoder { - return &interfaceDecoder{typ: typ} + return &interfaceDecoder{ + typ: typ, + } +} + +func (d *interfaceDecoder) numDecoder(s *stream) decoder { + if s.useNumber { + return newNumberDecoder(func(p uintptr, v Number) { + *(*interface{})(unsafe.Pointer(p)) = v + }) + } + return newFloatDecoder(func(p uintptr, v float64) { + *(*interface{})(unsafe.Pointer(p)) = v + }) } var ( @@ -28,7 +41,11 @@ func (d *interfaceDecoder) decodeStream(s *stream, p uintptr) error { var v map[interface{}]interface{} ptr := unsafe.Pointer(&v) d.dummy = ptr - dec := newMapDecoder(interfaceMapType, newInterfaceDecoder(d.typ), newInterfaceDecoder(d.typ)) + dec := newMapDecoder( + interfaceMapType, + newInterfaceDecoder(d.typ), + newInterfaceDecoder(d.typ), + ) if err := dec.decodeStream(s, uintptr(ptr)); err != nil { return err } @@ -38,16 +55,18 @@ func (d *interfaceDecoder) decodeStream(s *stream, p uintptr) error { var v []interface{} ptr := unsafe.Pointer(&v) d.dummy = ptr // escape ptr - dec := newSliceDecoder(newInterfaceDecoder(d.typ), d.typ, d.typ.Size()) + dec := newSliceDecoder( + newInterfaceDecoder(d.typ), + d.typ, + d.typ.Size(), + ) if err := dec.decodeStream(s, uintptr(ptr)); err != nil { return err } *(*interface{})(unsafe.Pointer(p)) = v return nil case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': - return newFloatDecoder(func(p uintptr, v float64) { - *(*interface{})(unsafe.Pointer(p)) = v - }).decodeStream(s, p) + return d.numDecoder(s).decodeStream(s, p) case '"': s.cursor++ start := s.cursor @@ -104,7 +123,11 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p uintptr) (int64, e var v map[interface{}]interface{} ptr := unsafe.Pointer(&v) d.dummy = ptr - dec := newMapDecoder(interfaceMapType, newInterfaceDecoder(d.typ), newInterfaceDecoder(d.typ)) + dec := newMapDecoder( + interfaceMapType, + newInterfaceDecoder(d.typ), + newInterfaceDecoder(d.typ), + ) cursor, err := dec.decode(buf, cursor, uintptr(ptr)) if err != nil { return 0, err @@ -115,7 +138,11 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p uintptr) (int64, e var v []interface{} ptr := unsafe.Pointer(&v) d.dummy = ptr // escape ptr - dec := newSliceDecoder(newInterfaceDecoder(d.typ), d.typ, d.typ.Size()) + dec := newSliceDecoder( + newInterfaceDecoder(d.typ), + d.typ, + d.typ.Size(), + ) cursor, err := dec.decode(buf, cursor, uintptr(ptr)) if err != nil { return 0, err diff --git a/decode_map.go b/decode_map.go index 6dd2695..a14cc6a 100644 --- a/decode_map.go +++ b/decode_map.go @@ -8,6 +8,7 @@ type mapDecoder struct { mapType *rtype keyDecoder decoder valueDecoder decoder + dummy *interfaceHeader } func newMapDecoder(mapType *rtype, keyDec decoder, valueDec decoder) *mapDecoder { @@ -37,11 +38,13 @@ func (d *mapDecoder) setValue(buf []byte, cursor int64, key interface{}) (int64, func (d *mapDecoder) setKeyStream(s *stream, key interface{}) error { header := (*interfaceHeader)(unsafe.Pointer(&key)) + d.dummy = header return d.keyDecoder.decodeStream(s, uintptr(header.ptr)) } func (d *mapDecoder) setValueStream(s *stream, key interface{}) error { header := (*interfaceHeader)(unsafe.Pointer(&key)) + d.dummy = header return d.valueDecoder.decodeStream(s, uintptr(header.ptr)) } diff --git a/decode_number.go b/decode_number.go new file mode 100644 index 0000000..0f1c856 --- /dev/null +++ b/decode_number.go @@ -0,0 +1,38 @@ +package json + +import ( + "unsafe" +) + +type numberDecoder struct { + *floatDecoder + op func(uintptr, Number) +} + +func newNumberDecoder(op func(uintptr, Number)) *numberDecoder { + return &numberDecoder{ + floatDecoder: newFloatDecoder(nil), + op: op, + } +} + +func (d *numberDecoder) decodeStream(s *stream, p uintptr) error { + bytes, err := d.floatDecoder.decodeStreamByte(s) + if err != nil { + return err + } + str := *(*string)(unsafe.Pointer(&bytes)) + d.op(p, Number(str)) + return nil +} + +func (d *numberDecoder) decode(buf []byte, cursor int64, p uintptr) (int64, error) { + bytes, c, err := d.floatDecoder.decodeByte(buf, cursor) + if err != nil { + return 0, err + } + cursor = c + s := *(*string)(unsafe.Pointer(&bytes)) + d.op(p, Number(s)) + return cursor, nil +} diff --git a/decode_stream.go b/decode_stream.go index ed5e889..8f771d3 100644 --- a/decode_stream.go +++ b/decode_stream.go @@ -10,12 +10,13 @@ const ( ) type stream struct { - buf []byte - length int64 - r io.Reader - offset int64 - cursor int64 - allRead bool + buf []byte + length int64 + r io.Reader + offset int64 + cursor int64 + allRead bool + useNumber bool } func (s *stream) buffered() io.Reader { diff --git a/decode_test.go b/decode_test.go index 28683b8..3ded818 100644 --- a/decode_test.go +++ b/decode_test.go @@ -217,6 +217,14 @@ func Test_Decoder(t *testing.T) { }) } +func Test_Decoder_UseNumber(t *testing.T) { + dec := json.NewDecoder(strings.NewReader(`{"a": 3.14}`)) + dec.UseNumber() + var v map[string]interface{} + assertErr(t, dec.Decode(&v)) + assertEq(t, "json.Number", "json.Number", fmt.Sprintf("%T", v["a"])) +} + type unmarshalJSON struct { v int } diff --git a/json.go b/json.go index a29a061..8675aa5 100644 --- a/json.go +++ b/json.go @@ -1,6 +1,9 @@ package json -import "bytes" +import ( + "bytes" + "strconv" +) // Marshaler is the interface implemented by types that // can marshal themselves into valid JSON. @@ -275,3 +278,19 @@ func UnmarshalNoEscape(data []byte, v interface{}) error { // nil, for JSON null // type Token interface{} + +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +}