diff --git a/include/tree_sitter/runtime.h b/include/tree_sitter/runtime.h index 8b9a863e..be321720 100644 --- a/include/tree_sitter/runtime.h +++ b/include/tree_sitter/runtime.h @@ -45,6 +45,11 @@ typedef struct { size_t column; } TSPoint; +typedef struct { + TSPoint start; + TSPoint end; +} TSRange; + typedef struct { const void *data; size_t offset[3]; @@ -98,6 +103,7 @@ void ts_document_set_logger(TSDocument *, TSLogger); void ts_document_print_debugging_graphs(TSDocument *, bool); void ts_document_edit(TSDocument *, TSInputEdit); int ts_document_parse(TSDocument *); +int ts_document_parse_and_diff(TSDocument *, TSRange **, size_t *); void ts_document_invalidate(TSDocument *); TSNode ts_document_root_node(const TSDocument *); size_t ts_document_parse_count(const TSDocument *); diff --git a/spec/helpers/point_helpers.cc b/spec/helpers/point_helpers.cc index 97a444a9..f07d1da7 100644 --- a/spec/helpers/point_helpers.cc +++ b/spec/helpers/point_helpers.cc @@ -8,10 +8,18 @@ bool operator==(const TSPoint &left, const TSPoint &right) { return left.row == right.row && left.column == right.column; } +bool operator==(const TSRange &left, const TSRange &right) { + return left.start == right.start && left.end == right.end; +} + std::ostream &operator<<(std::ostream &stream, const TSPoint &point) { return stream << "{" << point.row << ", " << point.column << "}"; } +std::ostream &operator<<(std::ostream &stream, const TSRange &range) { + return stream << "{" << range.start << ", " << range.end << "}"; +} + bool operator<(const TSPoint &left, const TSPoint &right) { if (left.row < right.row) return true; if (left.row > right.row) return false; diff --git a/spec/helpers/point_helpers.h b/spec/helpers/point_helpers.h index 321f05ad..3caf14c2 100644 --- a/spec/helpers/point_helpers.h +++ b/spec/helpers/point_helpers.h @@ -7,6 +7,10 @@ bool operator<(const TSPoint &left, const TSPoint &right); bool operator>(const TSPoint &left, const TSPoint &right); +bool operator==(const TSRange &left, const TSRange &right); + std::ostream &operator<<(std::ostream &stream, const TSPoint &point); +std::ostream &operator<<(std::ostream &stream, const TSRange &range); + #endif // HELPERS_POINT_HELPERS_H_ diff --git a/spec/runtime/document_spec.cc b/spec/runtime/document_spec.cc index ec479b4e..9e393134 100644 --- a/spec/runtime/document_spec.cc +++ b/spec/runtime/document_spec.cc @@ -3,6 +3,7 @@ #include "helpers/record_alloc.h" #include "helpers/stream_methods.h" #include "helpers/tree_helpers.h" +#include "helpers/point_helpers.h" #include "helpers/spy_logger.h" #include "helpers/spy_input.h" #include "helpers/load_language.h" @@ -192,6 +193,140 @@ describe("Document", [&]() { }); }); }); + + describe("parse_and_diff()", [&]() { + SpyInput *input; + + before_each([&]() { + ts_document_set_language(doc, get_test_language("javascript")); + input = new SpyInput("{a: null};", 3); + ts_document_set_input(doc, input->input()); + ts_document_parse(doc); + assert_node_string_equals( + ts_document_root_node(doc), + "(program (expression_statement (object (pair (identifier) (null)))))"); + }); + + after_each([&]() { + delete input; + }); + + auto get_ranges = [&](std::function callback) -> vector { + TSInputEdit edit = callback(); + ts_document_edit(doc, edit); + + TSRange *ranges; + size_t range_count = 0; + ts_document_parse_and_diff(doc, &ranges, &range_count); + + vector result; + for (size_t i = 0; i < range_count; i++) + result.push_back(ranges[i]); + ts_free(ranges); + + return result; + }; + + it("reports changes when one token has been updated", [&]() { + // Replace `null` with `nothing` + auto ranges = get_ranges([&]() { + return input->replace(input->content.find("ull"), 1, "othing"); + }); + + AssertThat(ranges, Equals(vector({ + TSRange{ + TSPoint{0, input->content.find("nothing")}, + TSPoint{0, input->content.find("}")} + }, + }))); + + // Replace `nothing` with `null` again + ranges = get_ranges([&]() { + return input->undo(); + }); + + AssertThat(ranges, Equals(vector({ + TSRange{ + TSPoint{0, input->content.find("null")}, + TSPoint{0, input->content.find("}")} + }, + }))); + }); + + it("reports changes when tokens have been appended", [&]() { + // Add a second key-value pair + auto ranges = get_ranges([&]() { + return input->replace(input->content.find("}"), 0, ", b: false"); + }); + + AssertThat(ranges, Equals(vector({ + TSRange{ + TSPoint{0, input->content.find(",")}, + TSPoint{0, input->content.find("}")}, + }, + }))); + + // Add a third key-value pair in between the first two + ranges = get_ranges([&]() { + return input->replace(input->content.find(", b"), 0, ", c: 1"); + }); + + assert_node_string_equals( + ts_document_root_node(doc), + "(program (expression_statement (object " + "(pair (identifier) (null)) " + "(pair (identifier) (number)) " + "(pair (identifier) (false)))))"); + + AssertThat(ranges, Equals(vector({ + TSRange{ + TSPoint{0, input->content.find(", c")}, + TSPoint{0, input->content.find(", b")}, + }, + }))); + + // Delete the middle pair. + ranges = get_ranges([&]() { + return input->undo(); + }); + + assert_node_string_equals( + ts_document_root_node(doc), + "(program (expression_statement (object " + "(pair (identifier) (null)) " + "(pair (identifier) (false)))))"); + + AssertThat(ranges, Equals(vector({ + }))); + + // Delete the second pair. + ranges = get_ranges([&]() { + return input->undo(); + }); + + assert_node_string_equals( + ts_document_root_node(doc), + "(program (expression_statement (object " + "(pair (identifier) (null)))))"); + + AssertThat(ranges, Equals(vector({ + }))); + }); + + it("reports changes when trees have been wrapped", [&]() { + // Wrap the object in an assignment expression. + auto ranges = get_ranges([&]() { + return input->replace(0, 0, "x.y = "); + }); + + AssertThat(ranges, Equals(vector({ + TSRange{ + TSPoint{0, 0}, + TSPoint{0, input->content.find(";")}, + }, + }))); + }); + }); }); END_TEST diff --git a/src/runtime/document.c b/src/runtime/document.c index 1211eb5e..22e4bb5b 100644 --- a/src/runtime/document.c +++ b/src/runtime/document.c @@ -89,7 +89,117 @@ void ts_document_edit(TSDocument *self, TSInputEdit edit) { ts_tree_edit(self->tree, edit); } -int ts_document_parse(TSDocument *self) { +typedef Array(TSRange) RangeArray; + +#define NAME(t) ((t) ? (ts_language_symbol_name(doc->parser.language, ((TSTree *)(t))->symbol)) : "") +// #define PRINT(msg, ...) for (size_t k = 0; k < depth; k++) { printf(" "); } printf(msg "\n", __VA_ARGS__); +#define PRINT(msg, ...) + +static bool push_diff(RangeArray *results, TSNode *node, bool *extend_last_change) { + TSPoint start = ts_node_start_point(*node); + TSPoint end = ts_node_end_point(*node); + if (*extend_last_change) { + TSRange *last_range = array_back(results); + last_range->end = end; + return true; + } + *extend_last_change = true; + return array_push(results, ((TSRange){start, end})); +} + +static bool ts_tree_diff(TSDocument *doc, TSTree *old, TSNode *new_node, + size_t depth, RangeArray *results, bool *extend_last_change) { + TSTree *new = (TSTree *)(new_node->data); + + PRINT("At %lu, ('%s', %lu) vs ('%s', %lu) {", + ts_node_start_char(*new_node), + NAME(old), old->size.chars, + NAME(new), new->size.chars); + + if (old->visible) { + if (old == new || (old->symbol == new->symbol && + old->size.chars == new->size.chars && !old->has_changes)) { + *extend_last_change = false; + PRINT("}", NULL); + return true; + } + + if (old->symbol != new->symbol) { + PRINT("}", NULL); + return push_diff(results, new_node, extend_last_change); + } + + TSNode child = ts_node_child(*new_node, 0); + if (child.data) { + *new_node = child; + } else { + PRINT("}", NULL); + return true; + } + } + + depth++; + size_t old_child_start; + size_t old_child_end = ts_node_start_char(*new_node) - old->padding.chars; + + for (size_t j = 0; j < old->child_count; j++) { + TSTree *old_child = old->children[j]; + if (old_child->padding.chars == 0 && old_child->size.chars == 0) + continue; + + old_child_start = old_child_end + old_child->padding.chars; + old_child_end = old_child_start + old_child->size.chars; + + while (true) { + size_t new_child_start = ts_node_start_char(*new_node); + if (new_child_start < old_child_start) { + PRINT("skip new:('%s', %lu), old:('%s', %lu), old_parent:%s", + NAME(new_node->data), ts_node_start_char(*new_node), NAME(old_child), + old_child_start, NAME(old)); + + if (!push_diff(results, new_node, extend_last_change)) + return false; + + TSNode next = ts_node_next_sibling(*new_node); + if (next.data) { + PRINT("advance before diff ('%s', %lu) -> ('%s', %lu)", + NAME(new_node->data), ts_node_start_char(*new_node), NAME(next.data), + ts_node_start_char(next)); + *new_node = next; + } + } else if (new_child_start == old_child_start) { + if (!ts_tree_diff(doc, old_child, new_node, depth, results, extend_last_change)) + return false; + + if (old_child->visible) { + TSNode next = ts_node_next_sibling(*new_node); + if (next.data) { + PRINT("advance after diff ('%s', %lu) -> ('%s', %lu)", + NAME(new_node->data), ts_node_start_char(*new_node), NAME(next.data), + ts_node_start_char(next)); + *new_node = next; + } + } + break; + } else { + break; + } + } + } + + depth--; + if (old->visible) { + *new_node = ts_node_parent(*new_node); + } + + PRINT("}", NULL); + return true; +} + +int ts_document_parse_and_diff(TSDocument *self, TSRange **ranges, size_t *range_count) { + if (ranges) *ranges = NULL; + if (range_count) *range_count = 0; + if (!self->input.read || !self->parser.language) return -1; @@ -101,14 +211,36 @@ int ts_document_parse(TSDocument *self) { if (!tree) return -1; - if (self->tree) - ts_tree_release(self->tree); + if (self->tree) { + TSTree *old_tree = self->tree; + self->tree = tree; + TSNode new_root = ts_document_root_node(self); + + // ts_tree_print_dot_graph(old_tree, self->parser.language, stderr); + // ts_tree_print_dot_graph(tree, self->parser.language, stderr); + + if (ranges && range_count) { + bool extend_last_change = false; + RangeArray result = {0, 0, 0}; + if (!ts_tree_diff(self, old_tree, &new_root, 0, &result, &extend_last_change)) + return -1; + *ranges = result.contents; + *range_count = result.size; + } + + ts_tree_release(old_tree); + } + self->tree = tree; self->parse_count++; self->valid = true; return 0; } +int ts_document_parse(TSDocument *self) { + return ts_document_parse_and_diff(self, NULL, NULL); +} + void ts_document_invalidate(TSDocument *self) { self->valid = false; }