From 5351464001f247eb964bae3c5875d70978ee49ea Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Wed, 17 Feb 2021 01:51:42 +0900 Subject: [PATCH] Fix decoding of null value --- decode_bool.go | 3 +- decode_float.go | 26 +++++++++++++++++ decode_int.go | 27 +++++++++++++++++ decode_interface.go | 62 ++++++++++++++++++++++++++++++++++++++-- decode_map.go | 2 ++ decode_slice.go | 2 ++ decode_string.go | 6 ++++ decode_test.go | 2 -- decode_uint.go | 26 +++++++++++++++++ decode_unmarshal_text.go | 15 ++++++++++ 10 files changed, 164 insertions(+), 7 deletions(-) diff --git a/decode_bool.go b/decode_bool.go index eefc33c..ccb96b0 100644 --- a/decode_bool.go +++ b/decode_bool.go @@ -81,7 +81,7 @@ func (d *boolDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err := nullBytes(s); err != nil { return err } - **(**bool)(unsafe.Pointer(&p)) = false + return nil case nul: if s.read() { continue @@ -147,7 +147,6 @@ func (d *boolDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, return 0, errInvalidCharacter(buf[cursor+3], "null", cursor) } cursor += 4 - **(**bool)(unsafe.Pointer(&p)) = false return cursor, nil } return 0, errUnexpectedEndOfJSON("bool", cursor) diff --git a/decode_float.go b/decode_float.go index 47a0cdf..f818b1f 100644 --- a/decode_float.go +++ b/decode_float.go @@ -72,6 +72,11 @@ func (d *floatDecoder) decodeStreamByte(s *stream) ([]byte, error) { continue case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': return floatBytes(s), nil + case 'n': + if err := nullBytes(s); err != nil { + return nil, err + } + return nil, nil case nul: if s.read() { continue @@ -102,6 +107,21 @@ func (d *floatDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, erro } num := buf[start:cursor] return num, cursor, nil + case 'n': + if cursor+3 >= buflen { + return nil, 0, errUnexpectedEndOfJSON("null", cursor) + } + if buf[cursor+1] != 'u' { + return nil, 0, errInvalidCharacter(buf[cursor+1], "null", cursor) + } + if buf[cursor+2] != 'l' { + return nil, 0, errInvalidCharacter(buf[cursor+2], "null", cursor) + } + if buf[cursor+3] != 'l' { + return nil, 0, errInvalidCharacter(buf[cursor+3], "null", cursor) + } + cursor += 4 + return nil, cursor, nil default: return nil, 0, errUnexpectedEndOfJSON("float", cursor) } @@ -114,6 +134,9 @@ func (d *floatDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } + if bytes == nil { + return nil + } str := *(*string)(unsafe.Pointer(&bytes)) f64, err := strconv.ParseFloat(str, 64) if err != nil { @@ -128,6 +151,9 @@ func (d *floatDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64 if err != nil { return 0, err } + if bytes == nil { + return c, nil + } cursor = c if !validEndNumberChar[buf[cursor]] { return 0, errUnexpectedEndOfJSON("float", cursor) diff --git a/decode_int.go b/decode_int.go index b091812..394e691 100644 --- a/decode_int.go +++ b/decode_int.go @@ -116,6 +116,11 @@ func (d *intDecoder) decodeStreamByte(s *stream) ([]byte, error) { } num := s.buf[start:s.cursor] return num, nil + case 'n': + if err := nullBytes(s); err != nil { + return nil, err + } + return nil, nil case nul: if s.read() { continue @@ -146,6 +151,22 @@ func (d *intDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error) } num := buf[start:cursor] return num, cursor, nil + case 'n': + buflen := int64(len(buf)) + if cursor+3 >= buflen { + return nil, 0, errUnexpectedEndOfJSON("null", cursor) + } + if buf[cursor+1] != 'u' { + return nil, 0, errInvalidCharacter(buf[cursor+1], "null", cursor) + } + if buf[cursor+2] != 'l' { + return nil, 0, errInvalidCharacter(buf[cursor+2], "null", cursor) + } + if buf[cursor+3] != 'l' { + return nil, 0, errInvalidCharacter(buf[cursor+3], "null", cursor) + } + cursor += 4 + return nil, cursor, nil default: return nil, 0, d.typeError([]byte{char(b, cursor)}, cursor) } @@ -157,6 +178,9 @@ func (d *intDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } + if bytes == nil { + return nil + } i64 := d.parseInt(bytes) switch d.kind { case reflect.Int8: @@ -182,6 +206,9 @@ func (d *intDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, if err != nil { return 0, err } + if bytes == nil { + return c, nil + } cursor = c i64 := d.parseInt(bytes) switch d.kind { diff --git a/decode_interface.go b/decode_interface.go index f213862..d08f133 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -1,6 +1,7 @@ package json import ( + "bytes" "encoding" "reflect" "unsafe" @@ -56,12 +57,34 @@ func decodeStreamUnmarshaler(s *stream, unmarshaler Unmarshaler) error { return nil } -func decodeStreamTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler) error { +func decodeUnmarshaler(buf []byte, cursor int64, unmarshaler Unmarshaler) (int64, error) { + cursor = skipWhiteSpace(buf, cursor) + start := cursor + end, err := skipValue(buf, cursor) + if err != nil { + return 0, err + } + src := buf[start:end] + dst := make([]byte, len(src)) + copy(dst, src) + + if err := unmarshaler.UnmarshalJSON(dst); err != nil { + return 0, err + } + return end, nil +} + +func decodeStreamTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error { start := s.cursor if err := s.skipValue(); err != nil { return err } src := s.buf[start:s.cursor] + if bytes.Equal(src, nullbytes) { + *(*unsafe.Pointer)(p) = nil + return nil + } + dst := make([]byte, len(src)) copy(dst, src) @@ -71,6 +94,27 @@ func decodeStreamTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler return nil } +func decodeTextUnmarshaler(buf []byte, cursor int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) (int64, error) { + cursor = skipWhiteSpace(buf, cursor) + start := cursor + end, err := skipValue(buf, cursor) + if err != nil { + return 0, err + } + src := buf[start:end] + if bytes.Equal(src, nullbytes) { + *(*unsafe.Pointer)(p) = nil + return end, nil + } + if s, ok := unquoteBytes(src); ok { + src = s + } + if err := unmarshaler.UnmarshalText(src); err != nil { + return 0, err + } + return end, nil +} + func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointer) error { s.skipWhiteSpace() for { @@ -168,9 +212,9 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { return decodeStreamUnmarshaler(s, u) } if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { - return decodeStreamTextUnmarshaler(s, u) + return decodeStreamTextUnmarshaler(s, u, p) } - return nil + return &UnsupportedTypeError{Type: rv.Type()} } iface := rv.Interface() ifaceHeader := (*interfaceHeader)(unsafe.Pointer(&iface)) @@ -182,6 +226,7 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if typ.Kind() == reflect.Ptr && typ.Elem() == d.typ || typ.Kind() != reflect.Ptr { return d.decodeStreamEmptyInterface(s, p) } + s.skipWhiteSpace() if s.char() == 'n' { if err := nullBytes(s); err != nil { return err @@ -202,6 +247,16 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i ptr: p, })) rv := reflect.ValueOf(runtimeInterfaceValue) + if rv.NumMethod() > 0 && rv.CanInterface() { + if u, ok := rv.Interface().(Unmarshaler); ok { + return decodeUnmarshaler(buf, cursor, u) + } + if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { + return decodeTextUnmarshaler(buf, cursor, u, p) + } + return 0, &UnsupportedTypeError{Type: rv.Type()} + } + iface := rv.Interface() ifaceHeader := (*interfaceHeader)(unsafe.Pointer(&iface)) typ := ifaceHeader.typ @@ -212,6 +267,7 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i if typ.Kind() == reflect.Ptr && typ.Elem() == d.typ || typ.Kind() != reflect.Ptr { return d.decodeEmptyInterface(buf, cursor, p) } + cursor = skipWhiteSpace(buf, cursor) if buf[cursor] == 'n' { if cursor+3 >= int64(len(buf)) { return 0, errUnexpectedEndOfJSON("null", cursor) diff --git a/decode_map.go b/decode_map.go index 0e87b1d..716d1f1 100644 --- a/decode_map.go +++ b/decode_map.go @@ -40,6 +40,7 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err := nullBytes(s); err != nil { return err } + **(**unsafe.Pointer)(unsafe.Pointer(&p)) = nil return nil case '{': default: @@ -107,6 +108,7 @@ func (d *mapDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, return 0, errInvalidCharacter(buf[cursor+3], "null", cursor) } cursor += 4 + **(**unsafe.Pointer)(unsafe.Pointer(&p)) = nil return cursor, nil case '{': default: diff --git a/decode_slice.go b/decode_slice.go index 39dae6f..3717d1c 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -83,6 +83,7 @@ func (d *sliceDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err := nullBytes(s); err != nil { return err } + *(*unsafe.Pointer)(p) = nil return nil case '[': s.cursor++ @@ -187,6 +188,7 @@ func (d *sliceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64 return 0, errInvalidCharacter(buf[cursor+3], "null", cursor) } cursor += 4 + *(*unsafe.Pointer)(p) = nil return cursor, nil case '[': cursor++ diff --git a/decode_string.go b/decode_string.go index 1d30b5d..c09d460 100644 --- a/decode_string.go +++ b/decode_string.go @@ -35,6 +35,9 @@ func (d *stringDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } + if bytes == nil { + return nil + } **(**string)(unsafe.Pointer(&p)) = *(*string)(unsafe.Pointer(&bytes)) s.reset() return nil @@ -45,6 +48,9 @@ func (d *stringDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int6 if err != nil { return 0, err } + if bytes == nil { + return c, nil + } cursor = c **(**string)(unsafe.Pointer(&p)) = *(*string)(unsafe.Pointer(&bytes)) return cursor, nil diff --git a/decode_test.go b/decode_test.go index 9aeaa03..da86ee2 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2114,7 +2114,6 @@ func TestInterfaceSet(t *testing.T) { } } -/* // JSON null values should be ignored for primitives and string values instead of resulting in an error. // Issue 2540 func TestUnmarshalNulls(t *testing.T) { @@ -2239,7 +2238,6 @@ func TestUnmarshalNulls(t *testing.T) { t.Errorf("Unmarshal of big.Int null set int to %v", nulls.BigInt.String()) } } -*/ func TestStringKind(t *testing.T) { type stringKind string diff --git a/decode_uint.go b/decode_uint.go index cc67922..b4c9f1c 100644 --- a/decode_uint.go +++ b/decode_uint.go @@ -70,6 +70,11 @@ func (d *uintDecoder) decodeStreamByte(s *stream) ([]byte, error) { } num := s.buf[start:s.cursor] return num, nil + case 'n': + if err := nullBytes(s); err != nil { + return nil, err + } + return nil, nil case nul: if s.read() { continue @@ -100,6 +105,21 @@ func (d *uintDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error } num := buf[start:cursor] return num, cursor, nil + case 'n': + if cursor+3 >= buflen { + return nil, 0, errUnexpectedEndOfJSON("null", cursor) + } + if buf[cursor+1] != 'u' { + return nil, 0, errInvalidCharacter(buf[cursor+1], "null", cursor) + } + if buf[cursor+2] != 'l' { + return nil, 0, errInvalidCharacter(buf[cursor+2], "null", cursor) + } + if buf[cursor+3] != 'l' { + return nil, 0, errInvalidCharacter(buf[cursor+3], "null", cursor) + } + cursor += 4 + return nil, cursor, nil default: return nil, 0, d.typeError([]byte{buf[cursor]}, cursor) } @@ -112,6 +132,9 @@ func (d *uintDecoder) decodeStream(s *stream, p unsafe.Pointer) error { if err != nil { return err } + if bytes == nil { + return nil + } u64 := d.parseUint(bytes) switch d.kind { case reflect.Uint8: @@ -136,6 +159,9 @@ func (d *uintDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64, if err != nil { return 0, err } + if bytes == nil { + return c, nil + } cursor = c u64 := d.parseUint(bytes) switch d.kind { diff --git a/decode_unmarshal_text.go b/decode_unmarshal_text.go index 94a0425..1cde243 100644 --- a/decode_unmarshal_text.go +++ b/decode_unmarshal_text.go @@ -1,6 +1,7 @@ package json import ( + "bytes" "encoding" "unicode" "unicode/utf16" @@ -32,6 +33,10 @@ func (d *unmarshalTextDecoder) annotateError(cursor int64, err error) { } } +var ( + nullbytes = []byte(`null`) +) + func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { s.skipWhiteSpace() start := s.cursor @@ -54,6 +59,11 @@ func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { Type: rtype2type(d.typ), Offset: s.totalOffset(), } + case 'n': + if bytes.Equal(src, nullbytes) { + *(*unsafe.Pointer)(p) = nil + return nil + } } dst := make([]byte, len(src)) copy(dst, src) @@ -80,6 +90,11 @@ func (d *unmarshalTextDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer return 0, err } src := buf[start:end] + if bytes.Equal(src, nullbytes) { + *(*unsafe.Pointer)(p) = nil + return end, nil + } + if s, ok := unquoteBytes(src); ok { src = s }