diff --git a/cli/src/tests/node_test.rs b/cli/src/tests/node_test.rs
index d798e726..fdabc499 100644
--- a/cli/src/tests/node_test.rs
+++ b/cli/src/tests/node_test.rs
@@ -169,24 +169,22 @@ fn test_node_child() {
assert_eq!(tree.root_node().parent(), None);
assert_eq!(
- tree.root_node()
- .child_containing_descendant(null_node)
- .unwrap(),
+ tree.root_node().child_with_descendant(null_node).unwrap(),
array_node
);
assert_eq!(
- array_node.child_containing_descendant(null_node).unwrap(),
+ array_node.child_with_descendant(null_node).unwrap(),
object_node
);
assert_eq!(
- object_node.child_containing_descendant(null_node).unwrap(),
+ object_node.child_with_descendant(null_node).unwrap(),
pair_node
);
assert_eq!(
- pair_node.child_containing_descendant(null_node).unwrap(),
+ pair_node.child_with_descendant(null_node).unwrap(),
null_node
);
- assert_eq!(null_node.child_containing_descendant(null_node), None);
+ assert_eq!(null_node.child_with_descendant(null_node), None);
}
#[test]
@@ -288,16 +286,14 @@ fn test_parent_of_zero_width_node() {
assert_eq!(block_parent.to_string(), "(function_definition name: (identifier) parameters: (parameters (identifier)) body: (block))");
assert_eq!(
- root.child_containing_descendant(block).unwrap(),
+ root.child_with_descendant(block).unwrap(),
function_definition
);
assert_eq!(
- function_definition
- .child_containing_descendant(block)
- .unwrap(),
+ function_definition.child_with_descendant(block).unwrap(),
block
);
- assert_eq!(block.child_containing_descendant(block), None);
+ assert_eq!(block.child_with_descendant(block), None);
let code = "";
parser.set_language(&get_language("html")).unwrap();
@@ -477,24 +473,22 @@ fn test_node_named_child() {
assert_eq!(tree.root_node().parent(), None);
assert_eq!(
- tree.root_node()
- .child_containing_descendant(null_node)
- .unwrap(),
+ tree.root_node().child_with_descendant(null_node).unwrap(),
array_node
);
assert_eq!(
- array_node.child_containing_descendant(null_node).unwrap(),
+ array_node.child_with_descendant(null_node).unwrap(),
object_node
);
assert_eq!(
- object_node.child_containing_descendant(null_node).unwrap(),
+ object_node.child_with_descendant(null_node).unwrap(),
pair_node
);
assert_eq!(
- pair_node.child_containing_descendant(null_node).unwrap(),
+ pair_node.child_with_descendant(null_node).unwrap(),
null_node
);
- assert_eq!(null_node.child_containing_descendant(null_node), None);
+ assert_eq!(null_node.child_with_descendant(null_node), None);
}
#[test]
diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs
index 445de081..6b69a099 100644
--- a/lib/binding_rust/bindings.rs
+++ b/lib/binding_rust/bindings.rs
@@ -1,4 +1,4 @@
-/* automatically generated by rust-bindgen 0.70.0 */
+/* automatically generated by rust-bindgen 0.70.1 */
pub const TREE_SITTER_LANGUAGE_VERSION: u32 = 14;
pub const TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION: u32 = 13;
@@ -359,9 +359,13 @@ extern "C" {
pub fn ts_node_parent(self_: TSNode) -> TSNode;
}
extern "C" {
- #[doc = " Get the node's child that contains `descendant`."]
+ #[doc = " @deprecated use [`ts_node_contains_descendant`] instead, this will be removed in 0.25\n\n Get the node's child containing `descendant`. This will not return\n the descendant if it is a direct child of `self`, for that use\n `ts_node_contains_descendant`."]
pub fn ts_node_child_containing_descendant(self_: TSNode, descendant: TSNode) -> TSNode;
}
+extern "C" {
+ #[doc = " Get the node that contains `descendant`.\n\n Note that this can return `descendant` itself, unlike the deprecated function\n [`ts_node_child_containing_descendant`]."]
+ pub fn ts_node_child_with_descendant(self_: TSNode, descendant: TSNode) -> TSNode;
+}
extern "C" {
#[doc = " Get the node's child at the given index, where zero represents the first\n child."]
pub fn ts_node_child(self_: TSNode, child_index: u32) -> TSNode;
diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs
index 6229b3b0..bba21bea 100644
--- a/lib/binding_rust/lib.rs
+++ b/lib/binding_rust/lib.rs
@@ -1359,13 +1359,26 @@ impl<'tree> Node<'tree> {
Self::new(unsafe { ffi::ts_node_parent(self.0) })
}
- /// Get this node's child that contains `descendant`.
+ /// Get this node's child containing `descendant`. This will not return
+ /// the descendant if it is a direct child of `self`, for that use
+ /// [`Node::child_contains_descendant`].
#[doc(alias = "ts_node_child_containing_descendant")]
#[must_use]
+ #[deprecated(since = "0.24.0", note = "Prefer child_with_descendant instead")]
pub fn child_containing_descendant(&self, descendant: Self) -> Option {
Self::new(unsafe { ffi::ts_node_child_containing_descendant(self.0, descendant.0) })
}
+ /// Get the node that contains `descendant`.
+ ///
+ /// Note that this can return `descendant` itself, unlike the deprecated function
+ /// [`Node::child_containing_descendant`].
+ #[doc(alias = "ts_node_child_with_descendant")]
+ #[must_use]
+ pub fn child_with_descendant(&self, descendant: Self) -> Option {
+ Self::new(unsafe { ffi::ts_node_child_with_descendant(self.0, descendant.0) })
+ }
+
/// Get this node's next sibling.
#[doc(alias = "ts_node_next_sibling")]
#[must_use]
diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h
index 4c19bbdf..8bfc4843 100644
--- a/lib/include/tree_sitter/api.h
+++ b/lib/include/tree_sitter/api.h
@@ -554,10 +554,22 @@ TSStateId ts_node_next_parse_state(TSNode self);
TSNode ts_node_parent(TSNode self);
/**
- * Get the node's child that contains `descendant`.
+ * @deprecated use [`ts_node_contains_descendant`] instead, this will be removed in 0.25
+ *
+ * Get the node's child containing `descendant`. This will not return
+ * the descendant if it is a direct child of `self`, for that use
+ * `ts_node_contains_descendant`.
*/
TSNode ts_node_child_containing_descendant(TSNode self, TSNode descendant);
+/**
+ * Get the node that contains `descendant`.
+ *
+ * Note that this can return `descendant` itself, unlike the deprecated function
+ * [`ts_node_child_containing_descendant`].
+ */
+TSNode ts_node_child_with_descendant(TSNode self, TSNode descendant);
+
/**
* Get the node's child at the given index, where zero represents the first
* child.
diff --git a/lib/src/node.c b/lib/src/node.c
index 3f07e442..818735a1 100644
--- a/lib/src/node.c
+++ b/lib/src/node.c
@@ -550,16 +550,54 @@ TSNode ts_node_parent(TSNode self) {
while (true) {
TSNode next_node = ts_node_child_containing_descendant(node, self);
- if (next_node.id == self.id) break;
+ if (ts_node_is_null(next_node)) break;
node = next_node;
}
return node;
}
-TSNode ts_node_child_containing_descendant(TSNode self, TSNode subnode) {
- uint32_t start_byte = ts_node_start_byte(subnode);
- uint32_t end_byte = ts_node_end_byte(subnode);
+TSNode ts_node_child_containing_descendant(TSNode self, TSNode descendant) {
+ uint32_t start_byte = ts_node_start_byte(descendant);
+ uint32_t end_byte = ts_node_end_byte(descendant);
+
+ do {
+ NodeChildIterator iter = ts_node_iterate_children(&self);
+ do {
+ if (
+ !ts_node_child_iterator_next(&iter, &self)
+ || ts_node_start_byte(self) > start_byte
+ || self.id == descendant.id
+ ) {
+ return ts_node__null();
+ }
+
+ // Here we check the current self node and *all* of its zero-width token siblings that follow.
+ // If any of these nodes contain the target subnode, we return that node. Otherwise, we restore the node we started at
+ // for the loop condition, and that will continue with the next *non-zero-width* sibling.
+ TSNode old = self;
+ // While the next sibling is a zero-width token
+ while (ts_node_child_iterator_next_sibling_is_empty_adjacent(&iter, self)) {
+ TSNode current_node = ts_node_child_containing_descendant(self, descendant);
+ // If the target child is in self, return it
+ if (!ts_node_is_null(current_node)) {
+ return current_node;
+ }
+ ts_node_child_iterator_next(&iter, &self);
+ if (self.id == descendant.id) {
+ return ts_node__null();
+ }
+ }
+ self = old;
+ } while (iter.position.bytes < end_byte || ts_node_child_count(self) == 0);
+ } while (!ts_node__is_relevant(self, true));
+
+ return self;
+}
+
+TSNode ts_node_child_with_descendant(TSNode self, TSNode descendant) {
+ uint32_t start_byte = ts_node_start_byte(descendant);
+ uint32_t end_byte = ts_node_end_byte(descendant);
do {
NodeChildIterator iter = ts_node_iterate_children(&self);
@@ -570,7 +608,7 @@ TSNode ts_node_child_containing_descendant(TSNode self, TSNode subnode) {
) {
return ts_node__null();
}
- if (self.id == subnode.id) {
+ if (self.id == descendant.id) {
return self;
}
@@ -580,13 +618,13 @@ TSNode ts_node_child_containing_descendant(TSNode self, TSNode subnode) {
TSNode old = self;
// While the next sibling is a zero-width token
while (ts_node_child_iterator_next_sibling_is_empty_adjacent(&iter, self)) {
- TSNode current_node = ts_node_child_containing_descendant(self, subnode);
+ TSNode current_node = ts_node_child_with_descendant(self, descendant);
// If the target child is in self, return it
if (!ts_node_is_null(current_node)) {
return current_node;
}
ts_node_child_iterator_next(&iter, &self);
- if (self.id == subnode.id) {
+ if (self.id == descendant.id) {
return self;
}
}