From 4c2f36a07b99732c96d474fdae30c1cf158b966e Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 8 Jun 2020 16:07:22 -0700 Subject: [PATCH] Mark steps as definite on query construction * Add a ts_query_pattern_is_definite API, just for debugging this * Store state_count on TSLanguage structs, to allow for scanning parse tables --- cli/src/generate/render.rs | 62 +-- cli/src/main.rs | 10 +- cli/src/tests/query_test.rs | 72 +++- lib/binding_rust/bindings.rs | 13 +- lib/binding_rust/lib.rs | 8 + lib/include/tree_sitter/api.h | 8 +- lib/include/tree_sitter/parser.h | 1 + lib/src/array.h | 24 ++ lib/src/language.h | 1 + lib/src/query.c | 632 +++++++++++++++++++++++++++++-- 10 files changed, 755 insertions(+), 76 deletions(-) diff --git a/cli/src/generate/render.rs b/cli/src/generate/render.rs index 270bd00d..300ad383 100644 --- a/cli/src/generate/render.rs +++ b/cli/src/generate/render.rs @@ -95,11 +95,7 @@ impl Generator { self.add_stats(); self.add_symbol_enum(); self.add_symbol_names_list(); - - if self.next_abi { - self.add_unique_symbol_map(); - } - + self.add_unique_symbol_map(); self.add_symbol_metadata_list(); if !self.field_names.is_empty() { @@ -177,20 +173,16 @@ impl Generator { // If we are opting in to the new unstable language ABI, then use the concept of // "small parse states". Otherwise, use the same representation for all parse // states. - if self.next_abi { - let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2); - self.large_state_count = self - .parse_table - .states - .iter() - .enumerate() - .take_while(|(i, s)| { - *i <= 1 || s.terminal_entries.len() + s.nonterminal_entries.len() > threshold - }) - .count(); - } else { - self.large_state_count = self.parse_table.states.len(); - } + let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2); + self.large_state_count = self + .parse_table + .states + .iter() + .enumerate() + .take_while(|(i, s)| { + *i <= 1 || s.terminal_entries.len() + s.nonterminal_entries.len() > threshold + }) + .count(); } fn add_includes(&mut self) { @@ -256,10 +248,7 @@ impl Generator { "#define STATE_COUNT {}", self.parse_table.states.len() ); - - if self.next_abi { - add_line!(self, "#define LARGE_STATE_COUNT {}", self.large_state_count); - } + add_line!(self, "#define LARGE_STATE_COUNT {}", self.large_state_count); add_line!( self, @@ -689,17 +678,12 @@ impl Generator { name ); indent!(self); + add_line!(self, "START_LEXER();"); - - if self.next_abi { - add_line!(self, "eof = lexer->eof(lexer);"); - } else { - add_line!(self, "eof = lookahead == 0;"); - } - + add_line!(self, "eof = lexer->eof(lexer);"); add_line!(self, "switch (state) {{"); - indent!(self); + indent!(self); for (i, state) in lex_table.states.into_iter().enumerate() { add_line!(self, "case {}:", i); indent!(self); @@ -714,6 +698,7 @@ impl Generator { dedent!(self); add_line!(self, "}}"); + dedent!(self); add_line!(self, "}}"); add_line!(self, ""); @@ -967,12 +952,7 @@ impl Generator { add_line!( self, - "static uint16_t ts_parse_table[{}][SYMBOL_COUNT] = {{", - if self.next_abi { - "LARGE_STATE_COUNT" - } else { - "STATE_COUNT" - } + "static uint16_t ts_parse_table[LARGE_STATE_COUNT][SYMBOL_COUNT] = {{", ); indent!(self); @@ -1224,9 +1204,10 @@ impl Generator { add_line!(self, ".symbol_count = SYMBOL_COUNT,"); add_line!(self, ".alias_count = ALIAS_COUNT,"); add_line!(self, ".token_count = TOKEN_COUNT,"); + add_line!(self, ".large_state_count = LARGE_STATE_COUNT,"); if self.next_abi { - add_line!(self, ".large_state_count = LARGE_STATE_COUNT,"); + add_line!(self, ".state_count = STATE_COUNT,"); } add_line!(self, ".symbol_metadata = ts_symbol_metadata,"); @@ -1249,10 +1230,7 @@ impl Generator { add_line!(self, ".parse_actions = ts_parse_actions,"); add_line!(self, ".lex_modes = ts_lex_modes,"); add_line!(self, ".symbol_names = ts_symbol_names,"); - - if self.next_abi { - add_line!(self, ".public_symbol_map = ts_symbol_map,"); - } + add_line!(self, ".public_symbol_map = ts_symbol_map,"); if !self.parse_table.production_infos.is_empty() { add_line!( diff --git a/cli/src/main.rs b/cli/src/main.rs index 757c70eb..04cd34cd 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -149,8 +149,14 @@ fn run() -> error::Result<()> { .arg(Arg::with_name("path").index(1).multiple(true)), ) .subcommand( - SubCommand::with_name("web-ui").about("Test a parser interactively in the browser") - .arg(Arg::with_name("quiet").long("quiet").short("q").help("open in default browser")), + SubCommand::with_name("web-ui") + .about("Test a parser interactively in the browser") + .arg( + Arg::with_name("quiet") + .long("quiet") + .short("q") + .help("open in default browser"), + ), ) .subcommand( SubCommand::with_name("dump-languages") diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index d4f18c7d..92aff5fb 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1919,7 +1919,7 @@ fn test_query_start_byte_for_pattern() { let patterns_3 = " ((identifier) @b (#match? @b i)) (function_declaration name: (identifier) @c) - (method_definition name: (identifier) @d) + (method_definition name: (property_identifier) @d) " .trim_start(); @@ -2048,6 +2048,76 @@ fn test_query_disable_pattern() { }); } +#[test] +fn test_query_is_definite() { + struct Row { + pattern: &'static str, + results_by_step_index: &'static [(usize, bool)], + } + + let rows = &[ + Row { + pattern: r#"(object "{" "}")"#, + results_by_step_index: &[ + (0, false), + (1, true), // "{" + (2, true), // "}" + ], + }, + Row { + pattern: r#"(pair (property_identifier) ":")"#, + results_by_step_index: &[ + (0, false), + (1, false), // property_identifier + (2, true), // ":"" + ], + }, + Row { + pattern: r#"(object "{" (_) "}")"#, + results_by_step_index: &[ + (0, false), + (1, false), // "{"" + (2, false), // (_) + (3, true), // "}" + ], + }, + Row { + // Named wildcards, fields + pattern: r#"(binary_expression left: (identifier) right: (_))"#, + results_by_step_index: &[ + (0, false), + (1, false), // identifier + (2, true), // (_) + ], + }, + Row { + pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, + results_by_step_index: &[ + (0, false), + (1, true), // identifier + (2, true), // statement_block + ], + }, + ]; + + allocations::record(|| { + let language = get_language("javascript"); + for row in rows.iter() { + let query = Query::new(language, row.pattern).unwrap(); + for (step_index, is_definite) in row.results_by_step_index { + assert_eq!( + query.pattern_is_definite(0, *step_index), + *is_definite, + "Pattern: {:?}, step: {}, expected is_definite to be {}", + row.pattern, + step_index, + is_definite, + ) + } + } + }); +} + fn assert_query_matches( language: Language, query: &Query, diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index cba87fa3..7dc48660 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -172,9 +172,9 @@ extern "C" { #[doc = " the given ranges must be ordered from earliest to latest in the document,"] #[doc = " and they must not overlap. That is, the following must hold for all"] #[doc = " `i` < `length - 1`:"] - #[doc = " ```text"] + #[doc = ""] #[doc = " ranges[i].end_byte <= ranges[i + 1].start_byte"] - #[doc = " ```"] + #[doc = ""] #[doc = " If this requirement is not satisfied, the operation will fail, the ranges"] #[doc = " will not be assigned, and this function will return `false`. On success,"] #[doc = " this function returns `true`"] @@ -649,6 +649,13 @@ extern "C" { length: *mut u32, ) -> *const TSQueryPredicateStep; } +extern "C" { + pub fn ts_query_pattern_is_definite( + self_: *const TSQuery, + pattern_index: u32, + step_index: u32, + ) -> bool; +} extern "C" { #[doc = " Get the name and length of one of the query\'s captures, or one of the"] #[doc = " query\'s string literals. Each capture and string is associated with a"] @@ -800,5 +807,5 @@ extern "C" { pub fn ts_language_version(arg1: *const TSLanguage) -> u32; } -pub const TREE_SITTER_LANGUAGE_VERSION: usize = 11; +pub const TREE_SITTER_LANGUAGE_VERSION: usize = 12; pub const TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION: usize = 9; diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index c0aba32f..453cb8e7 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1449,6 +1449,14 @@ impl Query { unsafe { ffi::ts_query_disable_pattern(self.ptr.as_ptr(), index as u32) } } + /// Check if a pattern will definitely match after a certain number of steps + /// have matched. + pub fn pattern_is_definite(&self, index: usize, step_index: usize) -> bool { + unsafe { + ffi::ts_query_pattern_is_definite(self.ptr.as_ptr(), index as u32, step_index as u32) + } + } + fn parse_property( function_name: &str, capture_names: &[String], diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 9d832e6e..1b2533fc 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -21,7 +21,7 @@ extern "C" { * The Tree-sitter library is generally backwards-compatible with languages * generated using older CLI versions, but is not forwards-compatible. */ -#define TREE_SITTER_LANGUAGE_VERSION 11 +#define TREE_SITTER_LANGUAGE_VERSION 12 /** * The earliest ABI version that is supported by the current version of the @@ -718,6 +718,12 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t *length ); +bool ts_query_pattern_is_definite( + const TSQuery *self, + uint32_t pattern_index, + uint32_t step_index +); + /** * Get the name and length of one of the query's captures, or one of the * query's string literals. Each capture and string is associated with a diff --git a/lib/include/tree_sitter/parser.h b/lib/include/tree_sitter/parser.h index 11bf4fc4..360e012f 100644 --- a/lib/include/tree_sitter/parser.h +++ b/lib/include/tree_sitter/parser.h @@ -119,6 +119,7 @@ struct TSLanguage { const uint16_t *small_parse_table; const uint32_t *small_parse_table_map; const TSSymbol *public_symbol_map; + uint32_t state_count; }; /* diff --git a/lib/src/array.h b/lib/src/array.h index 26cb8448..c7e0ae4a 100644 --- a/lib/src/array.h +++ b/lib/src/array.h @@ -66,6 +66,30 @@ extern "C" { #define array_assign(self, other) \ array__assign((VoidArray *)(self), (const VoidArray *)(other), array__elem_size(self)) +#define array_search_sorted_by(self, start, field, needle, out_index, out_exists) \ + do { \ + *(out_exists) = false; \ + for (*(out_index) = start; *(out_index) < (self)->size; (*(out_index))++) { \ + int _comparison = (int)((self)->contents[*(out_index)] field) - (int)(needle); \ + if (_comparison >= 0) { \ + if (_comparison == 0) *(out_exists) = true; \ + break; \ + } \ + } \ + } while (0); + +#define array_search_sorted_with(self, start, compare, needle, out_index, out_exists) \ + do { \ + *(out_exists) = false; \ + for (*(out_index) = start; *(out_index) < (self)->size; (*(out_index))++) { \ + int _comparison = compare(&(self)->contents[*(out_index)], (needle)); \ + if (_comparison >= 0) { \ + if (_comparison == 0) *(out_exists) = true; \ + break; \ + } \ + } \ + } while (0); + // Private typedef Array(void) VoidArray; diff --git a/lib/src/language.h b/lib/src/language.h index 2bb9a6f9..288c2a2b 100644 --- a/lib/src/language.h +++ b/lib/src/language.h @@ -12,6 +12,7 @@ extern "C" { #define TREE_SITTER_LANGUAGE_VERSION_WITH_FIELDS 10 #define TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING 11 #define TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES 11 +#define TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT 12 typedef struct { const TSParseAction *actions; diff --git a/lib/src/query.c b/lib/src/query.c index ff243494..10ab5371 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -14,6 +14,8 @@ #define MAX_STATE_COUNT 256 #define MAX_CAPTURE_LIST_COUNT 32 #define MAX_STEP_CAPTURE_COUNT 3 +#define MAX_STATE_PREDECESSOR_COUNT 100 +#define MAX_WALK_STATE_DEPTH 4 /* * Stream - A sequence of unicode characters derived from a UTF8 string. @@ -55,6 +57,7 @@ typedef struct { bool is_pass_through: 1; bool is_dead_end: 1; bool alternative_is_immediate: 1; + bool is_definite: 1; } QueryStep; /* @@ -89,6 +92,12 @@ typedef struct { uint16_t pattern_index; } PatternEntry; +typedef struct { + Slice predicate_steps; + uint32_t start_byte; + uint32_t start_step; +} QueryPattern; + /* * QueryState - The state of an in-progress match of a particular pattern * in a query. While executing, a `TSQueryCursor` must keep track of a number @@ -138,6 +147,31 @@ typedef struct { uint32_t usage_map; } CaptureListPool; +/* + * WalkState - The state needed for walking the parse table when analyzing + * a query pattern, to determine the steps where the pattern could fail + * to match. + */ +typedef struct { + TSStateId state; + TSSymbol parent_symbol; + uint16_t child_index; + TSFieldId field; +} WalkStateEntry; + +typedef struct { + WalkStateEntry stack[MAX_WALK_STATE_DEPTH]; + uint16_t depth; + uint16_t step_index; +} WalkState; + +/* + * StatePredecessorMap - A map that stores the predecessors of each parse state. + */ +typedef struct { + TSStateId *contents; +} StatePredecessorMap; + /* * TSQuery - A tree query, compiled from a string of S-expressions. The query * itself is immutable. The mutable state used in the process of executing the @@ -149,8 +183,7 @@ struct TSQuery { Array(QueryStep) steps; Array(PatternEntry) pattern_map; Array(TSQueryPredicateStep) predicate_steps; - Array(Slice) predicates_by_pattern; - Array(uint32_t) start_bytes_by_pattern; + Array(QueryPattern) patterns; const TSLanguage *language; uint16_t wildcard_root_pattern_count; TSSymbol *symbol_map; @@ -451,6 +484,7 @@ static QueryStep query_step__new( .is_pattern_start = false, .is_pass_through = false, .is_dead_end = false, + .is_definite = false, .is_immediate = is_immediate, .alternative_is_immediate = false, }; @@ -480,6 +514,67 @@ static void query_step__remove_capture(QueryStep *self, uint16_t capture_id) { } } +/********************** + * StatePredecessorMap + **********************/ + +static inline StatePredecessorMap state_predecessor_map_new(const TSLanguage *language) { + return (StatePredecessorMap) { + .contents = ts_calloc(language->state_count * (MAX_STATE_PREDECESSOR_COUNT + 1), sizeof(TSStateId)), + }; +} + +static inline void state_predecessor_map_delete(StatePredecessorMap *self) { + ts_free(self->contents); +} + +static inline void state_predecessor_map_add( + StatePredecessorMap *self, + TSStateId state, + TSStateId predecessor +) { + unsigned index = state * (MAX_STATE_PREDECESSOR_COUNT + 1); + TSStateId *count = &self->contents[index]; + if (*count == 0 || (*count < MAX_STATE_PREDECESSOR_COUNT && self->contents[index + *count] != predecessor)) { + (*count)++; + self->contents[index + *count] = predecessor; + } +} + +static inline const TSStateId *state_predecessor_map_get( + const StatePredecessorMap *self, + TSStateId state, + unsigned *count +) { + unsigned index = state * (MAX_STATE_PREDECESSOR_COUNT + 1); + *count = self->contents[index]; + return &self->contents[index + 1]; +} + +/************ + * WalkState + ************/ + +static inline int walk_state__compare(WalkState *self, WalkState *other) { + if (self->depth < other->depth) return -1; + if (self->depth > other->depth) return 1; + if (self->step_index < other->step_index) return -1; + if (self->step_index > other->step_index) return 1; + for (unsigned i = 0; i < self->depth; i++) { + if (self->stack[i].state < other->stack[i].state) return -1; + if (self->stack[i].state > other->stack[i].state) return 1; + if (self->stack[i].parent_symbol < other->stack[i].parent_symbol) return -1; + if (self->stack[i].parent_symbol > other->stack[i].parent_symbol) return 1; + if (self->stack[i].child_index < other->stack[i].child_index) return -1; + if (self->stack[i].child_index > other->stack[i].child_index) return 1; + } + return 0; +} + +static inline WalkStateEntry *walk_state__top(WalkState *self) { + return &self->stack[self->depth - 1]; +} + /********* * Query *********/ @@ -552,6 +647,466 @@ static inline void ts_query__pattern_map_insert( })); } +static void ts_query__analyze_patterns(TSQuery *self) { + typedef struct { + TSSymbol parent_symbol; + uint32_t parent_step_index; + Array(uint32_t) child_step_indices; + } ParentPattern; + + typedef struct { + TSStateId state; + uint8_t child_index; + uint8_t production_id; + bool done; + } SubgraphNode; + + typedef struct { + TSSymbol symbol; + Array(TSStateId) start_states; + Array(SubgraphNode) nodes; + } SymbolSubgraph; + + typedef Array(WalkState) WalkStateList; + + // Identify all of the patterns in the query that have child patterns. This + // includes both top-level patterns and patterns that are nested within some + // larger pattern. For each of these, record the parent symbol, the step index + // and all of the immediate child step indices in reverse order. + Array(ParentPattern) parent_patterns = array_new(); + Array(uint32_t) stack = array_new(); + for (unsigned i = 0; i < self->steps.size; i++) { + QueryStep *step = &self->steps.contents[i]; + if (step->depth == PATTERN_DONE_MARKER) { + array_clear(&stack); + } else { + uint32_t parent_pattern_index = 0; + while (stack.size > 0) { + parent_pattern_index = *array_back(&stack); + ParentPattern *parent_pattern = &parent_patterns.contents[parent_pattern_index]; + QueryStep *parent_step = &self->steps.contents[parent_pattern->parent_step_index]; + if (parent_step->depth >= step->depth) { + stack.size--; + } else { + break; + } + } + + if (stack.size > 0) { + ParentPattern *parent_pattern = &parent_patterns.contents[parent_pattern_index]; + step->is_definite = true; + array_push(&parent_pattern->child_step_indices, i); + } + + array_push(&stack, parent_patterns.size); + array_push(&parent_patterns, ((ParentPattern) { + .parent_symbol = step->symbol, + .parent_step_index = i, + })); + } + } + for (unsigned i = 0; i < parent_patterns.size; i++) { + ParentPattern *parent_pattern = &parent_patterns.contents[i]; + if (parent_pattern->child_step_indices.size == 0) { + array_erase(&parent_patterns, i); + i--; + } + } + + // Debug + // { + // printf("\nParent pattern entries\n"); + // for (unsigned i = 0; i < parent_patterns.size; i++) { + // ParentPattern *parent_pattern = &parent_patterns.contents[i]; + // printf(" %s ->", ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); + // for (unsigned j = 0; j < parent_pattern->child_step_indices.size; j++) { + // QueryStep *step = &self->steps.contents[parent_pattern->child_step_indices.contents[j]]; + // printf(" %s", ts_language_symbol_name(self->language, step->symbol)); + // } + // printf("\n"); + // } + // } + + // Initialize a set of subgraphs, with one subgraph for each parent symbol, + // in the query, and one subgraph for each hidden symbol. + unsigned subgraph_index = 0, exists; + Array(SymbolSubgraph) subgraphs = array_new(); + for (unsigned i = 0; i < parent_patterns.size; i++) { + TSSymbol parent_symbol = parent_patterns.contents[i].parent_symbol; + array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); + if (!exists) { + array_insert(&subgraphs, subgraph_index, ((SymbolSubgraph) { .symbol = parent_symbol, })); + } + } + subgraph_index = 0; + for (TSSymbol sym = 0; sym < self->language->symbol_count; sym++) { + if (!ts_language_symbol_metadata(self->language, sym).visible) { + array_search_sorted_by( + &subgraphs, subgraph_index, + .symbol, sym, + &subgraph_index, &exists + ); + if (!exists) { + array_insert(&subgraphs, subgraph_index, ((SymbolSubgraph) { .symbol = sym, })); + subgraph_index++; + } + } + } + + // Scan the parse table to find the data needed for these subgraphs. + // Collect three things during this scan: + // 1) All of the parse states where one of these symbols can start. + // 2) All of the parse states where one of these symbols can end, along + // with information about the node that would be created. + // 3) A list of predecessor states for each state. + StatePredecessorMap predecessor_map = state_predecessor_map_new(self->language); + for (TSStateId state = 1; state < self->language->state_count; state++) { + unsigned subgraph_index = 0, exists; + for (TSSymbol sym = 0; sym < self->language->token_count; sym++) { + unsigned count; + const TSParseAction *actions = ts_language_actions(self->language, state, sym, &count); + for (unsigned i = 0; i < count; i++) { + const TSParseAction *action = &actions[i]; + if (action->type == TSParseActionTypeReduce) { + unsigned exists; + array_search_sorted_by( + &subgraphs, + subgraph_index, + .symbol, + action->params.reduce.symbol, + &subgraph_index, + &exists + ); + if (exists) { + SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { + array_push(&subgraph->nodes, ((SubgraphNode) { + .state = state, + .production_id = action->params.reduce.production_id, + .child_index = action->params.reduce.child_count, + .done = true, + })); + } + } + } else if ( + action->type == TSParseActionTypeShift && + !action->params.shift.extra + ) { + TSStateId next_state = action->params.shift.state; + state_predecessor_map_add(&predecessor_map, next_state, state); + } + } + } + for (TSSymbol sym = self->language->token_count; sym < self->language->symbol_count; sym++) { + TSStateId next_state = ts_language_next_state(self->language, state, sym); + if (next_state != 0) { + state_predecessor_map_add(&predecessor_map, next_state, state); + array_search_sorted_by( + &subgraphs, + subgraph_index, + .symbol, + sym, + &subgraph_index, + &exists + ); + if (exists) { + SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + array_push(&subgraph->start_states, state); + } + } + } + } + + // For each subgraph, compute the remainder of the nodes by walking backward + // from the end states using the predecessor map. + Array(SubgraphNode) next_nodes = array_new(); + for (unsigned i = 0; i < subgraphs.size; i++) { + SymbolSubgraph *subgraph = &subgraphs.contents[i]; + if (subgraph->nodes.size == 0) { + array_delete(&subgraph->start_states); + array_erase(&subgraphs, i); + i--; + continue; + } + array_assign(&next_nodes, &subgraph->nodes); + while (next_nodes.size > 0) { + SubgraphNode node = array_pop(&next_nodes); + if (node.child_index > 1) { + unsigned predecessor_count; + const TSStateId *predecessors = state_predecessor_map_get( + &predecessor_map, + node.state, + &predecessor_count + ); + for (unsigned j = 0; j < predecessor_count; j++) { + SubgraphNode predecessor_node = { + .state = predecessors[j], + .child_index = node.child_index - 1, + .production_id = node.production_id, + .done = false, + }; + unsigned index, exists; + array_search_sorted_by(&subgraph->nodes, 0, .state, predecessor_node.state, &index, &exists); + if (!exists) { + array_insert(&subgraph->nodes, index, predecessor_node); + array_push(&next_nodes, predecessor_node); + } + } + } + } + } + + // Debug + // { + // printf("\nSubgraphs:\n"); + // for (unsigned i = 0; i < subgraphs.size; i++) { + // SymbolSubgraph *subgraph = &subgraphs.contents[i]; + // printf(" %u, %s:\n", subgraph->symbol, ts_language_symbol_name(self->language, subgraph->symbol)); + // for (unsigned j = 0; j < subgraph->nodes.size; j++) { + // SubgraphNode *node = &subgraph->nodes.contents[j]; + // printf(" {state: %u, child_index: %u}\n", node->state, node->child_index); + // } + // printf("\n"); + // } + // } + + // For each non-terminal pattern, determine if the pattern can successfully match, + // and all of the possible children within the pattern where matching could fail. + WalkStateList walk_states = array_new(); + WalkStateList next_walk_states = array_new(); + Array(uint16_t) finished_step_indices = array_new(); + for (unsigned i = 0; i < parent_patterns.size; i++) { + ParentPattern *parent_pattern = &parent_patterns.contents[i]; + unsigned subgraph_index, exists; + array_search_sorted_by(&subgraphs, 0, .symbol, parent_pattern->parent_symbol, &subgraph_index, &exists); + if (!exists) { + // TODO - what to do for ERROR patterns + continue; + } + SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + + // Initialize a walk at every possible parse state where this non-terminal + // symbol can start. + array_clear(&walk_states); + for (unsigned j = 0; j < subgraph->start_states.size; j++) { + TSStateId state = subgraph->start_states.contents[j]; + array_push(&walk_states, ((WalkState) { + .step_index = 0, + .stack = { + [0] = { + .state = state, + .child_index = 0, + .parent_symbol = subgraph->symbol, + .field = 0, + }, + }, + .depth = 1, + })); + } + + // Walk the subgraph for this non-terminal, tracking all of the possible + // sequences of progress within the pattern. + array_clear(&finished_step_indices); + while (walk_states.size > 0) { + // Debug + // { + // printf("Walk states for %u %s:\n", i, ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); + // for (unsigned j = 0; j < walk_states.size; j++) { + // WalkState *walk_state = &walk_states.contents[j]; + // printf( + // " %u: {depth: %u, step: %u, state: %u, child_index: %u, parent: %s}\n", + // j, + // walk_state->depth, + // walk_state->step_index, + // walk_state->stack[walk_state->depth - 1].state, + // walk_state->stack[walk_state->depth - 1].child_index, + // ts_language_symbol_name(self->language, walk_state->stack[walk_state->depth - 1].parent_symbol) + // ); + // } + + // printf("\nFinished step indices for %u %s:", i, ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); + // for (unsigned j = 0; j < finished_step_indices.size; j++) { + // printf(" %u", finished_step_indices.contents[j]); + // } + // printf("\n\n"); + // } + + array_clear(&next_walk_states); + for (unsigned j = 0; j < walk_states.size; j++) { + WalkState *walk_state = &walk_states.contents[j]; + TSStateId state = walk_state->stack[walk_state->depth - 1].state; + unsigned child_index = walk_state->stack[walk_state->depth - 1].child_index; + TSSymbol parent_symbol = walk_state->stack[walk_state->depth - 1].parent_symbol; + + unsigned subgraph_index, exists; + array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); + if (!exists) continue; + SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + + for (TSSymbol sym = 0; sym < self->language->symbol_count; sym++) { + TSStateId successor_state = ts_language_next_state(self->language, state, sym); + if (successor_state && successor_state != state) { + unsigned node_index; + array_search_sorted_by(&subgraph->nodes, 0, .state, successor_state, &node_index, &exists); + if (exists) { + SubgraphNode *node = &subgraph->nodes.contents[node_index]; + if (node->child_index != child_index + 1) continue; + + WalkState next_walk_state = *walk_state; + walk_state__top(&next_walk_state)->child_index++; + walk_state__top(&next_walk_state)->state = successor_state; + + bool does_match = true; + unsigned step_index = parent_pattern->child_step_indices.contents[walk_state->step_index]; + QueryStep *step = &self->steps.contents[step_index]; + TSSymbol alias = ts_language_alias_at(self->language, node->production_id, child_index); + TSSymbol visible_symbol = alias + ? alias + : self->language->symbol_metadata[sym].visible + ? self->language->public_symbol_map[sym] + : 0; + if (visible_symbol) { + if (step->symbol == NAMED_WILDCARD_SYMBOL) { + if (!ts_language_symbol_metadata(self->language, visible_symbol).named) does_match = false; + } else if (step->symbol != WILDCARD_SYMBOL) { + if (step->symbol != visible_symbol) does_match = false; + } + } else if (next_walk_state.depth < MAX_WALK_STATE_DEPTH) { + does_match = false; + next_walk_state.depth++; + walk_state__top(&next_walk_state)->state = state; + walk_state__top(&next_walk_state)->child_index = 0; + walk_state__top(&next_walk_state)->parent_symbol = sym; + } else { + continue; + } + + TSFieldId field_id = 0; + const TSFieldMapEntry *field_map, *field_map_end; + ts_language_field_map(self->language, node->production_id, &field_map, &field_map_end); + for (; field_map != field_map_end; field_map++) { + if (field_map->child_index == child_index) { + field_id = field_map->field_id; + break; + } + } + + if (does_match) { + next_walk_state.step_index++; + } + + if (node->done) { + next_walk_state.depth--; + } + + if ( + next_walk_state.depth == 0 || + next_walk_state.step_index == parent_pattern->child_step_indices.size + ) { + unsigned index, exists; + array_search_sorted_by(&finished_step_indices, 0, , next_walk_state.step_index, &index, &exists); + if (!exists) array_insert(&finished_step_indices, index, next_walk_state.step_index); + continue; + } + + unsigned index, exists; + array_search_sorted_with( + &next_walk_states, + 0, + walk_state__compare, + &next_walk_state, + &index, + &exists + ); + if (!exists) { + array_insert(&next_walk_states, index, next_walk_state); + } + } + } + } + } + + WalkStateList _walk_states = walk_states; + walk_states = next_walk_states; + next_walk_states = _walk_states; + } + + // Debug + // { + // printf("Finished step indices for %u %s:", i, ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); + // for (unsigned j = 0; j < finished_step_indices.size; j++) { + // printf(" %u", finished_step_indices.contents[j]); + // } + // printf("\n\n"); + // } + + // A query step is definite if the containing pattern will definitely match + // once the step is reached. In other words, a step is *not* definite if + // it's possible to create a syntax node that matches up to until that step, + // but does not match the entire pattern. + for (unsigned j = 0, n = parent_pattern->child_step_indices.size; j < n; j++) { + uint32_t step_index = parent_pattern->child_step_indices.contents[j]; + for (unsigned k = 0; k < finished_step_indices.size; k++) { + uint32_t finished_step_index = finished_step_indices.contents[k]; + if (finished_step_index >= j && finished_step_index < n) { + QueryStep *step = &self->steps.contents[step_index]; + step->is_definite = false; + break; + } + } + } + } + + // In order for a parent step to be definite, all of its child steps must + // be definite. Propagate the definiteness up the pattern trees by walking + // the query's steps in reverse. + for (unsigned i = self->steps.size - 1; i + 1 > 0; i--) { + QueryStep *step = &self->steps.contents[i]; + for (unsigned j = i + 1; j < self->steps.size; j++) { + QueryStep *child_step = &self->steps.contents[j]; + if (child_step->depth <= step->depth) break; + if (child_step->depth == step->depth + 1 && !child_step->is_definite) { + step->is_definite = false; + break; + } + } + } + + // Debug + // { + // printf("\nSteps:\n"); + // for (unsigned i = 0; i < self->steps.size; i++) { + // QueryStep *step = &self->steps.contents[i]; + // if (step->depth == PATTERN_DONE_MARKER) { + // printf("\n"); + // continue; + // } + // printf( + // " {symbol: %s, is_definite: %d}\n", + // (step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL) ? "ANY" : ts_language_symbol_name(self->language, step->symbol), + // step->is_definite + // ); + // } + // } + + // Cleanup + for (unsigned i = 0; i < parent_patterns.size; i++) { + array_delete(&parent_patterns.contents[i].child_step_indices); + } + for (unsigned i = 0; i < subgraphs.size; i++) { + array_delete(&subgraphs.contents[i].start_states); + array_delete(&subgraphs.contents[i].nodes); + } + array_delete(&stack); + array_delete(&subgraphs); + array_delete(&next_nodes); + array_delete(&walk_states); + array_delete(&parent_patterns); + array_delete(&next_walk_states); + array_delete(&finished_step_indices); + state_predecessor_map_delete(&predecessor_map); +} + static void ts_query__finalize_steps(TSQuery *self) { for (unsigned i = 0; i < self->steps.size; i++) { QueryStep *step = &self->steps.contents[i]; @@ -588,7 +1143,7 @@ static TSQueryError ts_query__parse_predicate( predicate_name, length ); - array_back(&self->predicates_by_pattern)->length++; + array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -599,7 +1154,7 @@ static TSQueryError ts_query__parse_predicate( if (stream->next == ')') { stream_advance(stream); stream_skip_whitespace(stream); - array_back(&self->predicates_by_pattern)->length++; + array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeDone, .value_id = 0, @@ -628,7 +1183,7 @@ static TSQueryError ts_query__parse_predicate( return TSQueryErrorCapture; } - array_back(&self->predicates_by_pattern)->length++; + array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeCapture, .value_id = capture_id, @@ -668,7 +1223,7 @@ static TSQueryError ts_query__parse_predicate( string_content, length ); - array_back(&self->predicates_by_pattern)->length++; + array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -688,7 +1243,7 @@ static TSQueryError ts_query__parse_predicate( symbol_start, length ); - array_back(&self->predicates_by_pattern)->length++; + array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -712,7 +1267,6 @@ static TSQueryError ts_query__parse_pattern( TSQuery *self, Stream *stream, uint32_t depth, - uint32_t *capture_count, bool is_immediate ) { const uint32_t starting_step_index = self->steps.size; @@ -737,7 +1291,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth, - capture_count, is_immediate ); @@ -790,7 +1343,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth, - capture_count, child_is_immediate ); if (e == PARENT_DONE && stream->next == ')') { @@ -871,7 +1423,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth + 1, - capture_count, child_is_immediate ); if (e == PARENT_DONE && stream->next == ')') { @@ -955,7 +1506,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth, - capture_count, is_immediate ); if (e == PARENT_DONE) return TSQueryErrorSyntax; @@ -1069,8 +1619,6 @@ static TSQueryError ts_query__parse_pattern( break; } } - - (*capture_count)++; } // No more suffix modifiers @@ -1123,7 +1671,7 @@ TSQuery *ts_query_new( .captures = symbol_table_new(), .predicate_values = symbol_table_new(), .predicate_steps = array_new(), - .predicates_by_pattern = array_new(), + .patterns = array_new(), .symbol_map = symbol_map, .wildcard_root_pattern_count = 0, .language = language, @@ -1133,15 +1681,14 @@ TSQuery *ts_query_new( Stream stream = stream_new(source, source_len); stream_skip_whitespace(&stream); while (stream.input < stream.end) { - uint32_t pattern_index = self->predicates_by_pattern.size; + uint32_t pattern_index = self->patterns.size; uint32_t start_step_index = self->steps.size; - uint32_t capture_count = 0; - array_push(&self->start_bytes_by_pattern, stream.input - source); - array_push(&self->predicates_by_pattern, ((Slice) { - .offset = self->predicate_steps.size, - .length = 0, + array_push(&self->patterns, ((QueryPattern) { + .predicate_steps = (Slice) {.offset = self->predicate_steps.size, .length = 0}, + .start_byte = stream.input - source, + .start_step = self->steps.size, })); - *error_type = ts_query__parse_pattern(self, &stream, 0, &capture_count, false); + *error_type = ts_query__parse_pattern(self, &stream, 0, false); array_push(&self->steps, query_step__new(0, PATTERN_DONE_MARKER, false)); // If any pattern could not be parsed, then report the error information @@ -1183,6 +1730,10 @@ TSQuery *ts_query_new( } } + if (self->language->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT) { + ts_query__analyze_patterns(self); + } + ts_query__finalize_steps(self); return self; } @@ -1192,8 +1743,7 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->steps); array_delete(&self->pattern_map); array_delete(&self->predicate_steps); - array_delete(&self->predicates_by_pattern); - array_delete(&self->start_bytes_by_pattern); + array_delete(&self->patterns); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); ts_free(self->symbol_map); @@ -1202,7 +1752,7 @@ void ts_query_delete(TSQuery *self) { } uint32_t ts_query_pattern_count(const TSQuery *self) { - return self->predicates_by_pattern.size; + return self->patterns.size; } uint32_t ts_query_capture_count(const TSQuery *self) { @@ -1234,7 +1784,7 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t pattern_index, uint32_t *step_count ) { - Slice slice = self->predicates_by_pattern.contents[pattern_index]; + Slice slice = self->patterns.contents[pattern_index].predicate_steps; *step_count = slice.length; return &self->predicate_steps.contents[slice.offset]; } @@ -1243,7 +1793,35 @@ uint32_t ts_query_start_byte_for_pattern( const TSQuery *self, uint32_t pattern_index ) { - return self->start_bytes_by_pattern.contents[pattern_index]; + return self->patterns.contents[pattern_index].start_byte; +} + +bool ts_query_pattern_is_definite( + const TSQuery *self, + uint32_t pattern_index, + uint32_t step_count +) { + uint32_t step_index = self->patterns.contents[pattern_index].start_step; + for (;;) { + QueryStep *start_step = &self->steps.contents[step_index]; + if (step_index + step_count < self->steps.size) { + QueryStep *step = start_step; + for (unsigned i = 0; i < step_count; i++) { + if (step->depth == PATTERN_DONE_MARKER) { + step = NULL; + break; + } + step++; + } + if (step && !step->is_definite) return false; + } + if (start_step->alternative_index != NONE && start_step->alternative_index > step_index) { + step_index = start_step->alternative_index; + } else { + break; + } + } + return true; } void ts_query_disable_capture(