Preserve matches that contain the QueryCursor's start byte

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
Co-Authored-By: Antonio Scandurra <me@as-cii.com>
This commit is contained in:
Max Brunsfeld 2021-05-25 13:06:24 -07:00
parent a61f25bc58
commit f597cc6a75
5 changed files with 214 additions and 171 deletions

View file

@ -3039,32 +3039,49 @@ fn test_query_text_callback_returns_chunks() {
#[test]
fn test_query_captures_advance_to_byte() {
allocations::record(|| {
let language = get_language("javascript");
let language = get_language("rust");
let query = Query::new(
language,
r#"
(identifier) @id
(array
"[" @lbracket
"]" @rbracket)
(function_item
name: (identifier) @fn_name)
(mod_item
name: (identifier) @mod_name
body: (declaration_list
"{" @lbrace
"}" @rbrace))
; functions that return Result<()>
((function_item
return_type: (generic_type
type: (type_identifier) @result
type_arguments: (type_arguments
(unit_type)))
body: _ @fallible_fn_body)
(#eq? @result "Result"))
"#,
)
.unwrap();
let source = "[one, two, [three, four, five, six, seven, eight, nine, ten], eleven, twelve, thirteen]";
let source = "
mod m1 {
mod m2 {
fn f1() -> Option<()> { Some(()) }
}
fn f2() -> Result<()> { Ok(()) }
fn f3() {}
}
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
cursor.set_byte_range(
source.find("two").unwrap() + 1,
source.find(", twelve").unwrap(),
);
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
// Retrieve four captures.
// Retrieve some captures
let mut results = Vec::new();
for (mat, capture_ix) in captures.by_ref().take(4) {
for (mat, capture_ix) in captures.by_ref().take(5) {
let capture = mat.captures[capture_ix as usize];
results.push((
query.capture_names()[capture.index as usize].as_str(),
@ -3074,16 +3091,18 @@ fn test_query_captures_advance_to_byte() {
assert_eq!(
results,
vec![
("id", "two"),
("lbracket", "["),
("id", "three"),
("id", "four")
("mod_name", "m1"),
("lbrace", "{"),
("mod_name", "m2"),
("lbrace", "{"),
("fn_name", "f1"),
]
);
// Advance further ahead in the source, retrieve the remaining captures.
results.clear();
captures.advance_to_byte(source.find("ten").unwrap() + 1);
captures.advance_to_byte(source.find("Ok").unwrap());
// Advance further ahead in the source, retrieve the remaining captures.
for (mat, capture_ix) in captures {
let capture = mat.captures[capture_ix as usize];
results.push((
@ -3093,7 +3112,11 @@ fn test_query_captures_advance_to_byte() {
}
assert_eq!(
results,
vec![("id", "ten"), ("rbracket", "]"), ("id", "eleven"),]
vec![
("fallible_fn_body", "{ Ok(()) }"),
("fn_name", "f3"),
("rbrace", "}")
]
);
// Advance past the last capture. There are no more captures.
@ -3104,6 +3127,46 @@ fn test_query_captures_advance_to_byte() {
});
}
#[test]
fn test_query_advance_to_byte_within_node() {
allocations::record(|| {
let language = get_language("rust");
let query = Query::new(
language,
r#"
(fn_item
name: (identifier) @name
return_type: _? @ret)
(mod_item
name: (identifier) @name
body: _ @body)
"#,
)
.unwrap();
let source = "
fn foo() -> i32 {}
...
mod foo {}
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes());
captures.advance_to_byte(source.find("{").unwrap());
assert_eq!(
collect_captures(captures, &query, source),
&[("body", "{}"),]
);
})
}
#[test]
fn test_query_start_byte_for_pattern() {
let language = get_language("javascript");

View file

@ -737,18 +737,8 @@ extern "C" {
pub fn ts_query_cursor_did_exceed_match_limit(arg1: *const TSQueryCursor) -> bool;
}
extern "C" {
#[doc = " Get or set the range of bytes or (row, column) positions in which the query"]
#[doc = " Set the range of bytes or (row, column) positions in which the query"]
#[doc = " will be executed."]
pub fn ts_query_cursor_byte_range(arg1: *const TSQueryCursor, arg2: *mut u32, arg3: *mut u32);
}
extern "C" {
pub fn ts_query_cursor_point_range(
arg1: *const TSQueryCursor,
arg2: *mut TSPoint,
arg3: *mut TSPoint,
);
}
extern "C" {
pub fn ts_query_cursor_set_byte_range(arg1: *mut TSQueryCursor, arg2: u32, arg3: u32);
}
extern "C" {
@ -764,6 +754,9 @@ extern "C" {
extern "C" {
pub fn ts_query_cursor_remove_match(arg1: *mut TSQueryCursor, id: u32);
}
extern "C" {
pub fn ts_query_cursor_advance_to_byte(arg1: *mut TSQueryCursor, offset: u32);
}
extern "C" {
#[doc = " Advance to the next capture of the currently running query."]
#[doc = ""]

View file

@ -1812,27 +1812,7 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> {
impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> {
pub fn advance_to_byte(&mut self, offset: usize) {
unsafe {
let mut current_start = 0u32;
let mut current_end = 0u32;
ffi::ts_query_cursor_byte_range(
self.ptr,
&mut current_start as *mut u32,
&mut current_end as *mut u32,
);
ffi::ts_query_cursor_set_byte_range(self.ptr, offset as u32, current_end);
}
}
pub fn advance_to_point(&mut self, point: Point) {
unsafe {
let mut current_start = ffi::TSPoint { row: 0, column: 0 };
let mut current_end = current_start;
ffi::ts_query_cursor_point_range(
self.ptr,
&mut current_start as *mut _,
&mut current_end as *mut _,
);
ffi::ts_query_cursor_set_point_range(self.ptr, point.into(), current_end);
ffi::ts_query_cursor_advance_to_byte(self.ptr, offset as u32);
}
}
}

View file

@ -809,11 +809,9 @@ void ts_query_cursor_exec(TSQueryCursor *, const TSQuery *, TSNode);
bool ts_query_cursor_did_exceed_match_limit(const TSQueryCursor *);
/**
* Get or set the range of bytes or (row, column) positions in which the query
* Set the range of bytes or (row, column) positions in which the query
* will be executed.
*/
void ts_query_cursor_byte_range(const TSQueryCursor *, uint32_t *, uint32_t *);
void ts_query_cursor_point_range(const TSQueryCursor *, TSPoint *, TSPoint *);
void ts_query_cursor_set_byte_range(TSQueryCursor *, uint32_t, uint32_t);
void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint);
@ -826,6 +824,8 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint);
bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match);
void ts_query_cursor_remove_match(TSQueryCursor *, uint32_t id);
void ts_query_cursor_advance_to_byte(TSQueryCursor *, uint32_t offset);
/**
* Advance to the next capture of the currently running query.
*

View file

@ -256,10 +256,7 @@ struct TSQueryCursor {
CaptureListPool capture_list_pool;
uint32_t depth;
uint32_t start_byte;
uint32_t end_byte;
uint32_t next_state_id;
TSPoint start_point;
TSPoint end_point;
bool ascending;
bool halted;
bool did_exceed_match_limit;
@ -2264,9 +2261,6 @@ TSQueryCursor *ts_query_cursor_new(void) {
.finished_states = array_new(),
.capture_list_pool = capture_list_pool_new(),
.start_byte = 0,
.end_byte = UINT32_MAX,
.start_point = {0, 0},
.end_point = POINT_MAX,
};
array_reserve(&self->states, 8);
array_reserve(&self->finished_states, 8);
@ -2296,40 +2290,18 @@ void ts_query_cursor_exec(
capture_list_pool_reset(&self->capture_list_pool);
self->next_state_id = 0;
self->depth = 0;
self->start_byte = 0;
self->ascending = false;
self->halted = false;
self->query = query;
self->did_exceed_match_limit = false;
}
void ts_query_cursor_byte_range(
const TSQueryCursor *self,
uint32_t *start_byte,
uint32_t *end_byte
) {
*start_byte = self->start_byte;
*end_byte = self->end_byte;
}
void ts_query_cursor_point_range(
const TSQueryCursor *self,
TSPoint *start_point,
TSPoint *end_point
) {
*start_point = self->start_point;
*end_point = self->end_point;
}
void ts_query_cursor_set_byte_range(
TSQueryCursor *self,
uint32_t start_byte,
uint32_t end_byte
) {
if (end_byte == 0) {
end_byte = UINT32_MAX;
}
self->start_byte = start_byte;
self->end_byte = end_byte;
}
void ts_query_cursor_set_point_range(
@ -2337,11 +2309,6 @@ void ts_query_cursor_set_point_range(
TSPoint start_point,
TSPoint end_point
) {
if (end_point.row == 0 && end_point.column == 0) {
end_point = POINT_MAX;
}
self->start_point = start_point;
self->end_point = end_point;
}
// Search through all of the in-progress states, and find the captured
@ -2358,31 +2325,41 @@ static bool ts_query_cursor__first_in_progress_capture(
*byte_offset = UINT32_MAX;
*pattern_index = UINT32_MAX;
for (unsigned i = 0; i < self->states.size; i++) {
const QueryState *state = &self->states.contents[i];
QueryState *state = &self->states.contents[i];
if (state->dead) continue;
const CaptureList *captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
if (captures->size > state->consumed_capture_count) {
uint32_t capture_byte = ts_node_start_byte(captures->contents[state->consumed_capture_count].node);
if (
!result ||
capture_byte < *byte_offset ||
(capture_byte == *byte_offset && state->pattern_index < *pattern_index)
) {
QueryStep *step = &self->query->steps.contents[state->step_index];
if (is_definite) {
*is_definite = step->is_definite;
} else if (step->is_definite) {
continue;
}
if (state->consumed_capture_count >= captures->size) {
continue;
}
result = true;
*state_index = i;
*byte_offset = capture_byte;
*pattern_index = state->pattern_index;
TSNode node = captures->contents[state->consumed_capture_count].node;
if (ts_node_end_byte(node) <= self->start_byte) {
state->consumed_capture_count++;
i--;
continue;
}
uint32_t node_start_byte = ts_node_start_byte(node);
if (
!result ||
node_start_byte < *byte_offset ||
(node_start_byte == *byte_offset && state->pattern_index < *pattern_index)
) {
QueryStep *step = &self->query->steps.contents[state->step_index];
if (is_definite) {
*is_definite = step->is_definite;
} else if (step->is_definite) {
continue;
}
result = true;
*state_index = i;
*byte_offset = node_start_byte;
*pattern_index = state->pattern_index;
}
}
return result;
@ -2707,26 +2684,8 @@ static inline bool ts_query_cursor__advance(
else {
// If this node is before the selected range, then avoid descending into it.
TSNode node = ts_tree_cursor_current_node(&self->cursor);
if (
ts_node_end_byte(node) <= self->start_byte ||
point_lte(ts_node_end_point(node), self->start_point)
) {
if (!ts_tree_cursor_goto_next_sibling(&self->cursor)) {
self->ascending = true;
}
LOG("skip until start of range\n");
continue;
}
// If this node is after the selected range, then stop walking.
if (
self->end_byte <= ts_node_start_byte(node) ||
point_lte(self->end_point, ts_node_start_point(node))
) {
LOG("halt at end of range\n");
self->halted = true;
continue;
}
bool node_exceeds_start_byte = ts_node_end_byte(node) > self->start_byte;
// Get the properties of the current node.
TSSymbol symbol = ts_node_symbol(node);
@ -2755,36 +2714,44 @@ static inline bool ts_query_cursor__advance(
self->finished_states.size
);
// Add new states for any patterns whose root node is a wildcard.
for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) {
PatternEntry *pattern = &self->query->pattern_map.contents[i];
QueryStep *step = &self->query->steps.contents[pattern->step_index];
if (node_exceeds_start_byte) {
// Add new states for any patterns whose root node is a wildcard.
for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) {
PatternEntry *pattern = &self->query->pattern_map.contents[i];
QueryStep *step = &self->query->steps.contents[pattern->step_index];
// If this node matches the first step of the pattern, then add a new
// state at the start of this pattern.
if (step->field && field_id != step->field) continue;
if (step->supertype_symbol && !supertype_count) continue;
ts_query_cursor__add_state(self, pattern);
}
// Add new states for any patterns whose root node matches this node.
unsigned i;
if (ts_query__pattern_map_search(self->query, symbol, &i)) {
PatternEntry *pattern = &self->query->pattern_map.contents[i];
QueryStep *step = &self->query->steps.contents[pattern->step_index];
do {
// If this node matches the first step of the pattern, then add a new
// state at the start of this pattern.
if (!step->field || field_id == step->field) {
ts_query_cursor__add_state(self, pattern);
}
if (step->field && field_id != step->field) continue;
if (step->supertype_symbol && !supertype_count) continue;
ts_query_cursor__add_state(self, pattern);
}
// Advance to the next pattern whose root node matches this node.
i++;
if (i == self->query->pattern_map.size) break;
pattern = &self->query->pattern_map.contents[i];
step = &self->query->steps.contents[pattern->step_index];
} while (step->symbol == symbol);
// Add new states for any patterns whose root node matches this node.
unsigned i;
if (ts_query__pattern_map_search(self->query, symbol, &i)) {
PatternEntry *pattern = &self->query->pattern_map.contents[i];
QueryStep *step = &self->query->steps.contents[pattern->step_index];
do {
// If this node matches the first step of the pattern, then add a new
// state at the start of this pattern.
if (!step->field || field_id == step->field) {
ts_query_cursor__add_state(self, pattern);
}
// Advance to the next pattern whose root node matches this node.
i++;
if (i == self->query->pattern_map.size) break;
pattern = &self->query->pattern_map.contents[i];
step = &self->query->steps.contents[pattern->step_index];
} while (step->symbol == symbol);
}
} else {
LOG(
" not starting new patterns. node end byte: %u, start_byte: %u\n",
ts_node_end_byte(node),
self->start_byte
);
}
// Update all of the in-progress states with current node.
@ -3070,8 +3037,32 @@ static inline bool ts_query_cursor__advance(
}
}
// Continue descending if possible.
if (ts_tree_cursor_goto_first_child(&self->cursor)) {
// When the current node ends prior to the desired start offset,
// only descend for the purpose of continuing in-progress matches.
bool should_descend = node_exceeds_start_byte;
if (!should_descend) {
for (unsigned i = 0; i < self->states.size; i++) {
QueryState *state = &self->states.contents[i];;
QueryStep *next_step = &self->query->steps.contents[state->step_index];
if (
next_step->depth != PATTERN_DONE_MARKER &&
state->start_depth + next_step->depth > self->depth
) {
should_descend = true;
break;
}
}
}
if (!should_descend) {
LOG(
" not descending. node end byte: %u, start byte: %u\n",
ts_node_end_byte(node),
self->start_byte
);
}
if (should_descend && ts_tree_cursor_goto_first_child(&self->cursor)) {
self->depth++;
} else {
self->ascending = true;
@ -3080,6 +3071,14 @@ static inline bool ts_query_cursor__advance(
}
}
void ts_query_cursor_advance_to_byte(
TSQueryCursor *self,
uint32_t offset
) {
LOG("advance_to_byte %u\n", offset);
self->start_byte = offset;
}
bool ts_query_cursor_next_match(
TSQueryCursor *self,
TSQueryMatch *match
@ -3148,35 +3147,43 @@ bool ts_query_cursor_next_capture(
QueryState *first_finished_state = NULL;
uint32_t first_finished_capture_byte = first_unfinished_capture_byte;
uint32_t first_finished_pattern_index = first_unfinished_pattern_index;
for (unsigned i = 0; i < self->finished_states.size; i++) {
for (unsigned i = 0; i < self->finished_states.size;) {
QueryState *state = &self->finished_states.contents[i];
const CaptureList *captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
if (captures->size > state->consumed_capture_count) {
uint32_t capture_byte = ts_node_start_byte(
captures->contents[state->consumed_capture_count].node
);
if (
capture_byte < first_finished_capture_byte ||
(
capture_byte == first_finished_capture_byte &&
state->pattern_index < first_finished_pattern_index
)
) {
first_finished_state = state;
first_finished_capture_byte = capture_byte;
first_finished_pattern_index = state->pattern_index;
}
} else {
// Remove states whose captures are all consumed.
if (state->consumed_capture_count >= captures->size) {
capture_list_pool_release(
&self->capture_list_pool,
state->capture_list_id
);
array_erase(&self->finished_states, i);
i--;
continue;
}
// Skip captures that precede the cursor's start byte.
TSNode node = captures->contents[state->consumed_capture_count].node;
if (ts_node_end_byte(node) <= self->start_byte) {
state->consumed_capture_count++;
continue;
}
uint32_t node_start_byte = ts_node_start_byte(node);
if (
node_start_byte < first_finished_capture_byte ||
(
node_start_byte == first_finished_capture_byte &&
state->pattern_index < first_finished_pattern_index
)
) {
first_finished_state = state;
first_finished_capture_byte = node_start_byte;
first_finished_pattern_index = state->pattern_index;
}
i++;
}
// If there is finished capture that is clearly before any unfinished