diff --git a/spec/runtime/node_spec.cc b/spec/runtime/node_spec.cc index 8e89ebb5..a8e48d46 100644 --- a/spec/runtime/node_spec.cc +++ b/spec/runtime/node_spec.cc @@ -65,23 +65,27 @@ describe("Node", []() { describe("next_sibling() and prev_sibling()", [&]() { it("returns the node's next and previous siblings", [&]() { - TSNode array = ts_node_child(root, 0); - TSNode number1 = ts_node_child(array, 0); - TSNode number2 = ts_node_child(array, 1); - TSNode number3 = ts_node_child(array, 2); + TSNode array_node = ts_node_child(root, 0); + TSNode number_node = ts_node_child(array_node, 0); + TSNode false_node = ts_node_child(array_node, 1); + TSNode object_node = ts_node_child(array_node, 2); + TSNode string_node = ts_node_child(object_node, 0); + TSNode null_node = ts_node_child(object_node, 1); - AssertThat(ts_node_next_sibling(number1), Equals(number2)); - AssertThat(ts_node_next_sibling(number2), Equals(number3)); - AssertThat(ts_node_prev_sibling(number3), Equals(number2)); - AssertThat(ts_node_prev_sibling(number2), Equals(number1)); + AssertThat(ts_node_next_sibling(number_node), Equals(false_node)); + AssertThat(ts_node_next_sibling(false_node), Equals(object_node)); + AssertThat(ts_node_next_sibling(string_node), Equals(null_node)); + AssertThat(ts_node_prev_sibling(object_node), Equals(false_node)); + AssertThat(ts_node_prev_sibling(false_node), Equals(number_node)); + AssertThat(ts_node_prev_sibling(null_node), Equals(string_node)); }); it("returns null when the node has no parent", [&]() { - TSNode array = ts_node_child(root, 0); + TSNode array_node = ts_node_child(root, 0); AssertThat(ts_node_next_sibling(root).data, Equals(nullptr)); AssertThat(ts_node_prev_sibling(root).data, Equals(nullptr)); - AssertThat(ts_node_next_sibling(array).data, Equals(nullptr)); - AssertThat(ts_node_prev_sibling(array).data, Equals(nullptr)); + AssertThat(ts_node_next_sibling(array_node).data, Equals(nullptr)); + AssertThat(ts_node_prev_sibling(array_node).data, Equals(nullptr)); }); }); diff --git a/src/runtime/node.c b/src/runtime/node.c index 86f450d4..dec994bb 100644 --- a/src/runtime/node.c +++ b/src/runtime/node.c @@ -33,53 +33,72 @@ const char *ts_node_string(TSNode this, const TSDocument *document) { return ts_tree_string(get_tree(this), document->parser.language->symbol_names); } -typedef struct { - TSNode node; - size_t index; -} NodeWithIndex; - -static inline NodeWithIndex ts_node_parent_with_index(TSNode this) { +TSNode ts_node_parent(TSNode this) { TSLength position = this.position; const TSTree *tree = get_tree(this); - size_t index = 0; do { - TSTree *parent = tree->parent; + TSTree *parent = tree->context.parent; if (!parent) - return (NodeWithIndex){ ts_node_null(), 0 }; + return ts_node_null(); for (size_t i = 0; i < parent->child_count; i++) { TSTree *child = parent->children[i]; if (child == tree) break; - index += ts_tree_is_visible(child) ? 1 : child->visible_child_count; position = ts_length_sub(position, ts_tree_total_size(child)); } tree = parent; } while (!ts_tree_is_visible(tree)); - return (NodeWithIndex){ ts_node_make(tree, position), index }; -} - -TSNode ts_node_parent(TSNode this) { - return ts_node_parent_with_index(this).node; + return ts_node_make(tree, position); } TSNode ts_node_prev_sibling(TSNode this) { - NodeWithIndex parent = ts_node_parent_with_index(this); - if (parent.node.data && parent.index > 0) - return ts_node_child(parent.node, parent.index - 1); - else - return ts_node_null(); + const TSTree *tree = get_tree(this); + TSLength position = this.position; + do { + TSTree *parent = tree->context.parent; + if (!parent) + break; + + for (size_t i = tree->context.index - 1; i + 1 > 0; i--) { + const TSTree *child = parent->children[i]; + position = ts_length_sub(position, ts_tree_total_size(child)); + if (ts_tree_is_visible(child)) + return ts_node_make(child, position); + if (child->visible_child_count > 0) + return ts_node_child(ts_node_make(child, position), child->visible_child_count - 1); + } + + tree = parent; + } while (!ts_tree_is_visible(tree)); + + return ts_node_null(); } TSNode ts_node_next_sibling(TSNode this) { - NodeWithIndex parent = ts_node_parent_with_index(this); - if (parent.node.data) - return ts_node_child(parent.node, parent.index + 1); - else - return ts_node_null(); + const TSTree *tree = get_tree(this); + TSLength position = this.position; + do { + TSTree *parent = tree->context.parent; + if (!parent) + break; + + for (size_t i = tree->context.index + 1; i < parent->child_count; i++) { + const TSTree *child = parent->children[i]; + position = ts_length_add(position, ts_tree_total_size(parent->children[i - 1])); + if (ts_tree_is_visible(child)) + return ts_node_make(child, position); + if (child->visible_child_count > 0) + return ts_node_child(ts_node_make(child, position), 0); + } + + tree = parent; + } while (!ts_tree_is_visible(tree)); + + return ts_node_null(); } size_t ts_node_child_count(TSNode this) { diff --git a/src/runtime/tree.c b/src/runtime/tree.c index 7214b5d7..da22dc0e 100644 --- a/src/runtime/tree.c +++ b/src/runtime/tree.c @@ -42,7 +42,8 @@ TSTree *ts_tree_make_node(TSSymbol symbol, size_t child_count, for (size_t i = 0; i < child_count; i++) { TSTree *child = children[i]; ts_tree_retain(child); - child->parent = result; + child->context.parent = result; + child->context.index = i; if (i == 0) { padding = child->padding; @@ -78,7 +79,7 @@ TSTree *ts_tree_make_node(TSSymbol symbol, size_t child_count, } *result = (TSTree){.ref_count = 1, - .parent = NULL, + .context = {.parent = NULL, .index = 0}, .symbol = symbol, .children = children, .child_count = child_count, diff --git a/src/runtime/tree.h b/src/runtime/tree.h index d107f5e7..9b463ed8 100644 --- a/src/runtime/tree.h +++ b/src/runtime/tree.h @@ -17,7 +17,10 @@ typedef enum { } TSTreeOptions; struct TSTree { - struct TSTree *parent; + struct { + struct TSTree *parent; + size_t index; + } context; size_t child_count; size_t visible_child_count; union {