diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index d187c965..c06fa01a 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1721,6 +1721,19 @@ impl<'a, 'tree> QueryMatch<'a, 'tree> { unsafe { ffi::ts_query_cursor_remove_match(self.cursor, self.id) } } + pub fn nodes_for_capture_index( + &self, + capture_ix: u32, + ) -> impl Iterator> + '_ { + self.captures.iter().filter_map(move |capture| { + if capture.index == capture_ix { + Some(capture.node) + } else { + None + } + }) + } + fn new(m: ffi::TSQueryMatch, cursor: *mut ffi::TSQueryCursor) -> Self { QueryMatch { cursor, @@ -1768,33 +1781,24 @@ impl<'a, 'tree> QueryMatch<'a, 'tree> { .iter() .all(|predicate| match predicate { TextPredicate::CaptureEqCapture(i, j, is_positive) => { - let node1 = self.capture_for_index(*i).unwrap(); - let node2 = self.capture_for_index(*j).unwrap(); + let node1 = self.nodes_for_capture_index(*i).next().unwrap(); + let node2 = self.nodes_for_capture_index(*j).next().unwrap(); let text1 = get_text(buffer1, text_provider.text(node1)); let text2 = get_text(buffer2, text_provider.text(node2)); (text1 == text2) == *is_positive } TextPredicate::CaptureEqString(i, s, is_positive) => { - let node = self.capture_for_index(*i).unwrap(); + let node = self.nodes_for_capture_index(*i).next().unwrap(); let text = get_text(buffer1, text_provider.text(node)); (text == s.as_bytes()) == *is_positive } TextPredicate::CaptureMatchString(i, r, is_positive) => { - let node = self.capture_for_index(*i).unwrap(); + let node = self.nodes_for_capture_index(*i).next().unwrap(); let text = get_text(buffer1, text_provider.text(node)); r.is_match(text) == *is_positive } }) } - - fn capture_for_index(&self, capture_index: u32) -> Option> { - for c in self.captures { - if c.index == capture_index { - return Some(c.node); - } - } - None - } } impl QueryProperty {