From 7e3f57265549f26f4fe3ac1ee8ee3b1c6ee182f4 Mon Sep 17 00:00:00 2001 From: Amaan Qureshi Date: Sun, 8 Sep 2024 20:49:13 -0400 Subject: [PATCH] feat: add `field_name_for_named_child` --- cli/src/tests/node_test.rs | 45 +++++++++++++++++++++++++++++++++++ lib/binding_rust/bindings.rs | 9 ++++++- lib/binding_rust/lib.rs | 8 +++++++ lib/include/tree_sitter/api.h | 6 +++++ lib/src/node.c | 42 ++++++++++++++++++++++++++++++++ 5 files changed, 109 insertions(+), 1 deletion(-) diff --git a/cli/src/tests/node_test.rs b/cli/src/tests/node_test.rs index e05ed932..259a2f81 100644 --- a/cli/src/tests/node_test.rs +++ b/cli/src/tests/node_test.rs @@ -308,6 +308,13 @@ fn test_node_field_name_for_child() { .child_by_field_name("value") .unwrap(); + // ------------------- + // left: (identifier) 0 + // operator: "+" 1 <--- (not a named child) + // (comment) 2 <--- (is an extra) + // right: (identifier) 3 + // ------------------- + assert_eq!(binary_expression_node.field_name_for_child(0), Some("left")); assert_eq!( binary_expression_node.field_name_for_child(1), @@ -323,6 +330,44 @@ fn test_node_field_name_for_child() { assert_eq!(binary_expression_node.field_name_for_child(4), None); } +#[test] +fn test_node_field_name_for_named_child() { + let mut parser = Parser::new(); + parser.set_language(&get_language("c")).unwrap(); + let tree = parser + .parse("int w = x + /* y is special! */ y;", None) + .unwrap(); + let translation_unit_node = tree.root_node(); + let declaration_node = translation_unit_node.named_child(0).unwrap(); + + let binary_expression_node = declaration_node + .child_by_field_name("declarator") + .unwrap() + .child_by_field_name("value") + .unwrap(); + + // ------------------- + // left: (identifier) 0 + // operator: "+" _ <--- (not a named child) + // (comment) 1 <--- (is an extra) + // right: (identifier) 2 + // ------------------- + + assert_eq!( + binary_expression_node.field_name_for_named_child(0), + Some("left") + ); + // The comment should not have a field name, as it's just an extra + assert_eq!(binary_expression_node.field_name_for_named_child(1), None); + // The operator is not a named child, so the named child at index 2 is the right child + assert_eq!( + binary_expression_node.field_name_for_named_child(2), + Some("right") + ); + // Negative test - Not a valid child index + assert_eq!(binary_expression_node.field_name_for_named_child(3), None); +} + #[test] fn test_node_child_by_field_name_with_extra_hidden_children() { let mut parser = Parser::new(); diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index feaa8ca8..445de081 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -373,6 +373,13 @@ extern "C" { child_index: u32, ) -> *const ::core::ffi::c_char; } +extern "C" { + #[doc = " Get the field name for node's named child at the given index, where zero\n represents the first named child. Returns NULL, if no field is found."] + pub fn ts_node_field_name_for_named_child( + self_: TSNode, + named_child_index: u32, + ) -> *const ::core::ffi::c_char; +} extern "C" { #[doc = " Get the node's number of children."] pub fn ts_node_child_count(self_: TSNode) -> u32; @@ -642,7 +649,7 @@ extern "C" { pub fn ts_query_cursor_set_timeout_micros(self_: *mut TSQueryCursor, timeout_micros: u64); } extern "C" { - #[doc = " Get the duration in microseconds that query execution is allowed to take."] + #[doc = " Get the duration in microseconds that query execution is allowed to take.\n\n This is set via [`ts_query_cursor_set_timeout_micros`]."] pub fn ts_query_cursor_timeout_micros(self_: *const TSQueryCursor) -> u64; } extern "C" { diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 79cfadf3..b5856c0d 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1236,6 +1236,14 @@ impl<'tree> Node<'tree> { } } + /// Get the field name of this node's named child at the given index. + pub fn field_name_for_named_child(&self, named_child_index: u32) -> Option<&'static str> { + unsafe { + let ptr = ffi::ts_node_field_name_for_named_child(self.0, named_child_index); + (!ptr.is_null()).then(|| CStr::from_ptr(ptr).to_str().unwrap()) + } + } + /// Iterate over this node's children. /// /// A [`TreeCursor`] is used to retrieve the children efficiently. Obtain diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 5ea845f5..4c19bbdf 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -570,6 +570,12 @@ TSNode ts_node_child(TSNode self, uint32_t child_index); */ const char *ts_node_field_name_for_child(TSNode self, uint32_t child_index); +/** + * Get the field name for node's named child at the given index, where zero + * represents the first named child. Returns NULL, if no field is found. + */ +const char *ts_node_field_name_for_named_child(TSNode self, uint32_t named_child_index); + /** * Get the node's number of children. */ diff --git a/lib/src/node.c b/lib/src/node.c index ef04d17e..849978c0 100644 --- a/lib/src/node.c +++ b/lib/src/node.c @@ -688,6 +688,48 @@ const char *ts_node_field_name_for_child(TSNode self, uint32_t child_index) { return NULL; } +const char *ts_node_field_name_for_named_child(TSNode self, uint32_t named_child_index) { + TSNode result = self; + bool did_descend = true; + const char *inherited_field_name = NULL; + + while (did_descend) { + did_descend = false; + + TSNode child; + uint32_t index = 0; + NodeChildIterator iterator = ts_node_iterate_children(&result); + while (ts_node_child_iterator_next(&iterator, &child)) { + if (ts_node__is_relevant(child, false)) { + if (index == named_child_index) { + if (ts_node_is_extra(child)) { + return NULL; + } + const char *field_name = ts_node__field_name_from_language(result, iterator.structural_child_index - 1); + if (field_name) return field_name; + return inherited_field_name; + } + index++; + } else { + uint32_t named_grandchild_index = named_child_index - index; + uint32_t grandchild_count = ts_node__relevant_child_count(child, false); + if (named_grandchild_index < grandchild_count) { + const char *field_name = ts_node__field_name_from_language(result, iterator.structural_child_index - 1); + if (field_name) inherited_field_name = field_name; + + did_descend = true; + result = child; + named_child_index = named_grandchild_index; + break; + } + index += grandchild_count; + } + } + } + + return NULL; +} + TSNode ts_node_child_by_field_name( TSNode self, const char *name,