diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index c691df30..e99fe06e 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -4084,6 +4084,68 @@ fn test_query_is_pattern_rooted() { }); } +#[test] +fn test_query_is_pattern_non_local() { + struct Row { + description: &'static str, + pattern: &'static str, + is_non_local: bool, + } + + let rows = [ + Row { + description: "simple token", + pattern: r#"(identifier)"#, + is_non_local: false, + }, + Row { + description: "siblings that can occur in an argument list", + pattern: r#"((identifier) (identifier))"#, + is_non_local: true, + }, + Row { + description: "siblings that can occur in a statement block", + pattern: r#"((return_statement) (return_statement))"#, + is_non_local: true, + }, + Row { + description: "siblings that can occur in a source file", + pattern: r#"((function_definition) (class_definition))"#, + is_non_local: true, + }, + Row { + description: "siblings that can't occur in any repetition", + pattern: r#"("{" "}")"#, + is_non_local: false, + }, + ]; + + allocations::record(|| { + eprintln!(""); + + let language = get_language("python"); + for row in &rows { + if let Some(filter) = EXAMPLE_FILTER.as_ref() { + if !row.description.contains(filter.as_str()) { + continue; + } + } + eprintln!(" query example: {:?}", row.description); + let query = Query::new(language, row.pattern).unwrap(); + assert_eq!( + query.is_pattern_non_local(0), + row.is_non_local, + "Description: {}, Pattern: {:?}", + row.description, + row.pattern + .split_ascii_whitespace() + .collect::>() + .join(" "), + ) + } + }); +} + #[test] fn test_capture_quantifiers() { struct Row { diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 4591a380..be117f83 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -677,6 +677,9 @@ extern "C" { length: *mut u32, ) -> *const TSQueryPredicateStep; } +extern "C" { + pub fn ts_query_is_pattern_non_local(self_: *const TSQuery, pattern_index: u32) -> bool; +} extern "C" { pub fn ts_query_is_pattern_rooted(self_: *const TSQuery, pattern_index: u32) -> bool; } diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 6f044cca..579bf8e2 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1736,11 +1736,17 @@ impl Query { } /// Check if a given pattern within a query has a single root node. - #[doc(alias = "ts_query_is_pattern_guaranteed_at_step")] + #[doc(alias = "ts_query_is_pattern_rooted")] pub fn is_pattern_rooted(&self, index: usize) -> bool { unsafe { ffi::ts_query_is_pattern_rooted(self.ptr.as_ptr(), index as u32) } } + /// Check if a given pattern within a query has a single root node. + #[doc(alias = "ts_query_is_pattern_non_local")] + pub fn is_pattern_non_local(&self, index: usize) -> bool { + unsafe { ffi::ts_query_is_pattern_non_local(self.ptr.as_ptr(), index as u32) } + } + /// Check if a given step in a query is 'definite'. /// /// A query step is 'definite' if its parent pattern will be guaranteed to match diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 5b48cf60..edc1c36a 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -750,15 +750,26 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t *length ); -bool ts_query_is_pattern_rooted( - const TSQuery *self, - uint32_t pattern_index -); +/* + * Check if the given pattern in the query has a single root node. + */ +bool ts_query_is_pattern_rooted(const TSQuery *self, uint32_t pattern_index); -bool ts_query_is_pattern_guaranteed_at_step( - const TSQuery *self, - uint32_t byte_offset -); +/* + * Check if the given pattern in the query is 'non local'. + * + * A non-local pattern has multiple root nodes and can match within a + * repeating sequence of nodes, as specified by the grammar. Non-local + * patterns disable certain optimizations that would otherwise be possible + * when executing a query on a specific range of a syntax tree. + */ +bool ts_query_is_pattern_non_local(const TSQuery *self, uint32_t pattern_index); + +/* + * Check if a given pattern is guaranteed to match once a given step is reached. + * The step is specified by its byte offset in the query's source code. + */ +bool ts_query_is_pattern_guaranteed_at_step(const TSQuery *self, uint32_t byte_offset); /** * Get the name and length of one of the query's captures, or one of the diff --git a/lib/src/query.c b/lib/src/query.c index b2450ce2..cfe11438 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -146,6 +146,7 @@ typedef struct { Slice steps; Slice predicate_steps; uint32_t start_byte; + bool is_non_local; } QueryPattern; typedef struct { @@ -1455,7 +1456,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { if (!pattern->is_rooted) { QueryStep *step = &self->steps.contents[pattern->step_index]; if (step->symbol != WILDCARD_SYMBOL) { - array_push(&non_rooted_pattern_start_steps, pattern->step_index); + array_push(&non_rooted_pattern_start_steps, i); } } } @@ -1868,7 +1869,8 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { // prevent certain optimizations with range restrictions. analysis.did_abort = false; for (uint32_t i = 0; i < non_rooted_pattern_start_steps.size; i++) { - uint16_t step_index = non_rooted_pattern_start_steps.contents[i]; + uint16_t pattern_entry_index = non_rooted_pattern_start_steps.contents[i]; + PatternEntry *pattern_entry = &self->pattern_map.contents[pattern_entry_index]; analysis_state_set__clear(&analysis.states, &analysis.state_pool); analysis_state_set__clear(&analysis.deeper_states, &analysis.state_pool); @@ -1880,7 +1882,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { for (uint32_t k = 0; k < subgraph->start_states.size; k++) { TSStateId parse_state = subgraph->start_states.contents[k]; analysis_state_set__push(&analysis.states, &analysis.state_pool, &((AnalysisState) { - .step_index = step_index, + .step_index = pattern_entry->step_index, .stack = { [0] = { .parse_state = parse_state, @@ -1906,6 +1908,10 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { &analysis ); + if (analysis.finished_parent_symbols.size > 0) { + self->patterns.contents[pattern_entry->pattern_index].is_non_local = true; + } + for (unsigned k = 0; k < analysis.finished_parent_symbols.size; k++) { TSSymbol symbol = analysis.finished_parent_symbols.contents[k]; array_insert_sorted_by(&self->repeat_symbols_with_rootless_patterns, , symbol); @@ -2697,6 +2703,7 @@ TSQuery *ts_query_new( .steps = (Slice) {.offset = start_step_index}, .predicate_steps = (Slice) {.offset = start_predicate_step_index}, .start_byte = stream_offset(&stream), + .is_non_local = false, })); CaptureQuantifiers capture_quantifiers = capture_quantifiers_new(); *error_type = ts_query__parse_pattern(self, &stream, 0, false, &capture_quantifiers); @@ -2876,6 +2883,17 @@ bool ts_query_is_pattern_rooted( return true; } +bool ts_query_is_pattern_non_local( + const TSQuery *self, + uint32_t pattern_index +) { + if (pattern_index < self->patterns.size) { + return self->patterns.contents[pattern_index].is_non_local; + } else { + return false; + } +} + bool ts_query_is_pattern_guaranteed_at_step( const TSQuery *self, uint32_t byte_offset