Fix bug where patterns with top-level alternatives were not considered 'rooted'

This commit is contained in:
Max Brunsfeld 2022-07-07 17:25:49 -07:00
parent 1401767689
commit 548c12fb88
5 changed files with 169 additions and 22 deletions

View file

@ -1695,31 +1695,54 @@ fn test_query_sibling_patterns_dont_match_children_of_an_error() {
language,
r#"
("{" @open "}" @close)
[
(line_comment)
(block_comment)
] @comment
("<" @first "<" @second)
"#,
)
.unwrap();
// Most of the document will fail to parse, resulting in a
// large number of tokens that are *direct* children of an
// ERROR node.
//
// These children should still match, unless they are part
// of a "non-rooted" pattern, in which there are multiple
// top-level sibling nodes. Those patterns should not match
// directly inside of an error node, because the contents of
// an error node are not syntactically well-structured, so we
// would get many spurious matches.
let source = "
fn a() {}
<<<<<<<<<< add pub b fn () {}
// comment 1
pub fn b() {
/* comment 2 */
==========
pub fn c() {
// comment 3
>>>>>>>>>> add pub c fn () {}
}
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), source.as_bytes());
assert_eq!(
collect_matches(matches, &query, source),
&[(0, vec![("open", "{"), ("close", "}")])],
&[
(0, vec![("open", "{"), ("close", "}")]),
(1, vec![("comment", "// comment 1")]),
(1, vec![("comment", "/* comment 2 */")]),
(1, vec![("comment", "// comment 3")]),
],
);
});
}
@ -3956,6 +3979,97 @@ fn test_query_is_pattern_guaranteed_at_step() {
});
}
#[test]
fn test_query_is_pattern_rooted() {
struct Row {
description: &'static str,
pattern: &'static str,
is_rooted: bool,
}
let rows = [
Row {
description: "simple token",
pattern: r#"(identifier)"#,
is_rooted: true,
},
Row {
description: "simple non-terminal",
pattern: r#"(function_definition name: (identifier))"#,
is_rooted: true,
},
Row {
description: "alternative of many tokens",
pattern: r#"["if" "def" (identifier) (comment)]"#,
is_rooted: true,
},
Row {
description: "alternative of many non-terminals",
pattern: r#"[
(function_definition name: (identifier))
(class_definition name: (identifier))
(block)
]"#,
is_rooted: true,
},
Row {
description: "two siblings",
pattern: r#"("{" "}")"#,
is_rooted: false,
},
Row {
description: "top-level repetition",
pattern: r#"(comment)*"#,
is_rooted: false,
},
Row {
description: "alternative where one option has two siblings",
pattern: r#"[
(block)
(class_definition)
("(" ")")
(function_definition)
]"#,
is_rooted: false,
},
Row {
description: "alternative where one option has a top-level repetition",
pattern: r#"[
(block)
(class_definition)
(comment)*
(function_definition)
]"#,
is_rooted: false,
},
];
allocations::record(|| {
eprintln!("");
let language = get_language("python");
for row in &rows {
if let Some(filter) = EXAMPLE_FILTER.as_ref() {
if !row.description.contains(filter.as_str()) {
continue;
}
}
eprintln!(" query example: {:?}", row.description);
let query = Query::new(language, row.pattern).unwrap();
assert_eq!(
query.is_pattern_rooted(0),
row.is_rooted,
"Description: {}, Pattern: {:?}",
row.description,
row.pattern
.split_ascii_whitespace()
.collect::<Vec<_>>()
.join(" "),
)
}
});
}
#[test]
fn test_capture_quantifiers() {
struct Row {

View file

@ -658,6 +658,9 @@ extern "C" {
length: *mut u32,
) -> *const TSQueryPredicateStep;
}
extern "C" {
pub fn ts_query_is_pattern_rooted(self_: *const TSQuery, pattern_index: u32) -> bool;
}
extern "C" {
pub fn ts_query_is_pattern_guaranteed_at_step(self_: *const TSQuery, byte_offset: u32) -> bool;
}

View file

@ -1699,6 +1699,12 @@ impl Query {
unsafe { ffi::ts_query_disable_pattern(self.ptr.as_ptr(), index as u32) }
}
/// Check if a given pattern within a query has a single root node.
#[doc(alias = "ts_query_is_pattern_guaranteed_at_step")]
pub fn is_pattern_rooted(&self, index: usize) -> bool {
unsafe { ffi::ts_query_is_pattern_rooted(self.ptr.as_ptr(), index as u32) }
}
/// Check if a given step in a query is 'definite'.
///
/// A query step is 'definite' if its parent pattern will be guaranteed to match

View file

@ -733,6 +733,11 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern(
uint32_t *length
);
bool ts_query_is_pattern_rooted(
const TSQuery *self,
uint32_t pattern_index
);
bool ts_query_is_pattern_guaranteed_at_step(
const TSQuery *self,
uint32_t byte_offset

View file

@ -2101,7 +2101,7 @@ static TSQueryError ts_query__parse_pattern(
return e;
}
if(start_index == starting_step_index) {
if (start_index == starting_step_index) {
capture_quantifiers_replace(capture_quantifiers, &branch_capture_quantifiers);
} else {
capture_quantifiers_join_all(capture_quantifiers, &branch_capture_quantifiers);
@ -2167,10 +2167,10 @@ static TSQueryError ts_query__parse_pattern(
}
capture_quantifiers_add_all(capture_quantifiers, &child_capture_quantifiers);
child_is_immediate = false;
capture_quantifiers_clear(&child_capture_quantifiers);
child_is_immediate = false;
}
capture_quantifiers_delete(&child_capture_quantifiers);
}
@ -2630,11 +2630,13 @@ TSQuery *ts_query_new(
// Determine whether the pattern has a single root node. This affects
// decisions about whether or not to start matching the pattern when
// a query cursor has a range restriction.
// a query cursor has a range restriction or when immediately within an
// error node.
uint32_t start_depth = step->depth;
bool is_rooted = start_depth == 0;
for (uint32_t step_index = start_step_index + 1; step_index < self->steps.size; step_index++) {
QueryStep *step = &self->steps.contents[step_index];
if (step->is_dead_end) break;
if (step->depth == start_depth) {
is_rooted = false;
break;
@ -2751,6 +2753,19 @@ uint32_t ts_query_start_byte_for_pattern(
return self->patterns.contents[pattern_index].start_byte;
}
bool ts_query_is_pattern_rooted(
const TSQuery *self,
uint32_t pattern_index
) {
for (unsigned i = 0; i < self->pattern_map.size; i++) {
PatternEntry *entry = &self->pattern_map.contents[i];
if (entry->pattern_index == pattern_index) {
if (!entry->is_rooted) return false;
}
}
return true;
}
bool ts_query_is_pattern_guaranteed_at_step(
const TSQuery *self,
uint32_t byte_offset
@ -3324,26 +3339,28 @@ static inline bool ts_query_cursor__advance(
point_gt(ts_node_end_point(parent_node), self->start_point) &&
point_lt(ts_node_start_point(parent_node), self->end_point)
);
bool node_is_error = symbol != ts_builtin_sym_error;
bool node_is_error = symbol == ts_builtin_sym_error;
bool parent_is_error =
!ts_node_is_null(parent_node) &&
ts_node_symbol(parent_node) == ts_builtin_sym_error;
// Add new states for any patterns whose root node is a wildcard.
for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) {
PatternEntry *pattern = &self->query->pattern_map.contents[i];
if (!node_is_error) {
for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) {
PatternEntry *pattern = &self->query->pattern_map.contents[i];
// If this node matches the first step of the pattern, then add a new
// state at the start of this pattern.
QueryStep *step = &self->query->steps.contents[pattern->step_index];
if (
(pattern->is_rooted ?
(node_intersects_range && !node_is_error) :
(parent_intersects_range && !parent_is_error)) &&
(!step->field || field_id == step->field) &&
(!step->supertype_symbol || supertype_count > 0)
) {
ts_query_cursor__add_state(self, pattern);
// If this node matches the first step of the pattern, then add a new
// state at the start of this pattern.
QueryStep *step = &self->query->steps.contents[pattern->step_index];
if (
(pattern->is_rooted ?
node_intersects_range :
(parent_intersects_range && !parent_is_error)) &&
(!step->field || field_id == step->field) &&
(!step->supertype_symbol || supertype_count > 0)
) {
ts_query_cursor__add_state(self, pattern);
}
}
}
@ -3357,7 +3374,9 @@ static inline bool ts_query_cursor__advance(
// If this node matches the first step of the pattern, then add a new
// state at the start of this pattern.
if (
(pattern->is_rooted ? node_intersects_range : (parent_intersects_range && !parent_is_error)) &&
(pattern->is_rooted ?
node_intersects_range :
(parent_intersects_range && !parent_is_error)) &&
(!step->field || field_id == step->field)
) {
ts_query_cursor__add_state(self, pattern);