feat!: properly handle predicates used on quantified captures
This commit is contained in:
parent
a0cf0a7104
commit
09ac28c77d
3 changed files with 218 additions and 63 deletions
|
|
@ -4574,6 +4574,89 @@ fn test_capture_quantifiers() {
|
|||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_quantified_captures() {
|
||||
struct Row {
|
||||
description: &'static str,
|
||||
language: Language,
|
||||
code: &'static str,
|
||||
pattern: &'static str,
|
||||
captures: &'static [(&'static str, &'static str)],
|
||||
}
|
||||
|
||||
// #[rustfmt::skip]
|
||||
let rows = &[
|
||||
Row {
|
||||
description: "doc comments where all must match the prefiix",
|
||||
language: get_language("c"),
|
||||
code: indoc! {"
|
||||
/// foo
|
||||
/// bar
|
||||
/// baz
|
||||
|
||||
void main() {}
|
||||
|
||||
/// qux
|
||||
/// quux
|
||||
// quuz
|
||||
"},
|
||||
pattern: r#"
|
||||
((comment)+ @comment.documentation
|
||||
(#match? @comment.documentation "^///"))
|
||||
"#,
|
||||
captures: &[
|
||||
("comment.documentation", "/// foo"),
|
||||
("comment.documentation", "/// bar"),
|
||||
("comment.documentation", "/// baz"),
|
||||
],
|
||||
},
|
||||
Row {
|
||||
description: "doc comments where one must match the prefix",
|
||||
language: get_language("c"),
|
||||
code: indoc! {"
|
||||
/// foo
|
||||
/// bar
|
||||
/// baz
|
||||
|
||||
void main() {}
|
||||
|
||||
/// qux
|
||||
/// quux
|
||||
// quuz
|
||||
"},
|
||||
pattern: r#"
|
||||
((comment)+ @comment.documentation
|
||||
(#any-match? @comment.documentation "^///"))
|
||||
"#,
|
||||
captures: &[
|
||||
("comment.documentation", "/// foo"),
|
||||
("comment.documentation", "/// bar"),
|
||||
("comment.documentation", "/// baz"),
|
||||
("comment.documentation", "/// qux"),
|
||||
("comment.documentation", "/// quux"),
|
||||
("comment.documentation", "// quuz"),
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
allocations::record(|| {
|
||||
for row in rows {
|
||||
eprintln!(" quantified query example: {:?}", row.description);
|
||||
|
||||
let mut parser = Parser::new();
|
||||
parser.set_language(row.language).unwrap();
|
||||
let tree = parser.parse(row.code, None).unwrap();
|
||||
|
||||
let query = Query::new(row.language, row.pattern).unwrap();
|
||||
|
||||
let mut cursor = QueryCursor::new();
|
||||
let matches = cursor.captures(&query, tree.root_node(), row.code.as_bytes());
|
||||
|
||||
assert_eq!(collect_captures(matches, &query, row.code), row.captures);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_max_start_depth() {
|
||||
struct Row {
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ pub struct Query {
|
|||
ptr: NonNull<ffi::TSQuery>,
|
||||
capture_names: Vec<String>,
|
||||
capture_quantifiers: Vec<Vec<CaptureQuantifier>>,
|
||||
text_predicates: Vec<Box<[TextPredicate]>>,
|
||||
text_predicates: Vec<Box<[TextPredicateCapture]>>,
|
||||
property_settings: Vec<Box<[QueryProperty]>>,
|
||||
property_predicates: Vec<Box<[(QueryProperty, bool)]>>,
|
||||
general_predicates: Vec<Box<[QueryPredicate]>>,
|
||||
|
|
@ -250,11 +250,16 @@ pub enum QueryErrorKind {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TextPredicate {
|
||||
CaptureEqString(u32, String, bool),
|
||||
CaptureEqCapture(u32, u32, bool),
|
||||
CaptureMatchString(u32, regex::bytes::Regex, bool),
|
||||
CaptureAnyString(u32, Vec<String>, bool),
|
||||
/// The first item is the capture index
|
||||
/// The next is capture specific, depending on what item is expected
|
||||
/// The first bool is if the capture is positive
|
||||
/// The last item is a bool signifying whether or not it's meant to match
|
||||
/// any or all captures
|
||||
enum TextPredicateCapture {
|
||||
EqString(u32, String, bool, bool),
|
||||
EqCapture(u32, u32, bool, bool),
|
||||
MatchString(u32, regex::bytes::Regex, bool, bool),
|
||||
AnyString(u32, Vec<String>, bool),
|
||||
}
|
||||
|
||||
// TODO: Remove this struct at at some point. If `core::str::lossy::Utf8Lossy`
|
||||
|
|
@ -1733,7 +1738,7 @@ impl Query {
|
|||
// Build a predicate for each of the known predicate function names.
|
||||
let operator_name = &string_values[p[0].value_id as usize];
|
||||
match operator_name.as_str() {
|
||||
"eq?" | "not-eq?" => {
|
||||
"eq?" | "not-eq?" | "any-eq?" | "any-not-eq?" => {
|
||||
if p.len() != 3 {
|
||||
return Err(predicate_error(
|
||||
row,
|
||||
|
|
@ -1750,23 +1755,30 @@ impl Query {
|
|||
)));
|
||||
}
|
||||
|
||||
let is_positive = operator_name == "eq?";
|
||||
let is_positive = operator_name == "eq?" || operator_name == "any-eq?";
|
||||
let match_all = match operator_name.as_str() {
|
||||
"eq?" | "not-eq?" => true,
|
||||
"any-eq?" | "any-not-eq?" => false,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
text_predicates.push(if p[2].type_ == type_capture {
|
||||
TextPredicate::CaptureEqCapture(
|
||||
TextPredicateCapture::EqCapture(
|
||||
p[1].value_id,
|
||||
p[2].value_id,
|
||||
is_positive,
|
||||
match_all,
|
||||
)
|
||||
} else {
|
||||
TextPredicate::CaptureEqString(
|
||||
TextPredicateCapture::EqString(
|
||||
p[1].value_id,
|
||||
string_values[p[2].value_id as usize].clone(),
|
||||
is_positive,
|
||||
match_all,
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
"match?" | "not-match?" => {
|
||||
"match?" | "not-match?" | "any-match?" | "any-not-match?" => {
|
||||
if p.len() != 3 {
|
||||
return Err(predicate_error(row, format!(
|
||||
"Wrong number of arguments to #match? predicate. Expected 2, got {}.",
|
||||
|
|
@ -1786,20 +1798,27 @@ impl Query {
|
|||
)));
|
||||
}
|
||||
|
||||
let is_positive = operator_name == "match?";
|
||||
let is_positive =
|
||||
operator_name == "match?" || operator_name == "any-match?";
|
||||
let match_all = match operator_name.as_str() {
|
||||
"match?" | "not-match?" => true,
|
||||
"any-match?" | "any-not-match?" => false,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let regex = &string_values[p[2].value_id as usize];
|
||||
text_predicates.push(TextPredicate::CaptureMatchString(
|
||||
text_predicates.push(TextPredicateCapture::MatchString(
|
||||
p[1].value_id,
|
||||
regex::bytes::Regex::new(regex).map_err(|_| {
|
||||
predicate_error(row, format!("Invalid regex '{}'", regex))
|
||||
})?,
|
||||
is_positive,
|
||||
match_all,
|
||||
));
|
||||
}
|
||||
|
||||
"set!" => property_settings.push(Self::parse_property(
|
||||
row,
|
||||
&operator_name,
|
||||
operator_name,
|
||||
&result.capture_names,
|
||||
&string_values,
|
||||
&p[1..],
|
||||
|
|
@ -1808,7 +1827,7 @@ impl Query {
|
|||
"is?" | "is-not?" => property_predicates.push((
|
||||
Self::parse_property(
|
||||
row,
|
||||
&operator_name,
|
||||
operator_name,
|
||||
&result.capture_names,
|
||||
&string_values,
|
||||
&p[1..],
|
||||
|
|
@ -1841,7 +1860,7 @@ impl Query {
|
|||
}
|
||||
values.push(string_values[arg.value_id as usize].clone());
|
||||
}
|
||||
text_predicates.push(TextPredicate::CaptureAnyString(
|
||||
text_predicates.push(TextPredicateCapture::AnyString(
|
||||
p[1].value_id,
|
||||
values,
|
||||
is_positive,
|
||||
|
|
@ -2203,7 +2222,7 @@ impl<'tree> QueryMatch<'_, 'tree> {
|
|||
) -> impl Iterator<Item = Node<'tree>> + '_ {
|
||||
self.captures
|
||||
.iter()
|
||||
.filter_map(move |capture| (capture.index == capture_ix).then(|| capture.node))
|
||||
.filter_map(move |capture| (capture.index == capture_ix).then_some(capture.node))
|
||||
}
|
||||
|
||||
fn new(m: ffi::TSQueryMatch, cursor: *mut ffi::TSQueryCursor) -> Self {
|
||||
|
|
@ -2266,52 +2285,61 @@ impl<'tree> QueryMatch<'_, 'tree> {
|
|||
query.text_predicates[self.pattern_index]
|
||||
.iter()
|
||||
.all(|predicate| match predicate {
|
||||
TextPredicate::CaptureEqCapture(i, j, is_positive) => {
|
||||
let node1 = self.nodes_for_capture_index(*i).next();
|
||||
let node2 = self.nodes_for_capture_index(*j).next();
|
||||
match (node1, node2) {
|
||||
(Some(node1), Some(node2)) => {
|
||||
let mut text1 = text_provider.text(node1);
|
||||
let mut text2 = text_provider.text(node2);
|
||||
let text1 = node_text1.get_text(&mut text1);
|
||||
let text2 = node_text2.get_text(&mut text2);
|
||||
(text1 == text2) == *is_positive
|
||||
TextPredicateCapture::EqCapture(i, j, is_positive, match_all_nodes) => {
|
||||
let mut nodes_1 = self.nodes_for_capture_index(*i);
|
||||
let mut nodes_2 = self.nodes_for_capture_index(*j);
|
||||
while let (Some(node1), Some(node2)) = (nodes_1.next(), nodes_2.next()) {
|
||||
let mut text1 = text_provider.text(node1);
|
||||
let mut text2 = text_provider.text(node2);
|
||||
let text1 = node_text1.get_text(&mut text1);
|
||||
let text2 = node_text2.get_text(&mut text2);
|
||||
if (text1 == text2) != *is_positive && *match_all_nodes {
|
||||
return false;
|
||||
}
|
||||
if (text1 == text2) == *is_positive && !*match_all_nodes {
|
||||
return true;
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
nodes_1.next().is_none() && nodes_2.next().is_none()
|
||||
}
|
||||
TextPredicate::CaptureEqString(i, s, is_positive) => {
|
||||
let node = self.nodes_for_capture_index(*i).next();
|
||||
match node {
|
||||
Some(node) => {
|
||||
let mut text = text_provider.text(node);
|
||||
let text = node_text1.get_text(&mut text);
|
||||
(text == s.as_bytes()) == *is_positive
|
||||
TextPredicateCapture::EqString(i, s, is_positive, match_all_nodes) => {
|
||||
let nodes = self.nodes_for_capture_index(*i);
|
||||
for node in nodes {
|
||||
let mut text = text_provider.text(node);
|
||||
let text = node_text1.get_text(&mut text);
|
||||
if (text == s.as_bytes()) != *is_positive && *match_all_nodes {
|
||||
return false;
|
||||
}
|
||||
if (text == s.as_bytes()) == *is_positive && !*match_all_nodes {
|
||||
return true;
|
||||
}
|
||||
None => true,
|
||||
}
|
||||
true
|
||||
}
|
||||
TextPredicate::CaptureMatchString(i, r, is_positive) => {
|
||||
let node = self.nodes_for_capture_index(*i).next();
|
||||
match node {
|
||||
Some(node) => {
|
||||
let mut text = text_provider.text(node);
|
||||
let text = node_text1.get_text(&mut text);
|
||||
r.is_match(text) == *is_positive
|
||||
TextPredicateCapture::MatchString(i, r, is_positive, match_all_nodes) => {
|
||||
let nodes = self.nodes_for_capture_index(*i);
|
||||
for node in nodes {
|
||||
let mut text = text_provider.text(node);
|
||||
let text = node_text1.get_text(&mut text);
|
||||
if (r.is_match(text)) != *is_positive && *match_all_nodes {
|
||||
return false;
|
||||
}
|
||||
if (r.is_match(text)) == *is_positive && !*match_all_nodes {
|
||||
return true;
|
||||
}
|
||||
None => true,
|
||||
}
|
||||
true
|
||||
}
|
||||
TextPredicate::CaptureAnyString(i, v, is_positive) => {
|
||||
let node = self.nodes_for_capture_index(*i).next();
|
||||
match node {
|
||||
Some(node) => {
|
||||
let mut text = text_provider.text(node);
|
||||
let text = node_text1.get_text(&mut text);
|
||||
v.iter().any(|s| text == s.as_bytes()) == *is_positive
|
||||
TextPredicateCapture::AnyString(i, v, is_positive) => {
|
||||
let nodes = self.nodes_for_capture_index(*i);
|
||||
for node in nodes {
|
||||
let mut text = text_provider.text(node);
|
||||
let text = node_text1.get_text(&mut text);
|
||||
if (v.iter().any(|s| text == s.as_bytes())) != *is_positive {
|
||||
return false;
|
||||
}
|
||||
None => true,
|
||||
}
|
||||
true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -841,7 +841,13 @@ class Language {
|
|||
}
|
||||
const operator = steps[0].value;
|
||||
let isPositive = true;
|
||||
let matchAll = true;
|
||||
switch (operator) {
|
||||
case 'any-not-eq?':
|
||||
isPositive = false;
|
||||
matchAll = false;
|
||||
case 'any-eq?':
|
||||
matchAll = false;
|
||||
case 'not-eq?':
|
||||
isPositive = false;
|
||||
case 'eq?':
|
||||
|
|
@ -855,28 +861,36 @@ class Language {
|
|||
const captureName1 = steps[1].name;
|
||||
const captureName2 = steps[2].name;
|
||||
textPredicates[i].push(function(captures) {
|
||||
let node1, node2
|
||||
let nodes_1 = [];
|
||||
let nodes_2 = [];
|
||||
for (const c of captures) {
|
||||
if (c.name === captureName1) node1 = c.node;
|
||||
if (c.name === captureName2) node2 = c.node;
|
||||
if (c.name === captureName1) nodes_1.push(c.node);
|
||||
if (c.name === captureName2) nodes_2.push(c.node);
|
||||
}
|
||||
if(node1 === undefined || node2 === undefined) return true;
|
||||
return (node1.text === node2.text) === isPositive;
|
||||
return matchAll
|
||||
? nodes_1.every(n1 => nodes_2.some(n2 => n1.text === n2.text)) === isPositive
|
||||
: nodes_1.some(n1 => nodes_2.some(n2 => n1.text === n2.text)) === isPositive;
|
||||
});
|
||||
} else {
|
||||
const captureName = steps[1].name;
|
||||
const stringValue = steps[2].value;
|
||||
textPredicates[i].push(function(captures) {
|
||||
let nodes = [];
|
||||
for (const c of captures) {
|
||||
if (c.name === captureName) {
|
||||
return (c.node.text === stringValue) === isPositive;
|
||||
};
|
||||
if (c.name === captureName) nodes.push(c.node);
|
||||
}
|
||||
return true;
|
||||
return matchAll
|
||||
? nodes.every(n => n.text === stringValue) === isPositive
|
||||
: nodes.some(n => n.text === stringValue) === isPositive;
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case 'not-any-match?':
|
||||
isPositive = false;
|
||||
matchAll = false;
|
||||
case 'any-match?':
|
||||
matchAll = false;
|
||||
case 'not-match?':
|
||||
isPositive = false;
|
||||
case 'match?':
|
||||
|
|
@ -892,10 +906,14 @@ class Language {
|
|||
const captureName = steps[1].name;
|
||||
const regex = new RegExp(steps[2].value);
|
||||
textPredicates[i].push(function(captures) {
|
||||
const nodes = [];
|
||||
for (const c of captures) {
|
||||
if (c.name === captureName) return regex.test(c.node.text) === isPositive;
|
||||
if (c.name === captureName) nodes.push(c.node.text);
|
||||
}
|
||||
return true;
|
||||
if (nodes.length === 0) return !isPositive;
|
||||
return matchAll
|
||||
? nodes.every(text => regex.test(text)) === isPositive
|
||||
: nodes.some(text => regex.test(text)) === isPositive;
|
||||
});
|
||||
break;
|
||||
|
||||
|
|
@ -923,6 +941,32 @@ class Language {
|
|||
properties[i][steps[1].value] = steps[2] ? steps[2].value : null;
|
||||
break;
|
||||
|
||||
case 'not-any-of?':
|
||||
isPositive = false;
|
||||
case 'any-of?':
|
||||
if (steps.length < 2) throw new Error(
|
||||
`Wrong number of arguments to \`#${operator}\` predicate. Expected at least 1. Got ${steps.length - 1}.`
|
||||
);
|
||||
if (steps[1].type !== 'capture') throw new Error(
|
||||
`First argument of \`#${operator}\` predicate must be a capture. Got "${steps[1].value}".`
|
||||
);
|
||||
for (let i = 2; i < steps.length; i++) {
|
||||
if (steps[i].type !== 'string') throw new Error(
|
||||
`Arguments to \`#${operator}\` predicate must be a strings.".`
|
||||
);
|
||||
}
|
||||
captureName = steps[1].name;
|
||||
const values = steps.slice(2).map(s => s.value);
|
||||
textPredicates[i].push(function(captures) {
|
||||
const nodes = [];
|
||||
for (const c of captures) {
|
||||
if (c.name === captureName) nodes.push(c.node.text);
|
||||
}
|
||||
if (nodes.length === 0) return !isPositive;
|
||||
return nodes.every(text => values.includes(text)) === isPositive;
|
||||
});
|
||||
break;
|
||||
|
||||
default:
|
||||
predicates[i].push({operator, operands: steps.slice(1)});
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue