diff --git a/decode_array.go b/decode_array.go index 7e34d85..5a0eac7 100644 --- a/decode_array.go +++ b/decode_array.go @@ -20,6 +20,11 @@ func (d *arrayDecoder) decodeStream(s *stream, p uintptr) error { for { switch s.char() { case ' ', '\n', '\t', '\r': + case 'n': + if err := nullBytes(s); err != nil { + return err + } + return nil case '[': idx := 0 for { @@ -63,6 +68,22 @@ func (d *arrayDecoder) decode(buf []byte, cursor int64, p uintptr) (int64, error switch buf[cursor] { case ' ', '\n', '\t', '\r': continue + case 'n': + buflen := int64(len(buf)) + if cursor+3 >= buflen { + return 0, errUnexpectedEndOfJSON("null", cursor) + } + if buf[cursor+1] != 'u' { + return 0, errInvalidCharacter(buf[cursor+1], "null", cursor) + } + if buf[cursor+2] != 'l' { + return 0, errInvalidCharacter(buf[cursor+2], "null", cursor) + } + if buf[cursor+3] != 'l' { + return 0, errInvalidCharacter(buf[cursor+3], "null", cursor) + } + cursor += 4 + return cursor, nil case '[': idx := 0 for { diff --git a/decode_map.go b/decode_map.go index 2ae99fb..cd36587 100644 --- a/decode_map.go +++ b/decode_map.go @@ -47,7 +47,14 @@ func (d *mapDecoder) setValueStream(s *stream, key interface{}) error { func (d *mapDecoder) decodeStream(s *stream, p uintptr) error { s.skipWhiteSpace() - if s.char() != '{' { + switch s.char() { + case 'n': + if err := nullBytes(s); err != nil { + return err + } + return nil + case '{': + default: return errExpected("{ character for map value", s.totalOffset()) } mapValue := makemap(d.mapType, 0) @@ -94,7 +101,24 @@ func (d *mapDecoder) decode(buf []byte, cursor int64, p uintptr) (int64, error) if buflen < 2 { return 0, errExpected("{} for map", cursor) } - if buf[cursor] != '{' { + switch buf[cursor] { + case 'n': + if cursor+3 >= buflen { + return 0, errUnexpectedEndOfJSON("null", cursor) + } + if buf[cursor+1] != 'u' { + return 0, errInvalidCharacter(buf[cursor+1], "null", cursor) + } + if buf[cursor+2] != 'l' { + return 0, errInvalidCharacter(buf[cursor+2], "null", cursor) + } + if buf[cursor+3] != 'l' { + return 0, errInvalidCharacter(buf[cursor+3], "null", cursor) + } + cursor += 4 + return cursor, nil + case '{': + default: return 0, errExpected("{ character for map value", cursor) } cursor++ diff --git a/decode_slice.go b/decode_slice.go index 34ffad8..974a6df 100644 --- a/decode_slice.go +++ b/decode_slice.go @@ -62,6 +62,11 @@ func (d *sliceDecoder) decodeStream(s *stream, p uintptr) error { case ' ', '\n', '\t', '\r': s.cursor++ continue + case 'n': + if err := nullBytes(s); err != nil { + return err + } + return nil case '[': idx := 0 slice := d.newSlice() @@ -136,6 +141,22 @@ func (d *sliceDecoder) decode(buf []byte, cursor int64, p uintptr) (int64, error switch buf[cursor] { case ' ', '\n', '\t', '\r': continue + case 'n': + buflen := int64(len(buf)) + if cursor+3 >= buflen { + return 0, errUnexpectedEndOfJSON("null", cursor) + } + if buf[cursor+1] != 'u' { + return 0, errInvalidCharacter(buf[cursor+1], "null", cursor) + } + if buf[cursor+2] != 'l' { + return 0, errInvalidCharacter(buf[cursor+2], "null", cursor) + } + if buf[cursor+3] != 'l' { + return 0, errInvalidCharacter(buf[cursor+3], "null", cursor) + } + cursor += 4 + return cursor, nil case '[': idx := 0 slice := d.newSlice() diff --git a/decode_test.go b/decode_test.go index 562f00c..61d004a 100644 --- a/decode_test.go +++ b/decode_test.go @@ -135,12 +135,26 @@ func Test_Decoder(t *testing.T) { assertEq(t, "struct.D.AA", 2, v.D.AA) assertEq(t, "struct.D.BB", "world", v.D.BB) assertEq(t, "struct.D.CC", true, v.D.CC) - t.Run("struct.null", func(t *testing.T) { + t.Run("struct.field null", func(t *testing.T) { var v struct { A string + B []string + C []int + D map[string]interface{} + E [2]string + F interface{} } - assertErr(t, json.Unmarshal([]byte(`{"a":null}`), &v)) - assertEq(t, "string is null", v.A, "") + assertErr(t, json.Unmarshal([]byte(`{"a":null,"b":null,"c":null,"d":null,"e":null,"f":null}`), &v)) + assertEq(t, "string", v.A, "") + assertNeq(t, "[]string", v.B, nil) + assertEq(t, "[]string", len(v.B), 0) + assertNeq(t, "[]int", v.C, nil) + assertEq(t, "[]int", len(v.C), 0) + assertNeq(t, "map", v.D, nil) + assertEq(t, "map", len(v.D), 0) + assertNeq(t, "array", v.E, nil) + assertEq(t, "array", len(v.E), 2) + assertEq(t, "interface{}", v.F, nil) }) }) t.Run("interface", func(t *testing.T) { diff --git a/helper_test.go b/helper_test.go index 26c0f12..53157ce 100644 --- a/helper_test.go +++ b/helper_test.go @@ -15,3 +15,10 @@ func assertEq(t *testing.T, msg string, exp interface{}, act interface{}) { t.Fatalf("failed to test for %s. exp=[%v] but act=[%v]", msg, exp, act) } } + +func assertNeq(t *testing.T, msg string, exp interface{}, act interface{}) { + t.Helper() + if exp == act { + t.Fatalf("failed to test for %s. expected value is not [%v] but got same value", msg, act) + } +}