Allow QueryCursor's text callback to return an iterator

This commit is contained in:
Max Brunsfeld 2021-05-23 15:12:24 -07:00
parent 0e445c47fa
commit 8c3d1466ec
5 changed files with 272 additions and 111 deletions

View file

@ -3,7 +3,7 @@ use crate::query_testing;
use std::fs;
use std::io::{self, Write};
use std::path::Path;
use tree_sitter::{Language, Node, Parser, Query, QueryCursor};
use tree_sitter::{Language, Parser, Query, QueryCursor};
pub fn query_files_at_paths(
language: Language,
@ -38,12 +38,11 @@ pub fn query_files_at_paths(
let source_code = fs::read(&path).map_err(Error::wrap(|| {
format!("Error reading source file {:?}", path)
}))?;
let text_callback = |n: Node| &source_code[n.byte_range()];
let tree = parser.parse(&source_code, None).unwrap();
if ordered_captures {
for (mat, capture_index) in
query_cursor.captures(&query, tree.root_node(), text_callback)
query_cursor.captures(&query, tree.root_node(), source_code.as_slice())
{
let capture = mat.captures[capture_index];
let capture_name = &query.capture_names()[capture.index as usize];
@ -62,7 +61,7 @@ pub fn query_files_at_paths(
});
}
} else {
for m in query_cursor.matches(&query, tree.root_node(), text_callback) {
for m in query_cursor.matches(&query, tree.root_node(), source_code.as_slice()) {
writeln!(&mut stdout, " pattern: {}", m.pattern_index)?;
for capture in m.captures {
let start = capture.node.start_position();

View file

@ -761,7 +761,7 @@ fn test_query_matches_with_named_wildcard() {
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
@ -1645,7 +1645,7 @@ fn test_query_matches_with_too_many_permutations_to_track() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(&source));
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
// For this pathological query, some match permutations will be dropped.
// Just check that a subset of the results are returned, and crash or
@ -1686,7 +1686,7 @@ fn test_query_matches_with_alternatives_and_too_many_permutations_to_track() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(&source));
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source.as_str()),
@ -1783,7 +1783,7 @@ fn test_query_matches_within_byte_range() {
let matches =
cursor
.set_byte_range(0, 8)
.matches(&query, tree.root_node(), to_callback(source));
.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
@ -1797,7 +1797,7 @@ fn test_query_matches_within_byte_range() {
let matches =
cursor
.set_byte_range(5, 15)
.matches(&query, tree.root_node(), to_callback(source));
.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
@ -1811,7 +1811,7 @@ fn test_query_matches_within_byte_range() {
let matches =
cursor
.set_byte_range(12, 0)
.matches(&query, tree.root_node(), to_callback(source));
.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
@ -1840,7 +1840,7 @@ fn test_query_matches_within_point_range() {
let matches = cursor
.set_point_range(Point::new(0, 0), Point::new(1, 3))
.matches(&query, tree.root_node(), to_callback(source));
.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
@ -1853,7 +1853,7 @@ fn test_query_matches_within_point_range() {
let matches = cursor
.set_point_range(Point::new(1, 0), Point::new(2, 3))
.matches(&query, tree.root_node(), to_callback(source));
.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
@ -1866,7 +1866,7 @@ fn test_query_matches_within_point_range() {
let matches = cursor
.set_point_range(Point::new(2, 1), Point::new(0, 0))
.matches(&query, tree.root_node(), to_callback(source));
.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
@ -1905,7 +1905,7 @@ fn test_query_captures_within_byte_range() {
let captures =
cursor
.set_byte_range(3, 27)
.captures(&query, tree.root_node(), to_callback(source));
.captures(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_captures(captures, &query, source),
@ -1955,13 +1955,13 @@ fn test_query_matches_different_queries_same_cursor() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let matches = cursor.matches(&query1, tree.root_node(), to_callback(source));
let matches = cursor.matches(&query1, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query1, source),
&[(0, vec![("id1", "a")]),]
);
let matches = cursor.matches(&query3, tree.root_node(), to_callback(source));
let matches = cursor.matches(&query3, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query3, source),
&[
@ -1971,7 +1971,7 @@ fn test_query_matches_different_queries_same_cursor() {
]
);
let matches = cursor.matches(&query2, tree.root_node(), to_callback(source));
let matches = cursor.matches(&query2, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query2, source),
&[(0, vec![("id1", "a")]), (1, vec![("id2", "b")]),]
@ -1998,7 +1998,7 @@ fn test_query_matches_with_multiple_captures_on_a_node() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
&[(
@ -2016,7 +2016,7 @@ fn test_query_matches_with_multiple_captures_on_a_node() {
// disabling captures still works when there are multiple captures on a
// single node.
query.disable_capture("name2");
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
&[(
@ -2087,7 +2087,7 @@ fn test_query_matches_with_captured_wildcard_at_root() {
let tree = parser.parse(&source, None).unwrap();
let match_capture_names_and_rows = cursor
.matches(&query, tree.root_node(), to_callback(source))
.matches(&query, tree.root_node(), source.as_bytes())
.map(|m| {
m.captures
.iter()
@ -2352,7 +2352,7 @@ fn test_query_captures_basic() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
@ -2368,7 +2368,7 @@ fn test_query_captures_basic() {
],
);
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
let captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_captures(captures, &query, source),
&[
@ -2425,7 +2425,7 @@ fn test_query_captures_with_text_conditions() {
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
let captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_captures(captures, &query, source),
&[
@ -2564,7 +2564,7 @@ fn test_query_captures_with_duplicates() {
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
let captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_captures(captures, &query, source),
&[("function", "x"), ("variable", "x"),],
@ -2608,7 +2608,7 @@ fn test_query_captures_with_many_nested_results_without_fields() {
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(&source));
let captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
let captures = collect_captures(captures, &query, &source);
assert_eq!(
@ -2668,7 +2668,7 @@ fn test_query_captures_with_many_nested_results_with_fields() {
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(&source));
let captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
let captures = collect_captures(captures, &query, &source);
assert_eq!(
@ -2765,7 +2765,7 @@ fn test_query_captures_with_too_many_nested_results() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(&source));
let captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
let captures = collect_captures(captures, &query, &source);
assert_eq!(
@ -2828,7 +2828,7 @@ fn test_query_captures_with_definite_pattern_containing_many_nested_matches() {
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
let captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_captures(captures, &query, source),
[("l-bracket", "[")]
@ -2864,7 +2864,7 @@ fn test_query_captures_ordered_by_both_start_and_end_positions() {
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
let captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_captures(captures, &query, source),
&[
@ -2906,7 +2906,7 @@ fn test_query_captures_with_matches_removed() {
let mut cursor = QueryCursor::new();
let mut captured_strings = Vec::new();
for (m, i) in cursor.captures(&query, tree.root_node(), to_callback(source)) {
for (m, i) in cursor.captures(&query, tree.root_node(), source.as_bytes()) {
let capture = m.captures[i];
let text = capture.node.utf8_text(source.as_bytes()).unwrap();
if text == "a" {
@ -2943,7 +2943,7 @@ fn test_query_captures_and_matches_iterators_are_fused() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let mut captures = cursor.captures(&query, tree.root_node(), to_callback(source));
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
assert_eq!(captures.next().unwrap().0.captures[0].index, 0);
assert_eq!(captures.next().unwrap().0.captures[0].index, 0);
@ -2953,7 +2953,7 @@ fn test_query_captures_and_matches_iterators_are_fused() {
assert!(captures.next().is_none());
drop(captures);
let mut matches = cursor.matches(&query, tree.root_node(), to_callback(source));
let mut matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(matches.next().unwrap().captures[0].index, 0);
assert_eq!(matches.next().unwrap().captures[0].index, 0);
assert_eq!(matches.next().unwrap().captures[0].index, 0);
@ -2963,6 +2963,79 @@ fn test_query_captures_and_matches_iterators_are_fused() {
});
}
#[test]
fn test_query_text_callback_returns_chunks() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
((identifier) @leading_upper
(#match? @leading_upper "^[A-Z][A-Z_]*[a-z]"))
((identifier) @all_upper
(#match? @all_upper "^[A-Z][A-Z_]*$"))
((identifier) @all_lower
(#match? @all_lower "^[a-z][a-z_]*$"))
"#,
)
.unwrap();
let source = "SOMETHING[a] = transform(AnotherThing[b].property[c], PARAMETER);";
// Store the source code in chunks of 3 bytes, and expose it via
// an iterator API.
let source_chunks = source.as_bytes().chunks(3).collect::<Vec<_>>();
let chunks_in_range = |range: std::ops::Range<usize>| {
let mut offset = 0;
source_chunks.iter().filter_map(move |chunk| {
let end_offset = offset + chunk.len();
if offset < range.end && range.start < end_offset {
let end_in_chunk = (range.end - offset).min(chunk.len());
let start_in_chunk = range.start.max(offset) - offset;
offset = end_offset;
Some(&chunk[start_in_chunk..end_in_chunk])
} else {
offset = end_offset;
None
}
})
};
assert_eq!(
chunks_in_range(0..9)
.map(|c| std::str::from_utf8(c).unwrap())
.collect::<String>(),
"SOMETHING",
);
assert_eq!(
chunks_in_range(15..24)
.map(|c| std::str::from_utf8(c).unwrap())
.collect::<String>(),
"transform",
);
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), |node: Node| {
chunks_in_range(node.byte_range())
});
assert_eq!(
collect_captures(captures, &query, source),
&[
("all_upper", "SOMETHING"),
("all_lower", "a"),
("all_lower", "transform"),
("leading_upper", "AnotherThing"),
("all_lower", "b"),
("all_lower", "c"),
("all_upper", "PARAMETER"),
]
);
});
}
#[test]
fn test_query_start_byte_for_pattern() {
let language = get_language("javascript");
@ -3058,7 +3131,7 @@ fn test_query_lifetime_is_separate_from_nodes_lifetime() {
let query = Query::new(language, query).unwrap();
let mut cursor = QueryCursor::new();
let node = cursor
.matches(&query, node, to_callback(source))
.matches(&query, node, source.as_bytes())
.next()
.unwrap()
.captures[0]
@ -3078,7 +3151,7 @@ fn test_query_lifetime_is_separate_from_nodes_lifetime() {
let query = Query::new(language, query).unwrap();
let mut cursor = QueryCursor::new();
let node = cursor
.captures(&query, node, to_callback(source))
.captures(&query, node, source.as_bytes())
.next()
.unwrap()
.0
@ -3123,7 +3196,7 @@ fn test_query_comments() {
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
&[(0, vec![("fn-name", "one")]),],
@ -3159,7 +3232,7 @@ fn test_query_disable_pattern() {
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
&[
@ -3502,13 +3575,13 @@ fn assert_query_matches(
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(collect_matches(matches, &query, source), expected);
assert_eq!(cursor.did_exceed_match_limit(), false);
}
fn collect_matches<'a>(
matches: impl Iterator<Item = QueryMatch<'a>>,
matches: impl Iterator<Item = QueryMatch<'a, 'a>>,
query: &'a Query,
source: &'a str,
) -> Vec<(usize, Vec<(&'a str, &'a str)>)> {
@ -3523,7 +3596,7 @@ fn collect_matches<'a>(
}
fn collect_captures<'a>(
captures: impl Iterator<Item = (QueryMatch<'a>, usize)>,
captures: impl Iterator<Item = (QueryMatch<'a, 'a>, usize)>,
query: &'a Query,
source: &'a str,
) -> Vec<(&'a str, &'a str)> {
@ -3544,7 +3617,3 @@ fn format_captures<'a>(
})
.collect()
}
fn to_callback<'a>(source: &'a str) -> impl Fn(Node) -> &'a [u8] {
move |n| &source.as_bytes()[n.byte_range()]
}