query: Handle captured wildcard nodes at the root of patterns

This commit is contained in:
Max Brunsfeld 2020-10-08 12:34:08 -07:00
parent 1f3248a3e0
commit 857a9ed07b
5 changed files with 212 additions and 35 deletions

View file

@ -1691,6 +1691,93 @@ fn test_query_matches_with_multiple_captures_on_a_node() {
});
}
#[test]
fn test_query_matches_with_captured_wildcard_at_root() {
allocations::record(|| {
let language = get_language("python");
let query = Query::new(
language,
"
; captured wildcard at the root
(_ [
(except_clause (block) @block)
(finally_clause (block) @block)
]) @stmt
[
(while_statement (block) @block)
(if_statement (block) @block)
; captured wildcard at the root within an alternation
(_ [
(else_clause (block) @block)
(elif_clause (block) @block)
])
(try_statement (block) @block)
(for_statement (block) @block)
] @stmt
",
)
.unwrap();
let source = "
for i in j:
while True:
if a:
print b
elif c:
print d
else:
try:
print f
except:
print g
finally:
print h
else:
print i
"
.trim();
let mut parser = Parser::new();
let mut cursor = QueryCursor::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let match_capture_names_and_rows = cursor
.matches(&query, tree.root_node(), to_callback(source))
.map(|m| {
m.captures
.iter()
.map(|c| {
(
query.capture_names()[c.index as usize].as_str(),
c.node.kind(),
c.node.start_position().row,
)
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
assert_eq!(
match_capture_names_and_rows,
&[
vec![("stmt", "for_statement", 0), ("block", "block", 1)],
vec![("stmt", "while_statement", 1), ("block", "block", 2)],
vec![("stmt", "if_statement", 2), ("block", "block", 3)],
vec![("stmt", "if_statement", 2), ("block", "block", 5)],
vec![("stmt", "if_statement", 2), ("block", "block", 7)],
vec![("stmt", "try_statement", 7), ("block", "block", 8)],
vec![("stmt", "try_statement", 7), ("block", "block", 10)],
vec![("stmt", "try_statement", 7), ("block", "block", 12)],
vec![("stmt", "while_statement", 1), ("block", "block", 14)],
]
)
});
}
#[test]
fn test_query_matches_with_no_captures() {
allocations::record(|| {

View file

@ -1273,7 +1273,7 @@ impl Query {
let raw_predicates =
ffi::ts_query_predicates_for_pattern(ptr, i as u32, &mut length as *mut u32);
if length > 0 {
slice::from_raw_parts(raw_predicates, length as usize)
slice::from_raw_parts(raw_predicates, length as usize)
} else {
&[]
}
@ -1655,10 +1655,10 @@ impl<'a> QueryMatch<'a> {
pattern_index: m.pattern_index as usize,
captures: if m.capture_count > 0 {
unsafe {
slice::from_raw_parts(
m.captures as *const QueryCapture<'a>,
m.capture_count as usize,
)
slice::from_raw_parts(
m.captures as *const QueryCapture<'a>,
m.capture_count as usize,
)
}
} else {
&[]

View file

@ -138,6 +138,7 @@ typedef struct {
bool seeking_immediate_match: 1;
bool has_in_progress_alternatives: 1;
bool dead: 1;
bool needs_parent: 1;
} QueryState;
typedef Array(TSQueryCapture) CaptureList;
@ -2011,20 +2012,24 @@ TSQuery *ts_query_new(
return NULL;
}
// If a pattern has a wildcard at its root, optimize the matching process
// by skipping matching the wildcard.
if (
self->steps.contents[start_step_index].symbol == WILDCARD_SYMBOL
) {
QueryStep *second_step = &self->steps.contents[start_step_index + 1];
if (second_step->symbol != WILDCARD_SYMBOL && second_step->depth != PATTERN_DONE_MARKER) {
start_step_index += 1;
}
}
// Maintain a map that can look up patterns for a given root symbol.
uint16_t wildcard_root_alternative_index = NONE;
for (;;) {
QueryStep *step = &self->steps.contents[start_step_index];
// If a pattern has a wildcard at its root, but it has a non-wildcard child,
// then optimize the matching process by skipping matching the wildcard.
// Later, during the matching process, the query cursor will check that
// there is a parent node, and capture it if necessary.
if (step->symbol == WILDCARD_SYMBOL && step->depth == 0) {
QueryStep *second_step = &self->steps.contents[start_step_index + 1];
if (second_step->symbol != WILDCARD_SYMBOL && second_step->depth == 1) {
wildcard_root_alternative_index = step->alternative_index;
start_step_index += 1;
step = second_step;
}
}
ts_query__pattern_map_insert(self, step->symbol, start_step_index, pattern_index);
if (step->symbol == WILDCARD_SYMBOL) {
self->wildcard_root_pattern_count++;
@ -2035,6 +2040,9 @@ TSQuery *ts_query_new(
if (step->alternative_index != NONE) {
start_step_index = step->alternative_index;
step->alternative_index = NONE;
} else if (wildcard_root_alternative_index != NONE) {
start_step_index = wildcard_root_alternative_index;
wildcard_root_alternative_index = NONE;
} else {
break;
}
@ -2386,8 +2394,8 @@ static void ts_query_cursor__add_state(
if (prev_state->start_depth == start_depth) {
if (prev_state->pattern_index < pattern->pattern_index) break;
if (prev_state->pattern_index == pattern->pattern_index) {
// Avoid unnecessarily inserting an unnecessary duplicate state,
// which would be immediately pruned by the longest-match criteria.
// Avoid inserting an unnecessary duplicate state, which would be
// immediately pruned by the longest-match criteria.
if (prev_state->step_index == pattern->step_index) return;
}
}
@ -2407,6 +2415,7 @@ static void ts_query_cursor__add_state(
.consumed_capture_count = 0,
.seeking_immediate_match = true,
.has_in_progress_alternatives = false,
.needs_parent = step->depth == 1,
.dead = false,
}));
}
@ -2460,6 +2469,33 @@ static CaptureList *ts_query_cursor__prepare_to_capture(
return capture_list_pool_get_mut(&self->capture_list_pool, state->capture_list_id);
}
static void ts_query_cursor__capture(
TSQueryCursor *self,
QueryState *state,
QueryStep *step,
TSNode node
) {
if (state->dead) return;
CaptureList *capture_list = ts_query_cursor__prepare_to_capture(self, state, UINT32_MAX);
if (!capture_list) {
state->dead = true;
return;
}
for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) {
uint16_t capture_id = step->capture_ids[j];
if (step->capture_ids[j] == NONE) break;
array_push(capture_list, ((TSQueryCapture) { node, capture_id }));
LOG(
" capture node. type:%s, pattern:%u, capture_id:%u, capture_count:%u\n",
ts_node_type(node),
state->pattern_index,
capture_id,
capture_list->size
);
}
}
// Duplicate the given state and insert the newly-created state immediately after
// the given state in the `states` array. Ensures that the given state reference is
// still valid, even if the states array is reallocated.
@ -2730,26 +2766,45 @@ static inline bool ts_query_cursor__advance(
}
}
// If this pattern started with a wildcard, such that the pattern map
// actually points to the *second* step of the pattern, then check
// that the node has a parent, and capture the parent node if necessary.
if (state->needs_parent) {
TSNode parent = ts_tree_cursor_parent_node(&self->cursor);
if (ts_node_is_null(parent)) {
LOG(" missing parent node\n");
state->dead = true;
} else {
state->needs_parent = false;
QueryStep *skipped_wildcard_step = step;
do {
skipped_wildcard_step--;
} while (
skipped_wildcard_step->is_dead_end ||
skipped_wildcard_step->is_pass_through ||
skipped_wildcard_step->depth > 0
);
if (skipped_wildcard_step->capture_ids[0] != NONE) {
LOG(" capture wildcard parent\n");
ts_query_cursor__capture(
self,
state,
skipped_wildcard_step,
parent
);
}
}
}
// If the current node is captured in this pattern, add it to the capture list.
if (step->capture_ids[0] != NONE) {
CaptureList *capture_list = ts_query_cursor__prepare_to_capture(self, state, UINT32_MAX);
if (!capture_list) {
array_erase(&self->states, i);
i--;
continue;
}
ts_query_cursor__capture(self, state, step, node);
}
for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) {
uint16_t capture_id = step->capture_ids[j];
if (step->capture_ids[j] == NONE) break;
array_push(capture_list, ((TSQueryCapture) { node, capture_id }));
LOG(
" capture node. pattern:%u, capture_id:%u, capture_count:%u\n",
state->pattern_index,
capture_id,
capture_list->size
);
}
if (state->dead) {
array_erase(&self->states, i);
i--;
continue;
}
// Advance this state to the next step of its pattern.
@ -2772,12 +2827,18 @@ static inline bool ts_query_cursor__advance(
QueryState *state = &self->states.contents[j];
QueryStep *next_step = &self->query->steps.contents[state->step_index];
if (next_step->alternative_index != NONE) {
// A "dead-end" step exists only to add a non-sequential jump into the step sequence,
// via its alternative index. When a state reaches a dead-end step, it jumps straight
// to the step's alternative.
if (next_step->is_dead_end) {
state->step_index = next_step->alternative_index;
j--;
continue;
}
// A "pass-through" step exists only to add a branch into the step sequence,
// via its alternative_index. When a state reaches a pass-through step, it splits
// in order to process the alternative step, and then it advances to the next step.
if (next_step->is_pass_through) {
state->step_index++;
j--;

View file

@ -364,6 +364,33 @@ void ts_tree_cursor_current_status(
}
}
TSNode ts_tree_cursor_parent_node(const TSTreeCursor *_self) {
const TreeCursor *self = (const TreeCursor *)_self;
for (int i = (int)self->stack.size - 2; i >= 0; i--) {
TreeCursorEntry *entry = &self->stack.contents[i];
bool is_visible = true;
TSSymbol alias_symbol = 0;
if (i > 0) {
TreeCursorEntry *parent_entry = &self->stack.contents[i - 1];
alias_symbol = ts_language_alias_at(
self->tree->language,
parent_entry->subtree->ptr->production_id,
entry->structural_child_index
);
is_visible = (alias_symbol != 0) || ts_subtree_visible(*entry->subtree);
}
if (is_visible) {
return ts_node_new(
self->tree,
entry->subtree,
entry->position,
alias_symbol
);
}
}
return ts_node_new(NULL, NULL, length_zero(), 0);
}
TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) {
const TreeCursor *self = (const TreeCursor *)_self;

View file

@ -26,4 +26,6 @@ void ts_tree_cursor_current_status(
unsigned *
);
TSNode ts_tree_cursor_parent_node(const TSTreeCursor *);
#endif // TREE_SITTER_TREE_CURSOR_H_