diff --git a/src/runtime/node.c b/src/runtime/node.c index 3c858b92..607cf9de 100644 --- a/src/runtime/node.c +++ b/src/runtime/node.c @@ -125,6 +125,7 @@ static inline uint32_t ts_node__relevant_child_count(TSNode self, bool include_a } static inline TSNode ts_node__child(TSNode self, uint32_t child_index, bool include_anonymous) { + const TSTree *tree = ts_node__tree(&self); TSNode result = self; bool did_descend = true; @@ -136,7 +137,10 @@ static inline TSNode ts_node__child(TSNode self, uint32_t child_index, bool incl ChildIterator iterator = ts_node_iterate_children(&result); while (ts_node_child_iterator_next(&iterator, &child)) { if (ts_node__is_relevant(child, include_anonymous)) { - if (index == child_index) return child; + if (index == child_index) { + ts_tree_set_cached_parent(tree, &child, &self); + return child; + } index++; } else { uint32_t grandchild_index = child_index - index; @@ -290,6 +294,7 @@ static inline TSNode ts_node__descendant_for_byte_range(TSNode self, uint32_t mi bool include_anonymous) { TSNode node = self; TSNode last_visible_node = self; + const TSTree *tree = ts_node__tree(&self); bool did_descend = true; while (did_descend) { @@ -301,7 +306,10 @@ static inline TSNode ts_node__descendant_for_byte_range(TSNode self, uint32_t mi if (iterator.position.bytes > max) { if (ts_node_start_byte(child) > min) break; node = child; - if (ts_node__is_relevant(node, include_anonymous)) last_visible_node = node; + if (ts_node__is_relevant(node, include_anonymous)) { + ts_tree_set_cached_parent(tree, &child, &last_visible_node); + last_visible_node = node; + } did_descend = true; break; } @@ -316,8 +324,7 @@ static inline TSNode ts_node__descendant_for_point_range(TSNode self, TSPoint mi bool include_anonymous) { TSNode node = self; TSNode last_visible_node = self; - TSPoint start_position = ts_node_start_point(self); - TSPoint end_position = ts_node_end_point(self); + const TSTree *tree = ts_node__tree(&self); bool did_descend = true; while (did_descend) { @@ -326,19 +333,16 @@ static inline TSNode ts_node__descendant_for_point_range(TSNode self, TSPoint mi TSNode child; ChildIterator iterator = ts_node_iterate_children(&node); while (ts_node_child_iterator_next(&iterator, &child)) { - const Subtree *child_tree = ts_node__subtree(child); - if (iterator.child_index != 1) { - start_position = point_add(start_position, child_tree->padding.extent); - } - end_position = point_add(start_position, child_tree->size.extent); - if (point_gt(end_position, max)) { - if (point_gt(start_position, min)) break; + if (point_gt(iterator.position.extent, max)) { + if (point_gt(ts_node_start_point(child), min)) break; node = child; - if (ts_node__is_relevant(node, include_anonymous)) last_visible_node = node; + if (ts_node__is_relevant(node, include_anonymous)) { + ts_tree_set_cached_parent(tree, &child, &last_visible_node); + last_visible_node = node; + } did_descend = true; break; } - start_position = end_position; } } @@ -397,7 +401,11 @@ bool ts_node_has_error(TSNode self) { } TSNode ts_node_parent(TSNode self) { - TSNode node = ts_tree_root_node(ts_node__tree(&self)); + const TSTree *tree = ts_node__tree(&self); + TSNode node = ts_tree_get_cached_parent(tree, &self); + if (node.id) return node; + + node = ts_tree_root_node(tree); uint32_t end_byte = ts_node_end_byte(self); if (ts_node__subtree(node) == ts_node__subtree(self)) return ts_node__null(); @@ -416,6 +424,7 @@ TSNode ts_node_parent(TSNode self) { if (iterator.position.bytes >= end_byte) { node = child; if (ts_node__is_relevant(child, true)) { + ts_tree_set_cached_parent(tree, &node, &last_visible_node); last_visible_node = node; } did_descend = true; diff --git a/src/runtime/tree.c b/src/runtime/tree.c index 9d7c36fc..72fba7a8 100644 --- a/src/runtime/tree.c +++ b/src/runtime/tree.c @@ -5,10 +5,15 @@ #include "runtime/tree_cursor.h" #include "runtime/tree.h" +static const unsigned PARENT_CACHE_CAPACITY = 32; + TSTree *ts_tree_new(const Subtree *root, const TSLanguage *language) { TSTree *result = ts_malloc(sizeof(TSTree)); result->root = root; result->language = language; + result->parent_cache = NULL; + result->parent_cache_start = 0; + result->parent_cache_size = 0; return result; } @@ -21,6 +26,7 @@ void ts_tree_delete(TSTree *self) { SubtreePool pool = ts_subtree_pool_new(0); ts_subtree_release(&pool, self->root); ts_subtree_pool_delete(&pool); + if (self->parent_cache) ts_free(self->parent_cache); ts_free(self); } @@ -51,3 +57,38 @@ TSRange *ts_tree_get_changed_ranges(const TSTree *self, const TSTree *other, uin void ts_tree_print_dot_graph(const TSTree *self, FILE *file) { ts_subtree_print_dot_graph(self->root, self->language, file); } + +TSNode ts_tree_get_cached_parent(const TSTree *self, const TSNode *node) { + for (uint32_t i = 0; i < self->parent_cache_size; i++) { + uint32_t index = (self->parent_cache_start + i) % PARENT_CACHE_CAPACITY; + ParentCacheEntry *entry = &self->parent_cache[index]; + if (entry->child == node->id) { + return ts_node_new(self, entry->parent, entry->position, entry->alias_symbol); + } + } + return ts_node_new(NULL, NULL, length_zero(), 0); +} + +void ts_tree_set_cached_parent(const TSTree *_self, const TSNode *node, const TSNode *parent) { + TSTree *self = (TSTree *)_self; + if (!self->parent_cache) { + self->parent_cache = ts_calloc(PARENT_CACHE_CAPACITY, sizeof(ParentCacheEntry)); + } + + uint32_t index = (self->parent_cache_start + self->parent_cache_size) % PARENT_CACHE_CAPACITY; + self->parent_cache[index] = (ParentCacheEntry) { + .child = node->id, + .parent = parent->id, + .position = { + parent->context[0], + {parent->context[1], parent->context[2]} + }, + .alias_symbol = parent->context[3], + }; + + if (self->parent_cache_size == PARENT_CACHE_CAPACITY) { + self->parent_cache_start++; + } else { + self->parent_cache_size++; + } +} diff --git a/src/runtime/tree.h b/src/runtime/tree.h index 50ae8490..99481d88 100644 --- a/src/runtime/tree.h +++ b/src/runtime/tree.h @@ -5,13 +5,25 @@ extern "C" { #endif +typedef struct { + const Subtree *child; + const Subtree *parent; + Length position; + TSSymbol alias_symbol; +} ParentCacheEntry; + struct TSTree { const Subtree *root; const TSLanguage *language; + ParentCacheEntry *parent_cache; + uint32_t parent_cache_start; + uint32_t parent_cache_size; }; TSTree *ts_tree_new(const Subtree *root, const TSLanguage *language); TSNode ts_node_new(const TSTree *, const Subtree *, Length, TSSymbol); +TSNode ts_tree_get_cached_parent(const TSTree *, const TSNode *); +void ts_tree_set_cached_parent(const TSTree *, const TSNode *, const TSNode *); #ifdef __cplusplus } diff --git a/test/runtime/node_test.cc b/test/runtime/node_test.cc index 8683a9ef..f1097916 100644 --- a/test/runtime/node_test.cc +++ b/test/runtime/node_test.cc @@ -438,6 +438,13 @@ describe("Node", [&]() { AssertThat(ts_node_end_byte(leaf), Equals(number_end_index)); AssertThat(ts_node_start_point(leaf), Equals({ 3, 2 })); AssertThat(ts_node_end_point(leaf), Equals({ 3, 5 })); + + TSNode parent = ts_node_parent(leaf); + AssertThat(ts_node_type(parent), Equals("array")); + AssertThat(ts_node_start_byte(parent), Equals(array_index)); + parent = ts_node_parent(parent); + AssertThat(ts_node_type(parent), Equals("value")); + AssertThat(ts_node_start_byte(parent), Equals(array_index)); }); }); @@ -495,6 +502,8 @@ describe("Node", [&]() { AssertThat(ts_node_end_byte(node2), Equals(null_end_index)); AssertThat(ts_node_start_point(node2), Equals({ 6, 4 })); AssertThat(ts_node_end_point(node2), Equals({ 6, 13 })); + + AssertThat(ts_node_parent(node1), Equals(node2)); }); it("works in the presence of multi-byte characters", [&]() { @@ -530,6 +539,8 @@ describe("Node", [&]() { AssertThat(ts_node_end_byte(node2), Equals(null_end_index)); AssertThat(ts_node_start_point(node2), Equals({ 6, 4 })); AssertThat(ts_node_end_point(node2), Equals({ 6, 13 })); + + AssertThat(ts_node_parent(node1), Equals(node2)); }); }); });