From 548c12fb8843d095059fc0c8d0d18c3d80dc2361 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 7 Jul 2022 17:25:49 -0700 Subject: [PATCH] Fix bug where patterns with top-level alternatives were not considered 'rooted' --- cli/src/tests/query_test.rs | 120 +++++++++++++++++++++++++++++++++- lib/binding_rust/bindings.rs | 3 + lib/binding_rust/lib.rs | 6 ++ lib/include/tree_sitter/api.h | 5 ++ lib/src/query.c | 57 ++++++++++------ 5 files changed, 169 insertions(+), 22 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index da888c4e..6821da2d 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1695,31 +1695,54 @@ fn test_query_sibling_patterns_dont_match_children_of_an_error() { language, r#" ("{" @open "}" @close) + + [ + (line_comment) + (block_comment) + ] @comment + + ("<" @first "<" @second) "#, ) .unwrap(); + // Most of the document will fail to parse, resulting in a + // large number of tokens that are *direct* children of an + // ERROR node. + // + // These children should still match, unless they are part + // of a "non-rooted" pattern, in which there are multiple + // top-level sibling nodes. Those patterns should not match + // directly inside of an error node, because the contents of + // an error node are not syntactically well-structured, so we + // would get many spurious matches. let source = " fn a() {} <<<<<<<<<< add pub b fn () {} + // comment 1 pub fn b() { + /* comment 2 */ ========== pub fn c() { + // comment 3 >>>>>>>>>> add pub c fn () {} } "; let mut parser = Parser::new(); parser.set_language(language).unwrap(); - let tree = parser.parse(&source, None).unwrap(); - let mut cursor = QueryCursor::new(); let matches = cursor.matches(&query, tree.root_node(), source.as_bytes()); assert_eq!( collect_matches(matches, &query, source), - &[(0, vec![("open", "{"), ("close", "}")])], + &[ + (0, vec![("open", "{"), ("close", "}")]), + (1, vec![("comment", "// comment 1")]), + (1, vec![("comment", "/* comment 2 */")]), + (1, vec![("comment", "// comment 3")]), + ], ); }); } @@ -3956,6 +3979,97 @@ fn test_query_is_pattern_guaranteed_at_step() { }); } +#[test] +fn test_query_is_pattern_rooted() { + struct Row { + description: &'static str, + pattern: &'static str, + is_rooted: bool, + } + + let rows = [ + Row { + description: "simple token", + pattern: r#"(identifier)"#, + is_rooted: true, + }, + Row { + description: "simple non-terminal", + pattern: r#"(function_definition name: (identifier))"#, + is_rooted: true, + }, + Row { + description: "alternative of many tokens", + pattern: r#"["if" "def" (identifier) (comment)]"#, + is_rooted: true, + }, + Row { + description: "alternative of many non-terminals", + pattern: r#"[ + (function_definition name: (identifier)) + (class_definition name: (identifier)) + (block) + ]"#, + is_rooted: true, + }, + Row { + description: "two siblings", + pattern: r#"("{" "}")"#, + is_rooted: false, + }, + Row { + description: "top-level repetition", + pattern: r#"(comment)*"#, + is_rooted: false, + }, + Row { + description: "alternative where one option has two siblings", + pattern: r#"[ + (block) + (class_definition) + ("(" ")") + (function_definition) + ]"#, + is_rooted: false, + }, + Row { + description: "alternative where one option has a top-level repetition", + pattern: r#"[ + (block) + (class_definition) + (comment)* + (function_definition) + ]"#, + is_rooted: false, + }, + ]; + + allocations::record(|| { + eprintln!(""); + + let language = get_language("python"); + for row in &rows { + if let Some(filter) = EXAMPLE_FILTER.as_ref() { + if !row.description.contains(filter.as_str()) { + continue; + } + } + eprintln!(" query example: {:?}", row.description); + let query = Query::new(language, row.pattern).unwrap(); + assert_eq!( + query.is_pattern_rooted(0), + row.is_rooted, + "Description: {}, Pattern: {:?}", + row.description, + row.pattern + .split_ascii_whitespace() + .collect::>() + .join(" "), + ) + } + }); +} + #[test] fn test_capture_quantifiers() { struct Row { diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index ba5dbaf3..1447d09d 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -658,6 +658,9 @@ extern "C" { length: *mut u32, ) -> *const TSQueryPredicateStep; } +extern "C" { + pub fn ts_query_is_pattern_rooted(self_: *const TSQuery, pattern_index: u32) -> bool; +} extern "C" { pub fn ts_query_is_pattern_guaranteed_at_step(self_: *const TSQuery, byte_offset: u32) -> bool; } diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index a2d4c1b4..f757b107 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1699,6 +1699,12 @@ impl Query { unsafe { ffi::ts_query_disable_pattern(self.ptr.as_ptr(), index as u32) } } + /// Check if a given pattern within a query has a single root node. + #[doc(alias = "ts_query_is_pattern_guaranteed_at_step")] + pub fn is_pattern_rooted(&self, index: usize) -> bool { + unsafe { ffi::ts_query_is_pattern_rooted(self.ptr.as_ptr(), index as u32) } + } + /// Check if a given step in a query is 'definite'. /// /// A query step is 'definite' if its parent pattern will be guaranteed to match diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 1ace7beb..e2941532 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -733,6 +733,11 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t *length ); +bool ts_query_is_pattern_rooted( + const TSQuery *self, + uint32_t pattern_index +); + bool ts_query_is_pattern_guaranteed_at_step( const TSQuery *self, uint32_t byte_offset diff --git a/lib/src/query.c b/lib/src/query.c index 2504b5ad..80a9e248 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -2101,7 +2101,7 @@ static TSQueryError ts_query__parse_pattern( return e; } - if(start_index == starting_step_index) { + if (start_index == starting_step_index) { capture_quantifiers_replace(capture_quantifiers, &branch_capture_quantifiers); } else { capture_quantifiers_join_all(capture_quantifiers, &branch_capture_quantifiers); @@ -2167,10 +2167,10 @@ static TSQueryError ts_query__parse_pattern( } capture_quantifiers_add_all(capture_quantifiers, &child_capture_quantifiers); - - child_is_immediate = false; capture_quantifiers_clear(&child_capture_quantifiers); + child_is_immediate = false; } + capture_quantifiers_delete(&child_capture_quantifiers); } @@ -2630,11 +2630,13 @@ TSQuery *ts_query_new( // Determine whether the pattern has a single root node. This affects // decisions about whether or not to start matching the pattern when - // a query cursor has a range restriction. + // a query cursor has a range restriction or when immediately within an + // error node. uint32_t start_depth = step->depth; bool is_rooted = start_depth == 0; for (uint32_t step_index = start_step_index + 1; step_index < self->steps.size; step_index++) { QueryStep *step = &self->steps.contents[step_index]; + if (step->is_dead_end) break; if (step->depth == start_depth) { is_rooted = false; break; @@ -2751,6 +2753,19 @@ uint32_t ts_query_start_byte_for_pattern( return self->patterns.contents[pattern_index].start_byte; } +bool ts_query_is_pattern_rooted( + const TSQuery *self, + uint32_t pattern_index +) { + for (unsigned i = 0; i < self->pattern_map.size; i++) { + PatternEntry *entry = &self->pattern_map.contents[i]; + if (entry->pattern_index == pattern_index) { + if (!entry->is_rooted) return false; + } + } + return true; +} + bool ts_query_is_pattern_guaranteed_at_step( const TSQuery *self, uint32_t byte_offset @@ -3324,26 +3339,28 @@ static inline bool ts_query_cursor__advance( point_gt(ts_node_end_point(parent_node), self->start_point) && point_lt(ts_node_start_point(parent_node), self->end_point) ); - bool node_is_error = symbol != ts_builtin_sym_error; + bool node_is_error = symbol == ts_builtin_sym_error; bool parent_is_error = !ts_node_is_null(parent_node) && ts_node_symbol(parent_node) == ts_builtin_sym_error; // Add new states for any patterns whose root node is a wildcard. - for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { - PatternEntry *pattern = &self->query->pattern_map.contents[i]; + if (!node_is_error) { + for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { + PatternEntry *pattern = &self->query->pattern_map.contents[i]; - // If this node matches the first step of the pattern, then add a new - // state at the start of this pattern. - QueryStep *step = &self->query->steps.contents[pattern->step_index]; - if ( - (pattern->is_rooted ? - (node_intersects_range && !node_is_error) : - (parent_intersects_range && !parent_is_error)) && - (!step->field || field_id == step->field) && - (!step->supertype_symbol || supertype_count > 0) - ) { - ts_query_cursor__add_state(self, pattern); + // If this node matches the first step of the pattern, then add a new + // state at the start of this pattern. + QueryStep *step = &self->query->steps.contents[pattern->step_index]; + if ( + (pattern->is_rooted ? + node_intersects_range : + (parent_intersects_range && !parent_is_error)) && + (!step->field || field_id == step->field) && + (!step->supertype_symbol || supertype_count > 0) + ) { + ts_query_cursor__add_state(self, pattern); + } } } @@ -3357,7 +3374,9 @@ static inline bool ts_query_cursor__advance( // If this node matches the first step of the pattern, then add a new // state at the start of this pattern. if ( - (pattern->is_rooted ? node_intersects_range : (parent_intersects_range && !parent_is_error)) && + (pattern->is_rooted ? + node_intersects_range : + (parent_intersects_range && !parent_is_error)) && (!step->field || field_id == step->field) ) { ts_query_cursor__add_state(self, pattern);