feat!: properly handle predicates used on quantified captures

This commit is contained in:
Amaan Qureshi 2023-08-18 19:45:00 -04:00
parent a0cf0a7104
commit 09ac28c77d
No known key found for this signature in database
GPG key ID: E67890ADC4227273
3 changed files with 218 additions and 63 deletions

View file

@ -118,7 +118,7 @@ pub struct Query {
ptr: NonNull<ffi::TSQuery>,
capture_names: Vec<String>,
capture_quantifiers: Vec<Vec<CaptureQuantifier>>,
text_predicates: Vec<Box<[TextPredicate]>>,
text_predicates: Vec<Box<[TextPredicateCapture]>>,
property_settings: Vec<Box<[QueryProperty]>>,
property_predicates: Vec<Box<[(QueryProperty, bool)]>>,
general_predicates: Vec<Box<[QueryPredicate]>>,
@ -250,11 +250,16 @@ pub enum QueryErrorKind {
}
#[derive(Debug)]
enum TextPredicate {
CaptureEqString(u32, String, bool),
CaptureEqCapture(u32, u32, bool),
CaptureMatchString(u32, regex::bytes::Regex, bool),
CaptureAnyString(u32, Vec<String>, bool),
/// The first item is the capture index
/// The next is capture specific, depending on what item is expected
/// The first bool is if the capture is positive
/// The last item is a bool signifying whether or not it's meant to match
/// any or all captures
enum TextPredicateCapture {
EqString(u32, String, bool, bool),
EqCapture(u32, u32, bool, bool),
MatchString(u32, regex::bytes::Regex, bool, bool),
AnyString(u32, Vec<String>, bool),
}
// TODO: Remove this struct at at some point. If `core::str::lossy::Utf8Lossy`
@ -1733,7 +1738,7 @@ impl Query {
// Build a predicate for each of the known predicate function names.
let operator_name = &string_values[p[0].value_id as usize];
match operator_name.as_str() {
"eq?" | "not-eq?" => {
"eq?" | "not-eq?" | "any-eq?" | "any-not-eq?" => {
if p.len() != 3 {
return Err(predicate_error(
row,
@ -1750,23 +1755,30 @@ impl Query {
)));
}
let is_positive = operator_name == "eq?";
let is_positive = operator_name == "eq?" || operator_name == "any-eq?";
let match_all = match operator_name.as_str() {
"eq?" | "not-eq?" => true,
"any-eq?" | "any-not-eq?" => false,
_ => unreachable!(),
};
text_predicates.push(if p[2].type_ == type_capture {
TextPredicate::CaptureEqCapture(
TextPredicateCapture::EqCapture(
p[1].value_id,
p[2].value_id,
is_positive,
match_all,
)
} else {
TextPredicate::CaptureEqString(
TextPredicateCapture::EqString(
p[1].value_id,
string_values[p[2].value_id as usize].clone(),
is_positive,
match_all,
)
});
}
"match?" | "not-match?" => {
"match?" | "not-match?" | "any-match?" | "any-not-match?" => {
if p.len() != 3 {
return Err(predicate_error(row, format!(
"Wrong number of arguments to #match? predicate. Expected 2, got {}.",
@ -1786,20 +1798,27 @@ impl Query {
)));
}
let is_positive = operator_name == "match?";
let is_positive =
operator_name == "match?" || operator_name == "any-match?";
let match_all = match operator_name.as_str() {
"match?" | "not-match?" => true,
"any-match?" | "any-not-match?" => false,
_ => unreachable!(),
};
let regex = &string_values[p[2].value_id as usize];
text_predicates.push(TextPredicate::CaptureMatchString(
text_predicates.push(TextPredicateCapture::MatchString(
p[1].value_id,
regex::bytes::Regex::new(regex).map_err(|_| {
predicate_error(row, format!("Invalid regex '{}'", regex))
})?,
is_positive,
match_all,
));
}
"set!" => property_settings.push(Self::parse_property(
row,
&operator_name,
operator_name,
&result.capture_names,
&string_values,
&p[1..],
@ -1808,7 +1827,7 @@ impl Query {
"is?" | "is-not?" => property_predicates.push((
Self::parse_property(
row,
&operator_name,
operator_name,
&result.capture_names,
&string_values,
&p[1..],
@ -1841,7 +1860,7 @@ impl Query {
}
values.push(string_values[arg.value_id as usize].clone());
}
text_predicates.push(TextPredicate::CaptureAnyString(
text_predicates.push(TextPredicateCapture::AnyString(
p[1].value_id,
values,
is_positive,
@ -2203,7 +2222,7 @@ impl<'tree> QueryMatch<'_, 'tree> {
) -> impl Iterator<Item = Node<'tree>> + '_ {
self.captures
.iter()
.filter_map(move |capture| (capture.index == capture_ix).then(|| capture.node))
.filter_map(move |capture| (capture.index == capture_ix).then_some(capture.node))
}
fn new(m: ffi::TSQueryMatch, cursor: *mut ffi::TSQueryCursor) -> Self {
@ -2266,52 +2285,61 @@ impl<'tree> QueryMatch<'_, 'tree> {
query.text_predicates[self.pattern_index]
.iter()
.all(|predicate| match predicate {
TextPredicate::CaptureEqCapture(i, j, is_positive) => {
let node1 = self.nodes_for_capture_index(*i).next();
let node2 = self.nodes_for_capture_index(*j).next();
match (node1, node2) {
(Some(node1), Some(node2)) => {
let mut text1 = text_provider.text(node1);
let mut text2 = text_provider.text(node2);
let text1 = node_text1.get_text(&mut text1);
let text2 = node_text2.get_text(&mut text2);
(text1 == text2) == *is_positive
TextPredicateCapture::EqCapture(i, j, is_positive, match_all_nodes) => {
let mut nodes_1 = self.nodes_for_capture_index(*i);
let mut nodes_2 = self.nodes_for_capture_index(*j);
while let (Some(node1), Some(node2)) = (nodes_1.next(), nodes_2.next()) {
let mut text1 = text_provider.text(node1);
let mut text2 = text_provider.text(node2);
let text1 = node_text1.get_text(&mut text1);
let text2 = node_text2.get_text(&mut text2);
if (text1 == text2) != *is_positive && *match_all_nodes {
return false;
}
if (text1 == text2) == *is_positive && !*match_all_nodes {
return true;
}
_ => true,
}
nodes_1.next().is_none() && nodes_2.next().is_none()
}
TextPredicate::CaptureEqString(i, s, is_positive) => {
let node = self.nodes_for_capture_index(*i).next();
match node {
Some(node) => {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
(text == s.as_bytes()) == *is_positive
TextPredicateCapture::EqString(i, s, is_positive, match_all_nodes) => {
let nodes = self.nodes_for_capture_index(*i);
for node in nodes {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
if (text == s.as_bytes()) != *is_positive && *match_all_nodes {
return false;
}
if (text == s.as_bytes()) == *is_positive && !*match_all_nodes {
return true;
}
None => true,
}
true
}
TextPredicate::CaptureMatchString(i, r, is_positive) => {
let node = self.nodes_for_capture_index(*i).next();
match node {
Some(node) => {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
r.is_match(text) == *is_positive
TextPredicateCapture::MatchString(i, r, is_positive, match_all_nodes) => {
let nodes = self.nodes_for_capture_index(*i);
for node in nodes {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
if (r.is_match(text)) != *is_positive && *match_all_nodes {
return false;
}
if (r.is_match(text)) == *is_positive && !*match_all_nodes {
return true;
}
None => true,
}
true
}
TextPredicate::CaptureAnyString(i, v, is_positive) => {
let node = self.nodes_for_capture_index(*i).next();
match node {
Some(node) => {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
v.iter().any(|s| text == s.as_bytes()) == *is_positive
TextPredicateCapture::AnyString(i, v, is_positive) => {
let nodes = self.nodes_for_capture_index(*i);
for node in nodes {
let mut text = text_provider.text(node);
let text = node_text1.get_text(&mut text);
if (v.iter().any(|s| text == s.as_bytes())) != *is_positive {
return false;
}
None => true,
}
true
}
})
}