Fix decoding of null value

This commit is contained in:
Masaaki Goshima 2021-02-17 01:51:42 +09:00
parent d8c3c8d209
commit 5351464001
10 changed files with 164 additions and 7 deletions

View File

@ -81,7 +81,7 @@ func (d *boolDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
if err := nullBytes(s); err != nil { if err := nullBytes(s); err != nil {
return err return err
} }
**(**bool)(unsafe.Pointer(&p)) = false return nil
case nul: case nul:
if s.read() { if s.read() {
continue continue
@ -147,7 +147,6 @@ func (d *boolDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64,
return 0, errInvalidCharacter(buf[cursor+3], "null", cursor) return 0, errInvalidCharacter(buf[cursor+3], "null", cursor)
} }
cursor += 4 cursor += 4
**(**bool)(unsafe.Pointer(&p)) = false
return cursor, nil return cursor, nil
} }
return 0, errUnexpectedEndOfJSON("bool", cursor) return 0, errUnexpectedEndOfJSON("bool", cursor)

View File

@ -72,6 +72,11 @@ func (d *floatDecoder) decodeStreamByte(s *stream) ([]byte, error) {
continue continue
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
return floatBytes(s), nil return floatBytes(s), nil
case 'n':
if err := nullBytes(s); err != nil {
return nil, err
}
return nil, nil
case nul: case nul:
if s.read() { if s.read() {
continue continue
@ -102,6 +107,21 @@ func (d *floatDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, erro
} }
num := buf[start:cursor] num := buf[start:cursor]
return num, cursor, nil return num, cursor, nil
case 'n':
if cursor+3 >= buflen {
return nil, 0, errUnexpectedEndOfJSON("null", cursor)
}
if buf[cursor+1] != 'u' {
return nil, 0, errInvalidCharacter(buf[cursor+1], "null", cursor)
}
if buf[cursor+2] != 'l' {
return nil, 0, errInvalidCharacter(buf[cursor+2], "null", cursor)
}
if buf[cursor+3] != 'l' {
return nil, 0, errInvalidCharacter(buf[cursor+3], "null", cursor)
}
cursor += 4
return nil, cursor, nil
default: default:
return nil, 0, errUnexpectedEndOfJSON("float", cursor) return nil, 0, errUnexpectedEndOfJSON("float", cursor)
} }
@ -114,6 +134,9 @@ func (d *floatDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
if err != nil { if err != nil {
return err return err
} }
if bytes == nil {
return nil
}
str := *(*string)(unsafe.Pointer(&bytes)) str := *(*string)(unsafe.Pointer(&bytes))
f64, err := strconv.ParseFloat(str, 64) f64, err := strconv.ParseFloat(str, 64)
if err != nil { if err != nil {
@ -128,6 +151,9 @@ func (d *floatDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64
if err != nil { if err != nil {
return 0, err return 0, err
} }
if bytes == nil {
return c, nil
}
cursor = c cursor = c
if !validEndNumberChar[buf[cursor]] { if !validEndNumberChar[buf[cursor]] {
return 0, errUnexpectedEndOfJSON("float", cursor) return 0, errUnexpectedEndOfJSON("float", cursor)

View File

@ -116,6 +116,11 @@ func (d *intDecoder) decodeStreamByte(s *stream) ([]byte, error) {
} }
num := s.buf[start:s.cursor] num := s.buf[start:s.cursor]
return num, nil return num, nil
case 'n':
if err := nullBytes(s); err != nil {
return nil, err
}
return nil, nil
case nul: case nul:
if s.read() { if s.read() {
continue continue
@ -146,6 +151,22 @@ func (d *intDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error)
} }
num := buf[start:cursor] num := buf[start:cursor]
return num, cursor, nil return num, cursor, nil
case 'n':
buflen := int64(len(buf))
if cursor+3 >= buflen {
return nil, 0, errUnexpectedEndOfJSON("null", cursor)
}
if buf[cursor+1] != 'u' {
return nil, 0, errInvalidCharacter(buf[cursor+1], "null", cursor)
}
if buf[cursor+2] != 'l' {
return nil, 0, errInvalidCharacter(buf[cursor+2], "null", cursor)
}
if buf[cursor+3] != 'l' {
return nil, 0, errInvalidCharacter(buf[cursor+3], "null", cursor)
}
cursor += 4
return nil, cursor, nil
default: default:
return nil, 0, d.typeError([]byte{char(b, cursor)}, cursor) return nil, 0, d.typeError([]byte{char(b, cursor)}, cursor)
} }
@ -157,6 +178,9 @@ func (d *intDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
if err != nil { if err != nil {
return err return err
} }
if bytes == nil {
return nil
}
i64 := d.parseInt(bytes) i64 := d.parseInt(bytes)
switch d.kind { switch d.kind {
case reflect.Int8: case reflect.Int8:
@ -182,6 +206,9 @@ func (d *intDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64,
if err != nil { if err != nil {
return 0, err return 0, err
} }
if bytes == nil {
return c, nil
}
cursor = c cursor = c
i64 := d.parseInt(bytes) i64 := d.parseInt(bytes)
switch d.kind { switch d.kind {

View File

@ -1,6 +1,7 @@
package json package json
import ( import (
"bytes"
"encoding" "encoding"
"reflect" "reflect"
"unsafe" "unsafe"
@ -56,12 +57,34 @@ func decodeStreamUnmarshaler(s *stream, unmarshaler Unmarshaler) error {
return nil return nil
} }
func decodeStreamTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler) error { func decodeUnmarshaler(buf []byte, cursor int64, unmarshaler Unmarshaler) (int64, error) {
cursor = skipWhiteSpace(buf, cursor)
start := cursor
end, err := skipValue(buf, cursor)
if err != nil {
return 0, err
}
src := buf[start:end]
dst := make([]byte, len(src))
copy(dst, src)
if err := unmarshaler.UnmarshalJSON(dst); err != nil {
return 0, err
}
return end, nil
}
func decodeStreamTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error {
start := s.cursor start := s.cursor
if err := s.skipValue(); err != nil { if err := s.skipValue(); err != nil {
return err return err
} }
src := s.buf[start:s.cursor] src := s.buf[start:s.cursor]
if bytes.Equal(src, nullbytes) {
*(*unsafe.Pointer)(p) = nil
return nil
}
dst := make([]byte, len(src)) dst := make([]byte, len(src))
copy(dst, src) copy(dst, src)
@ -71,6 +94,27 @@ func decodeStreamTextUnmarshaler(s *stream, unmarshaler encoding.TextUnmarshaler
return nil return nil
} }
func decodeTextUnmarshaler(buf []byte, cursor int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) (int64, error) {
cursor = skipWhiteSpace(buf, cursor)
start := cursor
end, err := skipValue(buf, cursor)
if err != nil {
return 0, err
}
src := buf[start:end]
if bytes.Equal(src, nullbytes) {
*(*unsafe.Pointer)(p) = nil
return end, nil
}
if s, ok := unquoteBytes(src); ok {
src = s
}
if err := unmarshaler.UnmarshalText(src); err != nil {
return 0, err
}
return end, nil
}
func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointer) error { func (d *interfaceDecoder) decodeStreamEmptyInterface(s *stream, p unsafe.Pointer) error {
s.skipWhiteSpace() s.skipWhiteSpace()
for { for {
@ -168,9 +212,9 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
return decodeStreamUnmarshaler(s, u) return decodeStreamUnmarshaler(s, u)
} }
if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok {
return decodeStreamTextUnmarshaler(s, u) return decodeStreamTextUnmarshaler(s, u, p)
} }
return nil return &UnsupportedTypeError{Type: rv.Type()}
} }
iface := rv.Interface() iface := rv.Interface()
ifaceHeader := (*interfaceHeader)(unsafe.Pointer(&iface)) ifaceHeader := (*interfaceHeader)(unsafe.Pointer(&iface))
@ -182,6 +226,7 @@ func (d *interfaceDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
if typ.Kind() == reflect.Ptr && typ.Elem() == d.typ || typ.Kind() != reflect.Ptr { if typ.Kind() == reflect.Ptr && typ.Elem() == d.typ || typ.Kind() != reflect.Ptr {
return d.decodeStreamEmptyInterface(s, p) return d.decodeStreamEmptyInterface(s, p)
} }
s.skipWhiteSpace()
if s.char() == 'n' { if s.char() == 'n' {
if err := nullBytes(s); err != nil { if err := nullBytes(s); err != nil {
return err return err
@ -202,6 +247,16 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i
ptr: p, ptr: p,
})) }))
rv := reflect.ValueOf(runtimeInterfaceValue) rv := reflect.ValueOf(runtimeInterfaceValue)
if rv.NumMethod() > 0 && rv.CanInterface() {
if u, ok := rv.Interface().(Unmarshaler); ok {
return decodeUnmarshaler(buf, cursor, u)
}
if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok {
return decodeTextUnmarshaler(buf, cursor, u, p)
}
return 0, &UnsupportedTypeError{Type: rv.Type()}
}
iface := rv.Interface() iface := rv.Interface()
ifaceHeader := (*interfaceHeader)(unsafe.Pointer(&iface)) ifaceHeader := (*interfaceHeader)(unsafe.Pointer(&iface))
typ := ifaceHeader.typ typ := ifaceHeader.typ
@ -212,6 +267,7 @@ func (d *interfaceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (i
if typ.Kind() == reflect.Ptr && typ.Elem() == d.typ || typ.Kind() != reflect.Ptr { if typ.Kind() == reflect.Ptr && typ.Elem() == d.typ || typ.Kind() != reflect.Ptr {
return d.decodeEmptyInterface(buf, cursor, p) return d.decodeEmptyInterface(buf, cursor, p)
} }
cursor = skipWhiteSpace(buf, cursor)
if buf[cursor] == 'n' { if buf[cursor] == 'n' {
if cursor+3 >= int64(len(buf)) { if cursor+3 >= int64(len(buf)) {
return 0, errUnexpectedEndOfJSON("null", cursor) return 0, errUnexpectedEndOfJSON("null", cursor)

View File

@ -40,6 +40,7 @@ func (d *mapDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
if err := nullBytes(s); err != nil { if err := nullBytes(s); err != nil {
return err return err
} }
**(**unsafe.Pointer)(unsafe.Pointer(&p)) = nil
return nil return nil
case '{': case '{':
default: default:
@ -107,6 +108,7 @@ func (d *mapDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64,
return 0, errInvalidCharacter(buf[cursor+3], "null", cursor) return 0, errInvalidCharacter(buf[cursor+3], "null", cursor)
} }
cursor += 4 cursor += 4
**(**unsafe.Pointer)(unsafe.Pointer(&p)) = nil
return cursor, nil return cursor, nil
case '{': case '{':
default: default:

View File

@ -83,6 +83,7 @@ func (d *sliceDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
if err := nullBytes(s); err != nil { if err := nullBytes(s); err != nil {
return err return err
} }
*(*unsafe.Pointer)(p) = nil
return nil return nil
case '[': case '[':
s.cursor++ s.cursor++
@ -187,6 +188,7 @@ func (d *sliceDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64
return 0, errInvalidCharacter(buf[cursor+3], "null", cursor) return 0, errInvalidCharacter(buf[cursor+3], "null", cursor)
} }
cursor += 4 cursor += 4
*(*unsafe.Pointer)(p) = nil
return cursor, nil return cursor, nil
case '[': case '[':
cursor++ cursor++

View File

@ -35,6 +35,9 @@ func (d *stringDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
if err != nil { if err != nil {
return err return err
} }
if bytes == nil {
return nil
}
**(**string)(unsafe.Pointer(&p)) = *(*string)(unsafe.Pointer(&bytes)) **(**string)(unsafe.Pointer(&p)) = *(*string)(unsafe.Pointer(&bytes))
s.reset() s.reset()
return nil return nil
@ -45,6 +48,9 @@ func (d *stringDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int6
if err != nil { if err != nil {
return 0, err return 0, err
} }
if bytes == nil {
return c, nil
}
cursor = c cursor = c
**(**string)(unsafe.Pointer(&p)) = *(*string)(unsafe.Pointer(&bytes)) **(**string)(unsafe.Pointer(&p)) = *(*string)(unsafe.Pointer(&bytes))
return cursor, nil return cursor, nil

View File

@ -2114,7 +2114,6 @@ func TestInterfaceSet(t *testing.T) {
} }
} }
/*
// JSON null values should be ignored for primitives and string values instead of resulting in an error. // JSON null values should be ignored for primitives and string values instead of resulting in an error.
// Issue 2540 // Issue 2540
func TestUnmarshalNulls(t *testing.T) { func TestUnmarshalNulls(t *testing.T) {
@ -2239,7 +2238,6 @@ func TestUnmarshalNulls(t *testing.T) {
t.Errorf("Unmarshal of big.Int null set int to %v", nulls.BigInt.String()) t.Errorf("Unmarshal of big.Int null set int to %v", nulls.BigInt.String())
} }
} }
*/
func TestStringKind(t *testing.T) { func TestStringKind(t *testing.T) {
type stringKind string type stringKind string

View File

@ -70,6 +70,11 @@ func (d *uintDecoder) decodeStreamByte(s *stream) ([]byte, error) {
} }
num := s.buf[start:s.cursor] num := s.buf[start:s.cursor]
return num, nil return num, nil
case 'n':
if err := nullBytes(s); err != nil {
return nil, err
}
return nil, nil
case nul: case nul:
if s.read() { if s.read() {
continue continue
@ -100,6 +105,21 @@ func (d *uintDecoder) decodeByte(buf []byte, cursor int64) ([]byte, int64, error
} }
num := buf[start:cursor] num := buf[start:cursor]
return num, cursor, nil return num, cursor, nil
case 'n':
if cursor+3 >= buflen {
return nil, 0, errUnexpectedEndOfJSON("null", cursor)
}
if buf[cursor+1] != 'u' {
return nil, 0, errInvalidCharacter(buf[cursor+1], "null", cursor)
}
if buf[cursor+2] != 'l' {
return nil, 0, errInvalidCharacter(buf[cursor+2], "null", cursor)
}
if buf[cursor+3] != 'l' {
return nil, 0, errInvalidCharacter(buf[cursor+3], "null", cursor)
}
cursor += 4
return nil, cursor, nil
default: default:
return nil, 0, d.typeError([]byte{buf[cursor]}, cursor) return nil, 0, d.typeError([]byte{buf[cursor]}, cursor)
} }
@ -112,6 +132,9 @@ func (d *uintDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
if err != nil { if err != nil {
return err return err
} }
if bytes == nil {
return nil
}
u64 := d.parseUint(bytes) u64 := d.parseUint(bytes)
switch d.kind { switch d.kind {
case reflect.Uint8: case reflect.Uint8:
@ -136,6 +159,9 @@ func (d *uintDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer) (int64,
if err != nil { if err != nil {
return 0, err return 0, err
} }
if bytes == nil {
return c, nil
}
cursor = c cursor = c
u64 := d.parseUint(bytes) u64 := d.parseUint(bytes)
switch d.kind { switch d.kind {

View File

@ -1,6 +1,7 @@
package json package json
import ( import (
"bytes"
"encoding" "encoding"
"unicode" "unicode"
"unicode/utf16" "unicode/utf16"
@ -32,6 +33,10 @@ func (d *unmarshalTextDecoder) annotateError(cursor int64, err error) {
} }
} }
var (
nullbytes = []byte(`null`)
)
func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error { func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
s.skipWhiteSpace() s.skipWhiteSpace()
start := s.cursor start := s.cursor
@ -54,6 +59,11 @@ func (d *unmarshalTextDecoder) decodeStream(s *stream, p unsafe.Pointer) error {
Type: rtype2type(d.typ), Type: rtype2type(d.typ),
Offset: s.totalOffset(), Offset: s.totalOffset(),
} }
case 'n':
if bytes.Equal(src, nullbytes) {
*(*unsafe.Pointer)(p) = nil
return nil
}
} }
dst := make([]byte, len(src)) dst := make([]byte, len(src))
copy(dst, src) copy(dst, src)
@ -80,6 +90,11 @@ func (d *unmarshalTextDecoder) decode(buf []byte, cursor int64, p unsafe.Pointer
return 0, err return 0, err
} }
src := buf[start:end] src := buf[start:end]
if bytes.Equal(src, nullbytes) {
*(*unsafe.Pointer)(p) = nil
return end, nil
}
if s, ok := unquoteBytes(src); ok { if s, ok := unquoteBytes(src); ok {
src = s src = s
} }