diff --git a/src/runtime/lexer.c b/src/runtime/lexer.c index 5ecd5a84..f4ddaf25 100644 --- a/src/runtime/lexer.c +++ b/src/runtime/lexer.c @@ -8,11 +8,15 @@ #define LOG(...) \ if (self->logger.log) { \ snprintf(self->debug_buffer, TREE_SITTER_SERIALIZATION_BUFFER_SIZE, __VA_ARGS__); \ - self->logger.log(self->logger.payload, TSLogTypeLex, self->debug_buffer); \ + self->logger.log(self->logger.payload, TSLogTypeLex, self->debug_buffer); \ } #define LOG_CHARACTER(message, character) \ - LOG(character < 255 ? message " character:'%c'" : message " character:%d", character) + LOG( \ + 32 <= character && character < 127 ? \ + message " character:'%c'" : \ + message " character:%d", character \ + ) static const char empty_chunk[3] = { 0, 0 }; @@ -27,6 +31,12 @@ static void ts_lexer__get_chunk(Lexer *self) { if (!self->chunk_size) self->chunk = empty_chunk; } +typedef utf8proc_ssize_t (*DecodeFunction)( + const utf8proc_uint8_t *, + utf8proc_ssize_t, + utf8proc_int32_t * +); + static void ts_lexer__get_lookahead(Lexer *self) { uint32_t position_in_chunk = self->current_position.bytes - self->chunk_start; const uint8_t *chunk = (const uint8_t *)self->chunk + position_in_chunk; @@ -38,15 +48,22 @@ static void ts_lexer__get_lookahead(Lexer *self) { return; } - if (self->input.encoding == TSInputEncodingUTF8) { - int64_t lookahead_size = utf8proc_iterate(chunk, size, &self->data.lookahead); - if (lookahead_size < 0) { - self->lookahead_size = 1; - } else { - self->lookahead_size = lookahead_size; - } - } else { - self->lookahead_size = utf16_iterate(chunk, size, &self->data.lookahead); + DecodeFunction decode = + self->input.encoding == TSInputEncodingUTF8 ? utf8proc_iterate : utf16_iterate; + + self->lookahead_size = decode(chunk, size, &self->data.lookahead); + + // If this chunk ended in the middle of a multi-byte character, + // try again with a fresh chunk. + if (self->data.lookahead == -1 && size < 4) { + ts_lexer__get_chunk(self); + chunk = (const uint8_t *)self->chunk; + size = self->chunk_size; + self->lookahead_size = decode(chunk, size, &self->data.lookahead); + } + + if (self->data.lookahead == -1) { + self->lookahead_size = 1; } } diff --git a/src/runtime/utf16.c b/src/runtime/utf16.c index 050caf19..adb82edf 100644 --- a/src/runtime/utf16.c +++ b/src/runtime/utf16.c @@ -1,6 +1,10 @@ #include "runtime/utf16.h" -int utf16_iterate(const uint8_t *string, size_t length, int32_t *code_point) { +utf8proc_ssize_t utf16_iterate( + const utf8proc_uint8_t *string, + utf8proc_ssize_t length, + utf8proc_int32_t *code_point +) { if (length < 2) { *code_point = -1; return 0; diff --git a/src/runtime/utf16.h b/src/runtime/utf16.h index 70146dd8..0cf69218 100644 --- a/src/runtime/utf16.h +++ b/src/runtime/utf16.h @@ -7,11 +7,12 @@ extern "C" { #include #include +#include "utf8proc.h" // Analogous to utf8proc's utf8proc_iterate function. Reads one code point from -// the given string and stores it in the location pointed to by `code_point`. +// the given UTF16 string and stores it in the location pointed to by `code_point`. // Returns the number of bytes in `string` that were read. -int utf16_iterate(const uint8_t *string, size_t length, int32_t *code_point); +utf8proc_ssize_t utf16_iterate(const utf8proc_uint8_t *, utf8proc_ssize_t, utf8proc_int32_t *); #ifdef __cplusplus } diff --git a/test/helpers/encoding_helpers.cc b/test/helpers/encoding_helpers.cc deleted file mode 100644 index 1169bb2d..00000000 --- a/test/helpers/encoding_helpers.cc +++ /dev/null @@ -1,64 +0,0 @@ -#include "helpers/encoding_helpers.h" -#include "runtime/utf16.h" -#include -#include "utf8proc.h" - -static inline int string_iterate(TSInputEncoding encoding, const uint8_t *string, size_t length, int32_t *code_point) { - if (encoding == TSInputEncodingUTF8) { - int32_t character_size = utf8proc_iterate(string, length, code_point); - if (character_size < 0) { - return 1; - } else { - return character_size; - } - } else { - return utf16_iterate(string, length, code_point); - } -} - -size_t string_char_count(TSInputEncoding encoding, const std::string &input) { - const char *string = input.data(); - size_t size = input.size(); - size_t character = 0, byte = 0; - - while (byte < size) { - int32_t code_point; - byte += string_iterate(encoding, (uint8_t *)string + byte, size - byte, &code_point); - character++; - } - - return character; -} - -long string_byte_for_character(TSInputEncoding encoding, const std::string &input, size_t byte_offset, size_t goal_character) { - const char *string = input.data() + byte_offset; - size_t size = input.size() - byte_offset; - size_t character = 0, byte = 0; - - while (character < goal_character) { - if (byte >= size) - return -1; - - int32_t code_point; - byte += string_iterate(encoding, (uint8_t *)string + byte, size - byte, &code_point); - character++; - } - - return byte; -} - -size_t utf8_char_count(const std::string &input) { - return string_char_count(TSInputEncodingUTF8, input); -} - -size_t utf16_char_count(const std::string &input) { - return string_char_count(TSInputEncodingUTF16, input); -} - -long utf8_byte_for_character(const std::string &input, size_t byte_offset, size_t goal_character) { - return string_byte_for_character(TSInputEncodingUTF8, input, byte_offset, goal_character); -} - -long utf16_byte_for_character(const std::string &input, size_t byte_offset, size_t goal_character) { - return string_byte_for_character(TSInputEncodingUTF16, input, byte_offset, goal_character); -} diff --git a/test/helpers/encoding_helpers.h b/test/helpers/encoding_helpers.h deleted file mode 100644 index 070b2326..00000000 --- a/test/helpers/encoding_helpers.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef HELPERS_ENCODING_HELPERS_H_ -#define HELPERS_ENCODING_HELPERS_H_ - -#include -#include "tree_sitter/runtime.h" - -size_t string_char_count(TSInputEncoding, const std::string &); -size_t utf8_char_count(const std::string &); -size_t utf16_char_count(const std::string &); - -long string_byte_for_character(TSInputEncoding, const std::string &, size_t byte_offset, size_t character); -long utf8_byte_for_character(const std::string &, size_t byte_offset, size_t character); -long utf16_byte_for_character(const std::string &, size_t byte_offset, size_t character); - -#endif // HELPERS_ENCODING_HELPERS_H_ diff --git a/test/helpers/spy_input.cc b/test/helpers/spy_input.cc index 34c5d997..8aa8963b 100644 --- a/test/helpers/spy_input.cc +++ b/test/helpers/spy_input.cc @@ -1,5 +1,4 @@ #include "helpers/spy_input.h" -#include "helpers/encoding_helpers.h" #include "helpers/point_helpers.h" #include "runtime/point.h" #include @@ -46,19 +45,13 @@ const char *SpyInput::read(void *payload, uint32_t byte_offset, TSPoint position, uint32_t *bytes_read) { auto spy = static_cast(payload); - if (byte_offset >= spy->content.size()) { - *bytes_read = 0; - return ""; + unsigned end_byte = byte_offset + spy->chars_per_chunk; + if (end_byte > spy->content.size()) { + end_byte = spy->content.size(); } - long byte_count = string_byte_for_character(spy->encoding, spy->content, byte_offset, spy->chars_per_chunk); - if (byte_count < 0) { - byte_count = spy->content.size() - byte_offset; - } - - string result = spy->content.substr(byte_offset, byte_count); - *bytes_read = byte_count; - add_byte_range(&spy->ranges_read, byte_offset, byte_count); + *bytes_read = end_byte - byte_offset; + add_byte_range(&spy->ranges_read, byte_offset, *bytes_read); /* * This class stores its entire `content` in a contiguous buffer, but we want @@ -70,9 +63,9 @@ const char *SpyInput::read(void *payload, uint32_t byte_offset, * can detect code reading too many bytes from the buffer. */ delete[] spy->buffer; - if (byte_count) { - spy->buffer = new char[byte_count]; - memcpy(spy->buffer, result.data(), byte_count); + if (*bytes_read) { + spy->buffer = new char[*bytes_read](); + memcpy(spy->buffer, spy->content.data() + byte_offset, *bytes_read); } else { spy->buffer = nullptr; } diff --git a/test/integration/real_grammars.cc b/test/integration/real_grammars.cc index 0a0e6d7d..7384799c 100644 --- a/test/integration/real_grammars.cc +++ b/test/integration/real_grammars.cc @@ -5,7 +5,6 @@ #include "helpers/spy_input.h" #include "helpers/stderr_logger.h" #include "helpers/point_helpers.h" -#include "helpers/encoding_helpers.h" #include "helpers/record_alloc.h" #include "helpers/random_helpers.h" #include "helpers/scope_sequence.h" @@ -57,7 +56,7 @@ for (auto &language_name : test_languages) { SpyInput *input; it(("parses " + entry.description + ": initial parse").c_str(), [&]() { - input = new SpyInput(entry.input, 3); + input = new SpyInput(entry.input, 4); if (debug_graphs_enabled) printf("%s\n\n", input->content.c_str()); TSTree *tree = ts_parser_parse(parser, nullptr, input->input()); @@ -77,8 +76,8 @@ for (auto &language_name : test_languages) { set> insertions; for (size_t i = 0; i < 60; i++) { - size_t edit_position = default_generator(utf8_char_count(entry.input)); - size_t deletion_size = default_generator(utf8_char_count(entry.input) - edit_position); + size_t edit_position = default_generator(entry.input.size()); + size_t deletion_size = default_generator(entry.input.size() - edit_position); string inserted_text = default_generator.words(default_generator(4) + 1); if (insertions.insert({edit_position, inserted_text}).second) { diff --git a/test/runtime/parser_test.cc b/test/runtime/parser_test.cc index 9a354252..e30b82c3 100644 --- a/test/runtime/parser_test.cc +++ b/test/runtime/parser_test.cc @@ -544,6 +544,21 @@ describe("Parser", [&]() { root = ts_tree_root_node(tree); AssertThat(ts_node_end_point(root), Equals({0, 28})); }); + + it("handles input chunks that end in the middle of multi-byte characters", [&]() { + ts_parser_set_language(parser, load_real_language("c")); + spy_input->content = "A b = {'👍','👍'};"; + spy_input->chars_per_chunk = 4; + + tree = ts_parser_parse(parser, nullptr, spy_input->input()); + root = ts_tree_root_node(tree); + assert_root_node( + "(translation_unit (declaration " + "(type_identifier) " + "(init_declarator " + "(identifier) " + "(initializer_list (char_literal) (char_literal)))))"); + }); }); describe("set_language(language)", [&]() { diff --git a/test/runtime/tree_test.cc b/test/runtime/tree_test.cc index c60f2af6..86abfea1 100644 --- a/test/runtime/tree_test.cc +++ b/test/runtime/tree_test.cc @@ -11,7 +11,6 @@ #include "helpers/load_language.h" #include "helpers/random_helpers.h" #include "helpers/read_test_entries.h" -#include "helpers/encoding_helpers.h" #include "helpers/tree_helpers.h" TSPoint point(uint32_t row, uint32_t column) { @@ -71,8 +70,8 @@ describe("Tree", [&]() { for (unsigned j = 0; j < 10; j++) { random.sleep_some(); - size_t edit_position = random(utf8_char_count(input->content)); - size_t deletion_size = random(utf8_char_count(input->content) - edit_position); + size_t edit_position = random(input->content.size()); + size_t deletion_size = random(input->content.size() - edit_position); string inserted_text = random.words(random(4) + 1); TSInputEdit edit = input->replace(edit_position, deletion_size, inserted_text); diff --git a/tests.gyp b/tests.gyp index 80e8d618..ed97c3e4 100644 --- a/tests.gyp +++ b/tests.gyp @@ -50,7 +50,6 @@ 'test/compiler/rules/character_set_test.cc', 'test/compiler/rules/rule_test.cc', 'test/compiler/util/string_helpers_test.cc', - 'test/helpers/encoding_helpers.cc', 'test/helpers/file_helpers.cc', 'test/helpers/load_language.cc', 'test/helpers/point_helpers.cc',