diff --git a/decode_struct.go b/decode_struct.go index f857bfc..d459fb8 100644 --- a/decode_struct.go +++ b/decode_struct.go @@ -6,6 +6,8 @@ import ( "math/bits" "sort" "strings" + "unicode" + "unicode/utf16" "unsafe" ) @@ -136,6 +138,52 @@ func (d *structDecoder) tryOptimize() { } } +// decode from '\uXXXX' +func decodeKeyCharByUnicodeRune(buf []byte, cursor int64) ([]byte, int64) { + const defaultOffset = 4 + const surrogateOffset = 6 + + r := unicodeToRune(buf[cursor : cursor+defaultOffset]) + if utf16.IsSurrogate(r) { + cursor += defaultOffset + if cursor+surrogateOffset >= int64(len(buf)) || buf[cursor] != '\\' || buf[cursor+1] != 'u' { + return []byte(string(unicode.ReplacementChar)), cursor + defaultOffset - 1 + } + cursor += 2 + r2 := unicodeToRune(buf[cursor : cursor+defaultOffset]) + if r := utf16.DecodeRune(r, r2); r != unicode.ReplacementChar { + return []byte(string(r)), cursor + defaultOffset - 1 + } + } + return []byte(string(r)), cursor + defaultOffset - 1 +} + +func decodeKeyCharByEscapedChar(buf []byte, cursor int64) ([]byte, int64) { + c := buf[cursor] + cursor++ + switch c { + case '"': + return []byte{'"'}, cursor + case '\\': + return []byte{'\\'}, cursor + case '/': + return []byte{'/'}, cursor + case 'b': + return []byte{'\b'}, cursor + case 'f': + return []byte{'\f'}, cursor + case 'n': + return []byte{'\n'}, cursor + case 'r': + return []byte{'\r'}, cursor + case 't': + return []byte{'\t'}, cursor + case 'u': + return decodeKeyCharByUnicodeRune(buf, cursor) + } + return nil, cursor +} + func decodeKeyByBitmapUint8(d *structDecoder, buf []byte, cursor int64) (int64, *structFieldSet, error) { var ( field *structFieldSet @@ -174,24 +222,21 @@ func decodeKeyByBitmapUint8(d *structDecoder, buf []byte, cursor int64) (int64, return cursor, field, nil case nul: return 0, nil, errUnexpectedEndOfJSON("string", cursor) + case '\\': + cursor++ + chars, nextCursor := decodeKeyCharByEscapedChar(buf, cursor) + for _, c := range chars { + curBit &= bitmap[keyIdx][largeToSmallTable[c]] + if curBit == 0 { + return decodeKeyNotFound(b, cursor, field) + } + keyIdx++ + } + cursor = nextCursor default: curBit &= bitmap[keyIdx][largeToSmallTable[c]] if curBit == 0 { - for { - cursor++ - switch char(b, cursor) { - case '"': - cursor++ - return cursor, field, nil - case '\\': - cursor++ - if char(b, cursor) == nul { - return 0, nil, errUnexpectedEndOfJSON("string", cursor) - } - case nul: - return 0, nil, errUnexpectedEndOfJSON("string", cursor) - } - } + return decodeKeyNotFound(b, cursor, field) } keyIdx++ } @@ -203,6 +248,24 @@ func decodeKeyByBitmapUint8(d *structDecoder, buf []byte, cursor int64) (int64, } } +func decodeKeyNotFound(b unsafe.Pointer, cursor int64, field *structFieldSet) (int64, *structFieldSet, error) { + for { + cursor++ + switch char(b, cursor) { + case '"': + cursor++ + return cursor, field, nil + case '\\': + cursor++ + if char(b, cursor) == nul { + return 0, nil, errUnexpectedEndOfJSON("string", cursor) + } + case nul: + return 0, nil, errUnexpectedEndOfJSON("string", cursor) + } + } +} + func decodeKeyByBitmapUint16(d *structDecoder, buf []byte, cursor int64) (int64, *structFieldSet, error) { var ( field *structFieldSet @@ -241,24 +304,21 @@ func decodeKeyByBitmapUint16(d *structDecoder, buf []byte, cursor int64) (int64, return cursor, field, nil case nul: return 0, nil, errUnexpectedEndOfJSON("string", cursor) + case '\\': + cursor++ + chars, nextCursor := decodeKeyCharByEscapedChar(buf, cursor) + for _, c := range chars { + curBit &= bitmap[keyIdx][largeToSmallTable[c]] + if curBit == 0 { + return decodeKeyNotFound(b, cursor, field) + } + keyIdx++ + } + cursor = nextCursor default: curBit &= bitmap[keyIdx][largeToSmallTable[c]] if curBit == 0 { - for { - cursor++ - switch char(b, cursor) { - case '"': - cursor++ - return cursor, field, nil - case '\\': - cursor++ - if char(b, cursor) == nul { - return 0, nil, errUnexpectedEndOfJSON("string", cursor) - } - case nul: - return 0, nil, errUnexpectedEndOfJSON("string", cursor) - } - } + return decodeKeyNotFound(b, cursor, field) } keyIdx++ } diff --git a/decode_test.go b/decode_test.go index 7dab500..8298a1f 100644 --- a/decode_test.go +++ b/decode_test.go @@ -3594,3 +3594,18 @@ func TestIssue218(t *testing.T) { }) } } + +func TestDecodeEscapedCharField(t *testing.T) { + b := []byte(`{"\u6D88\u606F":"\u6D88\u606F"}`) + t.Run("unmarshal", func(t *testing.T) { + v := struct { + Msg string `json:"消息"` + }{} + if err := json.Unmarshal(b, &v); err != nil { + t.Fatal(err) + } + if !bytes.Equal([]byte(v.Msg), []byte("消息")) { + t.Fatal("failed to decode unicode char") + } + }) +}