From 8c3d1466ecae2a22a9625d1456ffaae84b13fd3e Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Sun, 23 May 2021 15:12:24 -0700 Subject: [PATCH 01/12] Allow QueryCursor's text callback to return an iterator --- cli/src/query.rs | 7 +- cli/src/tests/query_test.rs | 147 +++++++++++++++++++-------- highlight/src/lib.rs | 26 ++--- lib/binding_rust/lib.rs | 195 +++++++++++++++++++++++++++--------- tags/src/lib.rs | 8 +- 5 files changed, 272 insertions(+), 111 deletions(-) diff --git a/cli/src/query.rs b/cli/src/query.rs index bf0ae320..e303a002 100644 --- a/cli/src/query.rs +++ b/cli/src/query.rs @@ -3,7 +3,7 @@ use crate::query_testing; use std::fs; use std::io::{self, Write}; use std::path::Path; -use tree_sitter::{Language, Node, Parser, Query, QueryCursor}; +use tree_sitter::{Language, Parser, Query, QueryCursor}; pub fn query_files_at_paths( language: Language, @@ -38,12 +38,11 @@ pub fn query_files_at_paths( let source_code = fs::read(&path).map_err(Error::wrap(|| { format!("Error reading source file {:?}", path) }))?; - let text_callback = |n: Node| &source_code[n.byte_range()]; let tree = parser.parse(&source_code, None).unwrap(); if ordered_captures { for (mat, capture_index) in - query_cursor.captures(&query, tree.root_node(), text_callback) + query_cursor.captures(&query, tree.root_node(), source_code.as_slice()) { let capture = mat.captures[capture_index]; let capture_name = &query.capture_names()[capture.index as usize]; @@ -62,7 +61,7 @@ pub fn query_files_at_paths( }); } } else { - for m in query_cursor.matches(&query, tree.root_node(), text_callback) { + for m in query_cursor.matches(&query, tree.root_node(), source_code.as_slice()) { writeln!(&mut stdout, " pattern: {}", m.pattern_index)?; for capture in m.captures { let start = capture.node.start_position(); diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index d6153dd4..8e2c1313 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -761,7 +761,7 @@ fn test_query_matches_with_named_wildcard() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1645,7 +1645,7 @@ fn test_query_matches_with_too_many_permutations_to_track() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(&source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); // For this pathological query, some match permutations will be dropped. // Just check that a subset of the results are returned, and crash or @@ -1686,7 +1686,7 @@ fn test_query_matches_with_alternatives_and_too_many_permutations_to_track() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(&source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source.as_str()), @@ -1783,7 +1783,7 @@ fn test_query_matches_within_byte_range() { let matches = cursor .set_byte_range(0, 8) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1797,7 +1797,7 @@ fn test_query_matches_within_byte_range() { let matches = cursor .set_byte_range(5, 15) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1811,7 +1811,7 @@ fn test_query_matches_within_byte_range() { let matches = cursor .set_byte_range(12, 0) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1840,7 +1840,7 @@ fn test_query_matches_within_point_range() { let matches = cursor .set_point_range(Point::new(0, 0), Point::new(1, 3)) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1853,7 +1853,7 @@ fn test_query_matches_within_point_range() { let matches = cursor .set_point_range(Point::new(1, 0), Point::new(2, 3)) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1866,7 +1866,7 @@ fn test_query_matches_within_point_range() { let matches = cursor .set_point_range(Point::new(2, 1), Point::new(0, 0)) - .matches(&query, tree.root_node(), to_callback(source)); + .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -1905,7 +1905,7 @@ fn test_query_captures_within_byte_range() { let captures = cursor .set_byte_range(3, 27) - .captures(&query, tree.root_node(), to_callback(source)); + .captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), @@ -1955,13 +1955,13 @@ fn test_query_matches_different_queries_same_cursor() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); - let matches = cursor.matches(&query1, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query1, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query1, source), &[(0, vec![("id1", "a")]),] ); - let matches = cursor.matches(&query3, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query3, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query3, source), &[ @@ -1971,7 +1971,7 @@ fn test_query_matches_different_queries_same_cursor() { ] ); - let matches = cursor.matches(&query2, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query2, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query2, source), &[(0, vec![("id1", "a")]), (1, vec![("id2", "b")]),] @@ -1998,7 +1998,7 @@ fn test_query_matches_with_multiple_captures_on_a_node() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), &[( @@ -2016,7 +2016,7 @@ fn test_query_matches_with_multiple_captures_on_a_node() { // disabling captures still works when there are multiple captures on a // single node. query.disable_capture("name2"); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), &[( @@ -2087,7 +2087,7 @@ fn test_query_matches_with_captured_wildcard_at_root() { let tree = parser.parse(&source, None).unwrap(); let match_capture_names_and_rows = cursor - .matches(&query, tree.root_node(), to_callback(source)) + .matches(&query, tree.root_node(), source.as_bytes()) .map(|m| { m.captures .iter() @@ -2352,7 +2352,7 @@ fn test_query_captures_basic() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), @@ -2368,7 +2368,7 @@ fn test_query_captures_basic() { ], ); - let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), &[ @@ -2425,7 +2425,7 @@ fn test_query_captures_with_text_conditions() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), &[ @@ -2564,7 +2564,7 @@ fn test_query_captures_with_duplicates() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), &[("function", "x"), ("variable", "x"),], @@ -2608,7 +2608,7 @@ fn test_query_captures_with_many_nested_results_without_fields() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(&source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); let captures = collect_captures(captures, &query, &source); assert_eq!( @@ -2668,7 +2668,7 @@ fn test_query_captures_with_many_nested_results_with_fields() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(&source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); let captures = collect_captures(captures, &query, &source); assert_eq!( @@ -2765,7 +2765,7 @@ fn test_query_captures_with_too_many_nested_results() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(&source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); let captures = collect_captures(captures, &query, &source); assert_eq!( @@ -2828,7 +2828,7 @@ fn test_query_captures_with_definite_pattern_containing_many_nested_matches() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), [("l-bracket", "[")] @@ -2864,7 +2864,7 @@ fn test_query_captures_ordered_by_both_start_and_end_positions() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_captures(captures, &query, source), &[ @@ -2906,7 +2906,7 @@ fn test_query_captures_with_matches_removed() { let mut cursor = QueryCursor::new(); let mut captured_strings = Vec::new(); - for (m, i) in cursor.captures(&query, tree.root_node(), to_callback(source)) { + for (m, i) in cursor.captures(&query, tree.root_node(), source.as_bytes()) { let capture = m.captures[i]; let text = capture.node.utf8_text(source.as_bytes()).unwrap(); if text == "a" { @@ -2943,7 +2943,7 @@ fn test_query_captures_and_matches_iterators_are_fused() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let mut captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); assert_eq!(captures.next().unwrap().0.captures[0].index, 0); assert_eq!(captures.next().unwrap().0.captures[0].index, 0); @@ -2953,7 +2953,7 @@ fn test_query_captures_and_matches_iterators_are_fused() { assert!(captures.next().is_none()); drop(captures); - let mut matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let mut matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!(matches.next().unwrap().captures[0].index, 0); assert_eq!(matches.next().unwrap().captures[0].index, 0); assert_eq!(matches.next().unwrap().captures[0].index, 0); @@ -2963,6 +2963,79 @@ fn test_query_captures_and_matches_iterators_are_fused() { }); } +#[test] +fn test_query_text_callback_returns_chunks() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + ((identifier) @leading_upper + (#match? @leading_upper "^[A-Z][A-Z_]*[a-z]")) + ((identifier) @all_upper + (#match? @all_upper "^[A-Z][A-Z_]*$")) + ((identifier) @all_lower + (#match? @all_lower "^[a-z][a-z_]*$")) + "#, + ) + .unwrap(); + + let source = "SOMETHING[a] = transform(AnotherThing[b].property[c], PARAMETER);"; + + // Store the source code in chunks of 3 bytes, and expose it via + // an iterator API. + let source_chunks = source.as_bytes().chunks(3).collect::>(); + let chunks_in_range = |range: std::ops::Range| { + let mut offset = 0; + source_chunks.iter().filter_map(move |chunk| { + let end_offset = offset + chunk.len(); + if offset < range.end && range.start < end_offset { + let end_in_chunk = (range.end - offset).min(chunk.len()); + let start_in_chunk = range.start.max(offset) - offset; + offset = end_offset; + Some(&chunk[start_in_chunk..end_in_chunk]) + } else { + offset = end_offset; + None + } + }) + }; + assert_eq!( + chunks_in_range(0..9) + .map(|c| std::str::from_utf8(c).unwrap()) + .collect::(), + "SOMETHING", + ); + assert_eq!( + chunks_in_range(15..24) + .map(|c| std::str::from_utf8(c).unwrap()) + .collect::(), + "transform", + ); + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let captures = cursor.captures(&query, tree.root_node(), |node: Node| { + chunks_in_range(node.byte_range()) + }); + + assert_eq!( + collect_captures(captures, &query, source), + &[ + ("all_upper", "SOMETHING"), + ("all_lower", "a"), + ("all_lower", "transform"), + ("leading_upper", "AnotherThing"), + ("all_lower", "b"), + ("all_lower", "c"), + ("all_upper", "PARAMETER"), + ] + ); + }); +} + #[test] fn test_query_start_byte_for_pattern() { let language = get_language("javascript"); @@ -3058,7 +3131,7 @@ fn test_query_lifetime_is_separate_from_nodes_lifetime() { let query = Query::new(language, query).unwrap(); let mut cursor = QueryCursor::new(); let node = cursor - .matches(&query, node, to_callback(source)) + .matches(&query, node, source.as_bytes()) .next() .unwrap() .captures[0] @@ -3078,7 +3151,7 @@ fn test_query_lifetime_is_separate_from_nodes_lifetime() { let query = Query::new(language, query).unwrap(); let mut cursor = QueryCursor::new(); let node = cursor - .captures(&query, node, to_callback(source)) + .captures(&query, node, source.as_bytes()) .next() .unwrap() .0 @@ -3123,7 +3196,7 @@ fn test_query_comments() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), &[(0, vec![("fn-name", "one")]),], @@ -3159,7 +3232,7 @@ fn test_query_disable_pattern() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), &[ @@ -3502,13 +3575,13 @@ fn assert_query_matches( parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!(collect_matches(matches, &query, source), expected); assert_eq!(cursor.did_exceed_match_limit(), false); } fn collect_matches<'a>( - matches: impl Iterator>, + matches: impl Iterator>, query: &'a Query, source: &'a str, ) -> Vec<(usize, Vec<(&'a str, &'a str)>)> { @@ -3523,7 +3596,7 @@ fn collect_matches<'a>( } fn collect_captures<'a>( - captures: impl Iterator, usize)>, + captures: impl Iterator, usize)>, query: &'a Query, source: &'a str, ) -> Vec<(&'a str, &'a str)> { @@ -3544,7 +3617,3 @@ fn format_captures<'a>( }) .collect() } - -fn to_callback<'a>(source: &'a str) -> impl Fn(Node) -> &'a [u8] { - move |n| &source.as_bytes()[n.byte_range()] -} diff --git a/highlight/src/lib.rs b/highlight/src/lib.rs index c1e7aba4..3f75c6dc 100644 --- a/highlight/src/lib.rs +++ b/highlight/src/lib.rs @@ -83,7 +83,7 @@ struct LocalScope<'a> { local_defs: Vec>, } -struct HighlightIter<'a, 'tree: 'a, F> +struct HighlightIter<'a, F> where F: FnMut(&str) -> Option<&'a HighlightConfiguration> + 'a, { @@ -92,16 +92,16 @@ where highlighter: &'a mut Highlighter, injection_callback: F, cancellation_flag: Option<&'a AtomicUsize>, - layers: Vec>, + layers: Vec>, iter_count: usize, next_event: Option, last_highlight_range: Option<(usize, usize, usize)>, } -struct HighlightIterLayer<'a, 'tree: 'a> { +struct HighlightIterLayer<'a> { _tree: Tree, cursor: QueryCursor, - captures: iter::Peekable>, + captures: iter::Peekable>, config: &'a HighlightConfiguration, highlight_end_stack: Vec, scope_stack: Vec>, @@ -319,7 +319,7 @@ impl HighlightConfiguration { } } -impl<'a, 'tree: 'a> HighlightIterLayer<'a, 'tree> { +impl<'a> HighlightIterLayer<'a> { /// Create a new 'layer' of highlighting for this document. /// /// In the even that the new layer contains "combined injections" (injections where multiple @@ -356,9 +356,7 @@ impl<'a, 'tree: 'a> HighlightIterLayer<'a, 'tree> { let mut injections_by_pattern_index = vec![(None, Vec::new(), false); combined_injections_query.pattern_count()]; let matches = - cursor.matches(combined_injections_query, tree.root_node(), |n: Node| { - &source[n.byte_range()] - }); + cursor.matches(combined_injections_query, tree.root_node(), source); for mat in matches { let entry = &mut injections_by_pattern_index[mat.pattern_index]; let (language_name, content_node, include_children) = @@ -395,9 +393,7 @@ impl<'a, 'tree: 'a> HighlightIterLayer<'a, 'tree> { let cursor_ref = unsafe { mem::transmute::<_, &'static mut QueryCursor>(&mut cursor) }; let captures = cursor_ref - .captures(&config.query, tree_ref.root_node(), move |n: Node| { - &source[n.byte_range()] - }) + .captures(&config.query, tree_ref.root_node(), source) .peekable(); result.push(HighlightIterLayer { @@ -548,7 +544,7 @@ impl<'a, 'tree: 'a> HighlightIterLayer<'a, 'tree> { } } -impl<'a, 'tree: 'a, F> HighlightIter<'a, 'tree, F> +impl<'a, F> HighlightIter<'a, F> where F: FnMut(&str) -> Option<&'a HighlightConfiguration> + 'a, { @@ -596,7 +592,7 @@ where } } - fn insert_layer(&mut self, mut layer: HighlightIterLayer<'a, 'tree>) { + fn insert_layer(&mut self, mut layer: HighlightIterLayer<'a>) { if let Some(sort_key) = layer.sort_key() { let mut i = 1; while i < self.layers.len() { @@ -615,7 +611,7 @@ where } } -impl<'a, 'tree: 'a, F> Iterator for HighlightIter<'a, 'tree, F> +impl<'a, F> Iterator for HighlightIter<'a, F> where F: FnMut(&str) -> Option<&'a HighlightConfiguration> + 'a, { @@ -1025,7 +1021,7 @@ impl HtmlRenderer { fn injection_for_match<'a>( config: &HighlightConfiguration, query: &'a Query, - query_match: &QueryMatch<'a>, + query_match: &QueryMatch<'a, 'a>, source: &'a [u8], ) -> (Option<&'a str>, Option>, bool) { let content_capture_index = config.injection_content_capture_index; diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 801f773f..92e3c9d6 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -102,7 +102,9 @@ pub struct Query { } /// A stateful object for executing a `Query` on a syntax `Tree`. -pub struct QueryCursor(NonNull); +pub struct QueryCursor { + ptr: NonNull, +} /// A key-value pair associated with a particular pattern in a `Query`. #[derive(Debug, PartialEq, Eq)] @@ -126,18 +128,36 @@ pub struct QueryPredicate { } /// A match of a `Query` to a particular set of `Node`s. -pub struct QueryMatch<'a> { +pub struct QueryMatch<'cursor, 'tree> { pub pattern_index: usize, - pub captures: &'a [QueryCapture<'a>], + pub captures: &'cursor [QueryCapture<'tree>], id: u32, cursor: *mut ffi::TSQueryCursor, } -/// A sequence of `QueryCapture`s within a `QueryMatch`. -pub struct QueryCaptures<'a, 'tree: 'a, T: AsRef<[u8]>> { +/// A sequence of `QueryMatch`es associated with a given `QueryCursor`. +pub struct QueryMatches<'a, 'tree: 'a, T: TextProvider<'a>> { ptr: *mut ffi::TSQueryCursor, query: &'a Query, - text_callback: Box) -> T + 'a>, + text_provider: T, + buffer1: Vec, + buffer2: Vec, + _tree: PhantomData<&'tree ()>, +} + +/// A sequence of `QueryCapture`s associated with a given `QueryCursor`. +pub struct QueryCaptures<'a, 'tree: 'a, T: TextProvider<'a>> { + ptr: *mut ffi::TSQueryCursor, + query: &'a Query, + text_provider: T, + buffer1: Vec, + buffer2: Vec, + _tree: PhantomData<&'tree ()>, +} + +pub trait TextProvider<'a> { + type I: Iterator + 'a; + fn text(&mut self, node: Node) -> Self::I; } /// A particular `Node` that has been captured with a particular name within a `Query`. @@ -178,6 +198,11 @@ pub enum QueryErrorKind { Structure, } +trait TextCallback<'a> { + fn call(&mut self, node: Node); + fn next_chunk(&mut self) -> Option<&'a [u8]>; +} + #[derive(Debug)] enum TextPredicate { CaptureEqString(u32, String, bool), @@ -1590,18 +1615,20 @@ impl Query { } } -impl<'a> QueryCursor { +impl QueryCursor { /// Create a new cursor for executing a given query. /// /// The cursor stores the state that is needed to iteratively search for matches. pub fn new() -> Self { - QueryCursor(unsafe { NonNull::new_unchecked(ffi::ts_query_cursor_new()) }) + QueryCursor { + ptr: unsafe { NonNull::new_unchecked(ffi::ts_query_cursor_new()) }, + } } /// Check if, on its last execution, this cursor exceeded its maximum number of /// in-progress matches. pub fn did_exceed_match_limit(&self) -> bool { - unsafe { ffi::ts_query_cursor_did_exceed_match_limit(self.0.as_ptr()) } + unsafe { ffi::ts_query_cursor_did_exceed_match_limit(self.ptr.as_ptr()) } } /// Iterate over all of the matches in the order that they were found. @@ -1609,52 +1636,50 @@ impl<'a> QueryCursor { /// Each match contains the index of the pattern that matched, and a list of captures. /// Because multiple patterns can match the same set of nodes, one match may contain /// captures that appear *before* some of the captures from a previous match. - pub fn matches<'tree: 'a, T: AsRef<[u8]>>( + pub fn matches<'a, 'tree: 'a, T: TextProvider<'a> + 'a>( &'a mut self, query: &'a Query, node: Node<'tree>, - mut text_callback: impl FnMut(Node<'tree>) -> T + 'a, - ) -> impl Iterator> + 'a { - let ptr = self.0.as_ptr(); + text_provider: T, + ) -> QueryMatches<'a, 'tree, T> { + let ptr = self.ptr.as_ptr(); unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) }; - std::iter::from_fn(move || loop { - unsafe { - let mut m = MaybeUninit::::uninit(); - if ffi::ts_query_cursor_next_match(ptr, m.as_mut_ptr()) { - let result = QueryMatch::new(m.assume_init(), ptr); - if result.satisfies_text_predicates(query, &mut text_callback) { - return Some(result); - } - } else { - return None; - } - } - }) + QueryMatches { + ptr, + query, + text_provider, + buffer1: Default::default(), + buffer2: Default::default(), + _tree: PhantomData, + } } /// Iterate over all of the individual captures in the order that they appear. /// /// This is useful if don't care about which pattern matched, and just want a single, /// ordered sequence of captures. - pub fn captures<'tree, T: AsRef<[u8]>>( + pub fn captures<'a, 'tree: 'a, T: TextProvider<'a> + 'a>( &'a mut self, query: &'a Query, node: Node<'tree>, - text_callback: impl FnMut(Node<'tree>) -> T + 'a, + text_provider: T, ) -> QueryCaptures<'a, 'tree, T> { - let ptr = self.0.as_ptr(); - unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) }; + let ptr = self.ptr.as_ptr(); + unsafe { ffi::ts_query_cursor_exec(self.ptr.as_ptr(), query.ptr.as_ptr(), node.0) }; QueryCaptures { ptr, query, - text_callback: Box::new(text_callback), + text_provider, + buffer1: Default::default(), + buffer2: Default::default(), + _tree: PhantomData, } } /// Set the range in which the query will be executed, in terms of byte offsets. pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self { unsafe { - ffi::ts_query_cursor_set_byte_range(self.0.as_ptr(), start as u32, end as u32); + ffi::ts_query_cursor_set_byte_range(self.ptr.as_ptr(), start as u32, end as u32); } self } @@ -1662,13 +1687,13 @@ impl<'a> QueryCursor { /// Set the range in which the query will be executed, in terms of rows and columns. pub fn set_point_range(&mut self, start: Point, end: Point) -> &mut Self { unsafe { - ffi::ts_query_cursor_set_point_range(self.0.as_ptr(), start.into(), end.into()); + ffi::ts_query_cursor_set_point_range(self.ptr.as_ptr(), start.into(), end.into()); } self } } -impl<'a> QueryMatch<'a> { +impl<'a, 'tree> QueryMatch<'a, 'tree> { pub fn remove(self) { unsafe { ffi::ts_query_cursor_remove_match(self.cursor, self.id) } } @@ -1681,7 +1706,7 @@ impl<'a> QueryMatch<'a> { captures: if m.capture_count > 0 { unsafe { slice::from_raw_parts( - m.captures as *const QueryCapture<'a>, + m.captures as *const QueryCapture<'tree>, m.capture_count as usize, ) } @@ -1691,31 +1716,55 @@ impl<'a> QueryMatch<'a> { } } - fn satisfies_text_predicates>( + fn satisfies_text_predicates( &self, query: &Query, - text_callback: &mut impl FnMut(Node<'a>) -> T, + buffer1: &mut Vec, + buffer2: &mut Vec, + text_provider: &mut impl TextProvider<'a>, ) -> bool { + fn get_text<'a, 'b: 'a, I: Iterator>( + buffer: &'a mut Vec, + mut chunks: I, + ) -> &'a [u8] { + let first_chunk = chunks.next().unwrap_or(&[]); + if let Some(next_chunk) = chunks.next() { + buffer.clear(); + buffer.extend_from_slice(first_chunk); + buffer.extend_from_slice(next_chunk); + for chunk in chunks { + buffer.extend_from_slice(chunk); + } + buffer.as_slice() + } else { + first_chunk + } + } + query.text_predicates[self.pattern_index] .iter() .all(|predicate| match predicate { TextPredicate::CaptureEqCapture(i, j, is_positive) => { let node1 = self.capture_for_index(*i).unwrap(); let node2 = self.capture_for_index(*j).unwrap(); - (text_callback(node1).as_ref() == text_callback(node2).as_ref()) == *is_positive + let text1 = get_text(buffer1, text_provider.text(node1)); + let text2 = get_text(buffer2, text_provider.text(node2)); + (text1 == text2) == *is_positive } TextPredicate::CaptureEqString(i, s, is_positive) => { let node = self.capture_for_index(*i).unwrap(); - (text_callback(node).as_ref() == s.as_bytes()) == *is_positive + let text = get_text(buffer1, text_provider.text(node)); + (text == s.as_bytes()) == *is_positive } TextPredicate::CaptureMatchString(i, r, is_positive) => { let node = self.capture_for_index(*i).unwrap(); - r.is_match(text_callback(node).as_ref()) == *is_positive + let text = get_text(buffer1, text_provider.text(node)); + r.is_match(text) == *is_positive } }) } - fn capture_for_index(&self, capture_index: u32) -> Option> { + fn capture_for_index(&self, capture_index: u32) -> Option> { for c in self.captures { if c.index == capture_index { return Some(c.node); @@ -1735,12 +1784,37 @@ impl QueryProperty { } } -impl<'a, 'tree: 'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, 'tree, T> { - type Item = (QueryMatch<'tree>, usize); +impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> { + type Item = QueryMatch<'a, 'tree>; fn next(&mut self) -> Option { - loop { - unsafe { + unsafe { + loop { + let mut m = MaybeUninit::::uninit(); + if ffi::ts_query_cursor_next_match(self.ptr, m.as_mut_ptr()) { + let result = QueryMatch::new(m.assume_init(), self.ptr); + if result.satisfies_text_predicates( + self.query, + &mut self.buffer1, + &mut self.buffer2, + &mut self.text_provider, + ) { + return Some(result); + } + } else { + return None; + } + } + } + } +} + +impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryCaptures<'a, 'tree, T> { + type Item = (QueryMatch<'a, 'tree>, usize); + + fn next(&mut self) -> Option { + unsafe { + loop { let mut capture_index = 0u32; let mut m = MaybeUninit::::uninit(); if ffi::ts_query_cursor_next_capture( @@ -1749,7 +1823,12 @@ impl<'a, 'tree: 'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, 'tree, T> { &mut capture_index as *mut u32, ) { let result = QueryMatch::new(m.assume_init(), self.ptr); - if result.satisfies_text_predicates(self.query, &mut self.text_callback) { + if result.satisfies_text_predicates( + self.query, + &mut self.buffer1, + &mut self.buffer2, + &mut self.text_provider, + ) { return Some((result, capture_index as usize)); } else { result.remove(); @@ -1762,7 +1841,7 @@ impl<'a, 'tree: 'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, 'tree, T> { } } -impl<'a> fmt::Debug for QueryMatch<'a> { +impl<'cursor, 'tree> fmt::Debug for QueryMatch<'cursor, 'tree> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -1772,6 +1851,26 @@ impl<'a> fmt::Debug for QueryMatch<'a> { } } +impl<'a, F, I> TextProvider<'a> for F +where + F: FnMut(Node) -> I, + I: Iterator + 'a, +{ + type I = I; + + fn text(&mut self, node: Node) -> Self::I { + (self)(node) + } +} + +impl<'a> TextProvider<'a> for &'a [u8] { + type I = std::option::IntoIter<&'a [u8]>; + + fn text(&mut self, node: Node) -> Self::I { + Some(&self[node.byte_range()]).into_iter() + } +} + impl PartialEq for Query { fn eq(&self, other: &Self) -> bool { self.ptr == other.ptr @@ -1786,7 +1885,7 @@ impl Drop for Query { impl Drop for QueryCursor { fn drop(&mut self) { - unsafe { ffi::ts_query_cursor_delete(self.0.as_ptr()) } + unsafe { ffi::ts_query_cursor_delete(self.ptr.as_ptr()) } } } diff --git a/tags/src/lib.rs b/tags/src/lib.rs index 89809052..23b877c3 100644 --- a/tags/src/lib.rs +++ b/tags/src/lib.rs @@ -88,7 +88,7 @@ struct LocalScope<'a> { struct TagsIter<'a, I> where - I: Iterator>, + I: Iterator>, { matches: I, _tree: Tree, @@ -265,9 +265,7 @@ impl TagsContext { let tree_ref = unsafe { mem::transmute::<_, &'static Tree>(&tree) }; let matches = self .cursor - .matches(&config.query, tree_ref.root_node(), move |node| { - &source[node.byte_range()] - }); + .matches(&config.query, tree_ref.root_node(), source); Ok(( TagsIter { _tree: tree, @@ -291,7 +289,7 @@ impl TagsContext { impl<'a, I> Iterator for TagsIter<'a, I> where - I: Iterator>, + I: Iterator>, { type Item = Result; From a61f25bc58e3affe81aaacaaf5d9b6150a5e90ef Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 24 May 2021 21:07:59 -0700 Subject: [PATCH 02/12] Add APIs for advancing a QueryCursor to an arbitrary position --- cli/src/tests/query_test.rs | 68 ++++++++++++++++++++++ lib/binding_rust/bindings.rs | 105 +++++++++++++++++++--------------- lib/binding_rust/lib.rs | 28 +++++++++ lib/include/tree_sitter/api.h | 4 +- lib/src/query.c | 23 +++++++- 5 files changed, 178 insertions(+), 50 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 8e2c1313..2245f4f9 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -3036,6 +3036,74 @@ fn test_query_text_callback_returns_chunks() { }); } +#[test] +fn test_query_captures_advance_to_byte() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + (identifier) @id + (array + "[" @lbracket + "]" @rbracket) + "#, + ) + .unwrap(); + let source = "[one, two, [three, four, five, six, seven, eight, nine, ten], eleven, twelve, thirteen]"; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + cursor.set_byte_range( + source.find("two").unwrap() + 1, + source.find(", twelve").unwrap(), + ); + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); + + // Retrieve four captures. + let mut results = Vec::new(); + for (mat, capture_ix) in captures.by_ref().take(4) { + let capture = mat.captures[capture_ix as usize]; + results.push(( + query.capture_names()[capture.index as usize].as_str(), + &source[capture.node.byte_range()], + )); + } + assert_eq!( + results, + vec![ + ("id", "two"), + ("lbracket", "["), + ("id", "three"), + ("id", "four") + ] + ); + + // Advance further ahead in the source, retrieve the remaining captures. + results.clear(); + captures.advance_to_byte(source.find("ten").unwrap() + 1); + for (mat, capture_ix) in captures { + let capture = mat.captures[capture_ix as usize]; + results.push(( + query.capture_names()[capture.index as usize].as_str(), + &source[capture.node.byte_range()], + )); + } + assert_eq!( + results, + vec![("id", "ten"), ("rbracket", "]"), ("id", "eleven"),] + ); + + // Advance past the last capture. There are no more captures. + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); + captures.advance_to_byte(source.len()); + assert!(captures.next().is_none()); + assert!(captures.next().is_none()); + }); +} + #[test] fn test_query_start_byte_for_pattern() { let language = get_language("javascript"); diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 50da12fc..a729c12c 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -1,6 +1,7 @@ -/* automatically generated by rust-bindgen */ +/* automatically generated by rust-bindgen 0.58.1 */ pub type __darwin_size_t = ::std::os::raw::c_ulong; +pub type size_t = usize; pub type FILE = [u64; 19usize]; pub type TSSymbol = u16; pub type TSFieldId = u16; @@ -31,11 +32,11 @@ pub struct TSQueryCursor { } pub const TSInputEncoding_TSInputEncodingUTF8: TSInputEncoding = 0; pub const TSInputEncoding_TSInputEncodingUTF16: TSInputEncoding = 1; -pub type TSInputEncoding = u32; +pub type TSInputEncoding = ::std::os::raw::c_uint; pub const TSSymbolType_TSSymbolTypeRegular: TSSymbolType = 0; pub const TSSymbolType_TSSymbolTypeAnonymous: TSSymbolType = 1; pub const TSSymbolType_TSSymbolTypeAuxiliary: TSSymbolType = 2; -pub type TSSymbolType = u32; +pub type TSSymbolType = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSPoint { @@ -66,7 +67,7 @@ pub struct TSInput { } pub const TSLogType_TSLogTypeParse: TSLogType = 0; pub const TSLogType_TSLogTypeLex: TSLogType = 1; -pub type TSLogType = u32; +pub type TSLogType = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSLogger { @@ -120,7 +121,7 @@ pub struct TSQueryMatch { pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeDone: TSQueryPredicateStepType = 0; pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture: TSQueryPredicateStepType = 1; pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeString: TSQueryPredicateStepType = 2; -pub type TSQueryPredicateStepType = u32; +pub type TSQueryPredicateStepType = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSQueryPredicateStep { @@ -133,7 +134,7 @@ pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2; pub const TSQueryError_TSQueryErrorField: TSQueryError = 3; pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4; pub const TSQueryError_TSQueryErrorStructure: TSQueryError = 5; -pub type TSQueryError = u32; +pub type TSQueryError = ::std::os::raw::c_uint; extern "C" { #[doc = " Create a new parser."] pub fn ts_parser_new() -> *mut TSParser; @@ -148,13 +149,13 @@ extern "C" { #[doc = " Returns a boolean indicating whether or not the language was successfully"] #[doc = " assigned. True means assignment succeeded. False means there was a version"] #[doc = " mismatch: the language was generated with an incompatible version of the"] - #[doc = " Tree-sitter CLI. Check the language\'s version using `ts_language_version`"] - #[doc = " and compare it to this library\'s `TREE_SITTER_LANGUAGE_VERSION` and"] + #[doc = " Tree-sitter CLI. Check the language's version using `ts_language_version`"] + #[doc = " and compare it to this library's `TREE_SITTER_LANGUAGE_VERSION` and"] #[doc = " `TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION` constants."] pub fn ts_parser_set_language(self_: *mut TSParser, language: *const TSLanguage) -> bool; } extern "C" { - #[doc = " Get the parser\'s current language."] + #[doc = " Get the parser's current language."] pub fn ts_parser_language(self_: *const TSParser) -> *const TSLanguage; } extern "C" { @@ -167,7 +168,7 @@ extern "C" { #[doc = ""] #[doc = " The second and third parameters specify the location and length of an array"] #[doc = " of ranges. The parser does *not* take ownership of these ranges; it copies"] - #[doc = " the data, so it doesn\'t matter how these ranges are allocated."] + #[doc = " the data, so it doesn't matter how these ranges are allocated."] #[doc = ""] #[doc = " If `length` is zero, then the entire document will be parsed. Otherwise,"] #[doc = " the given ranges must be ordered from earliest to latest in the document,"] @@ -266,7 +267,7 @@ extern "C" { #[doc = ""] #[doc = " If the parser previously failed because of a timeout or a cancellation, then"] #[doc = " by default, it will resume where it left off on the next call to"] - #[doc = " `ts_parser_parse` or other parsing functions. If you don\'t want to resume,"] + #[doc = " `ts_parser_parse` or other parsing functions. If you don't want to resume,"] #[doc = " and instead intend to use this parser to parse some other document, you must"] #[doc = " call `ts_parser_reset` first."] pub fn ts_parser_reset(self_: *mut TSParser); @@ -284,16 +285,16 @@ extern "C" { pub fn ts_parser_timeout_micros(self_: *const TSParser) -> u64; } extern "C" { - #[doc = " Set the parser\'s current cancellation flag pointer."] + #[doc = " Set the parser's current cancellation flag pointer."] #[doc = ""] #[doc = " If a non-null pointer is assigned, then the parser will periodically read"] #[doc = " from this pointer during parsing. If it reads a non-zero value, it will"] #[doc = " halt early, returning NULL. See `ts_parser_parse` for more information."] - pub fn ts_parser_set_cancellation_flag(self_: *mut TSParser, flag: *const usize); + pub fn ts_parser_set_cancellation_flag(self_: *mut TSParser, flag: *const size_t); } extern "C" { - #[doc = " Get the parser\'s current cancellation flag pointer."] - pub fn ts_parser_cancellation_flag(self_: *const TSParser) -> *const usize; + #[doc = " Get the parser's current cancellation flag pointer."] + pub fn ts_parser_cancellation_flag(self_: *const TSParser) -> *const size_t; } extern "C" { #[doc = " Set the logger that a parser should use during parsing."] @@ -304,7 +305,7 @@ extern "C" { pub fn ts_parser_set_logger(self_: *mut TSParser, logger: TSLogger); } extern "C" { - #[doc = " Get the parser\'s current logger."] + #[doc = " Get the parser's current logger."] pub fn ts_parser_logger(self_: *const TSParser) -> TSLogger; } extern "C" { @@ -346,7 +347,7 @@ extern "C" { #[doc = " document, returning an array of ranges whose syntactic structure has changed."] #[doc = ""] #[doc = " For this to work correctly, the old syntax tree must have been edited such"] - #[doc = " that its ranges match up to the new tree. Generally, you\'ll want to call"] + #[doc = " that its ranges match up to the new tree. Generally, you'll want to call"] #[doc = " this function right after calling one of the `ts_parser_parse` functions."] #[doc = " You need to pass the old tree that was passed to parse, as well as the new"] #[doc = " tree that was returned from that function."] @@ -365,27 +366,27 @@ extern "C" { pub fn ts_tree_print_dot_graph(arg1: *const TSTree, arg2: *mut FILE); } extern "C" { - #[doc = " Get the node\'s type as a null-terminated string."] + #[doc = " Get the node's type as a null-terminated string."] pub fn ts_node_type(arg1: TSNode) -> *const ::std::os::raw::c_char; } extern "C" { - #[doc = " Get the node\'s type as a numerical id."] + #[doc = " Get the node's type as a numerical id."] pub fn ts_node_symbol(arg1: TSNode) -> TSSymbol; } extern "C" { - #[doc = " Get the node\'s start byte."] + #[doc = " Get the node's start byte."] pub fn ts_node_start_byte(arg1: TSNode) -> u32; } extern "C" { - #[doc = " Get the node\'s start position in terms of rows and columns."] + #[doc = " Get the node's start position in terms of rows and columns."] pub fn ts_node_start_point(arg1: TSNode) -> TSPoint; } extern "C" { - #[doc = " Get the node\'s end byte."] + #[doc = " Get the node's end byte."] pub fn ts_node_end_byte(arg1: TSNode) -> u32; } extern "C" { - #[doc = " Get the node\'s end position in terms of rows and columns."] + #[doc = " Get the node's end position in terms of rows and columns."] pub fn ts_node_end_point(arg1: TSNode) -> TSPoint; } extern "C" { @@ -426,11 +427,11 @@ extern "C" { pub fn ts_node_has_error(arg1: TSNode) -> bool; } extern "C" { - #[doc = " Get the node\'s immediate parent."] + #[doc = " Get the node's immediate parent."] pub fn ts_node_parent(arg1: TSNode) -> TSNode; } extern "C" { - #[doc = " Get the node\'s child at the given index, where zero represents the first"] + #[doc = " Get the node's child at the given index, where zero represents the first"] #[doc = " child."] pub fn ts_node_child(arg1: TSNode, arg2: u32) -> TSNode; } @@ -440,23 +441,23 @@ extern "C" { pub fn ts_node_field_name_for_child(arg1: TSNode, arg2: u32) -> *const ::std::os::raw::c_char; } extern "C" { - #[doc = " Get the node\'s number of children."] + #[doc = " Get the node's number of children."] pub fn ts_node_child_count(arg1: TSNode) -> u32; } extern "C" { - #[doc = " Get the node\'s *named* child at the given index."] + #[doc = " Get the node's *named* child at the given index."] #[doc = ""] #[doc = " See also `ts_node_is_named`."] pub fn ts_node_named_child(arg1: TSNode, arg2: u32) -> TSNode; } extern "C" { - #[doc = " Get the node\'s number of *named* children."] + #[doc = " Get the node's number of *named* children."] #[doc = ""] #[doc = " See also `ts_node_is_named`."] pub fn ts_node_named_child_count(arg1: TSNode) -> u32; } extern "C" { - #[doc = " Get the node\'s child with the given field name."] + #[doc = " Get the node's child with the given field name."] pub fn ts_node_child_by_field_name( self_: TSNode, field_name: *const ::std::os::raw::c_char, @@ -464,32 +465,32 @@ extern "C" { ) -> TSNode; } extern "C" { - #[doc = " Get the node\'s child with the given numerical field id."] + #[doc = " Get the node's child with the given numerical field id."] #[doc = ""] #[doc = " You can convert a field name to an id using the"] #[doc = " `ts_language_field_id_for_name` function."] pub fn ts_node_child_by_field_id(arg1: TSNode, arg2: TSFieldId) -> TSNode; } extern "C" { - #[doc = " Get the node\'s next / previous sibling."] + #[doc = " Get the node's next / previous sibling."] pub fn ts_node_next_sibling(arg1: TSNode) -> TSNode; } extern "C" { pub fn ts_node_prev_sibling(arg1: TSNode) -> TSNode; } extern "C" { - #[doc = " Get the node\'s next / previous *named* sibling."] + #[doc = " Get the node's next / previous *named* sibling."] pub fn ts_node_next_named_sibling(arg1: TSNode) -> TSNode; } extern "C" { pub fn ts_node_prev_named_sibling(arg1: TSNode) -> TSNode; } extern "C" { - #[doc = " Get the node\'s first child that extends beyond the given byte offset."] + #[doc = " Get the node's first child that extends beyond the given byte offset."] pub fn ts_node_first_child_for_byte(arg1: TSNode, arg2: u32) -> TSNode; } extern "C" { - #[doc = " Get the node\'s first named child that extends beyond the given byte offset."] + #[doc = " Get the node's first named child that extends beyond the given byte offset."] pub fn ts_node_first_named_child_for_byte(arg1: TSNode, arg2: u32) -> TSNode; } extern "C" { @@ -544,22 +545,22 @@ extern "C" { pub fn ts_tree_cursor_reset(arg1: *mut TSTreeCursor, arg2: TSNode); } extern "C" { - #[doc = " Get the tree cursor\'s current node."] + #[doc = " Get the tree cursor's current node."] pub fn ts_tree_cursor_current_node(arg1: *const TSTreeCursor) -> TSNode; } extern "C" { - #[doc = " Get the field name of the tree cursor\'s current node."] + #[doc = " Get the field name of the tree cursor's current node."] #[doc = ""] - #[doc = " This returns `NULL` if the current node doesn\'t have a field."] + #[doc = " This returns `NULL` if the current node doesn't have a field."] #[doc = " See also `ts_node_child_by_field_name`."] pub fn ts_tree_cursor_current_field_name( arg1: *const TSTreeCursor, ) -> *const ::std::os::raw::c_char; } extern "C" { - #[doc = " Get the field name of the tree cursor\'s current node."] + #[doc = " Get the field name of the tree cursor's current node."] #[doc = ""] - #[doc = " This returns zero if the current node doesn\'t have a field."] + #[doc = " This returns zero if the current node doesn't have a field."] #[doc = " See also `ts_node_child_by_field_id`, `ts_language_field_id_for_name`."] pub fn ts_tree_cursor_current_field_id(arg1: *const TSTreeCursor) -> TSFieldId; } @@ -628,7 +629,7 @@ extern "C" { pub fn ts_query_string_count(arg1: *const TSQuery) -> u32; } extern "C" { - #[doc = " Get the byte offset where the given pattern starts in the query\'s source."] + #[doc = " Get the byte offset where the given pattern starts in the query's source."] #[doc = ""] #[doc = " This can be useful when combining queries by concatenating their source"] #[doc = " code strings."] @@ -659,9 +660,9 @@ extern "C" { pub fn ts_query_step_is_definite(self_: *const TSQuery, byte_offset: u32) -> bool; } extern "C" { - #[doc = " Get the name and length of one of the query\'s captures, or one of the"] - #[doc = " query\'s string literals. Each capture and string is associated with a"] - #[doc = " numeric id based on the order that it appeared in the query\'s source."] + #[doc = " Get the name and length of one of the query's captures, or one of the"] + #[doc = " query's string literals. Each capture and string is associated with a"] + #[doc = " numeric id based on the order that it appeared in the query's source."] pub fn ts_query_capture_name_for_id( arg1: *const TSQuery, id: u32, @@ -708,10 +709,10 @@ extern "C" { #[doc = " captures that appear *before* some of the captures from a previous match."] #[doc = " 2. Repeatedly call `ts_query_cursor_next_capture` to iterate over all of the"] #[doc = " individual *captures* in the order that they appear. This is useful if"] - #[doc = " don\'t care about which pattern matched, and just want a single ordered"] + #[doc = " don't care about which pattern matched, and just want a single ordered"] #[doc = " sequence of captures."] #[doc = ""] - #[doc = " If you don\'t care about consuming all of the results, you can stop calling"] + #[doc = " If you don't care about consuming all of the results, you can stop calling"] #[doc = " `ts_query_cursor_next_match` or `ts_query_cursor_next_capture` at any point."] #[doc = " You can then start executing another query on another node by calling"] #[doc = " `ts_query_cursor_exec` again."] @@ -736,8 +737,18 @@ extern "C" { pub fn ts_query_cursor_did_exceed_match_limit(arg1: *const TSQueryCursor) -> bool; } extern "C" { - #[doc = " Set the range of bytes or (row, column) positions in which the query"] + #[doc = " Get or set the range of bytes or (row, column) positions in which the query"] #[doc = " will be executed."] + pub fn ts_query_cursor_byte_range(arg1: *const TSQueryCursor, arg2: *mut u32, arg3: *mut u32); +} +extern "C" { + pub fn ts_query_cursor_point_range( + arg1: *const TSQueryCursor, + arg2: *mut TSPoint, + arg3: *mut TSPoint, + ); +} +extern "C" { pub fn ts_query_cursor_set_byte_range(arg1: *mut TSQueryCursor, arg2: u32, arg3: u32); } extern "C" { @@ -757,7 +768,7 @@ extern "C" { #[doc = " Advance to the next capture of the currently running query."] #[doc = ""] #[doc = " If there is a capture, write its match to `*match` and its index within"] - #[doc = " the matche\'s capture list to `*capture_index`. Otherwise, return `false`."] + #[doc = " the matche's capture list to `*capture_index`. Otherwise, return `false`."] pub fn ts_query_cursor_next_capture( arg1: *mut TSQueryCursor, match_: *mut TSQueryMatch, diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 92e3c9d6..ea99c067 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1809,6 +1809,34 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> { } } +impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> { + pub fn advance_to_byte(&mut self, offset: usize) { + unsafe { + let mut current_start = 0u32; + let mut current_end = 0u32; + ffi::ts_query_cursor_byte_range( + self.ptr, + &mut current_start as *mut u32, + &mut current_end as *mut u32, + ); + ffi::ts_query_cursor_set_byte_range(self.ptr, offset as u32, current_end); + } + } + + pub fn advance_to_point(&mut self, point: Point) { + unsafe { + let mut current_start = ffi::TSPoint { row: 0, column: 0 }; + let mut current_end = current_start; + ffi::ts_query_cursor_point_range( + self.ptr, + &mut current_start as *mut _, + &mut current_end as *mut _, + ); + ffi::ts_query_cursor_set_point_range(self.ptr, point.into(), current_end); + } + } +} + impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryCaptures<'a, 'tree, T> { type Item = (QueryMatch<'a, 'tree>, usize); diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 43315415..01d84a6e 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -809,9 +809,11 @@ void ts_query_cursor_exec(TSQueryCursor *, const TSQuery *, TSNode); bool ts_query_cursor_did_exceed_match_limit(const TSQueryCursor *); /** - * Set the range of bytes or (row, column) positions in which the query + * Get or set the range of bytes or (row, column) positions in which the query * will be executed. */ +void ts_query_cursor_byte_range(const TSQueryCursor *, uint32_t *, uint32_t *); +void ts_query_cursor_point_range(const TSQueryCursor *, TSPoint *, TSPoint *); void ts_query_cursor_set_byte_range(TSQueryCursor *, uint32_t, uint32_t); void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); diff --git a/lib/src/query.c b/lib/src/query.c index 65dbe1fe..278b3a3c 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -2302,6 +2302,24 @@ void ts_query_cursor_exec( self->did_exceed_match_limit = false; } +void ts_query_cursor_byte_range( + const TSQueryCursor *self, + uint32_t *start_byte, + uint32_t *end_byte +) { + *start_byte = self->start_byte; + *end_byte = self->end_byte; +} + +void ts_query_cursor_point_range( + const TSQueryCursor *self, + TSPoint *start_point, + TSPoint *end_point +) { + *start_point = self->start_point; + *end_point = self->end_point; +} + void ts_query_cursor_set_byte_range( TSQueryCursor *self, uint32_t start_byte, @@ -2639,7 +2657,7 @@ static inline bool ts_query_cursor__advance( } else if (ts_tree_cursor_goto_parent(&self->cursor)) { self->depth--; } else { - LOG("halt at root"); + LOG("halt at root\n"); self->halted = true; } @@ -2696,6 +2714,7 @@ static inline bool ts_query_cursor__advance( if (!ts_tree_cursor_goto_next_sibling(&self->cursor)) { self->ascending = true; } + LOG("skip until start of range\n"); continue; } @@ -2704,7 +2723,7 @@ static inline bool ts_query_cursor__advance( self->end_byte <= ts_node_start_byte(node) || point_lte(self->end_point, ts_node_start_point(node)) ) { - LOG("halt at end of range"); + LOG("halt at end of range\n"); self->halted = true; continue; } From f597cc6a75eb0ed9f4b815ce1a11545145949bfa Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 25 May 2021 13:06:24 -0700 Subject: [PATCH 03/12] Preserve matches that contain the QueryCursor's start byte Co-Authored-By: Nathan Sobo Co-Authored-By: Antonio Scandurra --- cli/src/tests/query_test.rs | 101 +++++++++++--- lib/binding_rust/bindings.rs | 15 +-- lib/binding_rust/lib.rs | 22 +--- lib/include/tree_sitter/api.h | 6 +- lib/src/query.c | 241 +++++++++++++++++----------------- 5 files changed, 214 insertions(+), 171 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 2245f4f9..4fde4f2a 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -3039,32 +3039,49 @@ fn test_query_text_callback_returns_chunks() { #[test] fn test_query_captures_advance_to_byte() { allocations::record(|| { - let language = get_language("javascript"); + let language = get_language("rust"); let query = Query::new( language, r#" - (identifier) @id - (array - "[" @lbracket - "]" @rbracket) + (function_item + name: (identifier) @fn_name) + + (mod_item + name: (identifier) @mod_name + body: (declaration_list + "{" @lbrace + "}" @rbrace)) + + ; functions that return Result<()> + ((function_item + return_type: (generic_type + type: (type_identifier) @result + type_arguments: (type_arguments + (unit_type))) + body: _ @fallible_fn_body) + (#eq? @result "Result")) "#, ) .unwrap(); - let source = "[one, two, [three, four, five, six, seven, eight, nine, ten], eleven, twelve, thirteen]"; + let source = " + mod m1 { + mod m2 { + fn f1() -> Option<()> { Some(()) } + } + fn f2() -> Result<()> { Ok(()) } + fn f3() {} + } + "; let mut parser = Parser::new(); parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - cursor.set_byte_range( - source.find("two").unwrap() + 1, - source.find(", twelve").unwrap(), - ); let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); - // Retrieve four captures. + // Retrieve some captures let mut results = Vec::new(); - for (mat, capture_ix) in captures.by_ref().take(4) { + for (mat, capture_ix) in captures.by_ref().take(5) { let capture = mat.captures[capture_ix as usize]; results.push(( query.capture_names()[capture.index as usize].as_str(), @@ -3074,16 +3091,18 @@ fn test_query_captures_advance_to_byte() { assert_eq!( results, vec![ - ("id", "two"), - ("lbracket", "["), - ("id", "three"), - ("id", "four") + ("mod_name", "m1"), + ("lbrace", "{"), + ("mod_name", "m2"), + ("lbrace", "{"), + ("fn_name", "f1"), ] ); - // Advance further ahead in the source, retrieve the remaining captures. results.clear(); - captures.advance_to_byte(source.find("ten").unwrap() + 1); + captures.advance_to_byte(source.find("Ok").unwrap()); + + // Advance further ahead in the source, retrieve the remaining captures. for (mat, capture_ix) in captures { let capture = mat.captures[capture_ix as usize]; results.push(( @@ -3093,7 +3112,11 @@ fn test_query_captures_advance_to_byte() { } assert_eq!( results, - vec![("id", "ten"), ("rbracket", "]"), ("id", "eleven"),] + vec![ + ("fallible_fn_body", "{ Ok(()) }"), + ("fn_name", "f3"), + ("rbrace", "}") + ] ); // Advance past the last capture. There are no more captures. @@ -3104,6 +3127,46 @@ fn test_query_captures_advance_to_byte() { }); } +#[test] +fn test_query_advance_to_byte_within_node() { + allocations::record(|| { + let language = get_language("rust"); + let query = Query::new( + language, + r#" + (fn_item + name: (identifier) @name + return_type: _? @ret) + + (mod_item + name: (identifier) @name + body: _ @body) + "#, + ) + .unwrap(); + let source = " + fn foo() -> i32 {} + + ... + + mod foo {} + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); + + captures.advance_to_byte(source.find("{").unwrap()); + + assert_eq!( + collect_captures(captures, &query, source), + &[("body", "{}"),] + ); + }) +} + #[test] fn test_query_start_byte_for_pattern() { let language = get_language("javascript"); diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index a729c12c..dccc9aca 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -737,18 +737,8 @@ extern "C" { pub fn ts_query_cursor_did_exceed_match_limit(arg1: *const TSQueryCursor) -> bool; } extern "C" { - #[doc = " Get or set the range of bytes or (row, column) positions in which the query"] + #[doc = " Set the range of bytes or (row, column) positions in which the query"] #[doc = " will be executed."] - pub fn ts_query_cursor_byte_range(arg1: *const TSQueryCursor, arg2: *mut u32, arg3: *mut u32); -} -extern "C" { - pub fn ts_query_cursor_point_range( - arg1: *const TSQueryCursor, - arg2: *mut TSPoint, - arg3: *mut TSPoint, - ); -} -extern "C" { pub fn ts_query_cursor_set_byte_range(arg1: *mut TSQueryCursor, arg2: u32, arg3: u32); } extern "C" { @@ -764,6 +754,9 @@ extern "C" { extern "C" { pub fn ts_query_cursor_remove_match(arg1: *mut TSQueryCursor, id: u32); } +extern "C" { + pub fn ts_query_cursor_advance_to_byte(arg1: *mut TSQueryCursor, offset: u32); +} extern "C" { #[doc = " Advance to the next capture of the currently running query."] #[doc = ""] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index ea99c067..da2b3252 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1812,27 +1812,7 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> { impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> { pub fn advance_to_byte(&mut self, offset: usize) { unsafe { - let mut current_start = 0u32; - let mut current_end = 0u32; - ffi::ts_query_cursor_byte_range( - self.ptr, - &mut current_start as *mut u32, - &mut current_end as *mut u32, - ); - ffi::ts_query_cursor_set_byte_range(self.ptr, offset as u32, current_end); - } - } - - pub fn advance_to_point(&mut self, point: Point) { - unsafe { - let mut current_start = ffi::TSPoint { row: 0, column: 0 }; - let mut current_end = current_start; - ffi::ts_query_cursor_point_range( - self.ptr, - &mut current_start as *mut _, - &mut current_end as *mut _, - ); - ffi::ts_query_cursor_set_point_range(self.ptr, point.into(), current_end); + ffi::ts_query_cursor_advance_to_byte(self.ptr, offset as u32); } } } diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 01d84a6e..6889a121 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -809,11 +809,9 @@ void ts_query_cursor_exec(TSQueryCursor *, const TSQuery *, TSNode); bool ts_query_cursor_did_exceed_match_limit(const TSQueryCursor *); /** - * Get or set the range of bytes or (row, column) positions in which the query + * Set the range of bytes or (row, column) positions in which the query * will be executed. */ -void ts_query_cursor_byte_range(const TSQueryCursor *, uint32_t *, uint32_t *); -void ts_query_cursor_point_range(const TSQueryCursor *, TSPoint *, TSPoint *); void ts_query_cursor_set_byte_range(TSQueryCursor *, uint32_t, uint32_t); void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); @@ -826,6 +824,8 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match); void ts_query_cursor_remove_match(TSQueryCursor *, uint32_t id); +void ts_query_cursor_advance_to_byte(TSQueryCursor *, uint32_t offset); + /** * Advance to the next capture of the currently running query. * diff --git a/lib/src/query.c b/lib/src/query.c index 278b3a3c..d70e5afd 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -256,10 +256,7 @@ struct TSQueryCursor { CaptureListPool capture_list_pool; uint32_t depth; uint32_t start_byte; - uint32_t end_byte; uint32_t next_state_id; - TSPoint start_point; - TSPoint end_point; bool ascending; bool halted; bool did_exceed_match_limit; @@ -2264,9 +2261,6 @@ TSQueryCursor *ts_query_cursor_new(void) { .finished_states = array_new(), .capture_list_pool = capture_list_pool_new(), .start_byte = 0, - .end_byte = UINT32_MAX, - .start_point = {0, 0}, - .end_point = POINT_MAX, }; array_reserve(&self->states, 8); array_reserve(&self->finished_states, 8); @@ -2296,40 +2290,18 @@ void ts_query_cursor_exec( capture_list_pool_reset(&self->capture_list_pool); self->next_state_id = 0; self->depth = 0; + self->start_byte = 0; self->ascending = false; self->halted = false; self->query = query; self->did_exceed_match_limit = false; } -void ts_query_cursor_byte_range( - const TSQueryCursor *self, - uint32_t *start_byte, - uint32_t *end_byte -) { - *start_byte = self->start_byte; - *end_byte = self->end_byte; -} - -void ts_query_cursor_point_range( - const TSQueryCursor *self, - TSPoint *start_point, - TSPoint *end_point -) { - *start_point = self->start_point; - *end_point = self->end_point; -} - void ts_query_cursor_set_byte_range( TSQueryCursor *self, uint32_t start_byte, uint32_t end_byte ) { - if (end_byte == 0) { - end_byte = UINT32_MAX; - } - self->start_byte = start_byte; - self->end_byte = end_byte; } void ts_query_cursor_set_point_range( @@ -2337,11 +2309,6 @@ void ts_query_cursor_set_point_range( TSPoint start_point, TSPoint end_point ) { - if (end_point.row == 0 && end_point.column == 0) { - end_point = POINT_MAX; - } - self->start_point = start_point; - self->end_point = end_point; } // Search through all of the in-progress states, and find the captured @@ -2358,31 +2325,41 @@ static bool ts_query_cursor__first_in_progress_capture( *byte_offset = UINT32_MAX; *pattern_index = UINT32_MAX; for (unsigned i = 0; i < self->states.size; i++) { - const QueryState *state = &self->states.contents[i]; + QueryState *state = &self->states.contents[i]; if (state->dead) continue; + const CaptureList *captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); - if (captures->size > state->consumed_capture_count) { - uint32_t capture_byte = ts_node_start_byte(captures->contents[state->consumed_capture_count].node); - if ( - !result || - capture_byte < *byte_offset || - (capture_byte == *byte_offset && state->pattern_index < *pattern_index) - ) { - QueryStep *step = &self->query->steps.contents[state->step_index]; - if (is_definite) { - *is_definite = step->is_definite; - } else if (step->is_definite) { - continue; - } + if (state->consumed_capture_count >= captures->size) { + continue; + } - result = true; - *state_index = i; - *byte_offset = capture_byte; - *pattern_index = state->pattern_index; + TSNode node = captures->contents[state->consumed_capture_count].node; + if (ts_node_end_byte(node) <= self->start_byte) { + state->consumed_capture_count++; + i--; + continue; + } + + uint32_t node_start_byte = ts_node_start_byte(node); + if ( + !result || + node_start_byte < *byte_offset || + (node_start_byte == *byte_offset && state->pattern_index < *pattern_index) + ) { + QueryStep *step = &self->query->steps.contents[state->step_index]; + if (is_definite) { + *is_definite = step->is_definite; + } else if (step->is_definite) { + continue; } + + result = true; + *state_index = i; + *byte_offset = node_start_byte; + *pattern_index = state->pattern_index; } } return result; @@ -2707,26 +2684,8 @@ static inline bool ts_query_cursor__advance( else { // If this node is before the selected range, then avoid descending into it. TSNode node = ts_tree_cursor_current_node(&self->cursor); - if ( - ts_node_end_byte(node) <= self->start_byte || - point_lte(ts_node_end_point(node), self->start_point) - ) { - if (!ts_tree_cursor_goto_next_sibling(&self->cursor)) { - self->ascending = true; - } - LOG("skip until start of range\n"); - continue; - } - // If this node is after the selected range, then stop walking. - if ( - self->end_byte <= ts_node_start_byte(node) || - point_lte(self->end_point, ts_node_start_point(node)) - ) { - LOG("halt at end of range\n"); - self->halted = true; - continue; - } + bool node_exceeds_start_byte = ts_node_end_byte(node) > self->start_byte; // Get the properties of the current node. TSSymbol symbol = ts_node_symbol(node); @@ -2755,36 +2714,44 @@ static inline bool ts_query_cursor__advance( self->finished_states.size ); - // Add new states for any patterns whose root node is a wildcard. - for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { - PatternEntry *pattern = &self->query->pattern_map.contents[i]; - QueryStep *step = &self->query->steps.contents[pattern->step_index]; + if (node_exceeds_start_byte) { + // Add new states for any patterns whose root node is a wildcard. + for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { + PatternEntry *pattern = &self->query->pattern_map.contents[i]; + QueryStep *step = &self->query->steps.contents[pattern->step_index]; - // If this node matches the first step of the pattern, then add a new - // state at the start of this pattern. - if (step->field && field_id != step->field) continue; - if (step->supertype_symbol && !supertype_count) continue; - ts_query_cursor__add_state(self, pattern); - } - - // Add new states for any patterns whose root node matches this node. - unsigned i; - if (ts_query__pattern_map_search(self->query, symbol, &i)) { - PatternEntry *pattern = &self->query->pattern_map.contents[i]; - QueryStep *step = &self->query->steps.contents[pattern->step_index]; - do { // If this node matches the first step of the pattern, then add a new // state at the start of this pattern. - if (!step->field || field_id == step->field) { - ts_query_cursor__add_state(self, pattern); - } + if (step->field && field_id != step->field) continue; + if (step->supertype_symbol && !supertype_count) continue; + ts_query_cursor__add_state(self, pattern); + } - // Advance to the next pattern whose root node matches this node. - i++; - if (i == self->query->pattern_map.size) break; - pattern = &self->query->pattern_map.contents[i]; - step = &self->query->steps.contents[pattern->step_index]; - } while (step->symbol == symbol); + // Add new states for any patterns whose root node matches this node. + unsigned i; + if (ts_query__pattern_map_search(self->query, symbol, &i)) { + PatternEntry *pattern = &self->query->pattern_map.contents[i]; + QueryStep *step = &self->query->steps.contents[pattern->step_index]; + do { + // If this node matches the first step of the pattern, then add a new + // state at the start of this pattern. + if (!step->field || field_id == step->field) { + ts_query_cursor__add_state(self, pattern); + } + + // Advance to the next pattern whose root node matches this node. + i++; + if (i == self->query->pattern_map.size) break; + pattern = &self->query->pattern_map.contents[i]; + step = &self->query->steps.contents[pattern->step_index]; + } while (step->symbol == symbol); + } + } else { + LOG( + " not starting new patterns. node end byte: %u, start_byte: %u\n", + ts_node_end_byte(node), + self->start_byte + ); } // Update all of the in-progress states with current node. @@ -3070,8 +3037,32 @@ static inline bool ts_query_cursor__advance( } } - // Continue descending if possible. - if (ts_tree_cursor_goto_first_child(&self->cursor)) { + // When the current node ends prior to the desired start offset, + // only descend for the purpose of continuing in-progress matches. + bool should_descend = node_exceeds_start_byte; + if (!should_descend) { + for (unsigned i = 0; i < self->states.size; i++) { + QueryState *state = &self->states.contents[i];; + QueryStep *next_step = &self->query->steps.contents[state->step_index]; + if ( + next_step->depth != PATTERN_DONE_MARKER && + state->start_depth + next_step->depth > self->depth + ) { + should_descend = true; + break; + } + } + } + + if (!should_descend) { + LOG( + " not descending. node end byte: %u, start byte: %u\n", + ts_node_end_byte(node), + self->start_byte + ); + } + + if (should_descend && ts_tree_cursor_goto_first_child(&self->cursor)) { self->depth++; } else { self->ascending = true; @@ -3080,6 +3071,14 @@ static inline bool ts_query_cursor__advance( } } +void ts_query_cursor_advance_to_byte( + TSQueryCursor *self, + uint32_t offset +) { + LOG("advance_to_byte %u\n", offset); + self->start_byte = offset; +} + bool ts_query_cursor_next_match( TSQueryCursor *self, TSQueryMatch *match @@ -3148,35 +3147,43 @@ bool ts_query_cursor_next_capture( QueryState *first_finished_state = NULL; uint32_t first_finished_capture_byte = first_unfinished_capture_byte; uint32_t first_finished_pattern_index = first_unfinished_pattern_index; - for (unsigned i = 0; i < self->finished_states.size; i++) { + for (unsigned i = 0; i < self->finished_states.size;) { QueryState *state = &self->finished_states.contents[i]; const CaptureList *captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); - if (captures->size > state->consumed_capture_count) { - uint32_t capture_byte = ts_node_start_byte( - captures->contents[state->consumed_capture_count].node - ); - if ( - capture_byte < first_finished_capture_byte || - ( - capture_byte == first_finished_capture_byte && - state->pattern_index < first_finished_pattern_index - ) - ) { - first_finished_state = state; - first_finished_capture_byte = capture_byte; - first_finished_pattern_index = state->pattern_index; - } - } else { + + // Remove states whose captures are all consumed. + if (state->consumed_capture_count >= captures->size) { capture_list_pool_release( &self->capture_list_pool, state->capture_list_id ); array_erase(&self->finished_states, i); - i--; + continue; } + + // Skip captures that precede the cursor's start byte. + TSNode node = captures->contents[state->consumed_capture_count].node; + if (ts_node_end_byte(node) <= self->start_byte) { + state->consumed_capture_count++; + continue; + } + + uint32_t node_start_byte = ts_node_start_byte(node); + if ( + node_start_byte < first_finished_capture_byte || + ( + node_start_byte == first_finished_capture_byte && + state->pattern_index < first_finished_pattern_index + ) + ) { + first_finished_state = state; + first_finished_capture_byte = node_start_byte; + first_finished_pattern_index = state->pattern_index; + } + i++; } // If there is finished capture that is clearly before any unfinished From fda35894d4c9c1cd078e6275ea31884177f456ba Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 25 May 2021 13:11:22 -0700 Subject: [PATCH 04/12] Stop matching new patterns past the end of QueryCursor's range This restores the original signatures of the `set_byte_range` and `set_point_range` functions. Now, the QueryCursor will properly report matches that intersect, but are not fully contained by its range. Co-Authored-By: Nathan Sobo --- cli/src/tests/query_test.rs | 217 ++++++++++++++-------------------- lib/binding_rust/lib.rs | 4 +- lib/include/tree_sitter/api.h | 2 - lib/src/point.h | 4 + lib/src/query.c | 47 +++++--- 5 files changed, 122 insertions(+), 152 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 4fde4f2a..daa3b04b 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1918,6 +1918,92 @@ fn test_query_captures_within_byte_range() { }); } +#[test] +fn test_query_captures_within_byte_range_assigned_after_iterating() { + allocations::record(|| { + let language = get_language("rust"); + let query = Query::new( + language, + r#" + (function_item + name: (identifier) @fn_name) + + (mod_item + name: (identifier) @mod_name + body: (declaration_list + "{" @lbrace + "}" @rbrace)) + + ; functions that return Result<()> + ((function_item + return_type: (generic_type + type: (type_identifier) @result + type_arguments: (type_arguments + (unit_type))) + body: _ @fallible_fn_body) + (#eq? @result "Result")) + "#, + ) + .unwrap(); + let source = " + mod m1 { + mod m2 { + fn f1() -> Option<()> { Some(()) } + } + fn f2() -> Result<()> { Ok(()) } + fn f3() {} + } + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); + + // Retrieve some captures + let mut results = Vec::new(); + for (mat, capture_ix) in captures.by_ref().take(5) { + let capture = mat.captures[capture_ix as usize]; + results.push(( + query.capture_names()[capture.index as usize].as_str(), + &source[capture.node.byte_range()], + )); + } + assert_eq!( + results, + vec![ + ("mod_name", "m1"), + ("lbrace", "{"), + ("mod_name", "m2"), + ("lbrace", "{"), + ("fn_name", "f1"), + ] + ); + + // Advance to a range that only partially intersects some matches. + // Captures from these matches are reported, but only those that + // intersect the range. + results.clear(); + captures.set_byte_range(source.find("Ok").unwrap(), source.len()); + for (mat, capture_ix) in captures { + let capture = mat.captures[capture_ix as usize]; + results.push(( + query.capture_names()[capture.index as usize].as_str(), + &source[capture.node.byte_range()], + )); + } + assert_eq!( + results, + vec![ + ("fallible_fn_body", "{ Ok(()) }"), + ("fn_name", "f3"), + ("rbrace", "}") + ] + ); + }); +} + #[test] fn test_query_matches_different_queries_same_cursor() { allocations::record(|| { @@ -3036,137 +3122,6 @@ fn test_query_text_callback_returns_chunks() { }); } -#[test] -fn test_query_captures_advance_to_byte() { - allocations::record(|| { - let language = get_language("rust"); - let query = Query::new( - language, - r#" - (function_item - name: (identifier) @fn_name) - - (mod_item - name: (identifier) @mod_name - body: (declaration_list - "{" @lbrace - "}" @rbrace)) - - ; functions that return Result<()> - ((function_item - return_type: (generic_type - type: (type_identifier) @result - type_arguments: (type_arguments - (unit_type))) - body: _ @fallible_fn_body) - (#eq? @result "Result")) - "#, - ) - .unwrap(); - let source = " - mod m1 { - mod m2 { - fn f1() -> Option<()> { Some(()) } - } - fn f2() -> Result<()> { Ok(()) } - fn f3() {} - } - "; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(&source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); - - // Retrieve some captures - let mut results = Vec::new(); - for (mat, capture_ix) in captures.by_ref().take(5) { - let capture = mat.captures[capture_ix as usize]; - results.push(( - query.capture_names()[capture.index as usize].as_str(), - &source[capture.node.byte_range()], - )); - } - assert_eq!( - results, - vec![ - ("mod_name", "m1"), - ("lbrace", "{"), - ("mod_name", "m2"), - ("lbrace", "{"), - ("fn_name", "f1"), - ] - ); - - results.clear(); - captures.advance_to_byte(source.find("Ok").unwrap()); - - // Advance further ahead in the source, retrieve the remaining captures. - for (mat, capture_ix) in captures { - let capture = mat.captures[capture_ix as usize]; - results.push(( - query.capture_names()[capture.index as usize].as_str(), - &source[capture.node.byte_range()], - )); - } - assert_eq!( - results, - vec![ - ("fallible_fn_body", "{ Ok(()) }"), - ("fn_name", "f3"), - ("rbrace", "}") - ] - ); - - // Advance past the last capture. There are no more captures. - let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); - captures.advance_to_byte(source.len()); - assert!(captures.next().is_none()); - assert!(captures.next().is_none()); - }); -} - -#[test] -fn test_query_advance_to_byte_within_node() { - allocations::record(|| { - let language = get_language("rust"); - let query = Query::new( - language, - r#" - (fn_item - name: (identifier) @name - return_type: _? @ret) - - (mod_item - name: (identifier) @name - body: _ @body) - "#, - ) - .unwrap(); - let source = " - fn foo() -> i32 {} - - ... - - mod foo {} - "; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(&source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); - - captures.advance_to_byte(source.find("{").unwrap()); - - assert_eq!( - collect_captures(captures, &query, source), - &[("body", "{}"),] - ); - }) -} - #[test] fn test_query_start_byte_for_pattern() { let language = get_language("javascript"); diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index da2b3252..88124f08 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1810,9 +1810,9 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> { } impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> { - pub fn advance_to_byte(&mut self, offset: usize) { + pub fn set_byte_range(&mut self, start: usize, end: usize) { unsafe { - ffi::ts_query_cursor_advance_to_byte(self.ptr, offset as u32); + ffi::ts_query_cursor_set_byte_range(self.ptr, start as u32, end as u32); } } } diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 6889a121..43315415 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -824,8 +824,6 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match); void ts_query_cursor_remove_match(TSQueryCursor *, uint32_t id); -void ts_query_cursor_advance_to_byte(TSQueryCursor *, uint32_t offset); - /** * Advance to the next capture of the currently running query. * diff --git a/lib/src/point.h b/lib/src/point.h index a50d2021..c3bf3c26 100644 --- a/lib/src/point.h +++ b/lib/src/point.h @@ -33,6 +33,10 @@ static inline bool point_lt(TSPoint a, TSPoint b) { return (a.row < b.row) || (a.row == b.row && a.column < b.column); } +static inline bool point_gt(TSPoint a, TSPoint b) { + return (a.row > b.row) || (a.row == b.row && a.column > b.column); +} + static inline bool point_eq(TSPoint a, TSPoint b) { return a.row == b.row && a.column == b.column; } diff --git a/lib/src/query.c b/lib/src/query.c index d70e5afd..00f66ec0 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -256,6 +256,9 @@ struct TSQueryCursor { CaptureListPool capture_list_pool; uint32_t depth; uint32_t start_byte; + uint32_t end_byte; + TSPoint start_point; + TSPoint end_point; uint32_t next_state_id; bool ascending; bool halted; @@ -2261,6 +2264,9 @@ TSQueryCursor *ts_query_cursor_new(void) { .finished_states = array_new(), .capture_list_pool = capture_list_pool_new(), .start_byte = 0, + .end_byte = UINT32_MAX, + .start_point = {0, 0}, + .end_point = POINT_MAX, }; array_reserve(&self->states, 8); array_reserve(&self->finished_states, 8); @@ -2290,7 +2296,6 @@ void ts_query_cursor_exec( capture_list_pool_reset(&self->capture_list_pool); self->next_state_id = 0; self->depth = 0; - self->start_byte = 0; self->ascending = false; self->halted = false; self->query = query; @@ -2302,6 +2307,11 @@ void ts_query_cursor_set_byte_range( uint32_t start_byte, uint32_t end_byte ) { + if (end_byte == 0) { + end_byte = UINT32_MAX; + } + self->start_byte = start_byte; + self->end_byte = end_byte; } void ts_query_cursor_set_point_range( @@ -2309,6 +2319,11 @@ void ts_query_cursor_set_point_range( TSPoint start_point, TSPoint end_point ) { + if (end_point.row == 0 && end_point.column == 0) { + end_point = POINT_MAX; + } + self->start_point = start_point; + self->end_point = end_point; } // Search through all of the in-progress states, and find the captured @@ -2337,7 +2352,10 @@ static bool ts_query_cursor__first_in_progress_capture( } TSNode node = captures->contents[state->consumed_capture_count].node; - if (ts_node_end_byte(node) <= self->start_byte) { + if ( + ts_node_end_byte(node) <= self->start_byte || + point_lte(ts_node_end_point(node), self->start_point) + ) { state->consumed_capture_count++; i--; continue; @@ -2682,12 +2700,8 @@ static inline bool ts_query_cursor__advance( // Enter a new node. else { - // If this node is before the selected range, then avoid descending into it. - TSNode node = ts_tree_cursor_current_node(&self->cursor); - - bool node_exceeds_start_byte = ts_node_end_byte(node) > self->start_byte; - // Get the properties of the current node. + TSNode node = ts_tree_cursor_current_node(&self->cursor); TSSymbol symbol = ts_node_symbol(node); bool is_named = ts_node_is_named(node); bool has_later_siblings; @@ -2714,7 +2728,14 @@ static inline bool ts_query_cursor__advance( self->finished_states.size ); - if (node_exceeds_start_byte) { + bool node_intersects_range = ( + ts_node_end_byte(node) > self->start_byte && + ts_node_start_byte(node) < self->end_byte && + point_gt(ts_node_end_point(node), self->start_point) && + point_lt(ts_node_start_point(node), self->end_point) + ); + + if (node_intersects_range) { // Add new states for any patterns whose root node is a wildcard. for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { PatternEntry *pattern = &self->query->pattern_map.contents[i]; @@ -3039,7 +3060,7 @@ static inline bool ts_query_cursor__advance( // When the current node ends prior to the desired start offset, // only descend for the purpose of continuing in-progress matches. - bool should_descend = node_exceeds_start_byte; + bool should_descend = node_intersects_range; if (!should_descend) { for (unsigned i = 0; i < self->states.size; i++) { QueryState *state = &self->states.contents[i];; @@ -3071,14 +3092,6 @@ static inline bool ts_query_cursor__advance( } } -void ts_query_cursor_advance_to_byte( - TSQueryCursor *self, - uint32_t offset -) { - LOG("advance_to_byte %u\n", offset); - self->start_byte = offset; -} - bool ts_query_cursor_next_match( TSQueryCursor *self, TSQueryMatch *match From 036aceed574c2c23eee8f0ff90be5a2409e524c1 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 25 May 2021 17:58:30 -0700 Subject: [PATCH 05/12] In script/generate-bindings, add flags for latest bindgen --- lib/binding_rust/bindings.rs | 18 +++++++----------- script/generate-bindings | 14 ++++++++------ 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index dccc9aca..22a6cea0 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -1,7 +1,6 @@ /* automatically generated by rust-bindgen 0.58.1 */ pub type __darwin_size_t = ::std::os::raw::c_ulong; -pub type size_t = usize; pub type FILE = [u64; 19usize]; pub type TSSymbol = u16; pub type TSFieldId = u16; @@ -32,11 +31,11 @@ pub struct TSQueryCursor { } pub const TSInputEncoding_TSInputEncodingUTF8: TSInputEncoding = 0; pub const TSInputEncoding_TSInputEncodingUTF16: TSInputEncoding = 1; -pub type TSInputEncoding = ::std::os::raw::c_uint; +pub type TSInputEncoding = u32; pub const TSSymbolType_TSSymbolTypeRegular: TSSymbolType = 0; pub const TSSymbolType_TSSymbolTypeAnonymous: TSSymbolType = 1; pub const TSSymbolType_TSSymbolTypeAuxiliary: TSSymbolType = 2; -pub type TSSymbolType = ::std::os::raw::c_uint; +pub type TSSymbolType = u32; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSPoint { @@ -67,7 +66,7 @@ pub struct TSInput { } pub const TSLogType_TSLogTypeParse: TSLogType = 0; pub const TSLogType_TSLogTypeLex: TSLogType = 1; -pub type TSLogType = ::std::os::raw::c_uint; +pub type TSLogType = u32; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSLogger { @@ -121,7 +120,7 @@ pub struct TSQueryMatch { pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeDone: TSQueryPredicateStepType = 0; pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture: TSQueryPredicateStepType = 1; pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeString: TSQueryPredicateStepType = 2; -pub type TSQueryPredicateStepType = ::std::os::raw::c_uint; +pub type TSQueryPredicateStepType = u32; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct TSQueryPredicateStep { @@ -134,7 +133,7 @@ pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2; pub const TSQueryError_TSQueryErrorField: TSQueryError = 3; pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4; pub const TSQueryError_TSQueryErrorStructure: TSQueryError = 5; -pub type TSQueryError = ::std::os::raw::c_uint; +pub type TSQueryError = u32; extern "C" { #[doc = " Create a new parser."] pub fn ts_parser_new() -> *mut TSParser; @@ -290,11 +289,11 @@ extern "C" { #[doc = " If a non-null pointer is assigned, then the parser will periodically read"] #[doc = " from this pointer during parsing. If it reads a non-zero value, it will"] #[doc = " halt early, returning NULL. See `ts_parser_parse` for more information."] - pub fn ts_parser_set_cancellation_flag(self_: *mut TSParser, flag: *const size_t); + pub fn ts_parser_set_cancellation_flag(self_: *mut TSParser, flag: *const usize); } extern "C" { #[doc = " Get the parser's current cancellation flag pointer."] - pub fn ts_parser_cancellation_flag(self_: *const TSParser) -> *const size_t; + pub fn ts_parser_cancellation_flag(self_: *const TSParser) -> *const usize; } extern "C" { #[doc = " Set the logger that a parser should use during parsing."] @@ -754,9 +753,6 @@ extern "C" { extern "C" { pub fn ts_query_cursor_remove_match(arg1: *mut TSQueryCursor, id: u32); } -extern "C" { - pub fn ts_query_cursor_advance_to_byte(arg1: *mut TSQueryCursor, offset: u32); -} extern "C" { #[doc = " Advance to the next capture of the currently running query."] #[doc = ""] diff --git a/script/generate-bindings b/script/generate-bindings index b1ac9640..19c79e96 100755 --- a/script/generate-bindings +++ b/script/generate-bindings @@ -3,12 +3,14 @@ output_path=lib/binding_rust/bindings.rs header_path='lib/include/tree_sitter/api.h' -bindgen \ - --no-layout-tests \ - --whitelist-type '^TS.*' \ - --whitelist-function '^ts_.*' \ - --opaque-type FILE \ - --distrust-clang-mangling \ +bindgen \ + --no-layout-tests \ + --whitelist-type '^TS.*' \ + --whitelist-function '^ts_.*' \ + --opaque-type FILE \ + --size_t-is-usize \ + --translate-enum-integer-types \ + --distrust-clang-mangling \ $header_path > $output_path echo "" >> $output_path From 919e9745a64c2dba8d57f5f3e3654d77891b04c9 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 27 May 2021 12:30:17 -0700 Subject: [PATCH 06/12] Add `ts_tree_cursor_goto_first_child_for_point` function This function (and the similar `ts_tree_cursor_goto_first_child_for_byte`) allows you to efficiently seek the tree cursor to a given position, exploiting the tree's internal balancing, without having to visit all of the preceding siblings of each node. --- cli/src/tests/tree_test.rs | 107 ++++++++++++++++++++++++++++++++++ lib/binding_rust/bindings.rs | 6 +- lib/binding_rust/lib.rs | 15 +++++ lib/include/tree_sitter/api.h | 3 +- lib/src/tree_cursor.c | 41 +++++++++++-- 5 files changed, 166 insertions(+), 6 deletions(-) diff --git a/cli/src/tests/tree_test.rs b/cli/src/tests/tree_test.rs index fb461d1b..db13ca3a 100644 --- a/cli/src/tests/tree_test.rs +++ b/cli/src/tests/tree_test.rs @@ -262,6 +262,113 @@ fn test_tree_cursor_fields() { assert_eq!(cursor.field_name(), Some("parameters")); } +#[test] +fn test_tree_cursor_child_for_point() { + let mut parser = Parser::new(); + parser.set_language(get_language("javascript")).unwrap(); + let source = &" + [ + one, + { + two: tree + }, + four, five, six + ];"[1..]; + let tree = parser.parse(source, None).unwrap(); + + let mut c = tree.walk(); + assert_eq!(c.node().kind(), "program"); + + assert_eq!(c.goto_first_child_for_point(Point::new(7, 0)), None); + assert_eq!(c.goto_first_child_for_point(Point::new(6, 6)), None); + assert_eq!(c.node().kind(), "program"); + + // descend to expression statement + assert_eq!(c.goto_first_child_for_point(Point::new(6, 5)), Some(0)); + assert_eq!(c.node().kind(), "expression_statement"); + + // step into ';' and back up + assert_eq!(c.goto_first_child_for_point(Point::new(7, 0)), None); + assert_eq!(c.goto_first_child_for_point(Point::new(6, 5)), Some(1)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + (";", Point::new(6, 5)) + ); + assert!(c.goto_parent()); + + // descend into array + assert_eq!(c.goto_first_child_for_point(Point::new(6, 4)), Some(0)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + ("array", Point::new(0, 4)) + ); + + // step into '[' and back up + assert_eq!(c.goto_first_child_for_point(Point::new(0, 4)), Some(0)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + ("[", Point::new(0, 4)) + ); + assert!(c.goto_parent()); + + // step into identifier 'one' and back up + assert_eq!(c.goto_first_child_for_point(Point::new(0, 5)), Some(1)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + ("identifier", Point::new(1, 8)) + ); + assert!(c.goto_parent()); + assert_eq!(c.goto_first_child_for_point(Point::new(1, 10)), Some(1)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + ("identifier", Point::new(1, 8)) + ); + assert!(c.goto_parent()); + + // step into first ',' and back up + assert_eq!(c.goto_first_child_for_point(Point::new(1, 11)), Some(2)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + (",", Point::new(1, 11)) + ); + assert!(c.goto_parent()); + + // step into identifier 'four' and back up + assert_eq!(c.goto_first_child_for_point(Point::new(4, 10)), Some(5)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + ("identifier", Point::new(5, 8)) + ); + assert!(c.goto_parent()); + assert_eq!(c.goto_first_child_for_point(Point::new(5, 0)), Some(5)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + ("identifier", Point::new(5, 8)) + ); + assert!(c.goto_parent()); + + // step into ']' and back up + assert_eq!(c.goto_first_child_for_point(Point::new(6, 0)), Some(10)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + ("]", Point::new(6, 4)) + ); + assert!(c.goto_parent()); + assert_eq!(c.goto_first_child_for_point(Point::new(5, 23)), Some(10)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + ("]", Point::new(6, 4)) + ); + assert!(c.goto_parent()); + + // descend into object + assert_eq!(c.goto_first_child_for_point(Point::new(2, 0)), Some(3)); + assert_eq!( + (c.node().kind(), c.node().start_position()), + ("object", Point::new(2, 8)) + ); +} + #[test] fn test_tree_node_equality() { let mut parser = Parser::new(); diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 22a6cea0..38021fe7 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -586,12 +586,16 @@ extern "C" { } extern "C" { #[doc = " Move the cursor to the first child of its current node that extends beyond"] - #[doc = " the given byte offset."] + #[doc = " the given byte offset or point."] #[doc = ""] #[doc = " This returns the index of the child node if one was found, and returns -1"] #[doc = " if no such child was found."] pub fn ts_tree_cursor_goto_first_child_for_byte(arg1: *mut TSTreeCursor, arg2: u32) -> i64; } +extern "C" { + pub fn ts_tree_cursor_goto_first_child_for_point(arg1: *mut TSTreeCursor, arg2: TSPoint) + -> i64; +} extern "C" { pub fn ts_tree_cursor_copy(arg1: *const TSTreeCursor) -> TSTreeCursor; } diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 88124f08..2cf5898e 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1172,6 +1172,21 @@ impl<'a> TreeCursor<'a> { } } + /// Move this cursor to the first child of its current node that extends beyond + /// the given byte offset. + /// + /// This returns the index of the child node if one was found, and returns `None` + /// if no such child was found. + pub fn goto_first_child_for_point(&mut self, point: Point) -> Option { + let result = + unsafe { ffi::ts_tree_cursor_goto_first_child_for_point(&mut self.0, point.into()) }; + if result < 0 { + None + } else { + Some(result as usize) + } + } + /// Re-initialize this tree cursor to start at a different node. pub fn reset(&mut self, node: Node<'a>) { unsafe { ffi::ts_tree_cursor_reset(&mut self.0, node.0) }; diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 43315415..44e07396 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -651,12 +651,13 @@ bool ts_tree_cursor_goto_first_child(TSTreeCursor *); /** * Move the cursor to the first child of its current node that extends beyond - * the given byte offset. + * the given byte offset or point. * * This returns the index of the child node if one was found, and returns -1 * if no such child was found. */ int64_t ts_tree_cursor_goto_first_child_for_byte(TSTreeCursor *, uint32_t); +int64_t ts_tree_cursor_goto_first_child_for_point(TSTreeCursor *, TSPoint); TSTreeCursor ts_tree_cursor_copy(const TSTreeCursor *); diff --git a/lib/src/tree_cursor.c b/lib/src/tree_cursor.c index c4ee7a90..6b4829f5 100644 --- a/lib/src/tree_cursor.c +++ b/lib/src/tree_cursor.c @@ -159,10 +159,43 @@ int64_t ts_tree_cursor_goto_first_child_for_byte(TSTreeCursor *_self, uint32_t g } } while (did_descend); - if (self->stack.size > initial_size && - ts_tree_cursor_goto_next_sibling((TSTreeCursor *)self)) { - return visible_child_index; - } + self->stack.size = initial_size; + return -1; +} + +int64_t ts_tree_cursor_goto_first_child_for_point(TSTreeCursor *_self, TSPoint goal_point) { + TreeCursor *self = (TreeCursor *)_self; + uint32_t initial_size = self->stack.size; + uint32_t visible_child_index = 0; + + bool did_descend; + do { + did_descend = false; + + bool visible; + TreeCursorEntry entry; + CursorChildIterator iterator = ts_tree_cursor_iterate_children(self); + while (ts_tree_cursor_child_iterator_next(&iterator, &entry, &visible)) { + TSPoint end_point = point_add(entry.position.extent, ts_subtree_size(*entry.subtree).extent); + bool at_goal = point_gt(end_point, goal_point); + uint32_t visible_child_count = ts_subtree_visible_child_count(*entry.subtree); + if (at_goal) { + if (visible) { + array_push(&self->stack, entry); + return visible_child_index; + } + if (visible_child_count > 0) { + array_push(&self->stack, entry); + did_descend = true; + break; + } + } else if (visible) { + visible_child_index++; + } else { + visible_child_index += visible_child_count; + } + } + } while (did_descend); self->stack.size = initial_size; return -1; From 851f55afcea09b465328095dd29d7e396669e31e Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 28 May 2021 11:58:38 -0700 Subject: [PATCH 07/12] Report non-rooted matches that intersect cursor's range restriction Co-Authored-By: Nathan Sobo --- cli/src/tests/query_test.rs | 54 ++++++++++++++++ lib/src/query.c | 119 ++++++++++++++++++++++-------------- 2 files changed, 126 insertions(+), 47 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index daa3b04b..06578ba8 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1918,6 +1918,60 @@ fn test_query_captures_within_byte_range() { }); } +#[test] +fn test_query_matches_with_unrooted_patterns_intersecting_byte_range() { + allocations::record(|| { + let language = get_language("rust"); + let query = Query::new( + language, + r#" + ("{" @left "}" @right) + ("<" @left ">" @right) + "#, + ) + .unwrap(); + + let source = "mod a { fn a(f: B) { g(f) } }"; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + + // within the type parameter list + let offset = source.find("D: E>").unwrap(); + let matches = cursor.set_byte_range(offset, offset).matches( + &query, + tree.root_node(), + source.as_bytes(), + ); + assert_eq!( + collect_matches(matches, &query, source), + &[ + (1, vec![("left", "<"), ("right", ">")]), + (0, vec![("left", "{"), ("right", "}")]), + ] + ); + + // from within the type parameter list to within the function body + let start_offset = source.find("D: E>").unwrap(); + let end_offset = source.find("g(f)").unwrap(); + let matches = cursor.set_byte_range(start_offset, end_offset).matches( + &query, + tree.root_node(), + source.as_bytes(), + ); + assert_eq!( + collect_matches(matches, &query, source), + &[ + (1, vec![("left", "<"), ("right", ">")]), + (0, vec![("left", "{"), ("right", "}")]), + (0, vec![("left", "{"), ("right", "}")]), + ] + ); + }); +} + #[test] fn test_query_captures_within_byte_range_assigned_after_iterating() { allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index 00f66ec0..9feb1177 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -104,16 +104,20 @@ typedef struct { } SymbolTable; /* - * PatternEntry - Information about the starting point for matching a - * particular pattern, consisting of the index of the pattern within the query, - * and the index of the patter's first step in the shared `steps` array. These - * entries are stored in a 'pattern map' - a sorted array that makes it - * possible to efficiently lookup patterns based on the symbol for their first - * step. + * PatternEntry - Information about the starting point for matching a particular + * pattern. These entries are stored in a 'pattern map' - a sorted array that + * makes it possible to efficiently lookup patterns based on the symbol for their + * first step. The entry consists of the following fields: + * - `pattern_index` - the index of the pattern within the query + * - `step_index` - the index of the pattern's first step in the shared `steps` array + * - `is_rooted` - whether or not the pattern has a single root node. This property + * affects decisions about whether or not to start the pattern for nodes outside + * of a QueryCursor's range restriction. */ typedef struct { uint16_t step_index; uint16_t pattern_index; + bool is_rooted; } PatternEntry; typedef struct { @@ -691,8 +695,7 @@ static inline bool ts_query__pattern_map_search( static inline void ts_query__pattern_map_insert( TSQuery *self, TSSymbol symbol, - uint32_t start_step_index, - uint32_t pattern_index + PatternEntry new_entry ) { uint32_t index; ts_query__pattern_map_search(self, symbol, &index); @@ -705,7 +708,7 @@ static inline void ts_query__pattern_map_insert( PatternEntry *entry = &self->pattern_map.contents[index]; if ( self->steps.contents[entry->step_index].symbol == symbol && - entry->pattern_index < pattern_index + entry->pattern_index < new_entry.pattern_index ) { index++; } else { @@ -713,10 +716,7 @@ static inline void ts_query__pattern_map_insert( } } - array_insert(&self->pattern_map, index, ((PatternEntry) { - .step_index = start_step_index, - .pattern_index = pattern_index, - })); + array_insert(&self->pattern_map, index, new_entry); } static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { @@ -2108,7 +2108,24 @@ TSQuery *ts_query_new( } } - ts_query__pattern_map_insert(self, step->symbol, start_step_index, pattern_index); + // Determine whether the pattern has a single root node. This affects + // decisions about whether or not to start matching the pattern when + // a query cursor has a range restriction. + bool is_rooted = true; + uint32_t start_depth = step->depth; + for (uint32_t step_index = start_step_index + 1; step_index < self->steps.size; step_index++) { + QueryStep *step = &self->steps.contents[step_index]; + if (step->depth == start_depth) { + is_rooted = false; + break; + } + } + + ts_query__pattern_map_insert(self, step->symbol, (PatternEntry) { + .step_index = start_step_index, + .pattern_index = pattern_index, + .is_rooted = is_rooted + }); if (step->symbol == WILDCARD_SYMBOL) { self->wildcard_root_pattern_count++; } @@ -2702,6 +2719,7 @@ static inline bool ts_query_cursor__advance( else { // Get the properties of the current node. TSNode node = ts_tree_cursor_current_node(&self->cursor); + TSNode parent_node = ts_tree_cursor_parent_node(&self->cursor); TSSymbol symbol = ts_node_symbol(node); bool is_named = ts_node_is_named(node); bool has_later_siblings; @@ -2735,44 +2753,51 @@ static inline bool ts_query_cursor__advance( point_lt(ts_node_start_point(node), self->end_point) ); - if (node_intersects_range) { - // Add new states for any patterns whose root node is a wildcard. - for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { - PatternEntry *pattern = &self->query->pattern_map.contents[i]; - QueryStep *step = &self->query->steps.contents[pattern->step_index]; + bool parent_intersects_range = ts_node_is_null(parent_node) || ( + ts_node_end_byte(parent_node) > self->start_byte && + ts_node_start_byte(parent_node) < self->end_byte && + point_gt(ts_node_end_point(parent_node), self->start_point) && + point_lt(ts_node_start_point(parent_node), self->end_point) + ); - // If this node matches the first step of the pattern, then add a new - // state at the start of this pattern. - if (step->field && field_id != step->field) continue; - if (step->supertype_symbol && !supertype_count) continue; + // Add new states for any patterns whose root node is a wildcard. + for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { + PatternEntry *pattern = &self->query->pattern_map.contents[i]; + + // If this node matches the first step of the pattern, then add a new + // state at the start of this pattern. + QueryStep *step = &self->query->steps.contents[pattern->step_index]; + if ( + (node_intersects_range || (!pattern->is_rooted && parent_intersects_range)) && + (!step->field || field_id == step->field) && + (!step->supertype_symbol || supertype_count > 0) + ) { ts_query_cursor__add_state(self, pattern); } + } - // Add new states for any patterns whose root node matches this node. - unsigned i; - if (ts_query__pattern_map_search(self->query, symbol, &i)) { - PatternEntry *pattern = &self->query->pattern_map.contents[i]; - QueryStep *step = &self->query->steps.contents[pattern->step_index]; - do { - // If this node matches the first step of the pattern, then add a new - // state at the start of this pattern. - if (!step->field || field_id == step->field) { - ts_query_cursor__add_state(self, pattern); - } + // Add new states for any patterns whose root node matches this node. + unsigned i; + if (ts_query__pattern_map_search(self->query, symbol, &i)) { + PatternEntry *pattern = &self->query->pattern_map.contents[i]; - // Advance to the next pattern whose root node matches this node. - i++; - if (i == self->query->pattern_map.size) break; - pattern = &self->query->pattern_map.contents[i]; - step = &self->query->steps.contents[pattern->step_index]; - } while (step->symbol == symbol); - } - } else { - LOG( - " not starting new patterns. node end byte: %u, start_byte: %u\n", - ts_node_end_byte(node), - self->start_byte - ); + QueryStep *step = &self->query->steps.contents[pattern->step_index]; + do { + // If this node matches the first step of the pattern, then add a new + // state at the start of this pattern. + if ( + (node_intersects_range || (!pattern->is_rooted && parent_intersects_range)) && + (!step->field || field_id == step->field) + ) { + ts_query_cursor__add_state(self, pattern); + } + + // Advance to the next pattern whose root node matches this node. + i++; + if (i == self->query->pattern_map.size) break; + pattern = &self->query->pattern_map.contents[i]; + step = &self->query->steps.contents[pattern->step_index]; + } while (step->symbol == symbol); } // Update all of the in-progress states with current node. From dab11134c2cbf973c9aef83831439e282a9b18c9 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 28 May 2021 12:27:50 -0700 Subject: [PATCH 08/12] Add Query::capture_index_for_name method Co-Authored-By: Nathan Sobo --- lib/binding_rust/lib.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 2cf5898e..d187c965 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1512,6 +1512,14 @@ impl Query { &self.capture_names } + /// Get the index for a given capture name. + pub fn capture_index_for_name(&self, name: &str) -> Option { + self.capture_names + .iter() + .position(|n| n == name) + .map(|ix| ix as u32) + } + /// Get the properties that are checked for the given pattern index. /// /// This includes predicates with the operators `is?` and `is-not?`. From 97dfee63257b5e92197399b381aa993514640adf Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 28 May 2021 12:38:30 -0700 Subject: [PATCH 09/12] Add QueryMatch::nodes_for_capture_index Co-Authored-By: Nathan Sobo --- lib/binding_rust/lib.rs | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index d187c965..c06fa01a 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1721,6 +1721,19 @@ impl<'a, 'tree> QueryMatch<'a, 'tree> { unsafe { ffi::ts_query_cursor_remove_match(self.cursor, self.id) } } + pub fn nodes_for_capture_index( + &self, + capture_ix: u32, + ) -> impl Iterator> + '_ { + self.captures.iter().filter_map(move |capture| { + if capture.index == capture_ix { + Some(capture.node) + } else { + None + } + }) + } + fn new(m: ffi::TSQueryMatch, cursor: *mut ffi::TSQueryCursor) -> Self { QueryMatch { cursor, @@ -1768,33 +1781,24 @@ impl<'a, 'tree> QueryMatch<'a, 'tree> { .iter() .all(|predicate| match predicate { TextPredicate::CaptureEqCapture(i, j, is_positive) => { - let node1 = self.capture_for_index(*i).unwrap(); - let node2 = self.capture_for_index(*j).unwrap(); + let node1 = self.nodes_for_capture_index(*i).next().unwrap(); + let node2 = self.nodes_for_capture_index(*j).next().unwrap(); let text1 = get_text(buffer1, text_provider.text(node1)); let text2 = get_text(buffer2, text_provider.text(node2)); (text1 == text2) == *is_positive } TextPredicate::CaptureEqString(i, s, is_positive) => { - let node = self.capture_for_index(*i).unwrap(); + let node = self.nodes_for_capture_index(*i).next().unwrap(); let text = get_text(buffer1, text_provider.text(node)); (text == s.as_bytes()) == *is_positive } TextPredicate::CaptureMatchString(i, r, is_positive) => { - let node = self.capture_for_index(*i).unwrap(); + let node = self.nodes_for_capture_index(*i).next().unwrap(); let text = get_text(buffer1, text_provider.text(node)); r.is_match(text) == *is_positive } }) } - - fn capture_for_index(&self, capture_index: u32) -> Option> { - for c in self.captures { - if c.index == capture_index { - return Some(c.node); - } - } - None - } } impl QueryProperty { From 7f4eb9a222670ca0807c4f2b4af2c8257e4061ba Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 28 May 2021 14:07:54 -0700 Subject: [PATCH 10/12] Provide ::set_{byte,point}_range on both query iterators --- lib/binding_rust/lib.rs | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index c06fa01a..4dc7af2a 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -198,11 +198,6 @@ pub enum QueryErrorKind { Structure, } -trait TextCallback<'a> { - fn call(&mut self, node: Node); - fn next_chunk(&mut self) -> Option<&'a [u8]>; -} - #[derive(Debug)] enum TextPredicate { CaptureEqString(u32, String, bool), @@ -1836,14 +1831,6 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> { } } -impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> { - pub fn set_byte_range(&mut self, start: usize, end: usize) { - unsafe { - ffi::ts_query_cursor_set_byte_range(self.ptr, start as u32, end as u32); - } - } -} - impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryCaptures<'a, 'tree, T> { type Item = (QueryMatch<'a, 'tree>, usize); @@ -1876,6 +1863,34 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryCaptures<'a, 'tree, T> { } } +impl<'a, 'tree, T: TextProvider<'a>> QueryMatches<'a, 'tree, T> { + pub fn set_byte_range(&mut self, start: usize, end: usize) { + unsafe { + ffi::ts_query_cursor_set_byte_range(self.ptr, start as u32, end as u32); + } + } + + pub fn set_point_range(&mut self, start: Point, end: Point) { + unsafe { + ffi::ts_query_cursor_set_point_range(self.ptr, start.into(), end.into()); + } + } +} + +impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> { + pub fn set_byte_range(&mut self, start: usize, end: usize) { + unsafe { + ffi::ts_query_cursor_set_byte_range(self.ptr, start as u32, end as u32); + } + } + + pub fn set_point_range(&mut self, start: Point, end: Point) { + unsafe { + ffi::ts_query_cursor_set_point_range(self.ptr, start.into(), end.into()); + } + } +} + impl<'cursor, 'tree> fmt::Debug for QueryMatch<'cursor, 'tree> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( From d72771a19f4143530b1cfd23808e344f1276e176 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 28 May 2021 14:15:05 -0700 Subject: [PATCH 11/12] Make ::set_{byte,point}_range methods take a Range Co-Authored-By: Nathan Sobo --- cli/src/main.rs | 2 +- cli/src/query.rs | 15 ++++++----- cli/src/tests/query_test.rs | 20 +++++++------- lib/binding_rust/lib.rs | 52 +++++++++++++++++++++++-------------- 4 files changed, 53 insertions(+), 36 deletions(-) diff --git a/cli/src/main.rs b/cli/src/main.rs index a2d0a7da..8a307ccd 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -309,7 +309,7 @@ fn run() -> error::Result<()> { let query_path = Path::new(matches.value_of("query-path").unwrap()); let range = matches.value_of("byte-range").map(|br| { let r: Vec<&str> = br.split(":").collect(); - (r[0].parse().unwrap(), r[1].parse().unwrap()) + r[0].parse().unwrap()..r[1].parse().unwrap() }); let should_test = matches.is_present("test"); query::query_files_at_paths( diff --git a/cli/src/query.rs b/cli/src/query.rs index e303a002..9d8ca578 100644 --- a/cli/src/query.rs +++ b/cli/src/query.rs @@ -1,8 +1,11 @@ use super::error::{Error, Result}; use crate::query_testing; -use std::fs; -use std::io::{self, Write}; -use std::path::Path; +use std::{ + fs, + io::{self, Write}, + ops::Range, + path::Path, +}; use tree_sitter::{Language, Parser, Query, QueryCursor}; pub fn query_files_at_paths( @@ -10,7 +13,7 @@ pub fn query_files_at_paths( paths: Vec, query_path: &Path, ordered_captures: bool, - range: Option<(usize, usize)>, + range: Option>, should_test: bool, ) -> Result<()> { let stdout = io::stdout(); @@ -23,8 +26,8 @@ pub fn query_files_at_paths( .map_err(|e| Error::new(format!("Query compilation failed: {:?}", e)))?; let mut query_cursor = QueryCursor::new(); - if let Some((beg, end)) = range { - query_cursor.set_byte_range(beg, end); + if let Some(range) = range { + query_cursor.set_byte_range(range); } let mut parser = Parser::new(); diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 06578ba8..21393dd2 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1782,7 +1782,7 @@ fn test_query_matches_within_byte_range() { let matches = cursor - .set_byte_range(0, 8) + .set_byte_range(0..8) .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( @@ -1796,7 +1796,7 @@ fn test_query_matches_within_byte_range() { let matches = cursor - .set_byte_range(5, 15) + .set_byte_range(5..15) .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( @@ -1810,7 +1810,7 @@ fn test_query_matches_within_byte_range() { let matches = cursor - .set_byte_range(12, 0) + .set_byte_range(12..0) .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( @@ -1839,7 +1839,7 @@ fn test_query_matches_within_point_range() { let mut cursor = QueryCursor::new(); let matches = cursor - .set_point_range(Point::new(0, 0), Point::new(1, 3)) + .set_point_range(Point::new(0, 0)..Point::new(1, 3)) .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( @@ -1852,7 +1852,7 @@ fn test_query_matches_within_point_range() { ); let matches = cursor - .set_point_range(Point::new(1, 0), Point::new(2, 3)) + .set_point_range(Point::new(1, 0)..Point::new(2, 3)) .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( @@ -1865,7 +1865,7 @@ fn test_query_matches_within_point_range() { ); let matches = cursor - .set_point_range(Point::new(2, 1), Point::new(0, 0)) + .set_point_range(Point::new(2, 1)..Point::new(0, 0)) .matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( @@ -1904,7 +1904,7 @@ fn test_query_captures_within_byte_range() { let mut cursor = QueryCursor::new(); let captures = cursor - .set_byte_range(3, 27) + .set_byte_range(3..27) .captures(&query, tree.root_node(), source.as_bytes()); assert_eq!( @@ -1940,7 +1940,7 @@ fn test_query_matches_with_unrooted_patterns_intersecting_byte_range() { // within the type parameter list let offset = source.find("D: E>").unwrap(); - let matches = cursor.set_byte_range(offset, offset).matches( + let matches = cursor.set_byte_range(offset..offset).matches( &query, tree.root_node(), source.as_bytes(), @@ -1956,7 +1956,7 @@ fn test_query_matches_with_unrooted_patterns_intersecting_byte_range() { // from within the type parameter list to within the function body let start_offset = source.find("D: E>").unwrap(); let end_offset = source.find("g(f)").unwrap(); - let matches = cursor.set_byte_range(start_offset, end_offset).matches( + let matches = cursor.set_byte_range(start_offset..end_offset).matches( &query, tree.root_node(), source.as_bytes(), @@ -2039,7 +2039,7 @@ fn test_query_captures_within_byte_range_assigned_after_iterating() { // Captures from these matches are reported, but only those that // intersect the range. results.clear(); - captures.set_byte_range(source.find("Ok").unwrap(), source.len()); + captures.set_byte_range(source.find("Ok").unwrap()..source.len()); for (mat, capture_ix) in captures { let capture = mat.captures[capture_ix as usize]; results.push(( diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 4dc7af2a..8313618a 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -7,13 +7,19 @@ pub mod allocations; #[cfg(unix)] use std::os::unix::io::AsRawFd; -use std::ffi::CStr; -use std::marker::PhantomData; -use std::mem::MaybeUninit; -use std::os::raw::{c_char, c_void}; -use std::ptr::NonNull; -use std::sync::atomic::AtomicUsize; -use std::{char, error, fmt, hash, iter, ptr, slice, str, u16}; +use std::{ + char, error, + ffi::CStr, + fmt, hash, iter, + marker::PhantomData, + mem::MaybeUninit, + ops, + os::raw::{c_char, c_void}, + ptr::{self, NonNull}, + slice, str, + sync::atomic::AtomicUsize, + u16, +}; /// The latest ABI version that is supported by the current version of the /// library. @@ -1695,17 +1701,25 @@ impl QueryCursor { } /// Set the range in which the query will be executed, in terms of byte offsets. - pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self { + pub fn set_byte_range(&mut self, range: ops::Range) -> &mut Self { unsafe { - ffi::ts_query_cursor_set_byte_range(self.ptr.as_ptr(), start as u32, end as u32); + ffi::ts_query_cursor_set_byte_range( + self.ptr.as_ptr(), + range.start as u32, + range.end as u32, + ); } self } /// Set the range in which the query will be executed, in terms of rows and columns. - pub fn set_point_range(&mut self, start: Point, end: Point) -> &mut Self { + pub fn set_point_range(&mut self, range: ops::Range) -> &mut Self { unsafe { - ffi::ts_query_cursor_set_point_range(self.ptr.as_ptr(), start.into(), end.into()); + ffi::ts_query_cursor_set_point_range( + self.ptr.as_ptr(), + range.start.into(), + range.end.into(), + ); } self } @@ -1864,29 +1878,29 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryCaptures<'a, 'tree, T> { } impl<'a, 'tree, T: TextProvider<'a>> QueryMatches<'a, 'tree, T> { - pub fn set_byte_range(&mut self, start: usize, end: usize) { + pub fn set_byte_range(&mut self, range: ops::Range) { unsafe { - ffi::ts_query_cursor_set_byte_range(self.ptr, start as u32, end as u32); + ffi::ts_query_cursor_set_byte_range(self.ptr, range.start as u32, range.end as u32); } } - pub fn set_point_range(&mut self, start: Point, end: Point) { + pub fn set_point_range(&mut self, range: ops::Range) { unsafe { - ffi::ts_query_cursor_set_point_range(self.ptr, start.into(), end.into()); + ffi::ts_query_cursor_set_point_range(self.ptr, range.start.into(), range.end.into()); } } } impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> { - pub fn set_byte_range(&mut self, start: usize, end: usize) { + pub fn set_byte_range(&mut self, range: ops::Range) { unsafe { - ffi::ts_query_cursor_set_byte_range(self.ptr, start as u32, end as u32); + ffi::ts_query_cursor_set_byte_range(self.ptr, range.start as u32, range.end as u32); } } - pub fn set_point_range(&mut self, start: Point, end: Point) { + pub fn set_point_range(&mut self, range: ops::Range) { unsafe { - ffi::ts_query_cursor_set_point_range(self.ptr, start.into(), end.into()); + ffi::ts_query_cursor_set_point_range(self.ptr, range.start.into(), range.end.into()); } } } From 84168949236315dc329124466ae71ea5af6577b6 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 2 Jun 2021 09:15:04 -0700 Subject: [PATCH 12/12] Use std::iter::Once in impl TextProvider for [u8] Co-Authored-By: Douglas Creager --- lib/binding_rust/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 8313618a..cc60d62b 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1928,10 +1928,10 @@ where } impl<'a> TextProvider<'a> for &'a [u8] { - type I = std::option::IntoIter<&'a [u8]>; + type I = iter::Once<&'a [u8]>; fn text(&mut self, node: Node) -> Self::I { - Some(&self[node.byte_range()]).into_iter() + iter::once(&self[node.byte_range()]) } }