From f341b31ea593c5b69c6694299d4bf8a096d5e19a Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Sun, 2 May 2021 16:02:14 +0900 Subject: [PATCH] Fix decoding of backslash char at the end of string --- decode_context.go | 38 ++++++++++----------- decode_stream.go | 42 ++++++++++++++++++----- decode_test.go | 85 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 30 deletions(-) diff --git a/decode_context.go b/decode_context.go index 305f0d6..eac8163 100644 --- a/decode_context.go +++ b/decode_context.go @@ -55,10 +55,12 @@ func skipObject(buf []byte, cursor, depth int64) (int64, error) { for { cursor++ switch buf[cursor] { - case '"': - if buf[cursor-1] == '\\' { - continue + case '\\': + cursor++ + if buf[cursor] == nul { + return 0, errUnexpectedEndOfJSON("string of object", cursor) } + case '"': goto SWITCH_OUT case nul: return 0, errUnexpectedEndOfJSON("string of object", cursor) @@ -99,10 +101,12 @@ func skipArray(buf []byte, cursor, depth int64) (int64, error) { for { cursor++ switch buf[cursor] { - case '"': - if buf[cursor-1] == '\\' { - continue + case '\\': + cursor++ + if buf[cursor] == nul { + return 0, errUnexpectedEndOfJSON("string of object", cursor) } + case '"': goto SWITCH_OUT case nul: return 0, errUnexpectedEndOfJSON("string of object", cursor) @@ -130,10 +134,12 @@ func skipValue(buf []byte, cursor, depth int64) (int64, error) { for { cursor++ switch buf[cursor] { - case '"': - if buf[cursor-1] == '\\' { - continue + case '\\': + cursor++ + if buf[cursor] == nul { + return 0, errUnexpectedEndOfJSON("string of object", cursor) } + case '"': return cursor + 1, nil case nul: return 0, errUnexpectedEndOfJSON("string of object", cursor) @@ -184,18 +190,8 @@ func skipValue(buf []byte, cursor, depth int64) (int64, error) { cursor += 5 return cursor, nil case 'n': - buflen := int64(len(buf)) - if cursor+3 >= buflen { - return 0, errUnexpectedEndOfJSON("null", cursor) - } - if buf[cursor+1] != 'u' { - return 0, errUnexpectedEndOfJSON("null", cursor) - } - if buf[cursor+2] != 'l' { - return 0, errUnexpectedEndOfJSON("null", cursor) - } - if buf[cursor+3] != 'l' { - return 0, errUnexpectedEndOfJSON("null", cursor) + if err := validateNull(buf, cursor); err != nil { + return 0, err } cursor += 4 return cursor, nil diff --git a/decode_stream.go b/decode_stream.go index 6019aaa..257b8fc 100644 --- a/decode_stream.go +++ b/decode_stream.go @@ -126,10 +126,18 @@ func (s *stream) skipObject(depth int64) error { for { cursor++ switch char(p, cursor) { - case '"': - if char(p, cursor-1) == '\\' { - continue + case '\\': + cursor++ + if char(p, cursor) == nul { + s.cursor = cursor + if s.read() { + s.cursor-- // for retry current character + _, cursor, p = s.stat() + continue + } + return errUnexpectedEndOfJSON("string of object", cursor) } + case '"': goto SWITCH_OUT case nul: s.cursor = cursor @@ -183,10 +191,18 @@ func (s *stream) skipArray(depth int64) error { for { cursor++ switch char(p, cursor) { - case '"': - if char(p, cursor-1) == '\\' { - continue + case '\\': + cursor++ + if char(p, cursor) == nul { + s.cursor = cursor + if s.read() { + s.cursor-- // for retry current character + _, cursor, p = s.stat() + continue + } + return errUnexpectedEndOfJSON("string of object", cursor) } + case '"': goto SWITCH_OUT case nul: s.cursor = cursor @@ -235,10 +251,18 @@ func (s *stream) skipValue(depth int64) error { for { cursor++ switch char(p, cursor) { - case '"': - if char(p, cursor-1) == '\\' { - continue + case '\\': + cursor++ + if char(p, cursor) == nul { + s.cursor = cursor + if s.read() { + s.cursor-- // for retry current character + _, cursor, p = s.stat() + continue + } + return errUnexpectedEndOfJSON("value of string", s.totalOffset()) } + case '"': s.cursor = cursor + 1 return nil case nul: diff --git a/decode_test.go b/decode_test.go index cfc6df3..fb4ea65 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3417,3 +3417,88 @@ func TestDecodeByteSliceNull(t *testing.T) { } }) } + +func TestDecodeBackSlash(t *testing.T) { + t.Run("unmarshal", func(t *testing.T) { + t.Run("string", func(t *testing.T) { + var v1 map[string]stdjson.RawMessage + if err := stdjson.Unmarshal([]byte(`{"c":"\\"}`), &v1); err != nil { + t.Fatal(err) + } + var v2 map[string]json.RawMessage + if err := json.Unmarshal([]byte(`{"c":"\\"}`), &v2); err != nil { + t.Fatal(err) + } + if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) { + t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2) + } + }) + t.Run("array", func(t *testing.T) { + var v1 map[string]stdjson.RawMessage + if err := stdjson.Unmarshal([]byte(`{"c":["\\"]}`), &v1); err != nil { + t.Fatal(err) + } + var v2 map[string]json.RawMessage + if err := json.Unmarshal([]byte(`{"c":["\\"]}`), &v2); err != nil { + t.Fatal(err) + } + if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) { + t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2) + } + }) + t.Run("object", func(t *testing.T) { + var v1 map[string]stdjson.RawMessage + if err := stdjson.Unmarshal([]byte(`{"c":{"\\":"\\"}}`), &v1); err != nil { + t.Fatal(err) + } + var v2 map[string]json.RawMessage + if err := json.Unmarshal([]byte(`{"c":{"\\":"\\"}}`), &v2); err != nil { + t.Fatal(err) + } + if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) { + t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2) + } + }) + }) + t.Run("stream", func(t *testing.T) { + t.Run("string", func(t *testing.T) { + var v1 map[string]stdjson.RawMessage + if err := stdjson.NewDecoder(strings.NewReader(`{"c":"\\"}`)).Decode(&v1); err != nil { + t.Fatal(err) + } + var v2 map[string]json.RawMessage + if err := json.NewDecoder(strings.NewReader(`{"c":"\\"}`)).Decode(&v2); err != nil { + t.Fatal(err) + } + if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) { + t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2) + } + }) + t.Run("array", func(t *testing.T) { + var v1 map[string]stdjson.RawMessage + if err := stdjson.NewDecoder(strings.NewReader(`{"c":["\\"]}`)).Decode(&v1); err != nil { + t.Fatal(err) + } + var v2 map[string]json.RawMessage + if err := json.NewDecoder(strings.NewReader(`{"c":["\\"]}`)).Decode(&v2); err != nil { + t.Fatal(err) + } + if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) { + t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2) + } + }) + t.Run("object", func(t *testing.T) { + var v1 map[string]stdjson.RawMessage + if err := stdjson.NewDecoder(strings.NewReader(`{"c":{"\\":"\\"}}`)).Decode(&v1); err != nil { + t.Fatal(err) + } + var v2 map[string]json.RawMessage + if err := json.NewDecoder(strings.NewReader(`{"c":{"\\":"\\"}}`)).Decode(&v2); err != nil { + t.Fatal(err) + } + if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) { + t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2) + } + }) + }) +}