From 9872a083b7f8442f02d59d7bbea986c3a676201f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 3 Oct 2019 12:45:58 -0700 Subject: [PATCH] rust: Change QueryCursor::captures to expose the full match --- cli/src/query.rs | 13 ++- cli/src/tests/query_test.rs | 65 +++++++++++-- lib/.ccls | 4 + lib/binding_rust/bindings.rs | 3 + lib/binding_rust/lib.rs | 173 ++++++++++++++++------------------ lib/include/tree_sitter/api.h | 1 + lib/src/query.c | 17 ++++ 7 files changed, 172 insertions(+), 104 deletions(-) create mode 100644 lib/.ccls diff --git a/cli/src/query.rs b/cli/src/query.rs index f373a314..c3a03679 100644 --- a/cli/src/query.rs +++ b/cli/src/query.rs @@ -34,12 +34,15 @@ pub fn query_files_at_paths( let tree = parser.parse(&source_code, None).unwrap(); if ordered_captures { - for (pattern_index, capture) in query_cursor.captures(&query, tree.root_node(), text_callback) { + for (mat, capture_index) in + query_cursor.captures(&query, tree.root_node(), text_callback) + { + let capture = mat.captures[capture_index]; writeln!( &mut stdout, " pattern: {}, capture: {}, row: {}, text: {:?}", - pattern_index, - &query.capture_names()[capture.index], + mat.pattern_index, + &query.capture_names()[capture.index as usize], capture.node.start_position().row, capture.node.utf8_text(&source_code).unwrap_or("") )?; @@ -47,11 +50,11 @@ pub fn query_files_at_paths( } else { for m in query_cursor.matches(&query, tree.root_node(), text_callback) { writeln!(&mut stdout, " pattern: {}", m.pattern_index)?; - for capture in m.captures() { + for capture in m.captures { writeln!( &mut stdout, " capture: {}, row: {}, text: {:?}", - &query.capture_names()[capture.index], + &query.capture_names()[capture.index as usize], capture.node.start_position().row, capture.node.utf8_text(&source_code).unwrap_or("") )?; diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 289ad9d2..740d6d73 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -874,6 +874,45 @@ fn test_query_captures_ordered_by_both_start_and_end_positions() { }); } +#[test] +fn test_query_captures_with_matches_removed() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + (binary_expression + left: (identifier) @left + operator: * @op + right: (identifier) @right) + "#, + ) + .unwrap(); + + let source = " + a === b && c > d && e < f; + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + + let mut captured_strings = Vec::new(); + for (m, i) in cursor.captures(&query, tree.root_node(), to_callback(source)) { + let capture = m.captures[i]; + let text = capture.node.utf8_text(source.as_bytes()).unwrap(); + if text == "a" { + m.remove(); + continue; + } + captured_strings.push(text); + } + + assert_eq!(captured_strings, &["c", ">", "d", "e", "<", "f",]); + }); +} + #[test] fn test_query_start_byte_for_pattern() { let language = get_language("javascript"); @@ -985,22 +1024,30 @@ fn collect_matches<'a>( .map(|m| { ( m.pattern_index, - collect_captures(m.captures().map(|c| (m.pattern_index, c)), query, source), + format_captures(m.captures.iter().cloned(), query, source), ) }) .collect() } -fn collect_captures<'a, 'b>( - captures: impl Iterator)>, - query: &'b Query, - source: &'b str, -) -> Vec<(&'b str, &'b str)> { +fn collect_captures<'a>( + captures: impl Iterator, usize)>, + query: &'a Query, + source: &'a str, +) -> Vec<(&'a str, &'a str)> { + format_captures(captures.map(|(m, i)| m.captures[i]), query, source) +} + +fn format_captures<'a>( + captures: impl Iterator>, + query: &'a Query, + source: &'a str, +) -> Vec<(&'a str, &'a str)> { captures - .map(|(_, QueryCapture { index, node })| { + .map(|capture| { ( - query.capture_names()[index].as_str(), - node.utf8_text(source.as_bytes()).unwrap(), + query.capture_names()[capture.index as usize].as_str(), + capture.node.utf8_text(source.as_bytes()).unwrap(), ) }) .collect() diff --git a/lib/.ccls b/lib/.ccls new file mode 100644 index 00000000..fdb974d8 --- /dev/null +++ b/lib/.ccls @@ -0,0 +1,4 @@ +-std=c99 +-Isrc +-Iinclude +-Iutf8proc diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 210a6f57..df1249a3 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -701,6 +701,9 @@ extern "C" { #[doc = " Otherwise, return `false`."] pub fn ts_query_cursor_next_match(arg1: *mut TSQueryCursor, match_: *mut TSQueryMatch) -> bool; } +extern "C" { + pub fn ts_query_cursor_remove_match(arg1: *mut TSQueryCursor, id: u32); +} extern "C" { #[doc = " Advance to the next capture of the currently running query."] #[doc = ""] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 39c99515..6dd398e2 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -122,6 +122,7 @@ pub struct PropertySheetJSON

{ } #[derive(Clone, Copy)] +#[repr(transparent)] pub struct Node<'a>(ffi::TSNode, PhantomData<&'a ()>); pub struct Parser(NonNull); @@ -163,15 +164,19 @@ pub struct Query { pub struct QueryCursor(NonNull); +#[derive(Clone)] pub struct QueryMatch<'a> { pub pattern_index: usize, - captures: &'a [ffi::TSQueryCapture], + pub captures: &'a [QueryCapture<'a>], + id: u32, + cursor: *mut ffi::TSQueryCursor, } -#[derive(Clone)] +#[derive(Clone, Copy)] +#[repr(C)] pub struct QueryCapture<'a> { - pub index: usize, pub node: Node<'a>, + pub index: u32, } #[derive(Debug, PartialEq, Eq)] @@ -1244,16 +1249,6 @@ impl Query { } } -impl QueryProperty { - pub fn new(key: &str, value: Option<&str>, capture_id: Option) -> Self { - QueryProperty { - capture_id, - key: key.to_string().into_boxed_str(), - value: value.map(|s| s.to_string().into_boxed_str()), - } - } -} - impl QueryCursor { pub fn new() -> Self { QueryCursor(unsafe { NonNull::new_unchecked(ffi::ts_query_cursor_new()) }) @@ -1267,27 +1262,16 @@ impl QueryCursor { ) -> impl Iterator> + 'a { let ptr = self.0.as_ptr(); unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) }; - std::iter::from_fn(move || -> Option> { - loop { - unsafe { - let mut m = MaybeUninit::::uninit(); - if ffi::ts_query_cursor_next_match(ptr, m.as_mut_ptr()) { - let m = m.assume_init(); - let captures = slice::from_raw_parts(m.captures, m.capture_count as usize); - if Self::captures_match_text_predicates( - query, - captures, - m.pattern_index as usize, - &mut text_callback, - ) { - return Some(QueryMatch { - pattern_index: m.pattern_index as usize, - captures, - }); - } - } else { - return None; + std::iter::from_fn(move || loop { + unsafe { + let mut m = MaybeUninit::::uninit(); + if ffi::ts_query_cursor_next_match(ptr, m.as_mut_ptr()) { + let result = QueryMatch::new(m.assume_init(), ptr); + if result.satisfies_text_predicates(query, &mut text_callback) { + return Some(result); } + } else { + return None; } } }) @@ -1298,34 +1282,23 @@ impl QueryCursor { query: &'a Query, node: Node<'a>, mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a, - ) -> impl Iterator + 'a { + ) -> impl Iterator, usize)> + 'a { let ptr = self.0.as_ptr(); unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) }; std::iter::from_fn(move || loop { unsafe { - let mut m = MaybeUninit::::uninit(); let mut capture_index = 0u32; + let mut m = MaybeUninit::::uninit(); if ffi::ts_query_cursor_next_capture( ptr, m.as_mut_ptr(), &mut capture_index as *mut u32, ) { - let m = m.assume_init(); - let captures = slice::from_raw_parts(m.captures, m.capture_count as usize); - if Self::captures_match_text_predicates( - query, - captures, - m.pattern_index as usize, - &mut text_callback, - ) { - let capture = captures[capture_index as usize]; - return Some(( - m.pattern_index as usize, - QueryCapture { - index: capture.index as usize, - node: Node::new(capture.node).unwrap(), - }, - )); + let result = QueryMatch::new(m.assume_init(), ptr); + if result.satisfies_text_predicates(query, &mut text_callback) { + return Some((result, capture_index as usize)); + } else { + result.remove(); } } else { return None; @@ -1334,40 +1307,6 @@ impl QueryCursor { }) } - fn captures_match_text_predicates<'a>( - query: &'a Query, - captures: &'a [ffi::TSQueryCapture], - pattern_index: usize, - text_callback: &mut impl FnMut(Node<'a>) -> &'a [u8], - ) -> bool { - query.text_predicates[pattern_index] - .iter() - .all(|predicate| match predicate { - TextPredicate::CaptureEqCapture(i, j) => { - let node1 = Self::capture_for_id(captures, *i).unwrap(); - let node2 = Self::capture_for_id(captures, *j).unwrap(); - text_callback(node1) == text_callback(node2) - } - TextPredicate::CaptureEqString(i, s) => { - let node = Self::capture_for_id(captures, *i).unwrap(); - text_callback(node) == s.as_bytes() - } - TextPredicate::CaptureMatchString(i, r) => { - let node = Self::capture_for_id(captures, *i).unwrap(); - r.is_match(text_callback(node)) - } - }) - } - - fn capture_for_id(captures: &[ffi::TSQueryCapture], capture_id: u32) -> Option { - for c in captures { - if c.index == capture_id { - return Node::new(c.node); - } - } - None - } - pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self { unsafe { ffi::ts_query_cursor_set_byte_range(self.0.as_ptr(), start as u32, end as u32); @@ -1384,11 +1323,65 @@ impl QueryCursor { } impl<'a> QueryMatch<'a> { - pub fn captures(&self) -> impl ExactSizeIterator { - self.captures.iter().map(|capture| QueryCapture { - index: capture.index as usize, - node: Node::new(capture.node).unwrap(), - }) + pub fn remove(self) { + unsafe { ffi::ts_query_cursor_remove_match(self.cursor, self.id) } + } + + fn new(m: ffi::TSQueryMatch, cursor: *mut ffi::TSQueryCursor) -> Self { + QueryMatch { + cursor, + id: m.id, + pattern_index: m.pattern_index as usize, + captures: unsafe { + slice::from_raw_parts( + m.captures as *const QueryCapture<'a>, + m.capture_count as usize, + ) + }, + } + } + + fn satisfies_text_predicates( + &self, + query: &Query, + text_callback: &mut impl FnMut(Node<'a>) -> &[u8], + ) -> bool { + query.text_predicates[self.pattern_index] + .iter() + .all(|predicate| match predicate { + TextPredicate::CaptureEqCapture(i, j) => { + let node1 = self.capture_for_index(*i).unwrap(); + let node2 = self.capture_for_index(*j).unwrap(); + text_callback(node1) == text_callback(node2) + } + TextPredicate::CaptureEqString(i, s) => { + let node = self.capture_for_index(*i).unwrap(); + text_callback(node) == s.as_bytes() + } + TextPredicate::CaptureMatchString(i, r) => { + let node = self.capture_for_index(*i).unwrap(); + r.is_match(text_callback(node)) + } + }) + } + + fn capture_for_index(&self, capture_index: u32) -> Option> { + for c in self.captures { + if c.index == capture_index { + return Some(c.node); + } + } + None + } +} + +impl QueryProperty { + pub fn new(key: &str, value: Option<&str>, capture_id: Option) -> Self { + QueryProperty { + capture_id, + key: key.to_string().into_boxed_str(), + value: value.map(|s| s.to_string().into_boxed_str()), + } } } diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index c1a97134..b53174fa 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -764,6 +764,7 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); * Otherwise, return `false`. */ bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match); +void ts_query_cursor_remove_match(TSQueryCursor *, uint32_t id); /** * Advance to the next capture of the currently running query. diff --git a/lib/src/query.c b/lib/src/query.c index c6ec3962..2716a3d7 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -1234,6 +1234,23 @@ bool ts_query_cursor_next_match( return true; } +void ts_query_cursor_remove_match( + TSQueryCursor *self, + uint32_t match_id +) { + for (unsigned i = 0; i < self->finished_states.size; i++) { + const QueryState *state = &self->finished_states.contents[i]; + if (state->id == match_id) { + capture_list_pool_release( + &self->capture_list_pool, + state->capture_list_id + ); + array_erase(&self->finished_states, i); + return; + } + } +} + bool ts_query_cursor_next_capture( TSQueryCursor *self, TSQueryMatch *match,