diff --git a/include/tree_sitter/runtime.h b/include/tree_sitter/runtime.h index 5ee8b2e7..a4dbedc9 100644 --- a/include/tree_sitter/runtime.h +++ b/include/tree_sitter/runtime.h @@ -34,8 +34,10 @@ typedef struct { } TSPoint; typedef struct { - TSPoint start; - TSPoint end; + TSPoint start_point; + TSPoint end_point; + uint32_t start_byte; + uint32_t end_byte; } TSRange; typedef struct { @@ -90,6 +92,8 @@ void ts_parser_set_enabled(TSParser *, bool); size_t ts_parser_operation_limit(const TSParser *); void ts_parser_set_operation_limit(TSParser *, size_t); void ts_parser_reset(TSParser *); +void ts_parser_set_included_ranges(TSParser *, const TSRange *, uint32_t); +const TSRange *ts_parser_included_ranges(const TSParser *, uint32_t *); TSTree *ts_tree_copy(const TSTree *); void ts_tree_delete(TSTree *); diff --git a/src/runtime/get_changed_ranges.c b/src/runtime/get_changed_ranges.c index d46bf704..f83e0493 100644 --- a/src/runtime/get_changed_ranges.c +++ b/src/runtime/get_changed_ranges.c @@ -9,17 +9,18 @@ typedef Array(TSRange) RangeArray; -static void range_array_add(RangeArray *results, TSPoint start, TSPoint end) { +static void range_array_add(RangeArray *results, Length start, Length end) { if (results->size > 0) { TSRange *last_range = array_back(results); - if (point_lte(start, last_range->end)) { - last_range->end = end; + if (start.bytes <= last_range->end_byte) { + last_range->end_byte = end.bytes; + last_range->end_point = end.extent; return; } } - if (point_lt(start, end)) { - TSRange range = { start, end }; + if (start.bytes < end.bytes) { + TSRange range = { start.extent, end.extent, start.bytes, end.bytes }; array_push(results, range); } } @@ -272,10 +273,10 @@ unsigned ts_subtree_get_changed_ranges(const Subtree *old_tree, const Subtree *n Length position = iterator_start_position(&old_iter); Length next_position = iterator_start_position(&new_iter); if (position.bytes < next_position.bytes) { - range_array_add(&results, position.extent, next_position.extent); + range_array_add(&results, position, next_position); position = next_position; } else if (position.bytes > next_position.bytes) { - range_array_add(&results, next_position.extent, position.extent); + range_array_add(&results, next_position, position); next_position = position; } @@ -343,7 +344,7 @@ unsigned ts_subtree_get_changed_ranges(const Subtree *old_tree, const Subtree *n ); #endif - range_array_add(&results, position.extent, next_position.extent); + range_array_add(&results, position, next_position); } position = next_position; diff --git a/src/runtime/lexer.c b/src/runtime/lexer.c index e78a0778..2443fb1a 100644 --- a/src/runtime/lexer.c +++ b/src/runtime/lexer.c @@ -72,6 +72,22 @@ static void ts_lexer__advance(void *payload, bool skip) { LOG_CHARACTER("consume", self->data.lookahead); } + TSRange *current_range = &self->included_ranges[self->current_included_range_index]; + if (self->current_position.bytes == current_range->end_byte) { + self->current_included_range_index++; + if (self->current_included_range_index == self->included_range_count) { + self->data.lookahead = '\0'; + self->lookahead_size = 1; + return; + } else { + current_range++; + self->current_position = (Length) { + current_range->start_byte, + current_range->start_point, + }; + } + } + if (self->current_position.bytes >= self->chunk_start + self->chunk_size) { ts_lexer__get_chunk(self); } @@ -118,14 +134,23 @@ void ts_lexer_init(Lexer *self) { }, .chunk = NULL, .chunk_start = 0, + .current_position = {UINT32_MAX, {0, 0}}, .logger = { .payload = NULL, .log = NULL }, + .current_included_range_index = 0, }; + + self->included_ranges = NULL; + ts_lexer_set_included_ranges(self, NULL, 0); ts_lexer_reset(self, length_zero()); } +void ts_lexer_delete(Lexer *self) { + ts_free(self->included_ranges); +} + void ts_lexer_set_input(Lexer *self, TSInput input) { self->input = input; self->data.lookahead = 0; @@ -135,22 +160,52 @@ void ts_lexer_set_input(Lexer *self, TSInput input) { self->chunk_size = 0; } -void ts_lexer_reset(Lexer *self, Length position) { - if (position.bytes != self->current_position.bytes) { - self->token_start_position = position; - self->token_end_position = LENGTH_UNDEFINED; - self->current_position = position; +static void ts_lexer_goto(Lexer *self, Length position) { + bool found_included_range = false; + for (unsigned i = 0; i < self->included_range_count; i++) { + TSRange *included_range = &self->included_ranges[i]; + if (included_range->end_byte > position.bytes) { + if (included_range->start_byte > position.bytes) { + position = (Length) { + .bytes = included_range->start_byte, + .extent = included_range->start_point, + }; + } - if (self->chunk && (position.bytes < self->chunk_start || - position.bytes >= self->chunk_start + self->chunk_size)) { - self->chunk = 0; - self->chunk_start = 0; - self->chunk_size = 0; + self->current_included_range_index = i; + found_included_range = true; + break; } - - self->lookahead_size = 0; - self->data.lookahead = 0; } + + if (!found_included_range) { + TSRange *last_included_range = &self->included_ranges[self->included_range_count - 1]; + position = (Length) { + .bytes = last_included_range->end_byte, + .extent = last_included_range->end_point, + }; + self->chunk = empty_chunk; + self->chunk_start = position.bytes; + self->chunk_size = 1; + } + + self->token_start_position = position; + self->token_end_position = LENGTH_UNDEFINED; + self->current_position = position; + + if (self->chunk && (position.bytes < self->chunk_start || + position.bytes >= self->chunk_start + self->chunk_size)) { + self->chunk = 0; + self->chunk_start = 0; + self->chunk_size = 0; + } + + self->lookahead_size = 0; + self->data.lookahead = 0; +} + +void ts_lexer_reset(Lexer *self, Length position) { + if (position.bytes != self->current_position.bytes) ts_lexer_goto(self, position); } void ts_lexer_start(Lexer *self) { @@ -164,3 +219,36 @@ void ts_lexer_start(Lexer *self) { void ts_lexer_advance_to_end(Lexer *self) { while (self->data.lookahead != 0) ts_lexer__advance(self, false); } + +static const TSRange DEFAULT_RANGES[] = { + { + .start_point = { + .row = 0, + .column = 0, + }, + .end_point = { + .row = UINT32_MAX, + .column = UINT32_MAX, + }, + .start_byte = 0, + .end_byte = UINT32_MAX + } +}; + +void ts_lexer_set_included_ranges(Lexer *self, const TSRange *ranges, uint32_t count) { + if (!ranges) { + ranges = DEFAULT_RANGES; + count = 1; + } + + size_t sz = count * sizeof(TSRange); + self->included_ranges = ts_realloc(self->included_ranges, sz); + memcpy(self->included_ranges, ranges, sz); + self->included_range_count = count; + ts_lexer_goto(self, self->current_position); +} + +TSRange *ts_lexer_included_ranges(const Lexer *self, uint32_t *count) { + *count = self->included_range_count; + return self->included_ranges; +} diff --git a/src/runtime/lexer.h b/src/runtime/lexer.h index d6cf6279..68ded2b0 100644 --- a/src/runtime/lexer.h +++ b/src/runtime/lexer.h @@ -16,6 +16,10 @@ typedef struct { Length token_start_position; Length token_end_position; + TSRange * included_ranges; + size_t included_range_count; + size_t current_included_range_index; + const char *chunk; uint32_t chunk_start; uint32_t chunk_size; @@ -27,10 +31,13 @@ typedef struct { } Lexer; void ts_lexer_init(Lexer *); +void ts_lexer_delete(Lexer *); void ts_lexer_set_input(Lexer *, TSInput); void ts_lexer_reset(Lexer *, Length); void ts_lexer_start(Lexer *); void ts_lexer_advance_to_end(Lexer *); +void ts_lexer_set_included_ranges(Lexer *self, const TSRange *ranges, uint32_t count); +TSRange *ts_lexer_included_ranges(const Lexer *self, uint32_t *count); #ifdef __cplusplus } diff --git a/src/runtime/parser.c b/src/runtime/parser.c index ed854ea0..cda33c58 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -1359,6 +1359,7 @@ void ts_parser_delete(TSParser *self) { ts_subtree_release(&self->tree_pool, self->old_tree); self->old_tree = NULL; } + ts_lexer_delete(&self->lexer); ts_parser__set_cached_token(self, 0, NULL, NULL); ts_subtree_pool_delete(&self->tree_pool); reusable_node_delete(&self->reusable_node); @@ -1419,6 +1420,14 @@ void ts_parser_set_operation_limit(TSParser *self, size_t limit) { self->operation_limit = limit; } +void ts_parser_set_included_ranges(TSParser *self, const TSRange *ranges, uint32_t count) { + ts_lexer_set_included_ranges(&self->lexer, ranges, count); +} + +const TSRange *ts_parser_included_ranges(const TSParser *self, uint32_t *count) { + return ts_lexer_included_ranges(&self->lexer, count); +} + void ts_parser_reset(TSParser *self) { if (self->language->external_scanner.deserialize) { self->language->external_scanner.deserialize(self->external_scanner_payload, NULL, 0); diff --git a/test/helpers/point_helpers.cc b/test/helpers/point_helpers.cc index 40dd67fd..fd6d8bb1 100644 --- a/test/helpers/point_helpers.cc +++ b/test/helpers/point_helpers.cc @@ -11,7 +11,12 @@ bool operator==(const TSPoint &left, const TSPoint &right) { } bool operator==(const TSRange &left, const TSRange &right) { - return left.start == right.start && left.end == right.end; + return ( + left.start_byte == right.start_byte && + left.end_byte == right.end_byte && + left.start_point == right.start_point && + left.end_point == right.end_point + ); } bool operator==(const Length &left, const Length &right) { @@ -34,7 +39,7 @@ std::ostream &operator<<(std::ostream &stream, const TSPoint &point) { } std::ostream &operator<<(std::ostream &stream, const TSRange &range) { - return stream << "{" << range.start << ", " << range.end << "}"; + return stream << "{" << range.start_point << ", " << range.end_point << "}"; } ostream &operator<<(ostream &stream, const Length &length) { diff --git a/test/helpers/scope_sequence.cc b/test/helpers/scope_sequence.cc index 8851f0c4..1121b80e 100644 --- a/test/helpers/scope_sequence.cc +++ b/test/helpers/scope_sequence.cc @@ -73,7 +73,7 @@ void verify_changed_ranges(const ScopeSequence &old_sequence, const ScopeSequenc bool found_containing_range = false; for (size_t j = 0; j < range_count; j++) { TSRange range = ranges[j]; - if (range.start <= current_position && current_position <= range.end) { + if (range.start_point <= current_position && current_position <= range.end_point) { found_containing_range = true; break; } diff --git a/test/runtime/parser_test.cc b/test/runtime/parser_test.cc index b9a41715..fd2697da 100644 --- a/test/runtime/parser_test.cc +++ b/test/runtime/parser_test.cc @@ -768,6 +768,113 @@ describe("Parser", [&]() { assert_root_node("(value (array (null) (number) (number) (number) (number)))"); }); }); + + describe("set_skipped_ranges", [&]() { + it("can parse code within a single range of a document", [&]() { + string source_code = "hi"; + + ts_parser_set_language(parser, load_real_language("html")); + TSTree *html_tree = ts_parser_parse_string(parser, nullptr, source_code.c_str(), source_code.size()); + TSNode script_content_node = ts_node_child( + ts_node_child(ts_tree_root_node(html_tree), 1), + 1 + ); + AssertThat(ts_node_type(script_content_node), Equals("raw_text")); + TSRange included_range = { + ts_node_start_point(script_content_node), + ts_node_end_point(script_content_node), + ts_node_start_byte(script_content_node), + ts_node_end_byte(script_content_node), + }; + ts_tree_delete(html_tree); + + ts_parser_set_included_ranges(parser, &included_range, 1); + ts_parser_set_language(parser, load_real_language("javascript")); + tree = ts_parser_parse_string(parser, nullptr, source_code.c_str(), source_code.size()); + + assert_root_node("(program (expression_statement (call_expression " + "(member_expression (identifier) (property_identifier)) " + "(arguments (string)))))"); + + AssertThat( + ts_node_start_point(ts_tree_root_node(tree)), + Equals({0, static_cast(source_code.find("console"))}) + ); + }); + + it("can parse code spread across multiple ranges in a document", [&]() { + string source_code = + "html `
Hello, ${name.toUpperCase()}, it's ${now()}.
`"; + + ts_parser_set_language(parser, load_real_language("javascript")); + TSTree *js_tree = ts_parser_parse_string(parser, nullptr, source_code.c_str(), source_code.size()); + TSNode root_node = ts_tree_root_node(js_tree); + TSNode string_node = ts_node_descendant_for_byte_range( + root_node, + source_code.find("
"), + source_code.find("Hell") + ); + TSNode open_quote_node = ts_node_child(string_node, 0); + TSNode interpolation_node1 = ts_node_child(string_node, 1); + TSNode interpolation_node2 = ts_node_child(string_node, 2); + TSNode close_quote_node = ts_node_child(string_node, 3); + + AssertThat(ts_node_type(string_node), Equals("template_string")); + AssertThat(ts_node_type(open_quote_node), Equals("`")); + AssertThat(ts_node_type(interpolation_node1), Equals("template_substitution")); + AssertThat(ts_node_type(interpolation_node2), Equals("template_substitution")); + AssertThat(ts_node_type(close_quote_node), Equals("`")); + ts_tree_delete(js_tree); + + TSRange included_ranges[] = { + { + ts_node_end_point(open_quote_node), + ts_node_start_point(interpolation_node1), + ts_node_end_byte(open_quote_node), + ts_node_start_byte(interpolation_node1), + }, + { + ts_node_end_point(interpolation_node1), + ts_node_start_point(interpolation_node2), + ts_node_end_byte(interpolation_node1), + ts_node_start_byte(interpolation_node2), + }, + { + ts_node_end_point(interpolation_node2), + ts_node_start_point(close_quote_node), + ts_node_end_byte(interpolation_node2), + ts_node_start_byte(close_quote_node), + } + }; + + ts_parser_set_included_ranges(parser, included_ranges, 3); + ts_parser_set_language(parser, load_real_language("html")); + tree = ts_parser_parse_string(parser, nullptr, source_code.c_str(), source_code.size()); + + assert_root_node("(fragment " + "(element " + "(start_tag (tag_name)) " + "(text) " + "(element " + "(start_tag (tag_name)) " + "(end_tag (tag_name))) " + "(text) " + "(end_tag (tag_name))))"); + + root_node = ts_tree_root_node(tree); + TSNode hello_text_node = ts_node_child(ts_node_child(root_node, 0), 1); + + AssertThat(ts_node_type(hello_text_node), Equals("text")); + AssertThat( + ts_node_start_point(hello_text_node), + Equals({0, static_cast(source_code.find("Hello"))}) + ); + AssertThat( + ts_node_end_point(hello_text_node), + Equals({0, static_cast(source_code.find(""))}) + ); + }); + }); }); END_TEST diff --git a/test/runtime/tree_test.cc b/test/runtime/tree_test.cc index b599f568..d703cd60 100644 --- a/test/runtime/tree_test.cc +++ b/test/runtime/tree_test.cc @@ -131,16 +131,22 @@ describe("Tree", [&]() { return result; }; + auto range_for_text = [&](string start_text, string end_text) { + return TSRange { + point(0, input->content.find(start_text)), + point(0, input->content.find(end_text)), + static_cast(input->content.find(start_text)), + static_cast(input->content.find(end_text)), + }; + }; + it("reports changes when one token has been updated", [&]() { // Replace `null` with `nothing` auto ranges = get_changed_ranges_for_edit([&]() { return input->replace(input->content.find("ull"), 1, "othing"); }); AssertThat(ranges, Equals(vector({ - TSRange{ - point(0, input->content.find("nothing")), - point(0, input->content.find("}")) - }, + range_for_text("nothing", "}"), }))); // Replace `nothing` with `null` again @@ -148,10 +154,7 @@ describe("Tree", [&]() { return input->undo(); }); AssertThat(ranges, Equals(vector({ - TSRange{ - point(0, input->content.find("null")), - point(0, input->content.find("}")) - }, + range_for_text("null", "}"), }))); }); @@ -192,10 +195,7 @@ describe("Tree", [&]() { return input->replace(input->content.find("}"), 0, ", b: false"); }); AssertThat(ranges, Equals(vector({ - TSRange{ - point(0, input->content.find(",")), - point(0, input->content.find("}")) - }, + range_for_text(",", "}"), }))); // Add a third key-value pair in between the first two @@ -209,10 +209,7 @@ describe("Tree", [&]() { "(pair (property_identifier) (false)))))" ); AssertThat(ranges, Equals(vector({ - TSRange{ - point(0, input->content.find(", c")), - point(0, input->content.find(", b")) - }, + range_for_text(", c", ", b"), }))); // Delete the middle pair. @@ -247,10 +244,7 @@ describe("Tree", [&]() { "(pair (property_identifier) (binary_expression (identifier) (null))))))" ); AssertThat(ranges, Equals(vector({ - TSRange{ - point(0, input->content.find("b ===")), - point(0, input->content.find("}")) - }, + range_for_text("b ===", "}"), }))); }); });