From 857a9ed07b983e2e6cff6dc851fcf9b37aec8e5a Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 8 Oct 2020 12:34:08 -0700 Subject: [PATCH] query: Handle captured wildcard nodes at the root of patterns --- cli/src/tests/query_test.rs | 87 ++++++++++++++++++++++++++ lib/binding_rust/lib.rs | 10 +-- lib/src/query.c | 121 +++++++++++++++++++++++++++--------- lib/src/tree_cursor.c | 27 ++++++++ lib/src/tree_cursor.h | 2 + 5 files changed, 212 insertions(+), 35 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 323a13fc..1f7ddaff 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1691,6 +1691,93 @@ fn test_query_matches_with_multiple_captures_on_a_node() { }); } +#[test] +fn test_query_matches_with_captured_wildcard_at_root() { + allocations::record(|| { + let language = get_language("python"); + let query = Query::new( + language, + " + ; captured wildcard at the root + (_ [ + (except_clause (block) @block) + (finally_clause (block) @block) + ]) @stmt + + [ + (while_statement (block) @block) + (if_statement (block) @block) + + ; captured wildcard at the root within an alternation + (_ [ + (else_clause (block) @block) + (elif_clause (block) @block) + ]) + + (try_statement (block) @block) + (for_statement (block) @block) + ] @stmt + ", + ) + .unwrap(); + + let source = " + for i in j: + while True: + if a: + print b + elif c: + print d + else: + try: + print f + except: + print g + finally: + print h + else: + print i + " + .trim(); + + let mut parser = Parser::new(); + let mut cursor = QueryCursor::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + + let match_capture_names_and_rows = cursor + .matches(&query, tree.root_node(), to_callback(source)) + .map(|m| { + m.captures + .iter() + .map(|c| { + ( + query.capture_names()[c.index as usize].as_str(), + c.node.kind(), + c.node.start_position().row, + ) + }) + .collect::>() + }) + .collect::>(); + + assert_eq!( + match_capture_names_and_rows, + &[ + vec![("stmt", "for_statement", 0), ("block", "block", 1)], + vec![("stmt", "while_statement", 1), ("block", "block", 2)], + vec![("stmt", "if_statement", 2), ("block", "block", 3)], + vec![("stmt", "if_statement", 2), ("block", "block", 5)], + vec![("stmt", "if_statement", 2), ("block", "block", 7)], + vec![("stmt", "try_statement", 7), ("block", "block", 8)], + vec![("stmt", "try_statement", 7), ("block", "block", 10)], + vec![("stmt", "try_statement", 7), ("block", "block", 12)], + vec![("stmt", "while_statement", 1), ("block", "block", 14)], + ] + ) + }); +} + #[test] fn test_query_matches_with_no_captures() { allocations::record(|| { diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 372d937f..b33beded 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1273,7 +1273,7 @@ impl Query { let raw_predicates = ffi::ts_query_predicates_for_pattern(ptr, i as u32, &mut length as *mut u32); if length > 0 { - slice::from_raw_parts(raw_predicates, length as usize) + slice::from_raw_parts(raw_predicates, length as usize) } else { &[] } @@ -1655,10 +1655,10 @@ impl<'a> QueryMatch<'a> { pattern_index: m.pattern_index as usize, captures: if m.capture_count > 0 { unsafe { - slice::from_raw_parts( - m.captures as *const QueryCapture<'a>, - m.capture_count as usize, - ) + slice::from_raw_parts( + m.captures as *const QueryCapture<'a>, + m.capture_count as usize, + ) } } else { &[] diff --git a/lib/src/query.c b/lib/src/query.c index ce0e4cdf..133762b9 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -138,6 +138,7 @@ typedef struct { bool seeking_immediate_match: 1; bool has_in_progress_alternatives: 1; bool dead: 1; + bool needs_parent: 1; } QueryState; typedef Array(TSQueryCapture) CaptureList; @@ -2011,20 +2012,24 @@ TSQuery *ts_query_new( return NULL; } - // If a pattern has a wildcard at its root, optimize the matching process - // by skipping matching the wildcard. - if ( - self->steps.contents[start_step_index].symbol == WILDCARD_SYMBOL - ) { - QueryStep *second_step = &self->steps.contents[start_step_index + 1]; - if (second_step->symbol != WILDCARD_SYMBOL && second_step->depth != PATTERN_DONE_MARKER) { - start_step_index += 1; - } - } - // Maintain a map that can look up patterns for a given root symbol. + uint16_t wildcard_root_alternative_index = NONE; for (;;) { QueryStep *step = &self->steps.contents[start_step_index]; + + // If a pattern has a wildcard at its root, but it has a non-wildcard child, + // then optimize the matching process by skipping matching the wildcard. + // Later, during the matching process, the query cursor will check that + // there is a parent node, and capture it if necessary. + if (step->symbol == WILDCARD_SYMBOL && step->depth == 0) { + QueryStep *second_step = &self->steps.contents[start_step_index + 1]; + if (second_step->symbol != WILDCARD_SYMBOL && second_step->depth == 1) { + wildcard_root_alternative_index = step->alternative_index; + start_step_index += 1; + step = second_step; + } + } + ts_query__pattern_map_insert(self, step->symbol, start_step_index, pattern_index); if (step->symbol == WILDCARD_SYMBOL) { self->wildcard_root_pattern_count++; @@ -2035,6 +2040,9 @@ TSQuery *ts_query_new( if (step->alternative_index != NONE) { start_step_index = step->alternative_index; step->alternative_index = NONE; + } else if (wildcard_root_alternative_index != NONE) { + start_step_index = wildcard_root_alternative_index; + wildcard_root_alternative_index = NONE; } else { break; } @@ -2386,8 +2394,8 @@ static void ts_query_cursor__add_state( if (prev_state->start_depth == start_depth) { if (prev_state->pattern_index < pattern->pattern_index) break; if (prev_state->pattern_index == pattern->pattern_index) { - // Avoid unnecessarily inserting an unnecessary duplicate state, - // which would be immediately pruned by the longest-match criteria. + // Avoid inserting an unnecessary duplicate state, which would be + // immediately pruned by the longest-match criteria. if (prev_state->step_index == pattern->step_index) return; } } @@ -2407,6 +2415,7 @@ static void ts_query_cursor__add_state( .consumed_capture_count = 0, .seeking_immediate_match = true, .has_in_progress_alternatives = false, + .needs_parent = step->depth == 1, .dead = false, })); } @@ -2460,6 +2469,33 @@ static CaptureList *ts_query_cursor__prepare_to_capture( return capture_list_pool_get_mut(&self->capture_list_pool, state->capture_list_id); } +static void ts_query_cursor__capture( + TSQueryCursor *self, + QueryState *state, + QueryStep *step, + TSNode node +) { + if (state->dead) return; + CaptureList *capture_list = ts_query_cursor__prepare_to_capture(self, state, UINT32_MAX); + if (!capture_list) { + state->dead = true; + return; + } + + for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) { + uint16_t capture_id = step->capture_ids[j]; + if (step->capture_ids[j] == NONE) break; + array_push(capture_list, ((TSQueryCapture) { node, capture_id })); + LOG( + " capture node. type:%s, pattern:%u, capture_id:%u, capture_count:%u\n", + ts_node_type(node), + state->pattern_index, + capture_id, + capture_list->size + ); + } +} + // Duplicate the given state and insert the newly-created state immediately after // the given state in the `states` array. Ensures that the given state reference is // still valid, even if the states array is reallocated. @@ -2730,26 +2766,45 @@ static inline bool ts_query_cursor__advance( } } + // If this pattern started with a wildcard, such that the pattern map + // actually points to the *second* step of the pattern, then check + // that the node has a parent, and capture the parent node if necessary. + if (state->needs_parent) { + TSNode parent = ts_tree_cursor_parent_node(&self->cursor); + if (ts_node_is_null(parent)) { + LOG(" missing parent node\n"); + state->dead = true; + } else { + state->needs_parent = false; + QueryStep *skipped_wildcard_step = step; + do { + skipped_wildcard_step--; + } while ( + skipped_wildcard_step->is_dead_end || + skipped_wildcard_step->is_pass_through || + skipped_wildcard_step->depth > 0 + ); + if (skipped_wildcard_step->capture_ids[0] != NONE) { + LOG(" capture wildcard parent\n"); + ts_query_cursor__capture( + self, + state, + skipped_wildcard_step, + parent + ); + } + } + } + // If the current node is captured in this pattern, add it to the capture list. if (step->capture_ids[0] != NONE) { - CaptureList *capture_list = ts_query_cursor__prepare_to_capture(self, state, UINT32_MAX); - if (!capture_list) { - array_erase(&self->states, i); - i--; - continue; - } + ts_query_cursor__capture(self, state, step, node); + } - for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) { - uint16_t capture_id = step->capture_ids[j]; - if (step->capture_ids[j] == NONE) break; - array_push(capture_list, ((TSQueryCapture) { node, capture_id })); - LOG( - " capture node. pattern:%u, capture_id:%u, capture_count:%u\n", - state->pattern_index, - capture_id, - capture_list->size - ); - } + if (state->dead) { + array_erase(&self->states, i); + i--; + continue; } // Advance this state to the next step of its pattern. @@ -2772,12 +2827,18 @@ static inline bool ts_query_cursor__advance( QueryState *state = &self->states.contents[j]; QueryStep *next_step = &self->query->steps.contents[state->step_index]; if (next_step->alternative_index != NONE) { + // A "dead-end" step exists only to add a non-sequential jump into the step sequence, + // via its alternative index. When a state reaches a dead-end step, it jumps straight + // to the step's alternative. if (next_step->is_dead_end) { state->step_index = next_step->alternative_index; j--; continue; } + // A "pass-through" step exists only to add a branch into the step sequence, + // via its alternative_index. When a state reaches a pass-through step, it splits + // in order to process the alternative step, and then it advances to the next step. if (next_step->is_pass_through) { state->step_index++; j--; diff --git a/lib/src/tree_cursor.c b/lib/src/tree_cursor.c index 64e8b414..f109524e 100644 --- a/lib/src/tree_cursor.c +++ b/lib/src/tree_cursor.c @@ -364,6 +364,33 @@ void ts_tree_cursor_current_status( } } +TSNode ts_tree_cursor_parent_node(const TSTreeCursor *_self) { + const TreeCursor *self = (const TreeCursor *)_self; + for (int i = (int)self->stack.size - 2; i >= 0; i--) { + TreeCursorEntry *entry = &self->stack.contents[i]; + bool is_visible = true; + TSSymbol alias_symbol = 0; + if (i > 0) { + TreeCursorEntry *parent_entry = &self->stack.contents[i - 1]; + alias_symbol = ts_language_alias_at( + self->tree->language, + parent_entry->subtree->ptr->production_id, + entry->structural_child_index + ); + is_visible = (alias_symbol != 0) || ts_subtree_visible(*entry->subtree); + } + if (is_visible) { + return ts_node_new( + self->tree, + entry->subtree, + entry->position, + alias_symbol + ); + } + } + return ts_node_new(NULL, NULL, length_zero(), 0); +} + TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) { const TreeCursor *self = (const TreeCursor *)_self; diff --git a/lib/src/tree_cursor.h b/lib/src/tree_cursor.h index 7c9c05d5..69647d1d 100644 --- a/lib/src/tree_cursor.h +++ b/lib/src/tree_cursor.h @@ -26,4 +26,6 @@ void ts_tree_cursor_current_status( unsigned * ); +TSNode ts_tree_cursor_parent_node(const TSTreeCursor *); + #endif // TREE_SITTER_TREE_CURSOR_H_