From 015be68c9423e97b17df68c4a35beb5a2f1c36bd Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 26 Sep 2019 15:53:40 -0700 Subject: [PATCH] rust: Generalize text_callback type for query methods --- lib/binding_rust/lib.rs | 79 +++++++++++++++++++++++++---------------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 6dd398e2..d824f964 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -24,7 +24,7 @@ use std::{char, fmt, ptr, slice, str, u16}; pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION; pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h"); -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[repr(transparent)] pub struct Language(*const ffi::TSLanguage); @@ -164,7 +164,6 @@ pub struct Query { pub struct QueryCursor(NonNull); -#[derive(Clone)] pub struct QueryMatch<'a> { pub pattern_index: usize, pub captures: &'a [QueryCapture<'a>], @@ -172,6 +171,12 @@ pub struct QueryMatch<'a> { cursor: *mut ffi::TSQueryCursor, } +pub struct QueryCaptures<'a, T: AsRef<[u8]>> { + ptr: *mut ffi::TSQueryCursor, + query: &'a Query, + text_callback: Box) -> T + 'a>, +} + #[derive(Clone, Copy)] #[repr(C)] pub struct QueryCapture<'a> { @@ -1258,7 +1263,7 @@ impl QueryCursor { &'a mut self, query: &'a Query, node: Node<'a>, - mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a, + mut text_callback: impl FnMut(Node<'a>) -> &[u8] + 'a, ) -> impl Iterator> + 'a { let ptr = self.0.as_ptr(); unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) }; @@ -1277,34 +1282,19 @@ impl QueryCursor { }) } - pub fn captures<'a>( + pub fn captures<'a, T: AsRef<[u8]>>( &'a mut self, query: &'a Query, node: Node<'a>, - mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a, - ) -> impl Iterator, usize)> + 'a { + text_callback: impl FnMut(Node<'a>) -> T + 'a, + ) -> QueryCaptures<'a, T> { let ptr = self.0.as_ptr(); unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) }; - std::iter::from_fn(move || loop { - unsafe { - let mut capture_index = 0u32; - let mut m = MaybeUninit::::uninit(); - if ffi::ts_query_cursor_next_capture( - ptr, - m.as_mut_ptr(), - &mut capture_index as *mut u32, - ) { - let result = QueryMatch::new(m.assume_init(), ptr); - if result.satisfies_text_predicates(query, &mut text_callback) { - return Some((result, capture_index as usize)); - } else { - result.remove(); - } - } else { - return None; - } - } - }) + QueryCaptures { + ptr, + query, + text_callback: Box::new(text_callback), + } } pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self { @@ -1341,10 +1331,10 @@ impl<'a> QueryMatch<'a> { } } - fn satisfies_text_predicates( + fn satisfies_text_predicates>( &self, query: &Query, - text_callback: &mut impl FnMut(Node<'a>) -> &[u8], + text_callback: &mut impl FnMut(Node<'a>) -> T, ) -> bool { query.text_predicates[self.pattern_index] .iter() @@ -1352,15 +1342,15 @@ impl<'a> QueryMatch<'a> { TextPredicate::CaptureEqCapture(i, j) => { let node1 = self.capture_for_index(*i).unwrap(); let node2 = self.capture_for_index(*j).unwrap(); - text_callback(node1) == text_callback(node2) + text_callback(node1).as_ref() == text_callback(node2).as_ref() } TextPredicate::CaptureEqString(i, s) => { let node = self.capture_for_index(*i).unwrap(); - text_callback(node) == s.as_bytes() + text_callback(node).as_ref() == s.as_bytes() } TextPredicate::CaptureMatchString(i, r) => { let node = self.capture_for_index(*i).unwrap(); - r.is_match(text_callback(node)) + r.is_match(text_callback(node).as_ref()) } }) } @@ -1385,6 +1375,33 @@ impl QueryProperty { } } +impl<'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, T> { + type Item = (QueryMatch<'a>, usize); + + fn next(&mut self) -> Option { + loop { + unsafe { + let mut capture_index = 0u32; + let mut m = MaybeUninit::::uninit(); + if ffi::ts_query_cursor_next_capture( + self.ptr, + m.as_mut_ptr(), + &mut capture_index as *mut u32, + ) { + let result = QueryMatch::new(m.assume_init(), self.ptr); + if result.satisfies_text_predicates(self.query, &mut self.text_callback) { + return Some((result, capture_index as usize)); + } else { + result.remove(); + } + } else { + return None; + } + } + } + } +} + impl PartialEq for Query { fn eq(&self, other: &Self) -> bool { self.ptr == other.ptr