diff --git a/include/tree_sitter/runtime.h b/include/tree_sitter/runtime.h index a4f4df1f..59e73cd4 100644 --- a/include/tree_sitter/runtime.h +++ b/include/tree_sitter/runtime.h @@ -8,10 +8,16 @@ extern "C" { #include #include +typedef enum { + TSInputEncodingUTF8, + TSInputEncodingUTF16, +} TSInputEncoding; + typedef struct { void *payload; const char *(*read_fn)(void *payload, size_t *bytes_read); int (*seek_fn)(void *payload, size_t character, size_t byte); + TSInputEncoding encoding; } TSInput; typedef enum { diff --git a/project.gyp b/project.gyp index d4fa8b44..16d629a8 100644 --- a/project.gyp +++ b/project.gyp @@ -112,6 +112,7 @@ 'src/runtime/parser.c', 'src/runtime/string_input.c', 'src/runtime/tree.c', + 'src/runtime/utf16.c', 'externals/utf8proc/utf8proc.c', ], 'cflags_c': [ diff --git a/spec/runtime/document_spec.cc b/spec/runtime/document_spec.cc index 8e2bed99..acc176dd 100644 --- a/spec/runtime/document_spec.cc +++ b/spec/runtime/document_spec.cc @@ -40,6 +40,20 @@ describe("Document", [&]() { delete spy_input; }); + it("handles both UTF8 and UTF16 encodings", [&]() { + const char16_t content[] = u"[true, false]"; + spy_input->content = string((const char *)content, sizeof(content)); + spy_input->encoding = TSInputEncodingUTF16; + + ts_document_set_input(doc, spy_input->input()); + ts_document_invalidate(doc); + ts_document_parse(doc); + + root = ts_document_root_node(doc); + AssertThat(ts_node_string(root, doc), Equals( + "(array (true) (false))")); + }); + it("allows the input to be retrieved later", [&]() { ts_document_set_input(doc, spy_input->input()); AssertThat(ts_document_input(doc).payload, Equals(spy_input)); diff --git a/spec/runtime/helpers/encoding_helpers.cc b/spec/runtime/helpers/encoding_helpers.cc new file mode 100644 index 00000000..ec61f0a3 --- /dev/null +++ b/spec/runtime/helpers/encoding_helpers.cc @@ -0,0 +1,58 @@ +#include "runtime/helpers/encoding_helpers.h" +#include "runtime/utf16.h" +#include +#include "utf8proc.h" + +static inline int string_iterate(TSInputEncoding encoding, const uint8_t *string, size_t length, int32_t *code_point) { + if (encoding == TSInputEncodingUTF8) + return utf8proc_iterate(string, length, code_point); + else + return utf16_iterate(string, length, code_point); +} + +size_t string_char_count(TSInputEncoding encoding, const std::string &input) { + const char *string = input.data(); + size_t size = input.size(); + size_t character = 0, byte = 0; + + while (byte < size) { + int32_t code_point; + byte += string_iterate(encoding, (uint8_t *)string + byte, size - byte, &code_point); + character++; + } + + return character; +} + +long string_byte_for_character(TSInputEncoding encoding, const std::string &input, size_t byte_offset, size_t goal_character) { + const char *string = input.data() + byte_offset; + size_t size = input.size() - byte_offset; + size_t character = 0, byte = 0; + + while (character < goal_character) { + if (byte >= size) + return -1; + + int32_t code_point; + byte += string_iterate(encoding, (uint8_t *)string + byte, size - byte, &code_point); + character++; + } + + return byte; +} + +size_t utf8_char_count(const std::string &input) { + return string_char_count(TSInputEncodingUTF8, input); +} + +size_t utf16_char_count(const std::string &input) { + return string_char_count(TSInputEncodingUTF16, input); +} + +long utf8_byte_for_character(const std::string &input, size_t byte_offset, size_t goal_character) { + return string_byte_for_character(TSInputEncodingUTF8, input, byte_offset, goal_character); +} + +long utf16_byte_for_character(const std::string &input, size_t byte_offset, size_t goal_character) { + return string_byte_for_character(TSInputEncodingUTF16, input, byte_offset, goal_character); +} diff --git a/spec/runtime/helpers/encoding_helpers.h b/spec/runtime/helpers/encoding_helpers.h new file mode 100644 index 00000000..070b2326 --- /dev/null +++ b/spec/runtime/helpers/encoding_helpers.h @@ -0,0 +1,15 @@ +#ifndef HELPERS_ENCODING_HELPERS_H_ +#define HELPERS_ENCODING_HELPERS_H_ + +#include +#include "tree_sitter/runtime.h" + +size_t string_char_count(TSInputEncoding, const std::string &); +size_t utf8_char_count(const std::string &); +size_t utf16_char_count(const std::string &); + +long string_byte_for_character(TSInputEncoding, const std::string &, size_t byte_offset, size_t character); +long utf8_byte_for_character(const std::string &, size_t byte_offset, size_t character); +long utf16_byte_for_character(const std::string &, size_t byte_offset, size_t character); + +#endif // HELPERS_ENCODING_HELPERS_H_ diff --git a/spec/runtime/helpers/spy_input.cc b/spec/runtime/helpers/spy_input.cc index 4e1020b6..49a1050c 100644 --- a/spec/runtime/helpers/spy_input.cc +++ b/spec/runtime/helpers/spy_input.cc @@ -1,53 +1,20 @@ #include "runtime/helpers/spy_input.h" +#include "runtime/helpers/encoding_helpers.h" #include #include -#include "utf8proc.h" #include using std::string; static const size_t UTF8_MAX_CHAR_SIZE = 4; -size_t SpyInput::char_count(const string &text) { - const char *bytes = text.data(); - size_t len = text.size(); - size_t character = 0, byte = 0; - int32_t dest_char; - - while (byte < len) { - byte += utf8proc_iterate( - (uint8_t *)bytes + byte, - len - byte, - &dest_char); - character++; - } - - return character; -} - -static long byte_for_character(const char *str, size_t len, size_t goal_character) { - size_t character = 0, byte = 0; - int32_t dest_char; - - while (character < goal_character) { - if (byte >= len) - return -1; - byte += utf8proc_iterate( - (uint8_t *)str + byte, - len - byte, - &dest_char); - character++; - } - - return byte; -} - SpyInput::SpyInput(string content, size_t chars_per_chunk) : chars_per_chunk(chars_per_chunk), buffer_size(UTF8_MAX_CHAR_SIZE * chars_per_chunk), buffer(new char[buffer_size]), byte_offset(0), content(content), + encoding(TSInputEncodingUTF8), strings_read({""}) {} SpyInput::~SpyInput() { @@ -62,14 +29,14 @@ const char * SpyInput::read(void *payload, size_t *bytes_read) { return ""; } - const char *start = spy->content.data() + spy->byte_offset; - long byte_count = byte_for_character(start, spy->content.size() - spy->byte_offset, spy->chars_per_chunk); + long byte_count = string_byte_for_character(spy->encoding, spy->content, spy->byte_offset, spy->chars_per_chunk); if (byte_count < 0) byte_count = spy->content.size() - spy->byte_offset; + string result = spy->content.substr(spy->byte_offset, byte_count); *bytes_read = byte_count; + spy->strings_read.back() += result; spy->byte_offset += byte_count; - spy->strings_read.back() += string(start, byte_count); /* * This class stores its entire `content` in a contiguous buffer, but we want @@ -80,7 +47,7 @@ const char * SpyInput::read(void *payload, size_t *bytes_read) { * content. */ memset(spy->buffer, 0, spy->buffer_size); - memcpy(spy->buffer, start, byte_count); + memcpy(spy->buffer, result.data(), byte_count); return spy->buffer; } @@ -95,6 +62,7 @@ int SpyInput::seek(void *payload, size_t character, size_t byte) { TSInput SpyInput::input() { TSInput result; result.payload = this; + result.encoding = encoding; result.seek_fn = seek; result.read_fn = read; return result; @@ -102,7 +70,7 @@ TSInput SpyInput::input() { TSInputEdit SpyInput::replace(size_t start_char, size_t chars_removed, string text) { string text_removed = swap_substr(start_char, chars_removed, text); - size_t chars_inserted = SpyInput::char_count(text); + size_t chars_inserted = string_char_count(encoding, text); undo_stack.push_back(SpyInputEdit{start_char, chars_inserted, text_removed}); return {start_char, chars_inserted, chars_removed}; } @@ -111,20 +79,17 @@ TSInputEdit SpyInput::undo() { SpyInputEdit entry = undo_stack.back(); undo_stack.pop_back(); swap_substr(entry.position, entry.chars_removed, entry.text_inserted); - size_t chars_inserted = SpyInput::char_count(entry.text_inserted); + size_t chars_inserted = string_char_count(encoding, entry.text_inserted); return TSInputEdit{entry.position, chars_inserted, entry.chars_removed}; } string SpyInput::swap_substr(size_t start_char, size_t chars_removed, string text) { - const char *bytes = content.data(); - size_t size = content.size(); - - long start_byte = byte_for_character(bytes, size, start_char); + long start_byte = string_byte_for_character(encoding, content, 0, start_char); assert(start_byte >= 0); - long bytes_removed = byte_for_character(bytes + start_byte, size - start_byte, chars_removed); + long bytes_removed = string_byte_for_character(encoding, content, start_byte, chars_removed); if (bytes_removed < 0) - bytes_removed = size - start_byte; + bytes_removed = content.size() - start_byte; string text_removed = content.substr(start_byte, bytes_removed); content.erase(start_byte, bytes_removed); diff --git a/spec/runtime/helpers/spy_input.h b/spec/runtime/helpers/spy_input.h index a0b268b6..3a9d5122 100644 --- a/spec/runtime/helpers/spy_input.h +++ b/spec/runtime/helpers/spy_input.h @@ -31,9 +31,8 @@ class SpyInput { TSInputEdit replace(size_t start_char, size_t chars_removed, std::string text); TSInputEdit undo(); - static size_t char_count(const std::string &); - std::string content; + TSInputEncoding encoding; std::vector strings_read; }; diff --git a/spec/runtime/language_specs.cc b/spec/runtime/language_specs.cc index d2bc0756..57f1fba2 100644 --- a/spec/runtime/language_specs.cc +++ b/spec/runtime/language_specs.cc @@ -7,6 +7,7 @@ #include "runtime/helpers/spy_input.h" #include "runtime/helpers/log_debugger.h" #include "runtime/helpers/point_helpers.h" +#include "runtime/helpers/encoding_helpers.h" extern "C" const TSLanguage *ts_language_javascript(); extern "C" const TSLanguage *ts_language_json(); @@ -145,8 +146,8 @@ describe("Languages", [&]() { std::set> insertions; for (size_t i = 0; i < 80; i++) { - size_t edit_position = random() % SpyInput::char_count(entry.input); - size_t deletion_size = random() % (SpyInput::char_count(entry.input) - edit_position); + size_t edit_position = random() % utf8_char_count(entry.input); + size_t deletion_size = random() % (utf8_char_count(entry.input) - edit_position); string inserted_text = random_words(random() % 4 + 1); if (insertions.insert({edit_position, inserted_text}).second) { diff --git a/src/runtime/lexer.c b/src/runtime/lexer.c index 7029cf91..1e8d6f3e 100644 --- a/src/runtime/lexer.c +++ b/src/runtime/lexer.c @@ -4,6 +4,7 @@ #include "runtime/tree.h" #include "runtime/length.h" #include "runtime/debugger.h" +#include "runtime/utf16.h" #include "utf8proc.h" #define LOG(...) \ @@ -18,7 +19,7 @@ : "lookahead char:%d", \ self->lookahead); -static const char *empty_chunk = ""; +static const char empty_chunk[2] = { 0, 0 }; static void ts_lexer__get_chunk(TSLexer *self) { TSInput input = self->input; @@ -35,9 +36,14 @@ static void ts_lexer__get_chunk(TSLexer *self) { static void ts_lexer__get_lookahead(TSLexer *self) { size_t position_in_chunk = self->current_position.bytes - self->chunk_start; - self->lookahead_size = utf8proc_iterate( - (const uint8_t *)self->chunk + position_in_chunk, - self->chunk_size - position_in_chunk + 1, &self->lookahead); + const uint8_t *chunk = (const uint8_t *)self->chunk + position_in_chunk; + size_t size = self->chunk_size - position_in_chunk + 1; + + if (self->input.encoding == TSInputEncodingUTF8) + self->lookahead_size = utf8proc_iterate(chunk, size, &self->lookahead); + else + self->lookahead_size = utf16_iterate(chunk, size, &self->lookahead); + LOG_LOOKAHEAD(); } diff --git a/src/runtime/utf16.c b/src/runtime/utf16.c new file mode 100644 index 00000000..a8ae6bdd --- /dev/null +++ b/src/runtime/utf16.c @@ -0,0 +1,24 @@ +#include "runtime/utf16.h" + +int utf16_iterate(const uint8_t *string, size_t length, int32_t *code_point) { + uint16_t *units = (uint16_t *)string; + uint16_t unit = units[0]; + + if (unit < 0xd800 || unit >= 0xe000) { + *code_point = unit; + return 2; + } + + if (unit < 0xdc00) { + if (length >= 4) { + uint16_t next_unit = units[1]; + if (next_unit >= 0xdc00 && next_unit < 0xe000) { + *code_point = 0x10000 + ((unit - 0xd800) << 10) + (next_unit - 0xdc00); + return 4; + } + } + } + + *code_point = -1; + return 2; +} diff --git a/src/runtime/utf16.h b/src/runtime/utf16.h new file mode 100644 index 00000000..70146dd8 --- /dev/null +++ b/src/runtime/utf16.h @@ -0,0 +1,20 @@ +#ifndef RUNTIME_UTF16_H_ +#define RUNTIME_UTF16_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +// Analogous to utf8proc's utf8proc_iterate function. Reads one code point from +// the given string and stores it in the location pointed to by `code_point`. +// Returns the number of bytes in `string` that were read. +int utf16_iterate(const uint8_t *string, size_t length, int32_t *code_point); + +#ifdef __cplusplus +} +#endif + +#endif // RUNTIME_UTF16_H_