feat!: properly handle predicates used on quantified captures

This commit is contained in:
Amaan Qureshi 2023-08-18 19:45:00 -04:00
parent a0cf0a7104
commit 09ac28c77d
No known key found for this signature in database
GPG key ID: E67890ADC4227273
3 changed files with 218 additions and 63 deletions

View file

@ -4574,6 +4574,89 @@ fn test_capture_quantifiers() {
});
}
#[test]
fn test_query_quantified_captures() {
struct Row {
description: &'static str,
language: Language,
code: &'static str,
pattern: &'static str,
captures: &'static [(&'static str, &'static str)],
}
// #[rustfmt::skip]
let rows = &[
Row {
description: "doc comments where all must match the prefiix",
language: get_language("c"),
code: indoc! {"
/// foo
/// bar
/// baz
void main() {}
/// qux
/// quux
// quuz
"},
pattern: r#"
((comment)+ @comment.documentation
(#match? @comment.documentation "^///"))
"#,
captures: &[
("comment.documentation", "/// foo"),
("comment.documentation", "/// bar"),
("comment.documentation", "/// baz"),
],
},
Row {
description: "doc comments where one must match the prefix",
language: get_language("c"),
code: indoc! {"
/// foo
/// bar
/// baz
void main() {}
/// qux
/// quux
// quuz
"},
pattern: r#"
((comment)+ @comment.documentation
(#any-match? @comment.documentation "^///"))
"#,
captures: &[
("comment.documentation", "/// foo"),
("comment.documentation", "/// bar"),
("comment.documentation", "/// baz"),
("comment.documentation", "/// qux"),
("comment.documentation", "/// quux"),
("comment.documentation", "// quuz"),
],
},
];
allocations::record(|| {
for row in rows {
eprintln!(" quantified query example: {:?}", row.description);
let mut parser = Parser::new();
parser.set_language(row.language).unwrap();
let tree = parser.parse(row.code, None).unwrap();
let query = Query::new(row.language, row.pattern).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.captures(&query, tree.root_node(), row.code.as_bytes());
assert_eq!(collect_captures(matches, &query, row.code), row.captures);
}
});
}
#[test]
fn test_query_max_start_depth() {
struct Row {

View file

@ -118,7 +118,7 @@ pub struct Query {
ptr: NonNull<ffi::TSQuery>,
capture_names: Vec<String>,
capture_quantifiers: Vec<Vec<CaptureQuantifier>>,
text_predicates: Vec<Box<[TextPredicate]>>,
text_predicates: Vec<Box<[TextPredicateCapture]>>,
property_settings: Vec<Box<[QueryProperty]>>,
property_predicates: Vec<Box<[(QueryProperty, bool)]>>,
general_predicates: Vec<Box<[QueryPredicate]>>,
@ -250,11 +250,16 @@ pub enum QueryErrorKind {
}
#[derive(Debug)]
enum TextPredicate {
CaptureEqString(u32, String, bool),
CaptureEqCapture(u32, u32, bool),
CaptureMatchString(u32, regex::bytes::Regex, bool),
CaptureAnyString(u32, Vec<String>, bool),
/// The first item is the capture index
/// The next is capture specific, depending on what item is expected
/// The first bool is if the capture is positive
/// The last item is a bool signifying whether or not it's meant to match
/// any or all captures
enum TextPredicateCapture {
EqString(u32, String, bool, bool),
EqCapture(u32, u32, bool, bool),
MatchString(u32, regex::bytes::Regex, bool, bool),
AnyString(u32, Vec<String>, bool),
}
// TODO: Remove this struct at at some point. If `core::str::lossy::Utf8Lossy`
@ -1733,7 +1738,7 @@ impl Query {
// Build a predicate for each of the known predicate function names.
let operator_name = &string_values[p[0].value_id as usize];
match operator_name.as_str() {
"eq?" | "not-eq?" => {
"eq?" | "not-eq?" | "any-eq?" | "any-not-eq?" => {
if p.len() != 3 {
return Err(predicate_error(
row,
@ -1750,23 +1755,30 @@ impl Query {
)));
}
let is_positive = operator_name == "eq?";
let is_positive = operator_name == "eq?" || operator_name == "any-eq?";
let match_all = match operator_name.as_str() {
"eq?" | "not-eq?" => true,
"any-eq?" | "any-not-eq?" => false,
_ => unreachable!(),
};
text_predicates.push(if p[2].type_ == type_capture {
TextPredicate::CaptureEqCapture(
TextPredicateCapture::EqCapture(
p[1].value_id,
p[2].value_id,
is_positive,
match_all,
)
} else {
TextPredicate::CaptureEqString(
TextPredicateCapture::EqString(
p[1].value_id,
string_values[p[2].value_id as usize].clone(),
is_positive,
match_all,
)
});
}
"match?" | "not-match?" => {
"match?" | "not-match?" | "any-match?" | "any-not-match?" => {
if p.len() != 3 {
return Err(predicate_error(row, format!(
"Wrong number of arguments to #match? predicate. Expected 2, got {}.",
@ -1786,20 +1798,27 @@ impl Query {
)));
}
let is_positive = operator_name == "match?";
let is_positive =
operator_name == "match?" || operator_name == "any-match?";
let match_all = match operator_name.as_str() {
"match?" | "not-match?" => true,
"any-match?" | "any-not-match?" => false,
_ => unreachable!(),
};
let regex = &string_values[p[2].value_id as usize];
text_predicates.push(TextPredicate::CaptureMatchString(
text_predicates.push(TextPredicateCapture::MatchString(
p[1].value_id,
regex::bytes::Regex::new(regex).map_err(|_| {
predicate_error(row, format!("Invalid regex '{}'", regex))
})?,
is_positive,
match_all,
));
}
"set!" => property_settings.push(Self::parse_property(
row,
&operator_name,
operator_name,
&result.capture_names,
&string_values,
&p[1..],
@ -1808,7 +1827,7 @@ impl Query {
"is?" | "is-not?" => property_predicates.push((
Self::parse_property(
row,
&operator_name,
operator_name,
&result.capture_names,
&string_values,
&p[1..],
@ -1841,7 +1860,7 @@ impl Query {
}
values.push(string_values[arg.value_id as usize].clone());
}
text_predicates.push(TextPredicate::CaptureAnyString(
text_predicates.push(TextPredicateCapture::AnyString(
p[1].value_id,
values,
is_positive,
@ -2203,7 +2222,7 @@ impl<'tree> QueryMatch<'_, 'tree> {
) -> impl Iterator<Item = Node<'tree>> + '_ {
self.captures
.iter()
.filter_map(move |capture| (capture.index == capture_ix).then(|| capture.node))
.filter_map(move |capture| (capture.index == capture_ix).then_some(capture.node))
}
fn new(m: ffi::TSQueryMatch, cursor: *mut ffi::TSQueryCursor) -> Self {
@ -2266,52 +2285,61 @@ impl<'tree> QueryMatch<'_, 'tree> {
query.text_predicates[self.pattern_index]
.iter()
.all(|predicate| match predicate {
TextPredicate::CaptureEqCapture(i, j, is_positive) => {
let node1 = self.nodes_for_capture_index(*i).next();
let node2 = self.nodes_for_capture_index(*j).next();
match (node1, node2) {
(Some(node1), Some(node2)) => {
let mut text1 = text_provider.text(node1);
let mut text2 = text_provider.text(node2);
let text1 = node_text1.get_text(&mut text1);
let text2 = node_text2.get_text(&mut text2);
(text1 == text2) == *is_positive
TextPredicateCapture::EqCapture(i, j, is_positive, match_all_nodes) => {
let mut nodes_1 = self.nodes_for_capture_index(*i);
let mut nodes_2 = self.nodes_for_capture_index(*j);
while let (Some(node1), Some(node2)) = (nodes_1.next(), nodes_2.next()) {
let mut text1 = text_provider.text(node1);
let mut text2 = text_provider.text(node2);
let text1 = node_text1.get_text(&mut text1);
let text2 = node_text2.get_text(&mut text2);
if (text1 == text2) != *is_positive && *match_all_nodes {
return false;
}
if (text1 == text2) == *is_positive && !*match_all_nodes {
return true;
}
_ => true,
}
nodes_1.next().is_none() && nodes_2.next().is_none()
}
TextPredicate::CaptureEqString(i, s, is_positive) => {
let node = self.nodes_for_capture_index(*i).next();
match node {
Some(node) => {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
(text == s.as_bytes()) == *is_positive
TextPredicateCapture::EqString(i, s, is_positive, match_all_nodes) => {
let nodes = self.nodes_for_capture_index(*i);
for node in nodes {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
if (text == s.as_bytes()) != *is_positive && *match_all_nodes {
return false;
}
if (text == s.as_bytes()) == *is_positive && !*match_all_nodes {
return true;
}
None => true,
}
true
}
TextPredicate::CaptureMatchString(i, r, is_positive) => {
let node = self.nodes_for_capture_index(*i).next();
match node {
Some(node) => {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
r.is_match(text) == *is_positive
TextPredicateCapture::MatchString(i, r, is_positive, match_all_nodes) => {
let nodes = self.nodes_for_capture_index(*i);
for node in nodes {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
if (r.is_match(text)) != *is_positive && *match_all_nodes {
return false;
}
if (r.is_match(text)) == *is_positive && !*match_all_nodes {
return true;
}
None => true,
}
true
}
TextPredicate::CaptureAnyString(i, v, is_positive) => {
let node = self.nodes_for_capture_index(*i).next();
match node {
Some(node) => {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
v.iter().any(|s| text == s.as_bytes()) == *is_positive
TextPredicateCapture::AnyString(i, v, is_positive) => {
let nodes = self.nodes_for_capture_index(*i);
for node in nodes {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
if (v.iter().any(|s| text == s.as_bytes())) != *is_positive {
return false;
}
None => true,
}
true
}
})
}

View file

@ -841,7 +841,13 @@ class Language {
}
const operator = steps[0].value;
let isPositive = true;
let matchAll = true;
switch (operator) {
case 'any-not-eq?':
isPositive = false;
matchAll = false;
case 'any-eq?':
matchAll = false;
case 'not-eq?':
isPositive = false;
case 'eq?':
@ -855,28 +861,36 @@ class Language {
const captureName1 = steps[1].name;
const captureName2 = steps[2].name;
textPredicates[i].push(function(captures) {
let node1, node2
let nodes_1 = [];
let nodes_2 = [];
for (const c of captures) {
if (c.name === captureName1) node1 = c.node;
if (c.name === captureName2) node2 = c.node;
if (c.name === captureName1) nodes_1.push(c.node);
if (c.name === captureName2) nodes_2.push(c.node);
}
if(node1 === undefined || node2 === undefined) return true;
return (node1.text === node2.text) === isPositive;
return matchAll
? nodes_1.every(n1 => nodes_2.some(n2 => n1.text === n2.text)) === isPositive
: nodes_1.some(n1 => nodes_2.some(n2 => n1.text === n2.text)) === isPositive;
});
} else {
const captureName = steps[1].name;
const stringValue = steps[2].value;
textPredicates[i].push(function(captures) {
let nodes = [];
for (const c of captures) {
if (c.name === captureName) {
return (c.node.text === stringValue) === isPositive;
};
if (c.name === captureName) nodes.push(c.node);
}
return true;
return matchAll
? nodes.every(n => n.text === stringValue) === isPositive
: nodes.some(n => n.text === stringValue) === isPositive;
});
}
break;
case 'not-any-match?':
isPositive = false;
matchAll = false;
case 'any-match?':
matchAll = false;
case 'not-match?':
isPositive = false;
case 'match?':
@ -892,10 +906,14 @@ class Language {
const captureName = steps[1].name;
const regex = new RegExp(steps[2].value);
textPredicates[i].push(function(captures) {
const nodes = [];
for (const c of captures) {
if (c.name === captureName) return regex.test(c.node.text) === isPositive;
if (c.name === captureName) nodes.push(c.node.text);
}
return true;
if (nodes.length === 0) return !isPositive;
return matchAll
? nodes.every(text => regex.test(text)) === isPositive
: nodes.some(text => regex.test(text)) === isPositive;
});
break;
@ -923,6 +941,32 @@ class Language {
properties[i][steps[1].value] = steps[2] ? steps[2].value : null;
break;
case 'not-any-of?':
isPositive = false;
case 'any-of?':
if (steps.length < 2) throw new Error(
`Wrong number of arguments to \`#${operator}\` predicate. Expected at least 1. Got ${steps.length - 1}.`
);
if (steps[1].type !== 'capture') throw new Error(
`First argument of \`#${operator}\` predicate must be a capture. Got "${steps[1].value}".`
);
for (let i = 2; i < steps.length; i++) {
if (steps[i].type !== 'string') throw new Error(
`Arguments to \`#${operator}\` predicate must be a strings.".`
);
}
captureName = steps[1].name;
const values = steps.slice(2).map(s => s.value);
textPredicates[i].push(function(captures) {
const nodes = [];
for (const c of captures) {
if (c.name === captureName) nodes.push(c.node.text);
}
if (nodes.length === 0) return !isPositive;
return nodes.every(text => values.includes(text)) === isPositive;
});
break;
default:
predicates[i].push({operator, operands: steps.slice(1)});
}