diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index a9961b87..df892e96 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2002,6 +2002,68 @@ fn test_query_matches_with_unrooted_patterns_intersecting_byte_range() { }); } +#[test] +fn test_query_matches_with_wildcard_at_root_intersecting_byte_range() { + allocations::record(|| { + let language = get_language("python"); + let query = Query::new( + language, + " + [ + (_ body: (block)) + (_ consequence: (block)) + ] @indent + ", + ) + .unwrap(); + + let source = " + class A: + def b(): + if c: + d + else: + e + " + .trim(); + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + let mut cursor = QueryCursor::new(); + + // After the first line of the class definition + let offset = source.find("A:").unwrap() + 2; + let matches = cursor + .set_byte_range(offset..offset) + .matches(&query, tree.root_node(), source.as_bytes()) + .map(|mat| mat.captures[0].node.kind()) + .collect::>(); + assert_eq!(matches, &["class_definition"]); + + // After the first line of the function definition + let offset = source.find("b():").unwrap() + 4; + let matches = cursor + .set_byte_range(offset..offset) + .matches(&query, tree.root_node(), source.as_bytes()) + .map(|mat| mat.captures[0].node.kind()) + .collect::>(); + assert_eq!(matches, &["class_definition", "function_definition"]); + + // After the first line of the if statement + let offset = source.find("c:").unwrap() + 2; + let matches = cursor + .set_byte_range(offset..offset) + .matches(&query, tree.root_node(), source.as_bytes()) + .map(|mat| mat.captures[0].node.kind()) + .collect::>(); + assert_eq!( + matches, + &["class_definition", "function_definition", "if_statement"] + ); + }); +} + #[test] fn test_query_captures_within_byte_range_assigned_after_iterating() { allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index 7878cf55..33d67648 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -2631,8 +2631,8 @@ TSQuery *ts_query_new( // Determine whether the pattern has a single root node. This affects // decisions about whether or not to start matching the pattern when // a query cursor has a range restriction. - bool is_rooted = true; uint32_t start_depth = step->depth; + bool is_rooted = start_depth == 0; for (uint32_t step_index = start_step_index + 1; step_index < self->steps.size; step_index++) { QueryStep *step = &self->steps.contents[step_index]; if (step->depth == start_depth) { @@ -3318,7 +3318,6 @@ static inline bool ts_query_cursor__advance( point_gt(ts_node_end_point(node), self->start_point) && point_lt(ts_node_start_point(node), self->end_point) ); - bool parent_intersects_range = ts_node_is_null(parent_node) || ( ts_node_end_byte(parent_node) > self->start_byte && ts_node_start_byte(parent_node) < self->end_byte && @@ -3326,7 +3325,7 @@ static inline bool ts_query_cursor__advance( point_lt(ts_node_start_point(parent_node), self->end_point) ); - // Add new states for any patterns whose root node is a wildcard. + // Add new states for any patterns whose root node is a wildcard. for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { PatternEntry *pattern = &self->query->pattern_map.contents[i]; @@ -3334,7 +3333,7 @@ static inline bool ts_query_cursor__advance( // state at the start of this pattern. QueryStep *step = &self->query->steps.contents[pattern->step_index]; if ( - (node_intersects_range || (!pattern->is_rooted && parent_intersects_range)) && + (pattern->is_rooted ? node_intersects_range : parent_intersects_range) && (!step->field || field_id == step->field) && (!step->supertype_symbol || supertype_count > 0) ) { @@ -3352,7 +3351,7 @@ static inline bool ts_query_cursor__advance( // If this node matches the first step of the pattern, then add a new // state at the start of this pattern. if ( - (node_intersects_range || (!pattern->is_rooted && parent_intersects_range)) && + (pattern->is_rooted ? node_intersects_range : parent_intersects_range) && (!step->field || field_id == step->field) ) { ts_query_cursor__add_state(self, pattern);