feat!: implement StreamingIterator instead of Iterator for QueryMatches and QueryCaptures

This fixes UB when either `QueryMatches` or `QueryCaptures` had collect called on it.

Co-authored-by: Amaan Qureshi <amaanq12@gmail.com>
This commit is contained in:
Lukas Seidel 2024-09-29 23:34:48 +02:00 committed by GitHub
parent 12007d3ebe
commit 6b1ebd3d29
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 271 additions and 105 deletions

View file

@ -1,6 +1,7 @@
use std::{cmp::Ordering, fmt::Write, ops::Range};
use rand::prelude::Rng;
use streaming_iterator::{IntoStreamingIterator, StreamingIterator};
use tree_sitter::{
Language, Node, Parser, Point, Query, QueryCapture, QueryCursor, QueryMatch, Tree, TreeCursor,
};
@ -324,39 +325,39 @@ pub fn assert_query_matches(
}
pub fn collect_matches<'a>(
matches: impl Iterator<Item = QueryMatch<'a, 'a>>,
mut matches: impl StreamingIterator<Item = QueryMatch<'a, 'a>>,
query: &'a Query,
source: &'a str,
) -> Vec<(usize, Vec<(&'a str, &'a str)>)> {
matches
.map(|m| {
(
m.pattern_index,
format_captures(m.captures.iter().copied(), query, source),
)
})
.collect()
let mut result = Vec::new();
while let Some(m) = matches.next() {
result.push((
m.pattern_index,
format_captures(m.captures.iter().into_streaming_iter_ref(), query, source),
));
}
result
}
pub fn collect_captures<'a>(
captures: impl Iterator<Item = (QueryMatch<'a, 'a>, usize)>,
captures: impl StreamingIterator<Item = (QueryMatch<'a, 'a>, usize)>,
query: &'a Query,
source: &'a str,
) -> Vec<(&'a str, &'a str)> {
format_captures(captures.map(|(m, i)| m.captures[i]), query, source)
format_captures(captures.map(|(m, i)| m.captures[*i]), query, source)
}
fn format_captures<'a>(
captures: impl Iterator<Item = QueryCapture<'a>>,
mut captures: impl StreamingIterator<Item = QueryCapture<'a>>,
query: &'a Query,
source: &'a str,
) -> Vec<(&'a str, &'a str)> {
captures
.map(|capture| {
(
query.capture_names()[capture.index as usize],
capture.node.utf8_text(source.as_bytes()).unwrap(),
)
})
.collect()
let mut result = Vec::new();
while let Some(capture) = captures.next() {
result.push((
query.capture_names()[capture.index as usize],
capture.node.utf8_text(source.as_bytes()).unwrap(),
));
}
result
}

View file

@ -3,6 +3,7 @@ use std::{env, fmt::Write};
use indoc::indoc;
use lazy_static::lazy_static;
use rand::{prelude::StdRng, SeedableRng};
use streaming_iterator::StreamingIterator;
use tree_sitter::{
CaptureQuantifier, Language, Node, Parser, Point, Query, QueryCursor, QueryError,
QueryErrorKind, QueryPredicate, QueryPredicateArg, QueryProperty,
@ -2267,29 +2268,50 @@ fn test_query_matches_with_wildcard_at_root_intersecting_byte_range() {
// After the first line of the class definition
let offset = source.find("A:").unwrap() + 2;
let matches = cursor
.set_byte_range(offset..offset)
.matches(&query, tree.root_node(), source.as_bytes())
.map(|mat| mat.captures[0].node.kind())
.collect::<Vec<_>>();
let mut matches = Vec::new();
let mut match_iter = cursor.set_byte_range(offset..offset).matches(
&query,
tree.root_node(),
source.as_bytes(),
);
while let Some(mat) = match_iter.next() {
if let Some(capture) = mat.captures.first() {
matches.push(capture.node.kind());
}
}
assert_eq!(matches, &["class_definition"]);
// After the first line of the function definition
let offset = source.find("b():").unwrap() + 4;
let matches = cursor
.set_byte_range(offset..offset)
.matches(&query, tree.root_node(), source.as_bytes())
.map(|mat| mat.captures[0].node.kind())
.collect::<Vec<_>>();
let mut matches = Vec::new();
let mut match_iter = cursor.set_byte_range(offset..offset).matches(
&query,
tree.root_node(),
source.as_bytes(),
);
while let Some(mat) = match_iter.next() {
if let Some(capture) = mat.captures.first() {
matches.push(capture.node.kind());
}
}
assert_eq!(matches, &["class_definition", "function_definition"]);
// After the first line of the if statement
let offset = source.find("c:").unwrap() + 2;
let matches = cursor
.set_byte_range(offset..offset)
.matches(&query, tree.root_node(), source.as_bytes())
.map(|mat| mat.captures[0].node.kind())
.collect::<Vec<_>>();
let mut matches = Vec::new();
let mut match_iter = cursor.set_byte_range(offset..offset).matches(
&query,
tree.root_node(),
source.as_bytes(),
);
while let Some(mat) = match_iter.next() {
if let Some(capture) = mat.captures.first() {
matches.push(capture.node.kind());
}
}
assert_eq!(
matches,
&["class_definition", "function_definition", "if_statement"]
@ -2342,8 +2364,9 @@ fn test_query_captures_within_byte_range_assigned_after_iterating() {
// Retrieve some captures
let mut results = Vec::new();
for (mat, capture_ix) in captures.by_ref().take(5) {
let capture = mat.captures[capture_ix];
let mut first_five = captures.by_ref().take(5);
while let Some((mat, capture_ix)) = first_five.next() {
let capture = mat.captures[*capture_ix];
results.push((
query.capture_names()[capture.index as usize],
&source[capture.node.byte_range()],
@ -2365,8 +2388,8 @@ fn test_query_captures_within_byte_range_assigned_after_iterating() {
// intersect the range.
results.clear();
captures.set_byte_range(source.find("Ok").unwrap()..source.len());
for (mat, capture_ix) in captures {
let capture = mat.captures[capture_ix];
while let Some((mat, capture_ix)) = captures.next() {
let capture = mat.captures[*capture_ix];
results.push((
query.capture_names()[capture.index as usize],
&source[capture.node.byte_range()],
@ -2602,21 +2625,23 @@ fn test_query_matches_with_captured_wildcard_at_root() {
parser.set_language(&language).unwrap();
let tree = parser.parse(source, None).unwrap();
let match_capture_names_and_rows = cursor
.matches(&query, tree.root_node(), source.as_bytes())
.map(|m| {
m.captures
.iter()
.map(|c| {
(
query.capture_names()[c.index as usize],
c.node.kind(),
c.node.start_position().row,
)
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let mut match_capture_names_and_rows = Vec::new();
let mut match_iter = cursor.matches(&query, tree.root_node(), source.as_bytes());
while let Some(m) = match_iter.next() {
let captures = m
.captures
.iter()
.map(|c| {
(
query.capture_names()[c.index as usize],
c.node.kind(),
c.node.start_position().row,
)
})
.collect::<Vec<_>>();
match_capture_names_and_rows.push(captures);
}
assert_eq!(
match_capture_names_and_rows,
@ -3460,9 +3485,13 @@ fn test_query_captures_with_matches_removed() {
let mut cursor = QueryCursor::new();
let mut captured_strings = Vec::new();
for (m, i) in cursor.captures(&query, tree.root_node(), source.as_bytes()) {
let capture = m.captures[i];
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
while let Some((m, i)) = captures.next() {
println!("captured: {:?}, {}", m, i);
let capture = m.captures[*i];
let text = capture.node.utf8_text(source.as_bytes()).unwrap();
println!("captured: {:?}", text);
if text == "a" {
m.remove();
continue;
@ -3504,8 +3533,9 @@ fn test_query_captures_with_matches_removed_before_they_finish() {
let mut cursor = QueryCursor::new();
let mut captured_strings = Vec::new();
for (m, i) in cursor.captures(&query, tree.root_node(), source.as_bytes()) {
let capture = m.captures[i];
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
while let Some((m, i)) = captures.next() {
let capture = m.captures[*i];
let text = capture.node.utf8_text(source.as_bytes()).unwrap();
if text == "as" {
m.remove();
@ -3912,21 +3942,24 @@ fn test_query_random() {
panic!("failed to build query for pattern {pattern} - {e}. seed: {seed}");
}
};
let mut actual_matches = cursor
.matches(
&query,
test_tree.root_node(),
include_bytes!("parser_test.rs").as_ref(),
)
.map(|mat| Match {
let mut actual_matches = Vec::new();
let mut match_iter = cursor.matches(
&query,
test_tree.root_node(),
include_bytes!("parser_test.rs").as_ref(),
);
while let Some(mat) = match_iter.next() {
let transformed_match = Match {
last_node: None,
captures: mat
.captures
.iter()
.map(|c| (query.capture_names()[c.index as usize], c.node))
.collect::<Vec<_>>(),
})
.collect::<Vec<_>>();
};
actual_matches.push(transformed_match);
}
// actual_matches.sort_unstable();
actual_matches.dedup();
@ -4908,12 +4941,12 @@ fn test_consecutive_zero_or_modifiers() {
assert!(matches.next().is_some());
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, three_tree.root_node(), three_source.as_bytes());
let mut matches = cursor.matches(&query, three_tree.root_node(), three_source.as_bytes());
let mut len_3 = false;
let mut len_1 = false;
for m in matches {
while let Some(m) = matches.next() {
if m.captures.len() == 3 {
len_3 = true;
}

View file

@ -1,5 +1,6 @@
use std::{iter, sync::Arc};
use streaming_iterator::StreamingIterator;
use tree_sitter::{Language, Node, Parser, Point, Query, QueryCursor, TextProvider, Tree};
use crate::tests::helpers::fixtures::get_language;
@ -30,8 +31,8 @@ fn tree_query<I: AsRef<[u8]>>(tree: &Tree, text: impl TextProvider<I>, language:
let mut cursor = QueryCursor::new();
let mut captures = cursor.captures(&query, tree.root_node(), text);
let (match_, idx) = captures.next().unwrap();
let capture = match_.captures[idx];
assert_eq!(capture.index as usize, idx);
let capture = match_.captures[*idx];
assert_eq!(capture.index as usize, *idx);
assert_eq!("comment", capture.node.kind());
}

View file

@ -1,6 +1,7 @@
use std::fs;
use lazy_static::lazy_static;
use streaming_iterator::StreamingIterator;
use tree_sitter::{
wasmtime::Engine, Parser, Query, QueryCursor, WasmError, WasmErrorKind, WasmStore,
};