rust: Generalize text_callback type for query methods

This commit is contained in:
Max Brunsfeld 2019-09-26 15:53:40 -07:00
parent 9872a083b7
commit 015be68c94

View file

@ -24,7 +24,7 @@ use std::{char, fmt, ptr, slice, str, u16};
pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION;
pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h");
#[derive(Clone, Copy, PartialEq, Eq)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(transparent)]
pub struct Language(*const ffi::TSLanguage);
@ -164,7 +164,6 @@ pub struct Query {
pub struct QueryCursor(NonNull<ffi::TSQueryCursor>);
#[derive(Clone)]
pub struct QueryMatch<'a> {
pub pattern_index: usize,
pub captures: &'a [QueryCapture<'a>],
@ -172,6 +171,12 @@ pub struct QueryMatch<'a> {
cursor: *mut ffi::TSQueryCursor,
}
pub struct QueryCaptures<'a, T: AsRef<[u8]>> {
ptr: *mut ffi::TSQueryCursor,
query: &'a Query,
text_callback: Box<dyn FnMut(Node<'a>) -> T + 'a>,
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct QueryCapture<'a> {
@ -1258,7 +1263,7 @@ impl QueryCursor {
&'a mut self,
query: &'a Query,
node: Node<'a>,
mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a,
mut text_callback: impl FnMut(Node<'a>) -> &[u8] + 'a,
) -> impl Iterator<Item = QueryMatch<'a>> + 'a {
let ptr = self.0.as_ptr();
unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) };
@ -1277,34 +1282,19 @@ impl QueryCursor {
})
}
pub fn captures<'a>(
pub fn captures<'a, T: AsRef<[u8]>>(
&'a mut self,
query: &'a Query,
node: Node<'a>,
mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a,
) -> impl Iterator<Item = (QueryMatch<'a>, usize)> + 'a {
text_callback: impl FnMut(Node<'a>) -> T + 'a,
) -> QueryCaptures<'a, T> {
let ptr = self.0.as_ptr();
unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) };
std::iter::from_fn(move || loop {
unsafe {
let mut capture_index = 0u32;
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_capture(
ptr,
m.as_mut_ptr(),
&mut capture_index as *mut u32,
) {
let result = QueryMatch::new(m.assume_init(), ptr);
if result.satisfies_text_predicates(query, &mut text_callback) {
return Some((result, capture_index as usize));
} else {
result.remove();
}
} else {
return None;
}
}
})
QueryCaptures {
ptr,
query,
text_callback: Box::new(text_callback),
}
}
pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self {
@ -1341,10 +1331,10 @@ impl<'a> QueryMatch<'a> {
}
}
fn satisfies_text_predicates(
fn satisfies_text_predicates<T: AsRef<[u8]>>(
&self,
query: &Query,
text_callback: &mut impl FnMut(Node<'a>) -> &[u8],
text_callback: &mut impl FnMut(Node<'a>) -> T,
) -> bool {
query.text_predicates[self.pattern_index]
.iter()
@ -1352,15 +1342,15 @@ impl<'a> QueryMatch<'a> {
TextPredicate::CaptureEqCapture(i, j) => {
let node1 = self.capture_for_index(*i).unwrap();
let node2 = self.capture_for_index(*j).unwrap();
text_callback(node1) == text_callback(node2)
text_callback(node1).as_ref() == text_callback(node2).as_ref()
}
TextPredicate::CaptureEqString(i, s) => {
let node = self.capture_for_index(*i).unwrap();
text_callback(node) == s.as_bytes()
text_callback(node).as_ref() == s.as_bytes()
}
TextPredicate::CaptureMatchString(i, r) => {
let node = self.capture_for_index(*i).unwrap();
r.is_match(text_callback(node))
r.is_match(text_callback(node).as_ref())
}
})
}
@ -1385,6 +1375,33 @@ impl QueryProperty {
}
}
impl<'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, T> {
type Item = (QueryMatch<'a>, usize);
fn next(&mut self) -> Option<Self::Item> {
loop {
unsafe {
let mut capture_index = 0u32;
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_capture(
self.ptr,
m.as_mut_ptr(),
&mut capture_index as *mut u32,
) {
let result = QueryMatch::new(m.assume_init(), self.ptr);
if result.satisfies_text_predicates(self.query, &mut self.text_callback) {
return Some((result, capture_index as usize));
} else {
result.remove();
}
} else {
return None;
}
}
}
}
}
impl PartialEq for Query {
fn eq(&self, other: &Self) -> bool {
self.ptr == other.ptr