Stop matching new patterns past the end of QueryCursor's range
This restores the original signatures of the `set_byte_range` and `set_point_range` functions. Now, the QueryCursor will properly report matches that intersect, but are not fully contained by its range. Co-Authored-By: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
parent
f597cc6a75
commit
fda35894d4
5 changed files with 122 additions and 152 deletions
|
|
@ -1918,6 +1918,92 @@ fn test_query_captures_within_byte_range() {
|
|||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_captures_within_byte_range_assigned_after_iterating() {
|
||||
allocations::record(|| {
|
||||
let language = get_language("rust");
|
||||
let query = Query::new(
|
||||
language,
|
||||
r#"
|
||||
(function_item
|
||||
name: (identifier) @fn_name)
|
||||
|
||||
(mod_item
|
||||
name: (identifier) @mod_name
|
||||
body: (declaration_list
|
||||
"{" @lbrace
|
||||
"}" @rbrace))
|
||||
|
||||
; functions that return Result<()>
|
||||
((function_item
|
||||
return_type: (generic_type
|
||||
type: (type_identifier) @result
|
||||
type_arguments: (type_arguments
|
||||
(unit_type)))
|
||||
body: _ @fallible_fn_body)
|
||||
(#eq? @result "Result"))
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let source = "
|
||||
mod m1 {
|
||||
mod m2 {
|
||||
fn f1() -> Option<()> { Some(()) }
|
||||
}
|
||||
fn f2() -> Result<()> { Ok(()) }
|
||||
fn f3() {}
|
||||
}
|
||||
";
|
||||
|
||||
let mut parser = Parser::new();
|
||||
parser.set_language(language).unwrap();
|
||||
let tree = parser.parse(&source, None).unwrap();
|
||||
let mut cursor = QueryCursor::new();
|
||||
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
|
||||
|
||||
// Retrieve some captures
|
||||
let mut results = Vec::new();
|
||||
for (mat, capture_ix) in captures.by_ref().take(5) {
|
||||
let capture = mat.captures[capture_ix as usize];
|
||||
results.push((
|
||||
query.capture_names()[capture.index as usize].as_str(),
|
||||
&source[capture.node.byte_range()],
|
||||
));
|
||||
}
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
("mod_name", "m1"),
|
||||
("lbrace", "{"),
|
||||
("mod_name", "m2"),
|
||||
("lbrace", "{"),
|
||||
("fn_name", "f1"),
|
||||
]
|
||||
);
|
||||
|
||||
// Advance to a range that only partially intersects some matches.
|
||||
// Captures from these matches are reported, but only those that
|
||||
// intersect the range.
|
||||
results.clear();
|
||||
captures.set_byte_range(source.find("Ok").unwrap(), source.len());
|
||||
for (mat, capture_ix) in captures {
|
||||
let capture = mat.captures[capture_ix as usize];
|
||||
results.push((
|
||||
query.capture_names()[capture.index as usize].as_str(),
|
||||
&source[capture.node.byte_range()],
|
||||
));
|
||||
}
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
("fallible_fn_body", "{ Ok(()) }"),
|
||||
("fn_name", "f3"),
|
||||
("rbrace", "}")
|
||||
]
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_matches_different_queries_same_cursor() {
|
||||
allocations::record(|| {
|
||||
|
|
@ -3036,137 +3122,6 @@ fn test_query_text_callback_returns_chunks() {
|
|||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_captures_advance_to_byte() {
|
||||
allocations::record(|| {
|
||||
let language = get_language("rust");
|
||||
let query = Query::new(
|
||||
language,
|
||||
r#"
|
||||
(function_item
|
||||
name: (identifier) @fn_name)
|
||||
|
||||
(mod_item
|
||||
name: (identifier) @mod_name
|
||||
body: (declaration_list
|
||||
"{" @lbrace
|
||||
"}" @rbrace))
|
||||
|
||||
; functions that return Result<()>
|
||||
((function_item
|
||||
return_type: (generic_type
|
||||
type: (type_identifier) @result
|
||||
type_arguments: (type_arguments
|
||||
(unit_type)))
|
||||
body: _ @fallible_fn_body)
|
||||
(#eq? @result "Result"))
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let source = "
|
||||
mod m1 {
|
||||
mod m2 {
|
||||
fn f1() -> Option<()> { Some(()) }
|
||||
}
|
||||
fn f2() -> Result<()> { Ok(()) }
|
||||
fn f3() {}
|
||||
}
|
||||
";
|
||||
|
||||
let mut parser = Parser::new();
|
||||
parser.set_language(language).unwrap();
|
||||
let tree = parser.parse(&source, None).unwrap();
|
||||
let mut cursor = QueryCursor::new();
|
||||
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
|
||||
|
||||
// Retrieve some captures
|
||||
let mut results = Vec::new();
|
||||
for (mat, capture_ix) in captures.by_ref().take(5) {
|
||||
let capture = mat.captures[capture_ix as usize];
|
||||
results.push((
|
||||
query.capture_names()[capture.index as usize].as_str(),
|
||||
&source[capture.node.byte_range()],
|
||||
));
|
||||
}
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
("mod_name", "m1"),
|
||||
("lbrace", "{"),
|
||||
("mod_name", "m2"),
|
||||
("lbrace", "{"),
|
||||
("fn_name", "f1"),
|
||||
]
|
||||
);
|
||||
|
||||
results.clear();
|
||||
captures.advance_to_byte(source.find("Ok").unwrap());
|
||||
|
||||
// Advance further ahead in the source, retrieve the remaining captures.
|
||||
for (mat, capture_ix) in captures {
|
||||
let capture = mat.captures[capture_ix as usize];
|
||||
results.push((
|
||||
query.capture_names()[capture.index as usize].as_str(),
|
||||
&source[capture.node.byte_range()],
|
||||
));
|
||||
}
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
("fallible_fn_body", "{ Ok(()) }"),
|
||||
("fn_name", "f3"),
|
||||
("rbrace", "}")
|
||||
]
|
||||
);
|
||||
|
||||
// Advance past the last capture. There are no more captures.
|
||||
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
|
||||
captures.advance_to_byte(source.len());
|
||||
assert!(captures.next().is_none());
|
||||
assert!(captures.next().is_none());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_advance_to_byte_within_node() {
|
||||
allocations::record(|| {
|
||||
let language = get_language("rust");
|
||||
let query = Query::new(
|
||||
language,
|
||||
r#"
|
||||
(fn_item
|
||||
name: (identifier) @name
|
||||
return_type: _? @ret)
|
||||
|
||||
(mod_item
|
||||
name: (identifier) @name
|
||||
body: _ @body)
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let source = "
|
||||
fn foo() -> i32 {}
|
||||
|
||||
...
|
||||
|
||||
mod foo {}
|
||||
";
|
||||
|
||||
let mut parser = Parser::new();
|
||||
parser.set_language(language).unwrap();
|
||||
let tree = parser.parse(&source, None).unwrap();
|
||||
let mut cursor = QueryCursor::new();
|
||||
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
|
||||
|
||||
captures.advance_to_byte(source.find("{").unwrap());
|
||||
|
||||
assert_eq!(
|
||||
collect_captures(captures, &query, source),
|
||||
&[("body", "{}"),]
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_start_byte_for_pattern() {
|
||||
let language = get_language("javascript");
|
||||
|
|
|
|||
|
|
@ -1810,9 +1810,9 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> {
|
|||
}
|
||||
|
||||
impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> {
|
||||
pub fn advance_to_byte(&mut self, offset: usize) {
|
||||
pub fn set_byte_range(&mut self, start: usize, end: usize) {
|
||||
unsafe {
|
||||
ffi::ts_query_cursor_advance_to_byte(self.ptr, offset as u32);
|
||||
ffi::ts_query_cursor_set_byte_range(self.ptr, start as u32, end as u32);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -824,8 +824,6 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint);
|
|||
bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match);
|
||||
void ts_query_cursor_remove_match(TSQueryCursor *, uint32_t id);
|
||||
|
||||
void ts_query_cursor_advance_to_byte(TSQueryCursor *, uint32_t offset);
|
||||
|
||||
/**
|
||||
* Advance to the next capture of the currently running query.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -33,6 +33,10 @@ static inline bool point_lt(TSPoint a, TSPoint b) {
|
|||
return (a.row < b.row) || (a.row == b.row && a.column < b.column);
|
||||
}
|
||||
|
||||
static inline bool point_gt(TSPoint a, TSPoint b) {
|
||||
return (a.row > b.row) || (a.row == b.row && a.column > b.column);
|
||||
}
|
||||
|
||||
static inline bool point_eq(TSPoint a, TSPoint b) {
|
||||
return a.row == b.row && a.column == b.column;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -256,6 +256,9 @@ struct TSQueryCursor {
|
|||
CaptureListPool capture_list_pool;
|
||||
uint32_t depth;
|
||||
uint32_t start_byte;
|
||||
uint32_t end_byte;
|
||||
TSPoint start_point;
|
||||
TSPoint end_point;
|
||||
uint32_t next_state_id;
|
||||
bool ascending;
|
||||
bool halted;
|
||||
|
|
@ -2261,6 +2264,9 @@ TSQueryCursor *ts_query_cursor_new(void) {
|
|||
.finished_states = array_new(),
|
||||
.capture_list_pool = capture_list_pool_new(),
|
||||
.start_byte = 0,
|
||||
.end_byte = UINT32_MAX,
|
||||
.start_point = {0, 0},
|
||||
.end_point = POINT_MAX,
|
||||
};
|
||||
array_reserve(&self->states, 8);
|
||||
array_reserve(&self->finished_states, 8);
|
||||
|
|
@ -2290,7 +2296,6 @@ void ts_query_cursor_exec(
|
|||
capture_list_pool_reset(&self->capture_list_pool);
|
||||
self->next_state_id = 0;
|
||||
self->depth = 0;
|
||||
self->start_byte = 0;
|
||||
self->ascending = false;
|
||||
self->halted = false;
|
||||
self->query = query;
|
||||
|
|
@ -2302,6 +2307,11 @@ void ts_query_cursor_set_byte_range(
|
|||
uint32_t start_byte,
|
||||
uint32_t end_byte
|
||||
) {
|
||||
if (end_byte == 0) {
|
||||
end_byte = UINT32_MAX;
|
||||
}
|
||||
self->start_byte = start_byte;
|
||||
self->end_byte = end_byte;
|
||||
}
|
||||
|
||||
void ts_query_cursor_set_point_range(
|
||||
|
|
@ -2309,6 +2319,11 @@ void ts_query_cursor_set_point_range(
|
|||
TSPoint start_point,
|
||||
TSPoint end_point
|
||||
) {
|
||||
if (end_point.row == 0 && end_point.column == 0) {
|
||||
end_point = POINT_MAX;
|
||||
}
|
||||
self->start_point = start_point;
|
||||
self->end_point = end_point;
|
||||
}
|
||||
|
||||
// Search through all of the in-progress states, and find the captured
|
||||
|
|
@ -2337,7 +2352,10 @@ static bool ts_query_cursor__first_in_progress_capture(
|
|||
}
|
||||
|
||||
TSNode node = captures->contents[state->consumed_capture_count].node;
|
||||
if (ts_node_end_byte(node) <= self->start_byte) {
|
||||
if (
|
||||
ts_node_end_byte(node) <= self->start_byte ||
|
||||
point_lte(ts_node_end_point(node), self->start_point)
|
||||
) {
|
||||
state->consumed_capture_count++;
|
||||
i--;
|
||||
continue;
|
||||
|
|
@ -2682,12 +2700,8 @@ static inline bool ts_query_cursor__advance(
|
|||
|
||||
// Enter a new node.
|
||||
else {
|
||||
// If this node is before the selected range, then avoid descending into it.
|
||||
TSNode node = ts_tree_cursor_current_node(&self->cursor);
|
||||
|
||||
bool node_exceeds_start_byte = ts_node_end_byte(node) > self->start_byte;
|
||||
|
||||
// Get the properties of the current node.
|
||||
TSNode node = ts_tree_cursor_current_node(&self->cursor);
|
||||
TSSymbol symbol = ts_node_symbol(node);
|
||||
bool is_named = ts_node_is_named(node);
|
||||
bool has_later_siblings;
|
||||
|
|
@ -2714,7 +2728,14 @@ static inline bool ts_query_cursor__advance(
|
|||
self->finished_states.size
|
||||
);
|
||||
|
||||
if (node_exceeds_start_byte) {
|
||||
bool node_intersects_range = (
|
||||
ts_node_end_byte(node) > self->start_byte &&
|
||||
ts_node_start_byte(node) < self->end_byte &&
|
||||
point_gt(ts_node_end_point(node), self->start_point) &&
|
||||
point_lt(ts_node_start_point(node), self->end_point)
|
||||
);
|
||||
|
||||
if (node_intersects_range) {
|
||||
// Add new states for any patterns whose root node is a wildcard.
|
||||
for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) {
|
||||
PatternEntry *pattern = &self->query->pattern_map.contents[i];
|
||||
|
|
@ -3039,7 +3060,7 @@ static inline bool ts_query_cursor__advance(
|
|||
|
||||
// When the current node ends prior to the desired start offset,
|
||||
// only descend for the purpose of continuing in-progress matches.
|
||||
bool should_descend = node_exceeds_start_byte;
|
||||
bool should_descend = node_intersects_range;
|
||||
if (!should_descend) {
|
||||
for (unsigned i = 0; i < self->states.size; i++) {
|
||||
QueryState *state = &self->states.contents[i];;
|
||||
|
|
@ -3071,14 +3092,6 @@ static inline bool ts_query_cursor__advance(
|
|||
}
|
||||
}
|
||||
|
||||
void ts_query_cursor_advance_to_byte(
|
||||
TSQueryCursor *self,
|
||||
uint32_t offset
|
||||
) {
|
||||
LOG("advance_to_byte %u\n", offset);
|
||||
self->start_byte = offset;
|
||||
}
|
||||
|
||||
bool ts_query_cursor_next_match(
|
||||
TSQueryCursor *self,
|
||||
TSQueryMatch *match
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue