diff --git a/decode_context.go b/decode_context.go index 9f87b05..305f0d6 100644 --- a/decode_context.go +++ b/decode_context.go @@ -204,3 +204,19 @@ func skipValue(buf []byte, cursor, depth int64) (int64, error) { } } } + +func validateNull(buf []byte, cursor int64) error { + if cursor+3 >= int64(len(buf)) { + return errUnexpectedEndOfJSON("null", cursor) + } + if buf[cursor+1] != 'u' { + return errInvalidCharacter(buf[cursor+1], "null", cursor) + } + if buf[cursor+2] != 'l' { + return errInvalidCharacter(buf[cursor+2], "null", cursor) + } + if buf[cursor+3] != 'l' { + return errInvalidCharacter(buf[cursor+3], "null", cursor) + } + return nil +} diff --git a/decode_interface.go b/decode_interface.go index d2eafab..fee2cdb 100644 --- a/decode_interface.go +++ b/decode_interface.go @@ -256,6 +256,14 @@ func (d *interfaceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { return decodeStreamTextUnmarshaler(s, depth, u, p) } + s.skipWhiteSpace() + if s.char() == 'n' { + if err := nullBytes(s); err != nil { + return err + } + *(*interface{})(p) = nil + return nil + } return d.errUnmarshalType(rv.Type(), s.totalOffset()) } iface := rv.Interface() @@ -306,6 +314,15 @@ func (d *interfaceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Poin if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { return decodeTextUnmarshaler(buf, cursor, depth, u, p) } + cursor = skipWhiteSpace(buf, cursor) + if buf[cursor] == 'n' { + if err := validateNull(buf, cursor); err != nil { + return 0, err + } + cursor += 4 + **(**interface{})(unsafe.Pointer(&p)) = nil + return cursor, nil + } return 0, d.errUnmarshalType(rv.Type(), cursor) } @@ -321,17 +338,8 @@ func (d *interfaceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Poin } cursor = skipWhiteSpace(buf, cursor) if buf[cursor] == 'n' { - if cursor+3 >= int64(len(buf)) { - return 0, errUnexpectedEndOfJSON("null", cursor) - } - if buf[cursor+1] != 'u' { - return 0, errInvalidCharacter(buf[cursor+1], "null", cursor) - } - if buf[cursor+2] != 'l' { - return 0, errInvalidCharacter(buf[cursor+2], "null", cursor) - } - if buf[cursor+3] != 'l' { - return 0, errInvalidCharacter(buf[cursor+3], "null", cursor) + if err := validateNull(buf, cursor); err != nil { + return 0, err } cursor += 4 **(**interface{})(unsafe.Pointer(&p)) = nil diff --git a/decode_test.go b/decode_test.go index 4fbf1ec..c057add 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3351,3 +3351,40 @@ func TestInvalidNumber(t *testing.T) { }) }) } + +type someInterface interface { + DoesNotMatter() +} + +func TestDecodeUnknownInterface(t *testing.T) { + t.Run("unmarshal", func(t *testing.T) { + var v map[string]someInterface + if err := json.Unmarshal([]byte(`{"a":null,"b":null}`), &v); err != nil { + t.Fatal(err) + } + if len(v) != 2 { + t.Fatalf("failed to decode: %v", v) + } + if a, exists := v["a"]; a != nil || !exists { + t.Fatalf("failed to decode: %v", v) + } + if b, exists := v["b"]; b != nil || !exists { + t.Fatalf("failed to decode: %v", v) + } + }) + t.Run("stream", func(t *testing.T) { + var v map[string]someInterface + if err := json.NewDecoder(strings.NewReader(`{"a":null,"b":null}`)).Decode(&v); err != nil { + t.Fatal(err) + } + if len(v) != 2 { + t.Fatalf("failed to decode: %v", v) + } + if a, exists := v["a"]; a != nil || !exists { + t.Fatalf("failed to decode: %v", v) + } + if b, exists := v["b"]; b != nil || !exists { + t.Fatalf("failed to decode: %v", v) + } + }) +}