From cc37da7457da79795e47a41878342758b443004b Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 26 Jun 2020 16:31:08 -0700 Subject: [PATCH] Query analysis: fix propagation of uncertainty from later siblings --- cli/src/tests/query_test.rs | 89 ++++++++++++++++++++++------------- lib/binding_rust/bindings.rs | 1 + lib/binding_rust/lib.rs | 4 +- lib/include/tree_sitter/api.h | 1 + lib/src/query.c | 71 ++++++++++++++++------------ 5 files changed, 99 insertions(+), 67 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index c73931ce..5f6979a2 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2079,90 +2079,111 @@ fn test_query_is_definite() { struct Row { language: Language, pattern: &'static str, - results_by_step_index: &'static [(usize, bool)], + results_by_symbol: &'static [(&'static str, bool)], } let rows = &[ Row { language: get_language("python"), pattern: r#"(expression_statement (string))"#, - results_by_step_index: &[ - (0, false), - (1, false), // string + results_by_symbol: &[ + ("expression_statement", false), + ("string", false), ], }, Row { language: get_language("javascript"), pattern: r#"(expression_statement (string))"#, - results_by_step_index: &[ - (0, false), - (1, false), // string + results_by_symbol: &[ + ("expression_statement", false), + ("string", false), // string ], }, Row { language: get_language("javascript"), pattern: r#"(object "{" "}")"#, - results_by_step_index: &[ - (0, false), - (1, true), // "{" - (2, true), // "}" + results_by_symbol: &[ + ("object", false), + ("{", true), + ("}", true), ], }, Row { language: get_language("javascript"), pattern: r#"(pair (property_identifier) ":")"#, - results_by_step_index: &[ - (0, false), - (1, false), // property_identifier - (2, true), // ":"" + results_by_symbol: &[ + ("pair", false), + ("property_identifier", false), + (":", true), ], }, Row { language: get_language("javascript"), pattern: r#"(object "{" (_) "}")"#, - results_by_step_index: &[ - (0, false), - (1, false), // "{"" - (2, false), // (_) - (3, true), // "}" + results_by_symbol: &[ + ("object", false), + ("{", false), + ("", false), + ("}", true), ], }, Row { - // Named wildcards, fields language: get_language("javascript"), pattern: r#"(binary_expression left: (identifier) right: (_))"#, - results_by_step_index: &[ - (0, false), - (1, false), // identifier - (2, true), // (_) + results_by_symbol: &[ + ("binary_expression", false), + ("identifier", false), + ("", true), ], }, Row { language: get_language("javascript"), pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, - results_by_step_index: &[ - (0, false), - (1, true), // identifier - (2, true), // statement_block + results_by_symbol: &[ + ("function_declaration", false), + ("identifier", true), + ("statement_block", true), + ], + }, + Row { + language: get_language("javascript"), + pattern: r#" + (function_declaration + name: (identifier) + body: (statement_block "{" (expression_statement) "}"))"#, + results_by_symbol: &[ + ("function_declaration", false), + ("identifier", false), + ("statement_block", false), + ("{", false), + ("expression_statement", false), + ("}", true), ], }, Row { language: get_language("javascript"), pattern: r#""#, - results_by_step_index: &[], + results_by_symbol: &[], }, ]; allocations::record(|| { for row in rows.iter() { let query = Query::new(row.language, row.pattern).unwrap(); - for (step_index, is_definite) in row.results_by_step_index { + for (symbol_name, is_definite) in row.results_by_symbol { + let mut symbol = 0; + if !symbol_name.is_empty() { + symbol = row.language.id_for_node_kind(symbol_name, true); + if symbol == 0 { + symbol = row.language.id_for_node_kind(symbol_name, false); + } + } assert_eq!( - query.pattern_is_definite(0, *step_index), + query.pattern_is_definite(0, symbol, 0), *is_definite, - "Pattern: {:?}, step: {}, expected is_definite to be {}", + "Pattern: {:?}, symbol: {}, expected is_definite to be {}", row.pattern, - step_index, + symbol_name, is_definite, ) } diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 167edebf..b5ff7a9e 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -654,6 +654,7 @@ extern "C" { pub fn ts_query_pattern_is_definite( self_: *const TSQuery, pattern_index: u32, + symbol: TSSymbol, step_index: u32, ) -> bool; } diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index d3284974..b4d6f8c5 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1467,9 +1467,9 @@ impl Query { /// Check if a pattern will definitely match after a certain number of steps /// have matched. - pub fn pattern_is_definite(&self, index: usize, step_index: usize) -> bool { + pub fn pattern_is_definite(&self, pattern_index: usize, symbol: u16, step_index: usize) -> bool { unsafe { - ffi::ts_query_pattern_is_definite(self.ptr.as_ptr(), index as u32, step_index as u32) + ffi::ts_query_pattern_is_definite(self.ptr.as_ptr(), pattern_index as u32, symbol, step_index as u32) } } diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 1abbf28c..850cd31e 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -722,6 +722,7 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( bool ts_query_pattern_is_definite( const TSQuery *self, uint32_t pattern_index, + TSSymbol symbol, uint32_t step_index ); diff --git a/lib/src/query.c b/lib/src/query.c index 64a1b8a0..dd6ad8c0 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -149,8 +149,7 @@ typedef struct { /* * AnalysisState - The state needed for walking the parse table when analyzing - * a query pattern, to determine the steps where the pattern could fail - * to match. + * a query pattern, to determine at which steps the pattern might fail to match. */ typedef struct { TSStateId parse_state; @@ -166,6 +165,12 @@ typedef struct { uint16_t step_index; } AnalysisState; +/* + * AnalysisSubgraph - A subset of the states in the parse table that are used + * in constructing nodes with a certain symbol. Each state is accompanied by + * some information about the possible node that could be produced in + * downstream states. + */ typedef struct { TSStateId state; uint8_t production_id; @@ -914,7 +919,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index " {parent: %s, child_index: %u, field: %s, state: %3u, done:%d}", self->language->symbol_names[state->stack[k].parent_symbol], state->stack[k].child_index, - self->language->field_names[state->stack[k].field_id], + state->stack[k].field_id ? self->language->field_names[state->stack[k].field_id] : "", state->stack[k].parse_state, state->stack[k].done ); @@ -1018,7 +1023,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // If this is a hidden child, then push a new entry to the stack, in order to // walk through the children of this child. - else if (next_state.depth < MAX_ANALYSIS_STATE_DEPTH) { + else if (sym >= self->language->token_count && next_state.depth < MAX_ANALYSIS_STATE_DEPTH) { next_state.depth++; analysis_state__top(&next_state)->parse_state = parse_state; analysis_state__top(&next_state)->child_index = 0; @@ -1122,17 +1127,29 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index } } - // In order for a parent step to be definite, all of its child steps must - // be definite. Propagate the definiteness up the pattern trees by walking - // the query's steps in reverse. + // In order for a step to be definite, all of its child steps must be definite, + // and all of its later sibling steps must be definite. Propagate any indefiniteness + // upward and backward through the pattern trees. for (unsigned i = self->steps.size - 1; i + 1 > 0; i--) { QueryStep *step = &self->steps.contents[i]; - for (unsigned j = i + 1; j < self->steps.size; j++) { + bool all_later_children_definite = true; + unsigned end_step_index = i + 1; + while (end_step_index < self->steps.size) { + QueryStep *child_step = &self->steps.contents[end_step_index]; + if (child_step->depth <= step->depth || child_step->depth == PATTERN_DONE_MARKER) break; + end_step_index++; + } + for (unsigned j = end_step_index - 1; j > i; j--) { QueryStep *child_step = &self->steps.contents[j]; - if (child_step->depth <= step->depth) break; - if (child_step->depth == step->depth + 1 && !child_step->is_definite) { - step->is_definite = false; - break; + if (child_step->depth == step->depth + 1) { + if (all_later_children_definite) { + if (!child_step->is_definite) { + all_later_children_definite = false; + step->is_definite = false; + } + } else { + child_step->is_definite = false; + } } } } @@ -1870,29 +1887,21 @@ uint32_t ts_query_start_byte_for_pattern( bool ts_query_pattern_is_definite( const TSQuery *self, uint32_t pattern_index, - uint32_t step_count + TSSymbol symbol, + uint32_t index ) { uint32_t step_index = self->patterns.contents[pattern_index].start_step; - for (;;) { - QueryStep *start_step = &self->steps.contents[step_index]; - if (step_index + step_count < self->steps.size) { - QueryStep *step = start_step; - for (unsigned i = 0; i < step_count; i++) { - if (step->depth == PATTERN_DONE_MARKER) { - step = NULL; - break; - } - step++; - } - if (step && !step->is_definite) return false; - } - if (start_step->alternative_index != NONE && start_step->alternative_index > step_index) { - step_index = start_step->alternative_index; - } else { - break; + QueryStep *step = &self->steps.contents[step_index]; + for (; step->depth != PATTERN_DONE_MARKER; step++) { + bool does_match = symbol ? + step->symbol == symbol : + step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL; + if (does_match) { + if (index == 0) return step->is_definite; + index--; } } - return true; + return false; } void ts_query_disable_capture(