Fix decoding of null value for interface type that does not implement Unmarshaler

This commit is contained in:
Masaaki Goshima 2021-05-02 03:36:58 +09:00
parent 117e9eff37
commit 275aade00d
3 changed files with 72 additions and 11 deletions

View File

@ -204,3 +204,19 @@ func skipValue(buf []byte, cursor, depth int64) (int64, error) {
} }
} }
} }
func validateNull(buf []byte, cursor int64) error {
if cursor+3 >= int64(len(buf)) {
return errUnexpectedEndOfJSON("null", cursor)
}
if buf[cursor+1] != 'u' {
return errInvalidCharacter(buf[cursor+1], "null", cursor)
}
if buf[cursor+2] != 'l' {
return errInvalidCharacter(buf[cursor+2], "null", cursor)
}
if buf[cursor+3] != 'l' {
return errInvalidCharacter(buf[cursor+3], "null", cursor)
}
return nil
}

View File

@ -256,6 +256,14 @@ func (d *interfaceDecoder) decodeStream(s *stream, depth int64, p unsafe.Pointer
if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok {
return decodeStreamTextUnmarshaler(s, depth, u, p) return decodeStreamTextUnmarshaler(s, depth, u, p)
} }
s.skipWhiteSpace()
if s.char() == 'n' {
if err := nullBytes(s); err != nil {
return err
}
*(*interface{})(p) = nil
return nil
}
return d.errUnmarshalType(rv.Type(), s.totalOffset()) return d.errUnmarshalType(rv.Type(), s.totalOffset())
} }
iface := rv.Interface() iface := rv.Interface()
@ -306,6 +314,15 @@ func (d *interfaceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Poin
if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok { if u, ok := rv.Interface().(encoding.TextUnmarshaler); ok {
return decodeTextUnmarshaler(buf, cursor, depth, u, p) return decodeTextUnmarshaler(buf, cursor, depth, u, p)
} }
cursor = skipWhiteSpace(buf, cursor)
if buf[cursor] == 'n' {
if err := validateNull(buf, cursor); err != nil {
return 0, err
}
cursor += 4
**(**interface{})(unsafe.Pointer(&p)) = nil
return cursor, nil
}
return 0, d.errUnmarshalType(rv.Type(), cursor) return 0, d.errUnmarshalType(rv.Type(), cursor)
} }
@ -321,17 +338,8 @@ func (d *interfaceDecoder) decode(buf []byte, cursor, depth int64, p unsafe.Poin
} }
cursor = skipWhiteSpace(buf, cursor) cursor = skipWhiteSpace(buf, cursor)
if buf[cursor] == 'n' { if buf[cursor] == 'n' {
if cursor+3 >= int64(len(buf)) { if err := validateNull(buf, cursor); err != nil {
return 0, errUnexpectedEndOfJSON("null", cursor) return 0, err
}
if buf[cursor+1] != 'u' {
return 0, errInvalidCharacter(buf[cursor+1], "null", cursor)
}
if buf[cursor+2] != 'l' {
return 0, errInvalidCharacter(buf[cursor+2], "null", cursor)
}
if buf[cursor+3] != 'l' {
return 0, errInvalidCharacter(buf[cursor+3], "null", cursor)
} }
cursor += 4 cursor += 4
**(**interface{})(unsafe.Pointer(&p)) = nil **(**interface{})(unsafe.Pointer(&p)) = nil

View File

@ -3351,3 +3351,40 @@ func TestInvalidNumber(t *testing.T) {
}) })
}) })
} }
type someInterface interface {
DoesNotMatter()
}
func TestDecodeUnknownInterface(t *testing.T) {
t.Run("unmarshal", func(t *testing.T) {
var v map[string]someInterface
if err := json.Unmarshal([]byte(`{"a":null,"b":null}`), &v); err != nil {
t.Fatal(err)
}
if len(v) != 2 {
t.Fatalf("failed to decode: %v", v)
}
if a, exists := v["a"]; a != nil || !exists {
t.Fatalf("failed to decode: %v", v)
}
if b, exists := v["b"]; b != nil || !exists {
t.Fatalf("failed to decode: %v", v)
}
})
t.Run("stream", func(t *testing.T) {
var v map[string]someInterface
if err := json.NewDecoder(strings.NewReader(`{"a":null,"b":null}`)).Decode(&v); err != nil {
t.Fatal(err)
}
if len(v) != 2 {
t.Fatalf("failed to decode: %v", v)
}
if a, exists := v["a"]; a != nil || !exists {
t.Fatalf("failed to decode: %v", v)
}
if b, exists := v["b"]; b != nil || !exists {
t.Fatalf("failed to decode: %v", v)
}
})
}