From 8c3d1466ecae2a22a9625d1456ffaae84b13fd3e Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Sun, 23 May 2021 15:12:24 -0700 Subject: [PATCH] Allow QueryCursor's text callback to return an iterator --- cli/src/query.rs | 7 +- cli/src/tests/query_test.rs | 147 +++++++++++++++++++-------- highlight/src/lib.rs | 26 ++--- lib/binding_rust/lib.rs | 195 +++++++++++++++++++++++++++--------- tags/src/lib.rs | 8 +- 5 files changed, 272 insertions(+), 111 deletions(-) diff --git a/cli/src/query.rs b/cli/src/query.rs index bf0ae320..e303a002 100644 --- a/cli/src/query.rs +++ b/cli/src/query.rs @@ -3,7 +3,7 @@ use crate::query_testing; use std::fs; use std::io::{self, Write}; use std::path::Path; -use tree_sitter::{Language, Node, Parser, Query, QueryCursor}; +use tree_sitter::{Language, Parser, Query, QueryCursor}; pub fn query_files_at_paths( language: Language, @@ -38,12 +38,11 @@ pub fn query_files_at_paths( let source_code = fs::read(&path).map_err(Error::wrap(|| { format!("Error reading source file {:?}", path) }))?; - let text_callback = |n: Node| &source_code[n.byte_range()]; let tree = parser.parse(&source_code, None).unwrap(); if ordered_captures { for (mat, capture_index) in - query_cursor.captures(&query, tree.root_node(), text_callback) + query_cursor.captures(&query, tree.root_node(), source_code.as_slice()) { let capture = mat.captures[capture_index]; let capture_name = &query.capture_names()[capture.index as usize]; @@ -62,7 +61,7 @@ pub fn query_files_at_paths( }); } } else { - for m in query_cursor.matches(&query, tree.root_node(), text_callback) { + for m in query_cursor.matches(&query, tree.root_node(), source_code.as_slice()) { writeln!(&mut stdout, " pattern: {}", m.pattern_index)?; for capture in m.captures { let start = capture.node.start_position(); diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index d6153dd4..8e2c1313 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -761,7 +761,7 @@ fn test_query_matches_with_named_wildcard() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1645,7 +1645,7 @@ fn test_query_matches_with_too_many_permutations_to_track() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(&source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); // For this pathological query, some match permutations will be dropped. // Just check that a subset of the results are returned, and crash or @@ -1686,7 +1686,7 @@ fn test_query_matches_with_alternatives_and_too_many_permutations_to_track() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(&source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source.as_str()), @@ -1783,7 +1783,7 @@ fn test_query_matches_within_byte_range() { let matches = cursor .set_byte_range(0, 8) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1797,7 +1797,7 @@ fn test_query_matches_within_byte_range() { let matches = cursor .set_byte_range(5, 15) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1811,7 +1811,7 @@ fn test_query_matches_within_byte_range() { let matches = cursor .set_byte_range(12, 0) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1840,7 +1840,7 @@ fn test_query_matches_within_point_range() { let matches = cursor .set_point_range(Point::new(0, 0), Point::new(1, 3)) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1853,7 +1853,7 @@ fn test_query_matches_within_point_range() { let matches = cursor .set_point_range(Point::new(1, 0), Point::new(2, 3)) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1866,7 +1866,7 @@ fn test_query_matches_within_point_range() { let matches = cursor .set_point_range(Point::new(2, 1), Point::new(0, 0)) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1905,7 +1905,7 @@ fn test_query_captures_within_byte_range() { let captures = cursor .set_byte_range(3, 27) - .captures(&query, tree.root_node(), to_callback(source)); + .captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), @@ -1955,13 +1955,13 @@ fn test_query_matches_different_queries_same_cursor() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); - let matches = cursor.matches(&query1, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query1, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query1, source), &[(0, vec![("id1", "a")]),] ); - let matches = cursor.matches(&query3, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query3, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query3, source), &[ @@ -1971,7 +1971,7 @@ fn test_query_matches_different_queries_same_cursor() { ] ); - let matches = cursor.matches(&query2, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query2, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query2, source), &[(0, vec![("id1", "a")]), (1, vec![("id2", "b")]),] @@ -1998,7 +1998,7 @@ fn test_query_matches_with_multiple_captures_on_a_node() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), &[( @@ -2016,7 +2016,7 @@ fn test_query_matches_with_multiple_captures_on_a_node() { // disabling captures still works when there are multiple captures on a // single node. query.disable_capture("name2"); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), &[( @@ -2087,7 +2087,7 @@ fn test_query_matches_with_captured_wildcard_at_root() { let tree = parser.parse(&source, None).unwrap(); let match_capture_names_and_rows = cursor - .matches(&query, tree.root_node(), to_callback(source)) + .matches(&query, tree.root_node(), source.as_bytes()) .map(|m| { m.captures .iter() @@ -2352,7 +2352,7 @@ fn test_query_captures_basic() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -2368,7 +2368,7 @@ fn test_query_captures_basic() { ], ); - let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), &[ @@ -2425,7 +2425,7 @@ fn test_query_captures_with_text_conditions() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), &[ @@ -2564,7 +2564,7 @@ fn test_query_captures_with_duplicates() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), &[("function", "x"), ("variable", "x"),], @@ -2608,7 +2608,7 @@ fn test_query_captures_with_many_nested_results_without_fields() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(&source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); let captures = collect_captures(captures, &query, &source); assert_eq!( @@ -2668,7 +2668,7 @@ fn test_query_captures_with_many_nested_results_with_fields() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(&source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); let captures = collect_captures(captures, &query, &source); assert_eq!( @@ -2765,7 +2765,7 @@ fn test_query_captures_with_too_many_nested_results() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(&source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); let captures = collect_captures(captures, &query, &source); assert_eq!( @@ -2828,7 +2828,7 @@ fn test_query_captures_with_definite_pattern_containing_many_nested_matches() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), [("l-bracket", "[")] @@ -2864,7 +2864,7 @@ fn test_query_captures_ordered_by_both_start_and_end_positions() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), &[ @@ -2906,7 +2906,7 @@ fn test_query_captures_with_matches_removed() { let mut cursor = QueryCursor::new(); let mut captured_strings = Vec::new(); - for (m, i) in cursor.captures(&query, tree.root_node(), to_callback(source)) { + for (m, i) in cursor.captures(&query, tree.root_node(), source.as_bytes()) { let capture = m.captures[i]; let text = capture.node.utf8_text(source.as_bytes()).unwrap(); if text == "a" { @@ -2943,7 +2943,7 @@ fn test_query_captures_and_matches_iterators_are_fused() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let mut captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!(captures.next().unwrap().0.captures[0].index, 0); assert_eq!(captures.next().unwrap().0.captures[0].index, 0); @@ -2953,7 +2953,7 @@ fn test_query_captures_and_matches_iterators_are_fused() { assert!(captures.next().is_none()); drop(captures); - let mut matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let mut matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!(matches.next().unwrap().captures[0].index, 0); assert_eq!(matches.next().unwrap().captures[0].index, 0); assert_eq!(matches.next().unwrap().captures[0].index, 0); @@ -2963,6 +2963,79 @@ fn test_query_captures_and_matches_iterators_are_fused() { }); } +#[test] +fn test_query_text_callback_returns_chunks() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + ((identifier) @leading_upper + (#match? @leading_upper "^[A-Z][A-Z_]*[a-z]")) + ((identifier) @all_upper + (#match? @all_upper "^[A-Z][A-Z_]*$")) + ((identifier) @all_lower + (#match? @all_lower "^[a-z][a-z_]*$")) + "#, + ) + .unwrap(); + + let source = "SOMETHING[a] = transform(AnotherThing[b].property[c], PARAMETER);"; + + // Store the source code in chunks of 3 bytes, and expose it via + // an iterator API. + let source_chunks = source.as_bytes().chunks(3).collect::>(); + let chunks_in_range = |range: std::ops::Range| { + let mut offset = 0; + source_chunks.iter().filter_map(move |chunk| { + let end_offset = offset + chunk.len(); + if offset < range.end && range.start < end_offset { + let end_in_chunk = (range.end - offset).min(chunk.len()); + let start_in_chunk = range.start.max(offset) - offset; + offset = end_offset; + Some(&chunk[start_in_chunk..end_in_chunk]) + } else { + offset = end_offset; + None + } + }) + }; + assert_eq!( + chunks_in_range(0..9) + .map(|c| std::str::from_utf8(c).unwrap()) + .collect::(), + "SOMETHING", + ); + assert_eq!( + chunks_in_range(15..24) + .map(|c| std::str::from_utf8(c).unwrap()) + .collect::(), + "transform", + ); + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let captures = cursor.captures(&query, tree.root_node(), |node: Node| { + chunks_in_range(node.byte_range()) + }); + + assert_eq!( + collect_captures(captures, &query, source), + &[ + ("all_upper", "SOMETHING"), + ("all_lower", "a"), + ("all_lower", "transform"), + ("leading_upper", "AnotherThing"), + ("all_lower", "b"), + ("all_lower", "c"), + ("all_upper", "PARAMETER"), + ] + ); + }); +} + #[test] fn test_query_start_byte_for_pattern() { let language = get_language("javascript"); @@ -3058,7 +3131,7 @@ fn test_query_lifetime_is_separate_from_nodes_lifetime() { let query = Query::new(language, query).unwrap(); let mut cursor = QueryCursor::new(); let node = cursor - .matches(&query, node, to_callback(source)) + .matches(&query, node, source.as_bytes()) .next() .unwrap() .captures[0] @@ -3078,7 +3151,7 @@ fn test_query_lifetime_is_separate_from_nodes_lifetime() { let query = Query::new(language, query).unwrap(); let mut cursor = QueryCursor::new(); let node = cursor - .captures(&query, node, to_callback(source)) + .captures(&query, node, source.as_bytes()) .next() .unwrap() .0 @@ -3123,7 +3196,7 @@ fn test_query_comments() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), &[(0, vec![("fn-name", "one")]),], @@ -3159,7 +3232,7 @@ fn test_query_disable_pattern() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), &[ @@ -3502,13 +3575,13 @@ fn assert_query_matches( parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!(collect_matches(matches, &query, source), expected); assert_eq!(cursor.did_exceed_match_limit(), false); } fn collect_matches<'a>( - matches: impl Iterator>, + matches: impl Iterator>, query: &'a Query, source: &'a str, ) -> Vec<(usize, Vec<(&'a str, &'a str)>)> { @@ -3523,7 +3596,7 @@ fn collect_matches<'a>( } fn collect_captures<'a>( - captures: impl Iterator, usize)>, + captures: impl Iterator, usize)>, query: &'a Query, source: &'a str, ) -> Vec<(&'a str, &'a str)> { @@ -3544,7 +3617,3 @@ fn format_captures<'a>( }) .collect() } - -fn to_callback<'a>(source: &'a str) -> impl Fn(Node) -> &'a [u8] { - move |n| &source.as_bytes()[n.byte_range()] -} diff --git a/highlight/src/lib.rs b/highlight/src/lib.rs index c1e7aba4..3f75c6dc 100644 --- a/highlight/src/lib.rs +++ b/highlight/src/lib.rs @@ -83,7 +83,7 @@ struct LocalScope<'a> { local_defs: Vec>, } -struct HighlightIter<'a, 'tree: 'a, F> +struct HighlightIter<'a, F> where F: FnMut(&str) -> Option<&'a HighlightConfiguration> + 'a, { @@ -92,16 +92,16 @@ where highlighter: &'a mut Highlighter, injection_callback: F, cancellation_flag: Option<&'a AtomicUsize>, - layers: Vec>, + layers: Vec>, iter_count: usize, next_event: Option, last_highlight_range: Option<(usize, usize, usize)>, } -struct HighlightIterLayer<'a, 'tree: 'a> { +struct HighlightIterLayer<'a> { _tree: Tree, cursor: QueryCursor, - captures: iter::Peekable>, + captures: iter::Peekable>, config: &'a HighlightConfiguration, highlight_end_stack: Vec, scope_stack: Vec>, @@ -319,7 +319,7 @@ impl HighlightConfiguration { } } -impl<'a, 'tree: 'a> HighlightIterLayer<'a, 'tree> { +impl<'a> HighlightIterLayer<'a> { /// Create a new 'layer' of highlighting for this document. /// /// In the even that the new layer contains "combined injections" (injections where multiple @@ -356,9 +356,7 @@ impl<'a, 'tree: 'a> HighlightIterLayer<'a, 'tree> { let mut injections_by_pattern_index = vec![(None, Vec::new(), false); combined_injections_query.pattern_count()]; let matches = - cursor.matches(combined_injections_query, tree.root_node(), |n: Node| { - &source[n.byte_range()] - }); + cursor.matches(combined_injections_query, tree.root_node(), source); for mat in matches { let entry = &mut injections_by_pattern_index[mat.pattern_index]; let (language_name, content_node, include_children) = @@ -395,9 +393,7 @@ impl<'a, 'tree: 'a> HighlightIterLayer<'a, 'tree> { let cursor_ref = unsafe { mem::transmute::<_, &'static mut QueryCursor>(&mut cursor) }; let captures = cursor_ref - .captures(&config.query, tree_ref.root_node(), move |n: Node| { - &source[n.byte_range()] - }) + .captures(&config.query, tree_ref.root_node(), source) .peekable(); result.push(HighlightIterLayer { @@ -548,7 +544,7 @@ impl<'a, 'tree: 'a> HighlightIterLayer<'a, 'tree> { } } -impl<'a, 'tree: 'a, F> HighlightIter<'a, 'tree, F> +impl<'a, F> HighlightIter<'a, F> where F: FnMut(&str) -> Option<&'a HighlightConfiguration> + 'a, { @@ -596,7 +592,7 @@ where } } - fn insert_layer(&mut self, mut layer: HighlightIterLayer<'a, 'tree>) { + fn insert_layer(&mut self, mut layer: HighlightIterLayer<'a>) { if let Some(sort_key) = layer.sort_key() { let mut i = 1; while i < self.layers.len() { @@ -615,7 +611,7 @@ where } } -impl<'a, 'tree: 'a, F> Iterator for HighlightIter<'a, 'tree, F> +impl<'a, F> Iterator for HighlightIter<'a, F> where F: FnMut(&str) -> Option<&'a HighlightConfiguration> + 'a, { @@ -1025,7 +1021,7 @@ impl HtmlRenderer { fn injection_for_match<'a>( config: &HighlightConfiguration, query: &'a Query, - query_match: &QueryMatch<'a>, + query_match: &QueryMatch<'a, 'a>, source: &'a [u8], ) -> (Option<&'a str>, Option>, bool) { let content_capture_index = config.injection_content_capture_index; diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 801f773f..92e3c9d6 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -102,7 +102,9 @@ pub struct Query { } /// A stateful object for executing a `Query` on a syntax `Tree`. -pub struct QueryCursor(NonNull); +pub struct QueryCursor { + ptr: NonNull, +} /// A key-value pair associated with a particular pattern in a `Query`. #[derive(Debug, PartialEq, Eq)] @@ -126,18 +128,36 @@ pub struct QueryPredicate { } /// A match of a `Query` to a particular set of `Node`s. -pub struct QueryMatch<'a> { +pub struct QueryMatch<'cursor, 'tree> { pub pattern_index: usize, - pub captures: &'a [QueryCapture<'a>], + pub captures: &'cursor [QueryCapture<'tree>], id: u32, cursor: *mut ffi::TSQueryCursor, } -/// A sequence of `QueryCapture`s within a `QueryMatch`. -pub struct QueryCaptures<'a, 'tree: 'a, T: AsRef<[u8]>> { +/// A sequence of `QueryMatch`es associated with a given `QueryCursor`. +pub struct QueryMatches<'a, 'tree: 'a, T: TextProvider<'a>> { ptr: *mut ffi::TSQueryCursor, query: &'a Query, - text_callback: Box) -> T + 'a>, + text_provider: T, + buffer1: Vec, + buffer2: Vec, + _tree: PhantomData<&'tree ()>, +} + +/// A sequence of `QueryCapture`s associated with a given `QueryCursor`. +pub struct QueryCaptures<'a, 'tree: 'a, T: TextProvider<'a>> { + ptr: *mut ffi::TSQueryCursor, + query: &'a Query, + text_provider: T, + buffer1: Vec, + buffer2: Vec, + _tree: PhantomData<&'tree ()>, +} + +pub trait TextProvider<'a> { + type I: Iterator + 'a; + fn text(&mut self, node: Node) -> Self::I; } /// A particular `Node` that has been captured with a particular name within a `Query`. @@ -178,6 +198,11 @@ pub enum QueryErrorKind { Structure, } +trait TextCallback<'a> { + fn call(&mut self, node: Node); + fn next_chunk(&mut self) -> Option<&'a [u8]>; +} + #[derive(Debug)] enum TextPredicate { CaptureEqString(u32, String, bool), @@ -1590,18 +1615,20 @@ impl Query { } } -impl<'a> QueryCursor { +impl QueryCursor { /// Create a new cursor for executing a given query. /// /// The cursor stores the state that is needed to iteratively search for matches. pub fn new() -> Self { - QueryCursor(unsafe { NonNull::new_unchecked(ffi::ts_query_cursor_new()) }) + QueryCursor { + ptr: unsafe { NonNull::new_unchecked(ffi::ts_query_cursor_new()) }, + } } /// Check if, on its last execution, this cursor exceeded its maximum number of /// in-progress matches. pub fn did_exceed_match_limit(&self) -> bool { - unsafe { ffi::ts_query_cursor_did_exceed_match_limit(self.0.as_ptr()) } + unsafe { ffi::ts_query_cursor_did_exceed_match_limit(self.ptr.as_ptr()) } } /// Iterate over all of the matches in the order that they were found. @@ -1609,52 +1636,50 @@ impl<'a> QueryCursor { /// Each match contains the index of the pattern that matched, and a list of captures. /// Because multiple patterns can match the same set of nodes, one match may contain /// captures that appear *before* some of the captures from a previous match. - pub fn matches<'tree: 'a, T: AsRef<[u8]>>( + pub fn matches<'a, 'tree: 'a, T: TextProvider<'a> + 'a>( &'a mut self, query: &'a Query, node: Node<'tree>, - mut text_callback: impl FnMut(Node<'tree>) -> T + 'a, - ) -> impl Iterator> + 'a { - let ptr = self.0.as_ptr(); + text_provider: T, + ) -> QueryMatches<'a, 'tree, T> { + let ptr = self.ptr.as_ptr(); unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) }; - std::iter::from_fn(move || loop { - unsafe { - let mut m = MaybeUninit::::uninit(); - if ffi::ts_query_cursor_next_match(ptr, m.as_mut_ptr()) { - let result = QueryMatch::new(m.assume_init(), ptr); - if result.satisfies_text_predicates(query, &mut text_callback) { - return Some(result); - } - } else { - return None; - } - } - }) + QueryMatches { + ptr, + query, + text_provider, + buffer1: Default::default(), + buffer2: Default::default(), + _tree: PhantomData, + } } /// Iterate over all of the individual captures in the order that they appear. /// /// This is useful if don't care about which pattern matched, and just want a single, /// ordered sequence of captures. - pub fn captures<'tree, T: AsRef<[u8]>>( + pub fn captures<'a, 'tree: 'a, T: TextProvider<'a> + 'a>( &'a mut self, query: &'a Query, node: Node<'tree>, - text_callback: impl FnMut(Node<'tree>) -> T + 'a, + text_provider: T, ) -> QueryCaptures<'a, 'tree, T> { - let ptr = self.0.as_ptr(); - unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) }; + let ptr = self.ptr.as_ptr(); + unsafe { ffi::ts_query_cursor_exec(self.ptr.as_ptr(), query.ptr.as_ptr(), node.0) }; QueryCaptures { ptr, query, - text_callback: Box::new(text_callback), + text_provider, + buffer1: Default::default(), + buffer2: Default::default(), + _tree: PhantomData, } } /// Set the range in which the query will be executed, in terms of byte offsets. pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self { unsafe { - ffi::ts_query_cursor_set_byte_range(self.0.as_ptr(), start as u32, end as u32); + ffi::ts_query_cursor_set_byte_range(self.ptr.as_ptr(), start as u32, end as u32); } self } @@ -1662,13 +1687,13 @@ impl<'a> QueryCursor { /// Set the range in which the query will be executed, in terms of rows and columns. pub fn set_point_range(&mut self, start: Point, end: Point) -> &mut Self { unsafe { - ffi::ts_query_cursor_set_point_range(self.0.as_ptr(), start.into(), end.into()); + ffi::ts_query_cursor_set_point_range(self.ptr.as_ptr(), start.into(), end.into()); } self } } -impl<'a> QueryMatch<'a> { +impl<'a, 'tree> QueryMatch<'a, 'tree> { pub fn remove(self) { unsafe { ffi::ts_query_cursor_remove_match(self.cursor, self.id) } } @@ -1681,7 +1706,7 @@ impl<'a> QueryMatch<'a> { captures: if m.capture_count > 0 { unsafe { slice::from_raw_parts( - m.captures as *const QueryCapture<'a>, + m.captures as *const QueryCapture<'tree>, m.capture_count as usize, ) } @@ -1691,31 +1716,55 @@ impl<'a> QueryMatch<'a> { } } - fn satisfies_text_predicates>( + fn satisfies_text_predicates( &self, query: &Query, - text_callback: &mut impl FnMut(Node<'a>) -> T, + buffer1: &mut Vec, + buffer2: &mut Vec, + text_provider: &mut impl TextProvider<'a>, ) -> bool { + fn get_text<'a, 'b: 'a, I: Iterator>( + buffer: &'a mut Vec, + mut chunks: I, + ) -> &'a [u8] { + let first_chunk = chunks.next().unwrap_or(&[]); + if let Some(next_chunk) = chunks.next() { + buffer.clear(); + buffer.extend_from_slice(first_chunk); + buffer.extend_from_slice(next_chunk); + for chunk in chunks { + buffer.extend_from_slice(chunk); + } + buffer.as_slice() + } else { + first_chunk + } + } + query.text_predicates[self.pattern_index] .iter() .all(|predicate| match predicate { TextPredicate::CaptureEqCapture(i, j, is_positive) => { let node1 = self.capture_for_index(*i).unwrap(); let node2 = self.capture_for_index(*j).unwrap(); - (text_callback(node1).as_ref() == text_callback(node2).as_ref()) == *is_positive + let text1 = get_text(buffer1, text_provider.text(node1)); + let text2 = get_text(buffer2, text_provider.text(node2)); + (text1 == text2) == *is_positive } TextPredicate::CaptureEqString(i, s, is_positive) => { let node = self.capture_for_index(*i).unwrap(); - (text_callback(node).as_ref() == s.as_bytes()) == *is_positive + let text = get_text(buffer1, text_provider.text(node)); + (text == s.as_bytes()) == *is_positive } TextPredicate::CaptureMatchString(i, r, is_positive) => { let node = self.capture_for_index(*i).unwrap(); - r.is_match(text_callback(node).as_ref()) == *is_positive + let text = get_text(buffer1, text_provider.text(node)); + r.is_match(text) == *is_positive } }) } - fn capture_for_index(&self, capture_index: u32) -> Option> { + fn capture_for_index(&self, capture_index: u32) -> Option> { for c in self.captures { if c.index == capture_index { return Some(c.node); @@ -1735,12 +1784,37 @@ impl QueryProperty { } } -impl<'a, 'tree: 'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, 'tree, T> { - type Item = (QueryMatch<'tree>, usize); +impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> { + type Item = QueryMatch<'a, 'tree>; fn next(&mut self) -> Option { - loop { - unsafe { + unsafe { + loop { + let mut m = MaybeUninit::::uninit(); + if ffi::ts_query_cursor_next_match(self.ptr, m.as_mut_ptr()) { + let result = QueryMatch::new(m.assume_init(), self.ptr); + if result.satisfies_text_predicates( + self.query, + &mut self.buffer1, + &mut self.buffer2, + &mut self.text_provider, + ) { + return Some(result); + } + } else { + return None; + } + } + } + } +} + +impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryCaptures<'a, 'tree, T> { + type Item = (QueryMatch<'a, 'tree>, usize); + + fn next(&mut self) -> Option { + unsafe { + loop { let mut capture_index = 0u32; let mut m = MaybeUninit::::uninit(); if ffi::ts_query_cursor_next_capture( @@ -1749,7 +1823,12 @@ impl<'a, 'tree: 'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, 'tree, T> { &mut capture_index as *mut u32, ) { let result = QueryMatch::new(m.assume_init(), self.ptr); - if result.satisfies_text_predicates(self.query, &mut self.text_callback) { + if result.satisfies_text_predicates( + self.query, + &mut self.buffer1, + &mut self.buffer2, + &mut self.text_provider, + ) { return Some((result, capture_index as usize)); } else { result.remove(); @@ -1762,7 +1841,7 @@ impl<'a, 'tree: 'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, 'tree, T> { } } -impl<'a> fmt::Debug for QueryMatch<'a> { +impl<'cursor, 'tree> fmt::Debug for QueryMatch<'cursor, 'tree> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -1772,6 +1851,26 @@ impl<'a> fmt::Debug for QueryMatch<'a> { } } +impl<'a, F, I> TextProvider<'a> for F +where + F: FnMut(Node) -> I, + I: Iterator + 'a, +{ + type I = I; + + fn text(&mut self, node: Node) -> Self::I { + (self)(node) + } +} + +impl<'a> TextProvider<'a> for &'a [u8] { + type I = std::option::IntoIter<&'a [u8]>; + + fn text(&mut self, node: Node) -> Self::I { + Some(&self[node.byte_range()]).into_iter() + } +} + impl PartialEq for Query { fn eq(&self, other: &Self) -> bool { self.ptr == other.ptr @@ -1786,7 +1885,7 @@ impl Drop for Query { impl Drop for QueryCursor { fn drop(&mut self) { - unsafe { ffi::ts_query_cursor_delete(self.0.as_ptr()) } + unsafe { ffi::ts_query_cursor_delete(self.ptr.as_ptr()) } } } diff --git a/tags/src/lib.rs b/tags/src/lib.rs index 89809052..23b877c3 100644 --- a/tags/src/lib.rs +++ b/tags/src/lib.rs @@ -88,7 +88,7 @@ struct LocalScope<'a> { struct TagsIter<'a, I> where - I: Iterator>, + I: Iterator>, { matches: I, _tree: Tree, @@ -265,9 +265,7 @@ impl TagsContext { let tree_ref = unsafe { mem::transmute::<_, &'static Tree>(&tree) }; let matches = self .cursor - .matches(&config.query, tree_ref.root_node(), move |node| { - &source[node.byte_range()] - }); + .matches(&config.query, tree_ref.root_node(), source); Ok(( TagsIter { _tree: tree, @@ -291,7 +289,7 @@ impl TagsContext { impl<'a, I> Iterator for TagsIter<'a, I> where - I: Iterator>, + I: Iterator>, { type Item = Result;