query: Handle captured wildcard nodes at the root of patterns
This commit is contained in:
parent
1f3248a3e0
commit
857a9ed07b
5 changed files with 212 additions and 35 deletions
|
|
@ -1691,6 +1691,93 @@ fn test_query_matches_with_multiple_captures_on_a_node() {
|
|||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_matches_with_captured_wildcard_at_root() {
|
||||
allocations::record(|| {
|
||||
let language = get_language("python");
|
||||
let query = Query::new(
|
||||
language,
|
||||
"
|
||||
; captured wildcard at the root
|
||||
(_ [
|
||||
(except_clause (block) @block)
|
||||
(finally_clause (block) @block)
|
||||
]) @stmt
|
||||
|
||||
[
|
||||
(while_statement (block) @block)
|
||||
(if_statement (block) @block)
|
||||
|
||||
; captured wildcard at the root within an alternation
|
||||
(_ [
|
||||
(else_clause (block) @block)
|
||||
(elif_clause (block) @block)
|
||||
])
|
||||
|
||||
(try_statement (block) @block)
|
||||
(for_statement (block) @block)
|
||||
] @stmt
|
||||
",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let source = "
|
||||
for i in j:
|
||||
while True:
|
||||
if a:
|
||||
print b
|
||||
elif c:
|
||||
print d
|
||||
else:
|
||||
try:
|
||||
print f
|
||||
except:
|
||||
print g
|
||||
finally:
|
||||
print h
|
||||
else:
|
||||
print i
|
||||
"
|
||||
.trim();
|
||||
|
||||
let mut parser = Parser::new();
|
||||
let mut cursor = QueryCursor::new();
|
||||
parser.set_language(language).unwrap();
|
||||
let tree = parser.parse(&source, None).unwrap();
|
||||
|
||||
let match_capture_names_and_rows = cursor
|
||||
.matches(&query, tree.root_node(), to_callback(source))
|
||||
.map(|m| {
|
||||
m.captures
|
||||
.iter()
|
||||
.map(|c| {
|
||||
(
|
||||
query.capture_names()[c.index as usize].as_str(),
|
||||
c.node.kind(),
|
||||
c.node.start_position().row,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
match_capture_names_and_rows,
|
||||
&[
|
||||
vec![("stmt", "for_statement", 0), ("block", "block", 1)],
|
||||
vec![("stmt", "while_statement", 1), ("block", "block", 2)],
|
||||
vec![("stmt", "if_statement", 2), ("block", "block", 3)],
|
||||
vec![("stmt", "if_statement", 2), ("block", "block", 5)],
|
||||
vec![("stmt", "if_statement", 2), ("block", "block", 7)],
|
||||
vec![("stmt", "try_statement", 7), ("block", "block", 8)],
|
||||
vec![("stmt", "try_statement", 7), ("block", "block", 10)],
|
||||
vec![("stmt", "try_statement", 7), ("block", "block", 12)],
|
||||
vec![("stmt", "while_statement", 1), ("block", "block", 14)],
|
||||
]
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_matches_with_no_captures() {
|
||||
allocations::record(|| {
|
||||
|
|
|
|||
|
|
@ -1273,7 +1273,7 @@ impl Query {
|
|||
let raw_predicates =
|
||||
ffi::ts_query_predicates_for_pattern(ptr, i as u32, &mut length as *mut u32);
|
||||
if length > 0 {
|
||||
slice::from_raw_parts(raw_predicates, length as usize)
|
||||
slice::from_raw_parts(raw_predicates, length as usize)
|
||||
} else {
|
||||
&[]
|
||||
}
|
||||
|
|
@ -1655,10 +1655,10 @@ impl<'a> QueryMatch<'a> {
|
|||
pattern_index: m.pattern_index as usize,
|
||||
captures: if m.capture_count > 0 {
|
||||
unsafe {
|
||||
slice::from_raw_parts(
|
||||
m.captures as *const QueryCapture<'a>,
|
||||
m.capture_count as usize,
|
||||
)
|
||||
slice::from_raw_parts(
|
||||
m.captures as *const QueryCapture<'a>,
|
||||
m.capture_count as usize,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
&[]
|
||||
|
|
|
|||
121
lib/src/query.c
121
lib/src/query.c
|
|
@ -138,6 +138,7 @@ typedef struct {
|
|||
bool seeking_immediate_match: 1;
|
||||
bool has_in_progress_alternatives: 1;
|
||||
bool dead: 1;
|
||||
bool needs_parent: 1;
|
||||
} QueryState;
|
||||
|
||||
typedef Array(TSQueryCapture) CaptureList;
|
||||
|
|
@ -2011,20 +2012,24 @@ TSQuery *ts_query_new(
|
|||
return NULL;
|
||||
}
|
||||
|
||||
// If a pattern has a wildcard at its root, optimize the matching process
|
||||
// by skipping matching the wildcard.
|
||||
if (
|
||||
self->steps.contents[start_step_index].symbol == WILDCARD_SYMBOL
|
||||
) {
|
||||
QueryStep *second_step = &self->steps.contents[start_step_index + 1];
|
||||
if (second_step->symbol != WILDCARD_SYMBOL && second_step->depth != PATTERN_DONE_MARKER) {
|
||||
start_step_index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Maintain a map that can look up patterns for a given root symbol.
|
||||
uint16_t wildcard_root_alternative_index = NONE;
|
||||
for (;;) {
|
||||
QueryStep *step = &self->steps.contents[start_step_index];
|
||||
|
||||
// If a pattern has a wildcard at its root, but it has a non-wildcard child,
|
||||
// then optimize the matching process by skipping matching the wildcard.
|
||||
// Later, during the matching process, the query cursor will check that
|
||||
// there is a parent node, and capture it if necessary.
|
||||
if (step->symbol == WILDCARD_SYMBOL && step->depth == 0) {
|
||||
QueryStep *second_step = &self->steps.contents[start_step_index + 1];
|
||||
if (second_step->symbol != WILDCARD_SYMBOL && second_step->depth == 1) {
|
||||
wildcard_root_alternative_index = step->alternative_index;
|
||||
start_step_index += 1;
|
||||
step = second_step;
|
||||
}
|
||||
}
|
||||
|
||||
ts_query__pattern_map_insert(self, step->symbol, start_step_index, pattern_index);
|
||||
if (step->symbol == WILDCARD_SYMBOL) {
|
||||
self->wildcard_root_pattern_count++;
|
||||
|
|
@ -2035,6 +2040,9 @@ TSQuery *ts_query_new(
|
|||
if (step->alternative_index != NONE) {
|
||||
start_step_index = step->alternative_index;
|
||||
step->alternative_index = NONE;
|
||||
} else if (wildcard_root_alternative_index != NONE) {
|
||||
start_step_index = wildcard_root_alternative_index;
|
||||
wildcard_root_alternative_index = NONE;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
|
@ -2386,8 +2394,8 @@ static void ts_query_cursor__add_state(
|
|||
if (prev_state->start_depth == start_depth) {
|
||||
if (prev_state->pattern_index < pattern->pattern_index) break;
|
||||
if (prev_state->pattern_index == pattern->pattern_index) {
|
||||
// Avoid unnecessarily inserting an unnecessary duplicate state,
|
||||
// which would be immediately pruned by the longest-match criteria.
|
||||
// Avoid inserting an unnecessary duplicate state, which would be
|
||||
// immediately pruned by the longest-match criteria.
|
||||
if (prev_state->step_index == pattern->step_index) return;
|
||||
}
|
||||
}
|
||||
|
|
@ -2407,6 +2415,7 @@ static void ts_query_cursor__add_state(
|
|||
.consumed_capture_count = 0,
|
||||
.seeking_immediate_match = true,
|
||||
.has_in_progress_alternatives = false,
|
||||
.needs_parent = step->depth == 1,
|
||||
.dead = false,
|
||||
}));
|
||||
}
|
||||
|
|
@ -2460,6 +2469,33 @@ static CaptureList *ts_query_cursor__prepare_to_capture(
|
|||
return capture_list_pool_get_mut(&self->capture_list_pool, state->capture_list_id);
|
||||
}
|
||||
|
||||
static void ts_query_cursor__capture(
|
||||
TSQueryCursor *self,
|
||||
QueryState *state,
|
||||
QueryStep *step,
|
||||
TSNode node
|
||||
) {
|
||||
if (state->dead) return;
|
||||
CaptureList *capture_list = ts_query_cursor__prepare_to_capture(self, state, UINT32_MAX);
|
||||
if (!capture_list) {
|
||||
state->dead = true;
|
||||
return;
|
||||
}
|
||||
|
||||
for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) {
|
||||
uint16_t capture_id = step->capture_ids[j];
|
||||
if (step->capture_ids[j] == NONE) break;
|
||||
array_push(capture_list, ((TSQueryCapture) { node, capture_id }));
|
||||
LOG(
|
||||
" capture node. type:%s, pattern:%u, capture_id:%u, capture_count:%u\n",
|
||||
ts_node_type(node),
|
||||
state->pattern_index,
|
||||
capture_id,
|
||||
capture_list->size
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Duplicate the given state and insert the newly-created state immediately after
|
||||
// the given state in the `states` array. Ensures that the given state reference is
|
||||
// still valid, even if the states array is reallocated.
|
||||
|
|
@ -2730,26 +2766,45 @@ static inline bool ts_query_cursor__advance(
|
|||
}
|
||||
}
|
||||
|
||||
// If this pattern started with a wildcard, such that the pattern map
|
||||
// actually points to the *second* step of the pattern, then check
|
||||
// that the node has a parent, and capture the parent node if necessary.
|
||||
if (state->needs_parent) {
|
||||
TSNode parent = ts_tree_cursor_parent_node(&self->cursor);
|
||||
if (ts_node_is_null(parent)) {
|
||||
LOG(" missing parent node\n");
|
||||
state->dead = true;
|
||||
} else {
|
||||
state->needs_parent = false;
|
||||
QueryStep *skipped_wildcard_step = step;
|
||||
do {
|
||||
skipped_wildcard_step--;
|
||||
} while (
|
||||
skipped_wildcard_step->is_dead_end ||
|
||||
skipped_wildcard_step->is_pass_through ||
|
||||
skipped_wildcard_step->depth > 0
|
||||
);
|
||||
if (skipped_wildcard_step->capture_ids[0] != NONE) {
|
||||
LOG(" capture wildcard parent\n");
|
||||
ts_query_cursor__capture(
|
||||
self,
|
||||
state,
|
||||
skipped_wildcard_step,
|
||||
parent
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the current node is captured in this pattern, add it to the capture list.
|
||||
if (step->capture_ids[0] != NONE) {
|
||||
CaptureList *capture_list = ts_query_cursor__prepare_to_capture(self, state, UINT32_MAX);
|
||||
if (!capture_list) {
|
||||
array_erase(&self->states, i);
|
||||
i--;
|
||||
continue;
|
||||
}
|
||||
ts_query_cursor__capture(self, state, step, node);
|
||||
}
|
||||
|
||||
for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) {
|
||||
uint16_t capture_id = step->capture_ids[j];
|
||||
if (step->capture_ids[j] == NONE) break;
|
||||
array_push(capture_list, ((TSQueryCapture) { node, capture_id }));
|
||||
LOG(
|
||||
" capture node. pattern:%u, capture_id:%u, capture_count:%u\n",
|
||||
state->pattern_index,
|
||||
capture_id,
|
||||
capture_list->size
|
||||
);
|
||||
}
|
||||
if (state->dead) {
|
||||
array_erase(&self->states, i);
|
||||
i--;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Advance this state to the next step of its pattern.
|
||||
|
|
@ -2772,12 +2827,18 @@ static inline bool ts_query_cursor__advance(
|
|||
QueryState *state = &self->states.contents[j];
|
||||
QueryStep *next_step = &self->query->steps.contents[state->step_index];
|
||||
if (next_step->alternative_index != NONE) {
|
||||
// A "dead-end" step exists only to add a non-sequential jump into the step sequence,
|
||||
// via its alternative index. When a state reaches a dead-end step, it jumps straight
|
||||
// to the step's alternative.
|
||||
if (next_step->is_dead_end) {
|
||||
state->step_index = next_step->alternative_index;
|
||||
j--;
|
||||
continue;
|
||||
}
|
||||
|
||||
// A "pass-through" step exists only to add a branch into the step sequence,
|
||||
// via its alternative_index. When a state reaches a pass-through step, it splits
|
||||
// in order to process the alternative step, and then it advances to the next step.
|
||||
if (next_step->is_pass_through) {
|
||||
state->step_index++;
|
||||
j--;
|
||||
|
|
|
|||
|
|
@ -364,6 +364,33 @@ void ts_tree_cursor_current_status(
|
|||
}
|
||||
}
|
||||
|
||||
TSNode ts_tree_cursor_parent_node(const TSTreeCursor *_self) {
|
||||
const TreeCursor *self = (const TreeCursor *)_self;
|
||||
for (int i = (int)self->stack.size - 2; i >= 0; i--) {
|
||||
TreeCursorEntry *entry = &self->stack.contents[i];
|
||||
bool is_visible = true;
|
||||
TSSymbol alias_symbol = 0;
|
||||
if (i > 0) {
|
||||
TreeCursorEntry *parent_entry = &self->stack.contents[i - 1];
|
||||
alias_symbol = ts_language_alias_at(
|
||||
self->tree->language,
|
||||
parent_entry->subtree->ptr->production_id,
|
||||
entry->structural_child_index
|
||||
);
|
||||
is_visible = (alias_symbol != 0) || ts_subtree_visible(*entry->subtree);
|
||||
}
|
||||
if (is_visible) {
|
||||
return ts_node_new(
|
||||
self->tree,
|
||||
entry->subtree,
|
||||
entry->position,
|
||||
alias_symbol
|
||||
);
|
||||
}
|
||||
}
|
||||
return ts_node_new(NULL, NULL, length_zero(), 0);
|
||||
}
|
||||
|
||||
TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) {
|
||||
const TreeCursor *self = (const TreeCursor *)_self;
|
||||
|
||||
|
|
|
|||
|
|
@ -26,4 +26,6 @@ void ts_tree_cursor_current_status(
|
|||
unsigned *
|
||||
);
|
||||
|
||||
TSNode ts_tree_cursor_parent_node(const TSTreeCursor *);
|
||||
|
||||
#endif // TREE_SITTER_TREE_CURSOR_H_
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue