feat!: implement StreamingIterator instead of Iterator for QueryMatches and QueryCaptures

This fixes UB when either `QueryMatches` or `QueryCaptures` had collect called on it.

Co-authored-by: Amaan Qureshi <amaanq12@gmail.com>
This commit is contained in:
Lukas Seidel 2024-09-29 23:34:48 +02:00 committed by GitHub
parent 12007d3ebe
commit 6b1ebd3d29
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 271 additions and 105 deletions

View file

@ -1,18 +1,23 @@
#![doc = include_str!("../README.md")]
pub mod c_lib;
use core::slice;
use std::{
collections::HashSet,
iter, mem, ops, str,
iter,
marker::PhantomData,
mem::{self, MaybeUninit},
ops, str,
sync::atomic::{AtomicUsize, Ordering},
};
pub use c_lib as c;
use lazy_static::lazy_static;
use streaming_iterator::StreamingIterator;
use thiserror::Error;
use tree_sitter::{
Language, LossyUtf8, Node, Parser, Point, Query, QueryCaptures, QueryCursor, QueryError,
QueryMatch, Range, Tree,
ffi, Language, LossyUtf8, Node, Parser, Point, Query, QueryCapture, QueryCaptures, QueryCursor,
QueryError, QueryMatch, Range, TextProvider, Tree,
};
const CANCELLATION_CHECK_INTERVAL: usize = 100;
@ -171,7 +176,7 @@ where
struct HighlightIterLayer<'a> {
_tree: Tree,
cursor: QueryCursor,
captures: iter::Peekable<QueryCaptures<'a, 'a, &'a [u8], &'a [u8]>>,
captures: iter::Peekable<_QueryCaptures<'a, 'a, &'a [u8], &'a [u8]>>,
config: &'a HighlightConfiguration,
highlight_end_stack: Vec<usize>,
scope_stack: Vec<LocalScope<'a>>,
@ -179,6 +184,77 @@ struct HighlightIterLayer<'a> {
depth: usize,
}
pub struct _QueryCaptures<'query, 'tree: 'query, T: TextProvider<I>, I: AsRef<[u8]>> {
ptr: *mut ffi::TSQueryCursor,
query: &'query Query,
text_provider: T,
buffer1: Vec<u8>,
buffer2: Vec<u8>,
_current_match: Option<(QueryMatch<'query, 'tree>, usize)>,
_phantom: PhantomData<(&'tree (), I)>,
}
struct _QueryMatch<'cursor, 'tree> {
pub _pattern_index: usize,
pub _captures: &'cursor [QueryCapture<'tree>],
_id: u32,
_cursor: *mut ffi::TSQueryCursor,
}
impl<'tree> _QueryMatch<'_, 'tree> {
fn new(m: &ffi::TSQueryMatch, cursor: *mut ffi::TSQueryCursor) -> Self {
_QueryMatch {
_cursor: cursor,
_id: m.id,
_pattern_index: m.pattern_index as usize,
_captures: (m.capture_count > 0)
.then(|| unsafe {
slice::from_raw_parts(
m.captures.cast::<QueryCapture<'tree>>(),
m.capture_count as usize,
)
})
.unwrap_or_default(),
}
}
}
impl<'query, 'tree: 'query, T: TextProvider<I>, I: AsRef<[u8]>> Iterator
for _QueryCaptures<'query, 'tree, T, I>
{
type Item = (QueryMatch<'query, 'tree>, usize);
fn next(&mut self) -> Option<Self::Item> {
unsafe {
loop {
let mut capture_index = 0u32;
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_capture(
self.ptr,
m.as_mut_ptr(),
core::ptr::addr_of_mut!(capture_index),
) {
let result = std::mem::transmute::<_QueryMatch, QueryMatch>(_QueryMatch::new(
&m.assume_init(),
self.ptr,
));
if result.satisfies_text_predicates(
self.query,
&mut self.buffer1,
&mut self.buffer2,
&mut self.text_provider,
) {
return Some((result, capture_index as usize));
}
result.remove();
} else {
return None;
}
}
}
}
}
impl Default for Highlighter {
fn default() -> Self {
Self::new()
@ -456,15 +532,15 @@ impl<'a> HighlightIterLayer<'a> {
if let Some(combined_injections_query) = &config.combined_injections_query {
let mut injections_by_pattern_index =
vec![(None, Vec::new(), false); combined_injections_query.pattern_count()];
let matches =
let mut matches =
cursor.matches(combined_injections_query, tree.root_node(), source);
for mat in matches {
while let Some(mat) = matches.next() {
let entry = &mut injections_by_pattern_index[mat.pattern_index];
let (language_name, content_node, include_children) = injection_for_match(
config,
parent_name,
combined_injections_query,
&mat,
mat,
source,
);
if language_name.is_some() {
@ -499,9 +575,12 @@ impl<'a> HighlightIterLayer<'a> {
let cursor_ref = unsafe {
mem::transmute::<&mut QueryCursor, &'static mut QueryCursor>(&mut cursor)
};
let captures = cursor_ref
.captures(&config.query, tree_ref.root_node(), source)
.peekable();
let captures = unsafe {
std::mem::transmute::<QueryCaptures<_, _>, _QueryCaptures<_, _>>(
cursor_ref.captures(&config.query, tree_ref.root_node(), source),
)
}
.peekable();
result.push(HighlightIterLayer {
highlight_end_stack: Vec::new(),