diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 05c8a0a1..09aac8a7 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2988,6 +2988,61 @@ fn test_query_matches_with_deeply_nested_patterns_with_fields() { }); } +#[test] +fn test_query_matches_with_alternations_and_predicates() { + allocations::record(|| { + let language = get_language("java"); + let query = Query::new( + &language, + " + (block + [ + (local_variable_declaration + (variable_declarator + (identifier) @def.a + (string_literal) @lit.a + ) + ) + (local_variable_declaration + (variable_declarator + (identifier) @def.b + (null_literal) @lit.b + ) + ) + ] + (expression_statement + (method_invocation [ + (argument_list + (identifier) @ref.a + (string_literal) + ) + (argument_list + (null_literal) + (identifier) @ref.b + ) + ]) + ) + (#eq? @def.a @ref.a ) + (#eq? @def.b @ref.b ) + ) + ", + ) + .unwrap(); + + assert_query_matches( + &language, + &query, + r#" + void test() { + int a = "foo"; + f(null, b); + } + "#, + &[], + ); + }); +} + #[test] fn test_query_matches_with_indefinite_step_containing_no_captures() { allocations::record(|| { diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index b37ecf1d..3574391f 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -3353,9 +3353,11 @@ impl<'tree> QueryMatch<'_, 'tree> { .iter() .all(|predicate| match predicate { TextPredicateCapture::EqCapture(i, j, is_positive, match_all_nodes) => { - let mut nodes_1 = self.nodes_for_capture_index(*i); - let mut nodes_2 = self.nodes_for_capture_index(*j); - while let (Some(node1), Some(node2)) = (nodes_1.next(), nodes_2.next()) { + let mut nodes_1 = self.nodes_for_capture_index(*i).peekable(); + let mut nodes_2 = self.nodes_for_capture_index(*j).peekable(); + while nodes_1.peek().is_some() && nodes_2.peek().is_some() { + let node1 = nodes_1.next().unwrap(); + let node2 = nodes_2.next().unwrap(); let mut text1 = text_provider.text(node1); let mut text2 = text_provider.text(node2); let text1 = node_text1.get_text(&mut text1);