From 09ac28c77d216964636ea054ba76bcf96a670933 Mon Sep 17 00:00:00 2001 From: Amaan Qureshi Date: Fri, 18 Aug 2023 19:45:00 -0400 Subject: [PATCH] feat!: properly handle predicates used on quantified captures --- cli/src/tests/query_test.rs | 83 +++++++++++++++++++++++ lib/binding_rust/lib.rs | 132 ++++++++++++++++++++++-------------- lib/binding_web/binding.js | 66 +++++++++++++++--- 3 files changed, 218 insertions(+), 63 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 94d5ca97..34cf40a9 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -4574,6 +4574,89 @@ fn test_capture_quantifiers() { }); } +#[test] +fn test_query_quantified_captures() { + struct Row { + description: &'static str, + language: Language, + code: &'static str, + pattern: &'static str, + captures: &'static [(&'static str, &'static str)], + } + + // #[rustfmt::skip] + let rows = &[ + Row { + description: "doc comments where all must match the prefiix", + language: get_language("c"), + code: indoc! {" + /// foo + /// bar + /// baz + + void main() {} + + /// qux + /// quux + // quuz + "}, + pattern: r#" + ((comment)+ @comment.documentation + (#match? @comment.documentation "^///")) + "#, + captures: &[ + ("comment.documentation", "/// foo"), + ("comment.documentation", "/// bar"), + ("comment.documentation", "/// baz"), + ], + }, + Row { + description: "doc comments where one must match the prefix", + language: get_language("c"), + code: indoc! {" + /// foo + /// bar + /// baz + + void main() {} + + /// qux + /// quux + // quuz + "}, + pattern: r#" + ((comment)+ @comment.documentation + (#any-match? @comment.documentation "^///")) + "#, + captures: &[ + ("comment.documentation", "/// foo"), + ("comment.documentation", "/// bar"), + ("comment.documentation", "/// baz"), + ("comment.documentation", "/// qux"), + ("comment.documentation", "/// quux"), + ("comment.documentation", "// quuz"), + ], + }, + ]; + + allocations::record(|| { + for row in rows { + eprintln!(" quantified query example: {:?}", row.description); + + let mut parser = Parser::new(); + parser.set_language(row.language).unwrap(); + let tree = parser.parse(row.code, None).unwrap(); + + let query = Query::new(row.language, row.pattern).unwrap(); + + let mut cursor = QueryCursor::new(); + let matches = cursor.captures(&query, tree.root_node(), row.code.as_bytes()); + + assert_eq!(collect_captures(matches, &query, row.code), row.captures); + } + }); +} + #[test] fn test_query_max_start_depth() { struct Row { diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 9cd04563..8762c7f6 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -118,7 +118,7 @@ pub struct Query { ptr: NonNull, capture_names: Vec, capture_quantifiers: Vec>, - text_predicates: Vec>, + text_predicates: Vec>, property_settings: Vec>, property_predicates: Vec>, general_predicates: Vec>, @@ -250,11 +250,16 @@ pub enum QueryErrorKind { } #[derive(Debug)] -enum TextPredicate { - CaptureEqString(u32, String, bool), - CaptureEqCapture(u32, u32, bool), - CaptureMatchString(u32, regex::bytes::Regex, bool), - CaptureAnyString(u32, Vec, bool), +/// The first item is the capture index +/// The next is capture specific, depending on what item is expected +/// The first bool is if the capture is positive +/// The last item is a bool signifying whether or not it's meant to match +/// any or all captures +enum TextPredicateCapture { + EqString(u32, String, bool, bool), + EqCapture(u32, u32, bool, bool), + MatchString(u32, regex::bytes::Regex, bool, bool), + AnyString(u32, Vec, bool), } // TODO: Remove this struct at at some point. If `core::str::lossy::Utf8Lossy` @@ -1733,7 +1738,7 @@ impl Query { // Build a predicate for each of the known predicate function names. let operator_name = &string_values[p[0].value_id as usize]; match operator_name.as_str() { - "eq?" | "not-eq?" => { + "eq?" | "not-eq?" | "any-eq?" | "any-not-eq?" => { if p.len() != 3 { return Err(predicate_error( row, @@ -1750,23 +1755,30 @@ impl Query { ))); } - let is_positive = operator_name == "eq?"; + let is_positive = operator_name == "eq?" || operator_name == "any-eq?"; + let match_all = match operator_name.as_str() { + "eq?" | "not-eq?" => true, + "any-eq?" | "any-not-eq?" => false, + _ => unreachable!(), + }; text_predicates.push(if p[2].type_ == type_capture { - TextPredicate::CaptureEqCapture( + TextPredicateCapture::EqCapture( p[1].value_id, p[2].value_id, is_positive, + match_all, ) } else { - TextPredicate::CaptureEqString( + TextPredicateCapture::EqString( p[1].value_id, string_values[p[2].value_id as usize].clone(), is_positive, + match_all, ) }); } - "match?" | "not-match?" => { + "match?" | "not-match?" | "any-match?" | "any-not-match?" => { if p.len() != 3 { return Err(predicate_error(row, format!( "Wrong number of arguments to #match? predicate. Expected 2, got {}.", @@ -1786,20 +1798,27 @@ impl Query { ))); } - let is_positive = operator_name == "match?"; + let is_positive = + operator_name == "match?" || operator_name == "any-match?"; + let match_all = match operator_name.as_str() { + "match?" | "not-match?" => true, + "any-match?" | "any-not-match?" => false, + _ => unreachable!(), + }; let regex = &string_values[p[2].value_id as usize]; - text_predicates.push(TextPredicate::CaptureMatchString( + text_predicates.push(TextPredicateCapture::MatchString( p[1].value_id, regex::bytes::Regex::new(regex).map_err(|_| { predicate_error(row, format!("Invalid regex '{}'", regex)) })?, is_positive, + match_all, )); } "set!" => property_settings.push(Self::parse_property( row, - &operator_name, + operator_name, &result.capture_names, &string_values, &p[1..], @@ -1808,7 +1827,7 @@ impl Query { "is?" | "is-not?" => property_predicates.push(( Self::parse_property( row, - &operator_name, + operator_name, &result.capture_names, &string_values, &p[1..], @@ -1841,7 +1860,7 @@ impl Query { } values.push(string_values[arg.value_id as usize].clone()); } - text_predicates.push(TextPredicate::CaptureAnyString( + text_predicates.push(TextPredicateCapture::AnyString( p[1].value_id, values, is_positive, @@ -2203,7 +2222,7 @@ impl<'tree> QueryMatch<'_, 'tree> { ) -> impl Iterator> + '_ { self.captures .iter() - .filter_map(move |capture| (capture.index == capture_ix).then(|| capture.node)) + .filter_map(move |capture| (capture.index == capture_ix).then_some(capture.node)) } fn new(m: ffi::TSQueryMatch, cursor: *mut ffi::TSQueryCursor) -> Self { @@ -2266,52 +2285,61 @@ impl<'tree> QueryMatch<'_, 'tree> { query.text_predicates[self.pattern_index] .iter() .all(|predicate| match predicate { - TextPredicate::CaptureEqCapture(i, j, is_positive) => { - let node1 = self.nodes_for_capture_index(*i).next(); - let node2 = self.nodes_for_capture_index(*j).next(); - match (node1, node2) { - (Some(node1), Some(node2)) => { - let mut text1 = text_provider.text(node1); - let mut text2 = text_provider.text(node2); - let text1 = node_text1.get_text(&mut text1); - let text2 = node_text2.get_text(&mut text2); - (text1 == text2) == *is_positive + TextPredicateCapture::EqCapture(i, j, is_positive, match_all_nodes) => { + let mut nodes_1 = self.nodes_for_capture_index(*i); + let mut nodes_2 = self.nodes_for_capture_index(*j); + while let (Some(node1), Some(node2)) = (nodes_1.next(), nodes_2.next()) { + let mut text1 = text_provider.text(node1); + let mut text2 = text_provider.text(node2); + let text1 = node_text1.get_text(&mut text1); + let text2 = node_text2.get_text(&mut text2); + if (text1 == text2) != *is_positive && *match_all_nodes { + return false; + } + if (text1 == text2) == *is_positive && !*match_all_nodes { + return true; } - _ => true, } + nodes_1.next().is_none() && nodes_2.next().is_none() } - TextPredicate::CaptureEqString(i, s, is_positive) => { - let node = self.nodes_for_capture_index(*i).next(); - match node { - Some(node) => { - let mut text = text_provider.text(node); - let text = node_text1.get_text(&mut text); - (text == s.as_bytes()) == *is_positive + TextPredicateCapture::EqString(i, s, is_positive, match_all_nodes) => { + let nodes = self.nodes_for_capture_index(*i); + for node in nodes { + let mut text = text_provider.text(node); + let text = node_text1.get_text(&mut text); + if (text == s.as_bytes()) != *is_positive && *match_all_nodes { + return false; + } + if (text == s.as_bytes()) == *is_positive && !*match_all_nodes { + return true; } - None => true, } + true } - TextPredicate::CaptureMatchString(i, r, is_positive) => { - let node = self.nodes_for_capture_index(*i).next(); - match node { - Some(node) => { - let mut text = text_provider.text(node); - let text = node_text1.get_text(&mut text); - r.is_match(text) == *is_positive + TextPredicateCapture::MatchString(i, r, is_positive, match_all_nodes) => { + let nodes = self.nodes_for_capture_index(*i); + for node in nodes { + let mut text = text_provider.text(node); + let text = node_text1.get_text(&mut text); + if (r.is_match(text)) != *is_positive && *match_all_nodes { + return false; + } + if (r.is_match(text)) == *is_positive && !*match_all_nodes { + return true; } - None => true, } + true } - TextPredicate::CaptureAnyString(i, v, is_positive) => { - let node = self.nodes_for_capture_index(*i).next(); - match node { - Some(node) => { - let mut text = text_provider.text(node); - let text = node_text1.get_text(&mut text); - v.iter().any(|s| text == s.as_bytes()) == *is_positive + TextPredicateCapture::AnyString(i, v, is_positive) => { + let nodes = self.nodes_for_capture_index(*i); + for node in nodes { + let mut text = text_provider.text(node); + let text = node_text1.get_text(&mut text); + if (v.iter().any(|s| text == s.as_bytes())) != *is_positive { + return false; } - None => true, } + true } }) } diff --git a/lib/binding_web/binding.js b/lib/binding_web/binding.js index 8443bf25..0ba30106 100644 --- a/lib/binding_web/binding.js +++ b/lib/binding_web/binding.js @@ -841,7 +841,13 @@ class Language { } const operator = steps[0].value; let isPositive = true; + let matchAll = true; switch (operator) { + case 'any-not-eq?': + isPositive = false; + matchAll = false; + case 'any-eq?': + matchAll = false; case 'not-eq?': isPositive = false; case 'eq?': @@ -855,28 +861,36 @@ class Language { const captureName1 = steps[1].name; const captureName2 = steps[2].name; textPredicates[i].push(function(captures) { - let node1, node2 + let nodes_1 = []; + let nodes_2 = []; for (const c of captures) { - if (c.name === captureName1) node1 = c.node; - if (c.name === captureName2) node2 = c.node; + if (c.name === captureName1) nodes_1.push(c.node); + if (c.name === captureName2) nodes_2.push(c.node); } - if(node1 === undefined || node2 === undefined) return true; - return (node1.text === node2.text) === isPositive; + return matchAll + ? nodes_1.every(n1 => nodes_2.some(n2 => n1.text === n2.text)) === isPositive + : nodes_1.some(n1 => nodes_2.some(n2 => n1.text === n2.text)) === isPositive; }); } else { const captureName = steps[1].name; const stringValue = steps[2].value; textPredicates[i].push(function(captures) { + let nodes = []; for (const c of captures) { - if (c.name === captureName) { - return (c.node.text === stringValue) === isPositive; - }; + if (c.name === captureName) nodes.push(c.node); } - return true; + return matchAll + ? nodes.every(n => n.text === stringValue) === isPositive + : nodes.some(n => n.text === stringValue) === isPositive; }); } break; + case 'not-any-match?': + isPositive = false; + matchAll = false; + case 'any-match?': + matchAll = false; case 'not-match?': isPositive = false; case 'match?': @@ -892,10 +906,14 @@ class Language { const captureName = steps[1].name; const regex = new RegExp(steps[2].value); textPredicates[i].push(function(captures) { + const nodes = []; for (const c of captures) { - if (c.name === captureName) return regex.test(c.node.text) === isPositive; + if (c.name === captureName) nodes.push(c.node.text); } - return true; + if (nodes.length === 0) return !isPositive; + return matchAll + ? nodes.every(text => regex.test(text)) === isPositive + : nodes.some(text => regex.test(text)) === isPositive; }); break; @@ -923,6 +941,32 @@ class Language { properties[i][steps[1].value] = steps[2] ? steps[2].value : null; break; + case 'not-any-of?': + isPositive = false; + case 'any-of?': + if (steps.length < 2) throw new Error( + `Wrong number of arguments to \`#${operator}\` predicate. Expected at least 1. Got ${steps.length - 1}.` + ); + if (steps[1].type !== 'capture') throw new Error( + `First argument of \`#${operator}\` predicate must be a capture. Got "${steps[1].value}".` + ); + for (let i = 2; i < steps.length; i++) { + if (steps[i].type !== 'string') throw new Error( + `Arguments to \`#${operator}\` predicate must be a strings.".` + ); + } + captureName = steps[1].name; + const values = steps.slice(2).map(s => s.value); + textPredicates[i].push(function(captures) { + const nodes = []; + for (const c of captures) { + if (c.name === captureName) nodes.push(c.node.text); + } + if (nodes.length === 0) return !isPositive; + return nodes.every(text => values.includes(text)) === isPositive; + }); + break; + default: predicates[i].push({operator, operands: steps.slice(1)}); }