diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index daa3b04b..06578ba8 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1918,6 +1918,60 @@ fn test_query_captures_within_byte_range() { }); } +#[test] +fn test_query_matches_with_unrooted_patterns_intersecting_byte_range() { + allocations::record(|| { + let language = get_language("rust"); + let query = Query::new( + language, + r#" + ("{" @left "}" @right) + ("<" @left ">" @right) + "#, + ) + .unwrap(); + + let source = "mod a { fn a(f: B) { g(f) } }"; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + + // within the type parameter list + let offset = source.find("D: E>").unwrap(); + let matches = cursor.set_byte_range(offset, offset).matches( + &query, + tree.root_node(), + source.as_bytes(), + ); + assert_eq!( + collect_matches(matches, &query, source), + &[ + (1, vec![("left", "<"), ("right", ">")]), + (0, vec![("left", "{"), ("right", "}")]), + ] + ); + + // from within the type parameter list to within the function body + let start_offset = source.find("D: E>").unwrap(); + let end_offset = source.find("g(f)").unwrap(); + let matches = cursor.set_byte_range(start_offset, end_offset).matches( + &query, + tree.root_node(), + source.as_bytes(), + ); + assert_eq!( + collect_matches(matches, &query, source), + &[ + (1, vec![("left", "<"), ("right", ">")]), + (0, vec![("left", "{"), ("right", "}")]), + (0, vec![("left", "{"), ("right", "}")]), + ] + ); + }); +} + #[test] fn test_query_captures_within_byte_range_assigned_after_iterating() { allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index 00f66ec0..9feb1177 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -104,16 +104,20 @@ typedef struct { } SymbolTable; /* - * PatternEntry - Information about the starting point for matching a - * particular pattern, consisting of the index of the pattern within the query, - * and the index of the patter's first step in the shared `steps` array. These - * entries are stored in a 'pattern map' - a sorted array that makes it - * possible to efficiently lookup patterns based on the symbol for their first - * step. + * PatternEntry - Information about the starting point for matching a particular + * pattern. These entries are stored in a 'pattern map' - a sorted array that + * makes it possible to efficiently lookup patterns based on the symbol for their + * first step. The entry consists of the following fields: + * - `pattern_index` - the index of the pattern within the query + * - `step_index` - the index of the pattern's first step in the shared `steps` array + * - `is_rooted` - whether or not the pattern has a single root node. This property + * affects decisions about whether or not to start the pattern for nodes outside + * of a QueryCursor's range restriction. */ typedef struct { uint16_t step_index; uint16_t pattern_index; + bool is_rooted; } PatternEntry; typedef struct { @@ -691,8 +695,7 @@ static inline bool ts_query__pattern_map_search( static inline void ts_query__pattern_map_insert( TSQuery *self, TSSymbol symbol, - uint32_t start_step_index, - uint32_t pattern_index + PatternEntry new_entry ) { uint32_t index; ts_query__pattern_map_search(self, symbol, &index); @@ -705,7 +708,7 @@ static inline void ts_query__pattern_map_insert( PatternEntry *entry = &self->pattern_map.contents[index]; if ( self->steps.contents[entry->step_index].symbol == symbol && - entry->pattern_index < pattern_index + entry->pattern_index < new_entry.pattern_index ) { index++; } else { @@ -713,10 +716,7 @@ static inline void ts_query__pattern_map_insert( } } - array_insert(&self->pattern_map, index, ((PatternEntry) { - .step_index = start_step_index, - .pattern_index = pattern_index, - })); + array_insert(&self->pattern_map, index, new_entry); } static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { @@ -2108,7 +2108,24 @@ TSQuery *ts_query_new( } } - ts_query__pattern_map_insert(self, step->symbol, start_step_index, pattern_index); + // Determine whether the pattern has a single root node. This affects + // decisions about whether or not to start matching the pattern when + // a query cursor has a range restriction. + bool is_rooted = true; + uint32_t start_depth = step->depth; + for (uint32_t step_index = start_step_index + 1; step_index < self->steps.size; step_index++) { + QueryStep *step = &self->steps.contents[step_index]; + if (step->depth == start_depth) { + is_rooted = false; + break; + } + } + + ts_query__pattern_map_insert(self, step->symbol, (PatternEntry) { + .step_index = start_step_index, + .pattern_index = pattern_index, + .is_rooted = is_rooted + }); if (step->symbol == WILDCARD_SYMBOL) { self->wildcard_root_pattern_count++; } @@ -2702,6 +2719,7 @@ static inline bool ts_query_cursor__advance( else { // Get the properties of the current node. TSNode node = ts_tree_cursor_current_node(&self->cursor); + TSNode parent_node = ts_tree_cursor_parent_node(&self->cursor); TSSymbol symbol = ts_node_symbol(node); bool is_named = ts_node_is_named(node); bool has_later_siblings; @@ -2735,44 +2753,51 @@ static inline bool ts_query_cursor__advance( point_lt(ts_node_start_point(node), self->end_point) ); - if (node_intersects_range) { - // Add new states for any patterns whose root node is a wildcard. - for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { - PatternEntry *pattern = &self->query->pattern_map.contents[i]; - QueryStep *step = &self->query->steps.contents[pattern->step_index]; + bool parent_intersects_range = ts_node_is_null(parent_node) || ( + ts_node_end_byte(parent_node) > self->start_byte && + ts_node_start_byte(parent_node) < self->end_byte && + point_gt(ts_node_end_point(parent_node), self->start_point) && + point_lt(ts_node_start_point(parent_node), self->end_point) + ); - // If this node matches the first step of the pattern, then add a new - // state at the start of this pattern. - if (step->field && field_id != step->field) continue; - if (step->supertype_symbol && !supertype_count) continue; + // Add new states for any patterns whose root node is a wildcard. + for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { + PatternEntry *pattern = &self->query->pattern_map.contents[i]; + + // If this node matches the first step of the pattern, then add a new + // state at the start of this pattern. + QueryStep *step = &self->query->steps.contents[pattern->step_index]; + if ( + (node_intersects_range || (!pattern->is_rooted && parent_intersects_range)) && + (!step->field || field_id == step->field) && + (!step->supertype_symbol || supertype_count > 0) + ) { ts_query_cursor__add_state(self, pattern); } + } - // Add new states for any patterns whose root node matches this node. - unsigned i; - if (ts_query__pattern_map_search(self->query, symbol, &i)) { - PatternEntry *pattern = &self->query->pattern_map.contents[i]; - QueryStep *step = &self->query->steps.contents[pattern->step_index]; - do { - // If this node matches the first step of the pattern, then add a new - // state at the start of this pattern. - if (!step->field || field_id == step->field) { - ts_query_cursor__add_state(self, pattern); - } + // Add new states for any patterns whose root node matches this node. + unsigned i; + if (ts_query__pattern_map_search(self->query, symbol, &i)) { + PatternEntry *pattern = &self->query->pattern_map.contents[i]; - // Advance to the next pattern whose root node matches this node. - i++; - if (i == self->query->pattern_map.size) break; - pattern = &self->query->pattern_map.contents[i]; - step = &self->query->steps.contents[pattern->step_index]; - } while (step->symbol == symbol); - } - } else { - LOG( - " not starting new patterns. node end byte: %u, start_byte: %u\n", - ts_node_end_byte(node), - self->start_byte - ); + QueryStep *step = &self->query->steps.contents[pattern->step_index]; + do { + // If this node matches the first step of the pattern, then add a new + // state at the start of this pattern. + if ( + (node_intersects_range || (!pattern->is_rooted && parent_intersects_range)) && + (!step->field || field_id == step->field) + ) { + ts_query_cursor__add_state(self, pattern); + } + + // Advance to the next pattern whose root node matches this node. + i++; + if (i == self->query->pattern_map.size) break; + pattern = &self->query->pattern_map.contents[i]; + step = &self->query->steps.contents[pattern->step_index]; + } while (step->symbol == symbol); } // Update all of the in-progress states with current node.