Allow QueryCursor's text callback to return an iterator

This commit is contained in:
Max Brunsfeld 2021-05-23 15:12:24 -07:00
parent 0e445c47fa
commit 8c3d1466ec
5 changed files with 272 additions and 111 deletions

View file

@ -102,7 +102,9 @@ pub struct Query {
}
/// A stateful object for executing a `Query` on a syntax `Tree`.
pub struct QueryCursor(NonNull<ffi::TSQueryCursor>);
pub struct QueryCursor {
ptr: NonNull<ffi::TSQueryCursor>,
}
/// A key-value pair associated with a particular pattern in a `Query`.
#[derive(Debug, PartialEq, Eq)]
@ -126,18 +128,36 @@ pub struct QueryPredicate {
}
/// A match of a `Query` to a particular set of `Node`s.
pub struct QueryMatch<'a> {
pub struct QueryMatch<'cursor, 'tree> {
pub pattern_index: usize,
pub captures: &'a [QueryCapture<'a>],
pub captures: &'cursor [QueryCapture<'tree>],
id: u32,
cursor: *mut ffi::TSQueryCursor,
}
/// A sequence of `QueryCapture`s within a `QueryMatch`.
pub struct QueryCaptures<'a, 'tree: 'a, T: AsRef<[u8]>> {
/// A sequence of `QueryMatch`es associated with a given `QueryCursor`.
pub struct QueryMatches<'a, 'tree: 'a, T: TextProvider<'a>> {
ptr: *mut ffi::TSQueryCursor,
query: &'a Query,
text_callback: Box<dyn FnMut(Node<'tree>) -> T + 'a>,
text_provider: T,
buffer1: Vec<u8>,
buffer2: Vec<u8>,
_tree: PhantomData<&'tree ()>,
}
/// A sequence of `QueryCapture`s associated with a given `QueryCursor`.
pub struct QueryCaptures<'a, 'tree: 'a, T: TextProvider<'a>> {
ptr: *mut ffi::TSQueryCursor,
query: &'a Query,
text_provider: T,
buffer1: Vec<u8>,
buffer2: Vec<u8>,
_tree: PhantomData<&'tree ()>,
}
pub trait TextProvider<'a> {
type I: Iterator<Item = &'a [u8]> + 'a;
fn text(&mut self, node: Node) -> Self::I;
}
/// A particular `Node` that has been captured with a particular name within a `Query`.
@ -178,6 +198,11 @@ pub enum QueryErrorKind {
Structure,
}
trait TextCallback<'a> {
fn call(&mut self, node: Node);
fn next_chunk(&mut self) -> Option<&'a [u8]>;
}
#[derive(Debug)]
enum TextPredicate {
CaptureEqString(u32, String, bool),
@ -1590,18 +1615,20 @@ impl Query {
}
}
impl<'a> QueryCursor {
impl QueryCursor {
/// Create a new cursor for executing a given query.
///
/// The cursor stores the state that is needed to iteratively search for matches.
pub fn new() -> Self {
QueryCursor(unsafe { NonNull::new_unchecked(ffi::ts_query_cursor_new()) })
QueryCursor {
ptr: unsafe { NonNull::new_unchecked(ffi::ts_query_cursor_new()) },
}
}
/// Check if, on its last execution, this cursor exceeded its maximum number of
/// in-progress matches.
pub fn did_exceed_match_limit(&self) -> bool {
unsafe { ffi::ts_query_cursor_did_exceed_match_limit(self.0.as_ptr()) }
unsafe { ffi::ts_query_cursor_did_exceed_match_limit(self.ptr.as_ptr()) }
}
/// Iterate over all of the matches in the order that they were found.
@ -1609,52 +1636,50 @@ impl<'a> QueryCursor {
/// Each match contains the index of the pattern that matched, and a list of captures.
/// Because multiple patterns can match the same set of nodes, one match may contain
/// captures that appear *before* some of the captures from a previous match.
pub fn matches<'tree: 'a, T: AsRef<[u8]>>(
pub fn matches<'a, 'tree: 'a, T: TextProvider<'a> + 'a>(
&'a mut self,
query: &'a Query,
node: Node<'tree>,
mut text_callback: impl FnMut(Node<'tree>) -> T + 'a,
) -> impl Iterator<Item = QueryMatch<'tree>> + 'a {
let ptr = self.0.as_ptr();
text_provider: T,
) -> QueryMatches<'a, 'tree, T> {
let ptr = self.ptr.as_ptr();
unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) };
std::iter::from_fn(move || loop {
unsafe {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_match(ptr, m.as_mut_ptr()) {
let result = QueryMatch::new(m.assume_init(), ptr);
if result.satisfies_text_predicates(query, &mut text_callback) {
return Some(result);
}
} else {
return None;
}
}
})
QueryMatches {
ptr,
query,
text_provider,
buffer1: Default::default(),
buffer2: Default::default(),
_tree: PhantomData,
}
}
/// Iterate over all of the individual captures in the order that they appear.
///
/// This is useful if don't care about which pattern matched, and just want a single,
/// ordered sequence of captures.
pub fn captures<'tree, T: AsRef<[u8]>>(
pub fn captures<'a, 'tree: 'a, T: TextProvider<'a> + 'a>(
&'a mut self,
query: &'a Query,
node: Node<'tree>,
text_callback: impl FnMut(Node<'tree>) -> T + 'a,
text_provider: T,
) -> QueryCaptures<'a, 'tree, T> {
let ptr = self.0.as_ptr();
unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) };
let ptr = self.ptr.as_ptr();
unsafe { ffi::ts_query_cursor_exec(self.ptr.as_ptr(), query.ptr.as_ptr(), node.0) };
QueryCaptures {
ptr,
query,
text_callback: Box::new(text_callback),
text_provider,
buffer1: Default::default(),
buffer2: Default::default(),
_tree: PhantomData,
}
}
/// Set the range in which the query will be executed, in terms of byte offsets.
pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self {
unsafe {
ffi::ts_query_cursor_set_byte_range(self.0.as_ptr(), start as u32, end as u32);
ffi::ts_query_cursor_set_byte_range(self.ptr.as_ptr(), start as u32, end as u32);
}
self
}
@ -1662,13 +1687,13 @@ impl<'a> QueryCursor {
/// Set the range in which the query will be executed, in terms of rows and columns.
pub fn set_point_range(&mut self, start: Point, end: Point) -> &mut Self {
unsafe {
ffi::ts_query_cursor_set_point_range(self.0.as_ptr(), start.into(), end.into());
ffi::ts_query_cursor_set_point_range(self.ptr.as_ptr(), start.into(), end.into());
}
self
}
}
impl<'a> QueryMatch<'a> {
impl<'a, 'tree> QueryMatch<'a, 'tree> {
pub fn remove(self) {
unsafe { ffi::ts_query_cursor_remove_match(self.cursor, self.id) }
}
@ -1681,7 +1706,7 @@ impl<'a> QueryMatch<'a> {
captures: if m.capture_count > 0 {
unsafe {
slice::from_raw_parts(
m.captures as *const QueryCapture<'a>,
m.captures as *const QueryCapture<'tree>,
m.capture_count as usize,
)
}
@ -1691,31 +1716,55 @@ impl<'a> QueryMatch<'a> {
}
}
fn satisfies_text_predicates<T: AsRef<[u8]>>(
fn satisfies_text_predicates(
&self,
query: &Query,
text_callback: &mut impl FnMut(Node<'a>) -> T,
buffer1: &mut Vec<u8>,
buffer2: &mut Vec<u8>,
text_provider: &mut impl TextProvider<'a>,
) -> bool {
fn get_text<'a, 'b: 'a, I: Iterator<Item = &'b [u8]>>(
buffer: &'a mut Vec<u8>,
mut chunks: I,
) -> &'a [u8] {
let first_chunk = chunks.next().unwrap_or(&[]);
if let Some(next_chunk) = chunks.next() {
buffer.clear();
buffer.extend_from_slice(first_chunk);
buffer.extend_from_slice(next_chunk);
for chunk in chunks {
buffer.extend_from_slice(chunk);
}
buffer.as_slice()
} else {
first_chunk
}
}
query.text_predicates[self.pattern_index]
.iter()
.all(|predicate| match predicate {
TextPredicate::CaptureEqCapture(i, j, is_positive) => {
let node1 = self.capture_for_index(*i).unwrap();
let node2 = self.capture_for_index(*j).unwrap();
(text_callback(node1).as_ref() == text_callback(node2).as_ref()) == *is_positive
let text1 = get_text(buffer1, text_provider.text(node1));
let text2 = get_text(buffer2, text_provider.text(node2));
(text1 == text2) == *is_positive
}
TextPredicate::CaptureEqString(i, s, is_positive) => {
let node = self.capture_for_index(*i).unwrap();
(text_callback(node).as_ref() == s.as_bytes()) == *is_positive
let text = get_text(buffer1, text_provider.text(node));
(text == s.as_bytes()) == *is_positive
}
TextPredicate::CaptureMatchString(i, r, is_positive) => {
let node = self.capture_for_index(*i).unwrap();
r.is_match(text_callback(node).as_ref()) == *is_positive
let text = get_text(buffer1, text_provider.text(node));
r.is_match(text) == *is_positive
}
})
}
fn capture_for_index(&self, capture_index: u32) -> Option<Node<'a>> {
fn capture_for_index(&self, capture_index: u32) -> Option<Node<'tree>> {
for c in self.captures {
if c.index == capture_index {
return Some(c.node);
@ -1735,12 +1784,37 @@ impl QueryProperty {
}
}
impl<'a, 'tree: 'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, 'tree, T> {
type Item = (QueryMatch<'tree>, usize);
impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> {
type Item = QueryMatch<'a, 'tree>;
fn next(&mut self) -> Option<Self::Item> {
loop {
unsafe {
unsafe {
loop {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_match(self.ptr, m.as_mut_ptr()) {
let result = QueryMatch::new(m.assume_init(), self.ptr);
if result.satisfies_text_predicates(
self.query,
&mut self.buffer1,
&mut self.buffer2,
&mut self.text_provider,
) {
return Some(result);
}
} else {
return None;
}
}
}
}
}
impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryCaptures<'a, 'tree, T> {
type Item = (QueryMatch<'a, 'tree>, usize);
fn next(&mut self) -> Option<Self::Item> {
unsafe {
loop {
let mut capture_index = 0u32;
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_capture(
@ -1749,7 +1823,12 @@ impl<'a, 'tree: 'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, 'tree, T> {
&mut capture_index as *mut u32,
) {
let result = QueryMatch::new(m.assume_init(), self.ptr);
if result.satisfies_text_predicates(self.query, &mut self.text_callback) {
if result.satisfies_text_predicates(
self.query,
&mut self.buffer1,
&mut self.buffer2,
&mut self.text_provider,
) {
return Some((result, capture_index as usize));
} else {
result.remove();
@ -1762,7 +1841,7 @@ impl<'a, 'tree: 'a, T: AsRef<[u8]>> Iterator for QueryCaptures<'a, 'tree, T> {
}
}
impl<'a> fmt::Debug for QueryMatch<'a> {
impl<'cursor, 'tree> fmt::Debug for QueryMatch<'cursor, 'tree> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
@ -1772,6 +1851,26 @@ impl<'a> fmt::Debug for QueryMatch<'a> {
}
}
impl<'a, F, I> TextProvider<'a> for F
where
F: FnMut(Node) -> I,
I: Iterator<Item = &'a [u8]> + 'a,
{
type I = I;
fn text(&mut self, node: Node) -> Self::I {
(self)(node)
}
}
impl<'a> TextProvider<'a> for &'a [u8] {
type I = std::option::IntoIter<&'a [u8]>;
fn text(&mut self, node: Node) -> Self::I {
Some(&self[node.byte_range()]).into_iter()
}
}
impl PartialEq for Query {
fn eq(&self, other: &Self) -> bool {
self.ptr == other.ptr
@ -1786,7 +1885,7 @@ impl Drop for Query {
impl Drop for QueryCursor {
fn drop(&mut self) {
unsafe { ffi::ts_query_cursor_delete(self.0.as_ptr()) }
unsafe { ffi::ts_query_cursor_delete(self.ptr.as_ptr()) }
}
}