Fix decoding of backslash char at the end of string

This commit is contained in:
Masaaki Goshima 2021-05-02 16:02:14 +09:00
parent 2f3afbf6ad
commit f341b31ea5
3 changed files with 135 additions and 30 deletions

View File

@ -55,10 +55,12 @@ func skipObject(buf []byte, cursor, depth int64) (int64, error) {
for { for {
cursor++ cursor++
switch buf[cursor] { switch buf[cursor] {
case '"': case '\\':
if buf[cursor-1] == '\\' { cursor++
continue if buf[cursor] == nul {
return 0, errUnexpectedEndOfJSON("string of object", cursor)
} }
case '"':
goto SWITCH_OUT goto SWITCH_OUT
case nul: case nul:
return 0, errUnexpectedEndOfJSON("string of object", cursor) return 0, errUnexpectedEndOfJSON("string of object", cursor)
@ -99,10 +101,12 @@ func skipArray(buf []byte, cursor, depth int64) (int64, error) {
for { for {
cursor++ cursor++
switch buf[cursor] { switch buf[cursor] {
case '"': case '\\':
if buf[cursor-1] == '\\' { cursor++
continue if buf[cursor] == nul {
return 0, errUnexpectedEndOfJSON("string of object", cursor)
} }
case '"':
goto SWITCH_OUT goto SWITCH_OUT
case nul: case nul:
return 0, errUnexpectedEndOfJSON("string of object", cursor) return 0, errUnexpectedEndOfJSON("string of object", cursor)
@ -130,10 +134,12 @@ func skipValue(buf []byte, cursor, depth int64) (int64, error) {
for { for {
cursor++ cursor++
switch buf[cursor] { switch buf[cursor] {
case '"': case '\\':
if buf[cursor-1] == '\\' { cursor++
continue if buf[cursor] == nul {
return 0, errUnexpectedEndOfJSON("string of object", cursor)
} }
case '"':
return cursor + 1, nil return cursor + 1, nil
case nul: case nul:
return 0, errUnexpectedEndOfJSON("string of object", cursor) return 0, errUnexpectedEndOfJSON("string of object", cursor)
@ -184,18 +190,8 @@ func skipValue(buf []byte, cursor, depth int64) (int64, error) {
cursor += 5 cursor += 5
return cursor, nil return cursor, nil
case 'n': case 'n':
buflen := int64(len(buf)) if err := validateNull(buf, cursor); err != nil {
if cursor+3 >= buflen { return 0, err
return 0, errUnexpectedEndOfJSON("null", cursor)
}
if buf[cursor+1] != 'u' {
return 0, errUnexpectedEndOfJSON("null", cursor)
}
if buf[cursor+2] != 'l' {
return 0, errUnexpectedEndOfJSON("null", cursor)
}
if buf[cursor+3] != 'l' {
return 0, errUnexpectedEndOfJSON("null", cursor)
} }
cursor += 4 cursor += 4
return cursor, nil return cursor, nil

View File

@ -126,10 +126,18 @@ func (s *stream) skipObject(depth int64) error {
for { for {
cursor++ cursor++
switch char(p, cursor) { switch char(p, cursor) {
case '"': case '\\':
if char(p, cursor-1) == '\\' { cursor++
if char(p, cursor) == nul {
s.cursor = cursor
if s.read() {
s.cursor-- // for retry current character
_, cursor, p = s.stat()
continue continue
} }
return errUnexpectedEndOfJSON("string of object", cursor)
}
case '"':
goto SWITCH_OUT goto SWITCH_OUT
case nul: case nul:
s.cursor = cursor s.cursor = cursor
@ -183,10 +191,18 @@ func (s *stream) skipArray(depth int64) error {
for { for {
cursor++ cursor++
switch char(p, cursor) { switch char(p, cursor) {
case '"': case '\\':
if char(p, cursor-1) == '\\' { cursor++
if char(p, cursor) == nul {
s.cursor = cursor
if s.read() {
s.cursor-- // for retry current character
_, cursor, p = s.stat()
continue continue
} }
return errUnexpectedEndOfJSON("string of object", cursor)
}
case '"':
goto SWITCH_OUT goto SWITCH_OUT
case nul: case nul:
s.cursor = cursor s.cursor = cursor
@ -235,10 +251,18 @@ func (s *stream) skipValue(depth int64) error {
for { for {
cursor++ cursor++
switch char(p, cursor) { switch char(p, cursor) {
case '"': case '\\':
if char(p, cursor-1) == '\\' { cursor++
if char(p, cursor) == nul {
s.cursor = cursor
if s.read() {
s.cursor-- // for retry current character
_, cursor, p = s.stat()
continue continue
} }
return errUnexpectedEndOfJSON("value of string", s.totalOffset())
}
case '"':
s.cursor = cursor + 1 s.cursor = cursor + 1
return nil return nil
case nul: case nul:

View File

@ -3417,3 +3417,88 @@ func TestDecodeByteSliceNull(t *testing.T) {
} }
}) })
} }
func TestDecodeBackSlash(t *testing.T) {
t.Run("unmarshal", func(t *testing.T) {
t.Run("string", func(t *testing.T) {
var v1 map[string]stdjson.RawMessage
if err := stdjson.Unmarshal([]byte(`{"c":"\\"}`), &v1); err != nil {
t.Fatal(err)
}
var v2 map[string]json.RawMessage
if err := json.Unmarshal([]byte(`{"c":"\\"}`), &v2); err != nil {
t.Fatal(err)
}
if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) {
t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2)
}
})
t.Run("array", func(t *testing.T) {
var v1 map[string]stdjson.RawMessage
if err := stdjson.Unmarshal([]byte(`{"c":["\\"]}`), &v1); err != nil {
t.Fatal(err)
}
var v2 map[string]json.RawMessage
if err := json.Unmarshal([]byte(`{"c":["\\"]}`), &v2); err != nil {
t.Fatal(err)
}
if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) {
t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2)
}
})
t.Run("object", func(t *testing.T) {
var v1 map[string]stdjson.RawMessage
if err := stdjson.Unmarshal([]byte(`{"c":{"\\":"\\"}}`), &v1); err != nil {
t.Fatal(err)
}
var v2 map[string]json.RawMessage
if err := json.Unmarshal([]byte(`{"c":{"\\":"\\"}}`), &v2); err != nil {
t.Fatal(err)
}
if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) {
t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2)
}
})
})
t.Run("stream", func(t *testing.T) {
t.Run("string", func(t *testing.T) {
var v1 map[string]stdjson.RawMessage
if err := stdjson.NewDecoder(strings.NewReader(`{"c":"\\"}`)).Decode(&v1); err != nil {
t.Fatal(err)
}
var v2 map[string]json.RawMessage
if err := json.NewDecoder(strings.NewReader(`{"c":"\\"}`)).Decode(&v2); err != nil {
t.Fatal(err)
}
if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) {
t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2)
}
})
t.Run("array", func(t *testing.T) {
var v1 map[string]stdjson.RawMessage
if err := stdjson.NewDecoder(strings.NewReader(`{"c":["\\"]}`)).Decode(&v1); err != nil {
t.Fatal(err)
}
var v2 map[string]json.RawMessage
if err := json.NewDecoder(strings.NewReader(`{"c":["\\"]}`)).Decode(&v2); err != nil {
t.Fatal(err)
}
if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) {
t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2)
}
})
t.Run("object", func(t *testing.T) {
var v1 map[string]stdjson.RawMessage
if err := stdjson.NewDecoder(strings.NewReader(`{"c":{"\\":"\\"}}`)).Decode(&v1); err != nil {
t.Fatal(err)
}
var v2 map[string]json.RawMessage
if err := json.NewDecoder(strings.NewReader(`{"c":{"\\":"\\"}}`)).Decode(&v2); err != nil {
t.Fatal(err)
}
if len(v1) != len(v2) || !bytes.Equal(v1["c"], v2["c"]) {
t.Fatalf("failed to decode backslash: expected %#v but got %#v", v1, v2)
}
})
})
}