rust: Change QueryCursor::captures to expose the full match

This commit is contained in:
Max Brunsfeld 2019-10-03 12:45:58 -07:00
parent 3e040b8951
commit 9872a083b7
7 changed files with 172 additions and 104 deletions

View file

@ -34,12 +34,15 @@ pub fn query_files_at_paths(
let tree = parser.parse(&source_code, None).unwrap();
if ordered_captures {
for (pattern_index, capture) in query_cursor.captures(&query, tree.root_node(), text_callback) {
for (mat, capture_index) in
query_cursor.captures(&query, tree.root_node(), text_callback)
{
let capture = mat.captures[capture_index];
writeln!(
&mut stdout,
" pattern: {}, capture: {}, row: {}, text: {:?}",
pattern_index,
&query.capture_names()[capture.index],
mat.pattern_index,
&query.capture_names()[capture.index as usize],
capture.node.start_position().row,
capture.node.utf8_text(&source_code).unwrap_or("")
)?;
@ -47,11 +50,11 @@ pub fn query_files_at_paths(
} else {
for m in query_cursor.matches(&query, tree.root_node(), text_callback) {
writeln!(&mut stdout, " pattern: {}", m.pattern_index)?;
for capture in m.captures() {
for capture in m.captures {
writeln!(
&mut stdout,
" capture: {}, row: {}, text: {:?}",
&query.capture_names()[capture.index],
&query.capture_names()[capture.index as usize],
capture.node.start_position().row,
capture.node.utf8_text(&source_code).unwrap_or("")
)?;

View file

@ -874,6 +874,45 @@ fn test_query_captures_ordered_by_both_start_and_end_positions() {
});
}
#[test]
fn test_query_captures_with_matches_removed() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
(binary_expression
left: (identifier) @left
operator: * @op
right: (identifier) @right)
"#,
)
.unwrap();
let source = "
a === b && c > d && e < f;
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let mut captured_strings = Vec::new();
for (m, i) in cursor.captures(&query, tree.root_node(), to_callback(source)) {
let capture = m.captures[i];
let text = capture.node.utf8_text(source.as_bytes()).unwrap();
if text == "a" {
m.remove();
continue;
}
captured_strings.push(text);
}
assert_eq!(captured_strings, &["c", ">", "d", "e", "<", "f",]);
});
}
#[test]
fn test_query_start_byte_for_pattern() {
let language = get_language("javascript");
@ -985,22 +1024,30 @@ fn collect_matches<'a>(
.map(|m| {
(
m.pattern_index,
collect_captures(m.captures().map(|c| (m.pattern_index, c)), query, source),
format_captures(m.captures.iter().cloned(), query, source),
)
})
.collect()
}
fn collect_captures<'a, 'b>(
captures: impl Iterator<Item = (usize, QueryCapture<'a>)>,
query: &'b Query,
source: &'b str,
) -> Vec<(&'b str, &'b str)> {
fn collect_captures<'a>(
captures: impl Iterator<Item = (QueryMatch<'a>, usize)>,
query: &'a Query,
source: &'a str,
) -> Vec<(&'a str, &'a str)> {
format_captures(captures.map(|(m, i)| m.captures[i]), query, source)
}
fn format_captures<'a>(
captures: impl Iterator<Item = QueryCapture<'a>>,
query: &'a Query,
source: &'a str,
) -> Vec<(&'a str, &'a str)> {
captures
.map(|(_, QueryCapture { index, node })| {
.map(|capture| {
(
query.capture_names()[index].as_str(),
node.utf8_text(source.as_bytes()).unwrap(),
query.capture_names()[capture.index as usize].as_str(),
capture.node.utf8_text(source.as_bytes()).unwrap(),
)
})
.collect()

4
lib/.ccls Normal file
View file

@ -0,0 +1,4 @@
-std=c99
-Isrc
-Iinclude
-Iutf8proc

View file

@ -701,6 +701,9 @@ extern "C" {
#[doc = " Otherwise, return `false`."]
pub fn ts_query_cursor_next_match(arg1: *mut TSQueryCursor, match_: *mut TSQueryMatch) -> bool;
}
extern "C" {
pub fn ts_query_cursor_remove_match(arg1: *mut TSQueryCursor, id: u32);
}
extern "C" {
#[doc = " Advance to the next capture of the currently running query."]
#[doc = ""]

View file

@ -122,6 +122,7 @@ pub struct PropertySheetJSON<P> {
}
#[derive(Clone, Copy)]
#[repr(transparent)]
pub struct Node<'a>(ffi::TSNode, PhantomData<&'a ()>);
pub struct Parser(NonNull<ffi::TSParser>);
@ -163,15 +164,19 @@ pub struct Query {
pub struct QueryCursor(NonNull<ffi::TSQueryCursor>);
#[derive(Clone)]
pub struct QueryMatch<'a> {
pub pattern_index: usize,
captures: &'a [ffi::TSQueryCapture],
pub captures: &'a [QueryCapture<'a>],
id: u32,
cursor: *mut ffi::TSQueryCursor,
}
#[derive(Clone)]
#[derive(Clone, Copy)]
#[repr(C)]
pub struct QueryCapture<'a> {
pub index: usize,
pub node: Node<'a>,
pub index: u32,
}
#[derive(Debug, PartialEq, Eq)]
@ -1244,16 +1249,6 @@ impl Query {
}
}
impl QueryProperty {
pub fn new(key: &str, value: Option<&str>, capture_id: Option<usize>) -> Self {
QueryProperty {
capture_id,
key: key.to_string().into_boxed_str(),
value: value.map(|s| s.to_string().into_boxed_str()),
}
}
}
impl QueryCursor {
pub fn new() -> Self {
QueryCursor(unsafe { NonNull::new_unchecked(ffi::ts_query_cursor_new()) })
@ -1267,27 +1262,16 @@ impl QueryCursor {
) -> impl Iterator<Item = QueryMatch<'a>> + 'a {
let ptr = self.0.as_ptr();
unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) };
std::iter::from_fn(move || -> Option<QueryMatch<'a>> {
loop {
unsafe {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_match(ptr, m.as_mut_ptr()) {
let m = m.assume_init();
let captures = slice::from_raw_parts(m.captures, m.capture_count as usize);
if Self::captures_match_text_predicates(
query,
captures,
m.pattern_index as usize,
&mut text_callback,
) {
return Some(QueryMatch {
pattern_index: m.pattern_index as usize,
captures,
});
}
} else {
return None;
std::iter::from_fn(move || loop {
unsafe {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_match(ptr, m.as_mut_ptr()) {
let result = QueryMatch::new(m.assume_init(), ptr);
if result.satisfies_text_predicates(query, &mut text_callback) {
return Some(result);
}
} else {
return None;
}
}
})
@ -1298,34 +1282,23 @@ impl QueryCursor {
query: &'a Query,
node: Node<'a>,
mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a,
) -> impl Iterator<Item = (usize, QueryCapture)> + 'a {
) -> impl Iterator<Item = (QueryMatch<'a>, usize)> + 'a {
let ptr = self.0.as_ptr();
unsafe { ffi::ts_query_cursor_exec(ptr, query.ptr.as_ptr(), node.0) };
std::iter::from_fn(move || loop {
unsafe {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
let mut capture_index = 0u32;
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_capture(
ptr,
m.as_mut_ptr(),
&mut capture_index as *mut u32,
) {
let m = m.assume_init();
let captures = slice::from_raw_parts(m.captures, m.capture_count as usize);
if Self::captures_match_text_predicates(
query,
captures,
m.pattern_index as usize,
&mut text_callback,
) {
let capture = captures[capture_index as usize];
return Some((
m.pattern_index as usize,
QueryCapture {
index: capture.index as usize,
node: Node::new(capture.node).unwrap(),
},
));
let result = QueryMatch::new(m.assume_init(), ptr);
if result.satisfies_text_predicates(query, &mut text_callback) {
return Some((result, capture_index as usize));
} else {
result.remove();
}
} else {
return None;
@ -1334,40 +1307,6 @@ impl QueryCursor {
})
}
fn captures_match_text_predicates<'a>(
query: &'a Query,
captures: &'a [ffi::TSQueryCapture],
pattern_index: usize,
text_callback: &mut impl FnMut(Node<'a>) -> &'a [u8],
) -> bool {
query.text_predicates[pattern_index]
.iter()
.all(|predicate| match predicate {
TextPredicate::CaptureEqCapture(i, j) => {
let node1 = Self::capture_for_id(captures, *i).unwrap();
let node2 = Self::capture_for_id(captures, *j).unwrap();
text_callback(node1) == text_callback(node2)
}
TextPredicate::CaptureEqString(i, s) => {
let node = Self::capture_for_id(captures, *i).unwrap();
text_callback(node) == s.as_bytes()
}
TextPredicate::CaptureMatchString(i, r) => {
let node = Self::capture_for_id(captures, *i).unwrap();
r.is_match(text_callback(node))
}
})
}
fn capture_for_id(captures: &[ffi::TSQueryCapture], capture_id: u32) -> Option<Node> {
for c in captures {
if c.index == capture_id {
return Node::new(c.node);
}
}
None
}
pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self {
unsafe {
ffi::ts_query_cursor_set_byte_range(self.0.as_ptr(), start as u32, end as u32);
@ -1384,11 +1323,65 @@ impl QueryCursor {
}
impl<'a> QueryMatch<'a> {
pub fn captures(&self) -> impl ExactSizeIterator<Item = QueryCapture> {
self.captures.iter().map(|capture| QueryCapture {
index: capture.index as usize,
node: Node::new(capture.node).unwrap(),
})
pub fn remove(self) {
unsafe { ffi::ts_query_cursor_remove_match(self.cursor, self.id) }
}
fn new(m: ffi::TSQueryMatch, cursor: *mut ffi::TSQueryCursor) -> Self {
QueryMatch {
cursor,
id: m.id,
pattern_index: m.pattern_index as usize,
captures: unsafe {
slice::from_raw_parts(
m.captures as *const QueryCapture<'a>,
m.capture_count as usize,
)
},
}
}
fn satisfies_text_predicates(
&self,
query: &Query,
text_callback: &mut impl FnMut(Node<'a>) -> &[u8],
) -> bool {
query.text_predicates[self.pattern_index]
.iter()
.all(|predicate| match predicate {
TextPredicate::CaptureEqCapture(i, j) => {
let node1 = self.capture_for_index(*i).unwrap();
let node2 = self.capture_for_index(*j).unwrap();
text_callback(node1) == text_callback(node2)
}
TextPredicate::CaptureEqString(i, s) => {
let node = self.capture_for_index(*i).unwrap();
text_callback(node) == s.as_bytes()
}
TextPredicate::CaptureMatchString(i, r) => {
let node = self.capture_for_index(*i).unwrap();
r.is_match(text_callback(node))
}
})
}
fn capture_for_index(&self, capture_index: u32) -> Option<Node<'a>> {
for c in self.captures {
if c.index == capture_index {
return Some(c.node);
}
}
None
}
}
impl QueryProperty {
pub fn new(key: &str, value: Option<&str>, capture_id: Option<usize>) -> Self {
QueryProperty {
capture_id,
key: key.to_string().into_boxed_str(),
value: value.map(|s| s.to_string().into_boxed_str()),
}
}
}

View file

@ -764,6 +764,7 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint);
* Otherwise, return `false`.
*/
bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match);
void ts_query_cursor_remove_match(TSQueryCursor *, uint32_t id);
/**
* Advance to the next capture of the currently running query.

View file

@ -1234,6 +1234,23 @@ bool ts_query_cursor_next_match(
return true;
}
void ts_query_cursor_remove_match(
TSQueryCursor *self,
uint32_t match_id
) {
for (unsigned i = 0; i < self->finished_states.size; i++) {
const QueryState *state = &self->finished_states.contents[i];
if (state->id == match_id) {
capture_list_pool_release(
&self->capture_list_pool,
state->capture_list_id
);
array_erase(&self->finished_states, i);
return;
}
}
}
bool ts_query_cursor_next_capture(
TSQueryCursor *self,
TSQueryMatch *match,