From 6a46dff89a9d9bd9ceb13d7838c1a801974ac08d Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 15 Jun 2020 09:58:07 -0700 Subject: [PATCH 01/26] Add ts_language_alias_at helper function --- lib/src/get_changed_ranges.c | 24 +++++++------- lib/src/language.h | 51 +++++++++++++++++++---------- lib/src/subtree.c | 8 ++--- lib/src/tree_cursor.c | 62 +++++++++++++++++++----------------- 4 files changed, 84 insertions(+), 61 deletions(-) diff --git a/lib/src/get_changed_ranges.c b/lib/src/get_changed_ranges.c index 5bd1d814..b24f3149 100644 --- a/lib/src/get_changed_ranges.c +++ b/lib/src/get_changed_ranges.c @@ -146,17 +146,21 @@ static bool iterator_tree_is_visible(const Iterator *self) { if (ts_subtree_visible(*entry.subtree)) return true; if (self->cursor.stack.size > 1) { Subtree parent = *self->cursor.stack.contents[self->cursor.stack.size - 2].subtree; - const TSSymbol *alias_sequence = ts_language_alias_sequence( + return ts_language_alias_at( self->language, - parent.ptr->production_id - ); - return alias_sequence && alias_sequence[entry.structural_child_index] != 0; + parent.ptr->production_id, + entry.structural_child_index + ) != 0; } return false; } -static void iterator_get_visible_state(const Iterator *self, Subtree *tree, - TSSymbol *alias_symbol, uint32_t *start_byte) { +static void iterator_get_visible_state( + const Iterator *self, + Subtree *tree, + TSSymbol *alias_symbol, + uint32_t *start_byte +) { uint32_t i = self->cursor.stack.size - 1; if (self->in_padding) { @@ -169,13 +173,11 @@ static void iterator_get_visible_state(const Iterator *self, Subtree *tree, if (i > 0) { const Subtree *parent = self->cursor.stack.contents[i - 1].subtree; - const TSSymbol *alias_sequence = ts_language_alias_sequence( + *alias_symbol = ts_language_alias_at( self->language, - parent->ptr->production_id + parent->ptr->production_id, + entry.structural_child_index ); - if (alias_sequence) { - *alias_symbol = alias_sequence[entry.structural_child_index]; - } } if (ts_subtree_visible(*entry.subtree) || *alias_symbol) { diff --git a/lib/src/language.h b/lib/src/language.h index 341f0f85..2bb9a6f9 100644 --- a/lib/src/language.h +++ b/lib/src/language.h @@ -41,17 +41,21 @@ static inline const TSParseAction *ts_language_actions( return entry.actions; } -static inline bool ts_language_has_actions(const TSLanguage *self, - TSStateId state, - TSSymbol symbol) { +static inline bool ts_language_has_actions( + const TSLanguage *self, + TSStateId state, + TSSymbol symbol +) { TableEntry entry; ts_language_table_entry(self, state, symbol, &entry); return entry.action_count > 0; } -static inline bool ts_language_has_reduce_action(const TSLanguage *self, - TSStateId state, - TSSymbol symbol) { +static inline bool ts_language_has_reduce_action( + const TSLanguage *self, + TSStateId state, + TSSymbol symbol +) { TableEntry entry; ts_language_table_entry(self, state, symbol, &entry); return entry.action_count > 0 && entry.actions[0].type == TSParseActionTypeReduce; @@ -82,9 +86,11 @@ static inline uint16_t ts_language_lookup( } } -static inline TSStateId ts_language_next_state(const TSLanguage *self, - TSStateId state, - TSSymbol symbol) { +static inline TSStateId ts_language_next_state( + const TSLanguage *self, + TSStateId state, + TSSymbol symbol +) { if (symbol == ts_builtin_sym_error || symbol == ts_builtin_sym_error_repeat) { return 0; } else if (symbol < self->token_count) { @@ -102,9 +108,10 @@ static inline TSStateId ts_language_next_state(const TSLanguage *self, } } -static inline const bool * -ts_language_enabled_external_tokens(const TSLanguage *self, - unsigned external_scanner_state) { +static inline const bool *ts_language_enabled_external_tokens( + const TSLanguage *self, + unsigned external_scanner_state +) { if (external_scanner_state == 0) { return NULL; } else { @@ -112,13 +119,25 @@ ts_language_enabled_external_tokens(const TSLanguage *self, } } -static inline const TSSymbol * -ts_language_alias_sequence(const TSLanguage *self, uint32_t production_id) { - return production_id > 0 ? - self->alias_sequences + production_id * self->max_alias_sequence_length : +static inline const TSSymbol *ts_language_alias_sequence( + const TSLanguage *self, + uint32_t production_id +) { + return production_id ? + &self->alias_sequences[production_id * self->max_alias_sequence_length] : NULL; } +static inline TSSymbol ts_language_alias_at( + const TSLanguage *self, + uint32_t production_id, + uint32_t child_index +) { + return production_id ? + self->alias_sequences[production_id * self->max_alias_sequence_length + child_index] : + 0; +} + static inline void ts_language_field_map( const TSLanguage *self, uint32_t production_id, diff --git a/lib/src/subtree.c b/lib/src/subtree.c index ef92a32f..24dc06b2 100644 --- a/lib/src/subtree.c +++ b/lib/src/subtree.c @@ -360,7 +360,7 @@ void ts_subtree_set_children( self.ptr->has_external_tokens = false; self.ptr->dynamic_precedence = 0; - uint32_t non_extra_index = 0; + uint32_t structural_index = 0; const TSSymbol *alias_sequence = ts_language_alias_sequence(language, self.ptr->production_id); uint32_t lookahead_end_byte = 0; @@ -387,9 +387,9 @@ void ts_subtree_set_children( self.ptr->dynamic_precedence += ts_subtree_dynamic_precedence(child); self.ptr->node_count += ts_subtree_node_count(child); - if (alias_sequence && alias_sequence[non_extra_index] != 0 && !ts_subtree_extra(child)) { + if (alias_sequence && alias_sequence[structural_index] != 0 && !ts_subtree_extra(child)) { self.ptr->visible_child_count++; - if (ts_language_symbol_metadata(language, alias_sequence[non_extra_index]).named) { + if (ts_language_symbol_metadata(language, alias_sequence[structural_index]).named) { self.ptr->named_child_count++; } } else if (ts_subtree_visible(child)) { @@ -407,7 +407,7 @@ void ts_subtree_set_children( self.ptr->parse_state = TS_TREE_STATE_NONE; } - if (!ts_subtree_extra(child)) non_extra_index++; + if (!ts_subtree_extra(child)) structural_index++; } self.ptr->lookahead_bytes = lookahead_end_byte - self.ptr->size.bytes - self.ptr->padding.bytes; diff --git a/lib/src/tree_cursor.c b/lib/src/tree_cursor.c index 00b9679d..06c724d2 100644 --- a/lib/src/tree_cursor.c +++ b/lib/src/tree_cursor.c @@ -205,19 +205,21 @@ bool ts_tree_cursor_goto_parent(TSTreeCursor *_self) { TreeCursor *self = (TreeCursor *)_self; for (unsigned i = self->stack.size - 2; i + 1 > 0; i--) { TreeCursorEntry *entry = &self->stack.contents[i]; - bool is_aliased = false; - if (i > 0) { - TreeCursorEntry *parent_entry = &self->stack.contents[i - 1]; - const TSSymbol *alias_sequence = ts_language_alias_sequence( - self->tree->language, - parent_entry->subtree->ptr->production_id - ); - is_aliased = alias_sequence && alias_sequence[entry->structural_child_index]; - } - if (ts_subtree_visible(*entry->subtree) || is_aliased) { + if (ts_subtree_visible(*entry->subtree)) { self->stack.size = i + 1; return true; } + if (i > 0 && !ts_subtree_extra(*entry->subtree)) { + TreeCursorEntry *parent_entry = &self->stack.contents[i - 1]; + if (ts_language_alias_at( + self->tree->language, + parent_entry->subtree->ptr->production_id, + entry->structural_child_index + )) { + self->stack.size = i + 1; + return true; + } + } } return false; } @@ -226,15 +228,13 @@ TSNode ts_tree_cursor_current_node(const TSTreeCursor *_self) { const TreeCursor *self = (const TreeCursor *)_self; TreeCursorEntry *last_entry = array_back(&self->stack); TSSymbol alias_symbol = 0; - if (self->stack.size > 1) { + if (self->stack.size > 1 && !ts_subtree_extra(*last_entry->subtree)) { TreeCursorEntry *parent_entry = &self->stack.contents[self->stack.size - 2]; - const TSSymbol *alias_sequence = ts_language_alias_sequence( + alias_symbol = ts_language_alias_at( self->tree->language, - parent_entry->subtree->ptr->production_id + parent_entry->subtree->ptr->production_id, + last_entry->structural_child_index ); - if (alias_sequence && !ts_subtree_extra(*last_entry->subtree)) { - alias_symbol = alias_sequence[last_entry->structural_child_index]; - } } return ts_node_new( self->tree, @@ -263,13 +263,14 @@ TSFieldId ts_tree_cursor_current_status( // Stop walking up when a visible ancestor is found. if (i != self->stack.size - 1) { if (ts_subtree_visible(*entry->subtree)) break; - const TSSymbol *alias_sequence = ts_language_alias_sequence( - self->tree->language, - parent_entry->subtree->ptr->production_id - ); - if (alias_sequence && alias_sequence[entry->structural_child_index]) { - break; - } + if ( + !ts_subtree_extra(*entry->subtree) && + ts_language_alias_at( + self->tree->language, + parent_entry->subtree->ptr->production_id, + entry->structural_child_index + ) + ) break; } if (ts_subtree_child_count(*parent_entry->subtree) > entry->child_index + 1) { @@ -321,13 +322,14 @@ TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) { // Stop walking up when another visible node is found. if (i != self->stack.size - 1) { if (ts_subtree_visible(*entry->subtree)) break; - const TSSymbol *alias_sequence = ts_language_alias_sequence( - self->tree->language, - parent_entry->subtree->ptr->production_id - ); - if (alias_sequence && alias_sequence[entry->structural_child_index]) { - break; - } + if ( + !ts_subtree_extra(*entry->subtree) && + ts_language_alias_at( + self->tree->language, + parent_entry->subtree->ptr->production_id, + entry->structural_child_index + ) + ) break; } if (ts_subtree_extra(*entry->subtree)) break; From 4c2f36a07b99732c96d474fdae30c1cf158b966e Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 8 Jun 2020 16:07:22 -0700 Subject: [PATCH 02/26] Mark steps as definite on query construction * Add a ts_query_pattern_is_definite API, just for debugging this * Store state_count on TSLanguage structs, to allow for scanning parse tables --- cli/src/generate/render.rs | 62 +-- cli/src/main.rs | 10 +- cli/src/tests/query_test.rs | 72 +++- lib/binding_rust/bindings.rs | 13 +- lib/binding_rust/lib.rs | 8 + lib/include/tree_sitter/api.h | 8 +- lib/include/tree_sitter/parser.h | 1 + lib/src/array.h | 24 ++ lib/src/language.h | 1 + lib/src/query.c | 632 +++++++++++++++++++++++++++++-- 10 files changed, 755 insertions(+), 76 deletions(-) diff --git a/cli/src/generate/render.rs b/cli/src/generate/render.rs index 270bd00d..300ad383 100644 --- a/cli/src/generate/render.rs +++ b/cli/src/generate/render.rs @@ -95,11 +95,7 @@ impl Generator { self.add_stats(); self.add_symbol_enum(); self.add_symbol_names_list(); - - if self.next_abi { - self.add_unique_symbol_map(); - } - + self.add_unique_symbol_map(); self.add_symbol_metadata_list(); if !self.field_names.is_empty() { @@ -177,20 +173,16 @@ impl Generator { // If we are opting in to the new unstable language ABI, then use the concept of // "small parse states". Otherwise, use the same representation for all parse // states. - if self.next_abi { - let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2); - self.large_state_count = self - .parse_table - .states - .iter() - .enumerate() - .take_while(|(i, s)| { - *i <= 1 || s.terminal_entries.len() + s.nonterminal_entries.len() > threshold - }) - .count(); - } else { - self.large_state_count = self.parse_table.states.len(); - } + let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2); + self.large_state_count = self + .parse_table + .states + .iter() + .enumerate() + .take_while(|(i, s)| { + *i <= 1 || s.terminal_entries.len() + s.nonterminal_entries.len() > threshold + }) + .count(); } fn add_includes(&mut self) { @@ -256,10 +248,7 @@ impl Generator { "#define STATE_COUNT {}", self.parse_table.states.len() ); - - if self.next_abi { - add_line!(self, "#define LARGE_STATE_COUNT {}", self.large_state_count); - } + add_line!(self, "#define LARGE_STATE_COUNT {}", self.large_state_count); add_line!( self, @@ -689,17 +678,12 @@ impl Generator { name ); indent!(self); + add_line!(self, "START_LEXER();"); - - if self.next_abi { - add_line!(self, "eof = lexer->eof(lexer);"); - } else { - add_line!(self, "eof = lookahead == 0;"); - } - + add_line!(self, "eof = lexer->eof(lexer);"); add_line!(self, "switch (state) {{"); - indent!(self); + indent!(self); for (i, state) in lex_table.states.into_iter().enumerate() { add_line!(self, "case {}:", i); indent!(self); @@ -714,6 +698,7 @@ impl Generator { dedent!(self); add_line!(self, "}}"); + dedent!(self); add_line!(self, "}}"); add_line!(self, ""); @@ -967,12 +952,7 @@ impl Generator { add_line!( self, - "static uint16_t ts_parse_table[{}][SYMBOL_COUNT] = {{", - if self.next_abi { - "LARGE_STATE_COUNT" - } else { - "STATE_COUNT" - } + "static uint16_t ts_parse_table[LARGE_STATE_COUNT][SYMBOL_COUNT] = {{", ); indent!(self); @@ -1224,9 +1204,10 @@ impl Generator { add_line!(self, ".symbol_count = SYMBOL_COUNT,"); add_line!(self, ".alias_count = ALIAS_COUNT,"); add_line!(self, ".token_count = TOKEN_COUNT,"); + add_line!(self, ".large_state_count = LARGE_STATE_COUNT,"); if self.next_abi { - add_line!(self, ".large_state_count = LARGE_STATE_COUNT,"); + add_line!(self, ".state_count = STATE_COUNT,"); } add_line!(self, ".symbol_metadata = ts_symbol_metadata,"); @@ -1249,10 +1230,7 @@ impl Generator { add_line!(self, ".parse_actions = ts_parse_actions,"); add_line!(self, ".lex_modes = ts_lex_modes,"); add_line!(self, ".symbol_names = ts_symbol_names,"); - - if self.next_abi { - add_line!(self, ".public_symbol_map = ts_symbol_map,"); - } + add_line!(self, ".public_symbol_map = ts_symbol_map,"); if !self.parse_table.production_infos.is_empty() { add_line!( diff --git a/cli/src/main.rs b/cli/src/main.rs index 757c70eb..04cd34cd 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -149,8 +149,14 @@ fn run() -> error::Result<()> { .arg(Arg::with_name("path").index(1).multiple(true)), ) .subcommand( - SubCommand::with_name("web-ui").about("Test a parser interactively in the browser") - .arg(Arg::with_name("quiet").long("quiet").short("q").help("open in default browser")), + SubCommand::with_name("web-ui") + .about("Test a parser interactively in the browser") + .arg( + Arg::with_name("quiet") + .long("quiet") + .short("q") + .help("open in default browser"), + ), ) .subcommand( SubCommand::with_name("dump-languages") diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index d4f18c7d..92aff5fb 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1919,7 +1919,7 @@ fn test_query_start_byte_for_pattern() { let patterns_3 = " ((identifier) @b (#match? @b i)) (function_declaration name: (identifier) @c) - (method_definition name: (identifier) @d) + (method_definition name: (property_identifier) @d) " .trim_start(); @@ -2048,6 +2048,76 @@ fn test_query_disable_pattern() { }); } +#[test] +fn test_query_is_definite() { + struct Row { + pattern: &'static str, + results_by_step_index: &'static [(usize, bool)], + } + + let rows = &[ + Row { + pattern: r#"(object "{" "}")"#, + results_by_step_index: &[ + (0, false), + (1, true), // "{" + (2, true), // "}" + ], + }, + Row { + pattern: r#"(pair (property_identifier) ":")"#, + results_by_step_index: &[ + (0, false), + (1, false), // property_identifier + (2, true), // ":"" + ], + }, + Row { + pattern: r#"(object "{" (_) "}")"#, + results_by_step_index: &[ + (0, false), + (1, false), // "{"" + (2, false), // (_) + (3, true), // "}" + ], + }, + Row { + // Named wildcards, fields + pattern: r#"(binary_expression left: (identifier) right: (_))"#, + results_by_step_index: &[ + (0, false), + (1, false), // identifier + (2, true), // (_) + ], + }, + Row { + pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, + results_by_step_index: &[ + (0, false), + (1, true), // identifier + (2, true), // statement_block + ], + }, + ]; + + allocations::record(|| { + let language = get_language("javascript"); + for row in rows.iter() { + let query = Query::new(language, row.pattern).unwrap(); + for (step_index, is_definite) in row.results_by_step_index { + assert_eq!( + query.pattern_is_definite(0, *step_index), + *is_definite, + "Pattern: {:?}, step: {}, expected is_definite to be {}", + row.pattern, + step_index, + is_definite, + ) + } + } + }); +} + fn assert_query_matches( language: Language, query: &Query, diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index cba87fa3..7dc48660 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -172,9 +172,9 @@ extern "C" { #[doc = " the given ranges must be ordered from earliest to latest in the document,"] #[doc = " and they must not overlap. That is, the following must hold for all"] #[doc = " `i` < `length - 1`:"] - #[doc = " ```text"] + #[doc = ""] #[doc = " ranges[i].end_byte <= ranges[i + 1].start_byte"] - #[doc = " ```"] + #[doc = ""] #[doc = " If this requirement is not satisfied, the operation will fail, the ranges"] #[doc = " will not be assigned, and this function will return `false`. On success,"] #[doc = " this function returns `true`"] @@ -649,6 +649,13 @@ extern "C" { length: *mut u32, ) -> *const TSQueryPredicateStep; } +extern "C" { + pub fn ts_query_pattern_is_definite( + self_: *const TSQuery, + pattern_index: u32, + step_index: u32, + ) -> bool; +} extern "C" { #[doc = " Get the name and length of one of the query\'s captures, or one of the"] #[doc = " query\'s string literals. Each capture and string is associated with a"] @@ -800,5 +807,5 @@ extern "C" { pub fn ts_language_version(arg1: *const TSLanguage) -> u32; } -pub const TREE_SITTER_LANGUAGE_VERSION: usize = 11; +pub const TREE_SITTER_LANGUAGE_VERSION: usize = 12; pub const TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION: usize = 9; diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index c0aba32f..453cb8e7 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1449,6 +1449,14 @@ impl Query { unsafe { ffi::ts_query_disable_pattern(self.ptr.as_ptr(), index as u32) } } + /// Check if a pattern will definitely match after a certain number of steps + /// have matched. + pub fn pattern_is_definite(&self, index: usize, step_index: usize) -> bool { + unsafe { + ffi::ts_query_pattern_is_definite(self.ptr.as_ptr(), index as u32, step_index as u32) + } + } + fn parse_property( function_name: &str, capture_names: &[String], diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 9d832e6e..1b2533fc 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -21,7 +21,7 @@ extern "C" { * The Tree-sitter library is generally backwards-compatible with languages * generated using older CLI versions, but is not forwards-compatible. */ -#define TREE_SITTER_LANGUAGE_VERSION 11 +#define TREE_SITTER_LANGUAGE_VERSION 12 /** * The earliest ABI version that is supported by the current version of the @@ -718,6 +718,12 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t *length ); +bool ts_query_pattern_is_definite( + const TSQuery *self, + uint32_t pattern_index, + uint32_t step_index +); + /** * Get the name and length of one of the query's captures, or one of the * query's string literals. Each capture and string is associated with a diff --git a/lib/include/tree_sitter/parser.h b/lib/include/tree_sitter/parser.h index 11bf4fc4..360e012f 100644 --- a/lib/include/tree_sitter/parser.h +++ b/lib/include/tree_sitter/parser.h @@ -119,6 +119,7 @@ struct TSLanguage { const uint16_t *small_parse_table; const uint32_t *small_parse_table_map; const TSSymbol *public_symbol_map; + uint32_t state_count; }; /* diff --git a/lib/src/array.h b/lib/src/array.h index 26cb8448..c7e0ae4a 100644 --- a/lib/src/array.h +++ b/lib/src/array.h @@ -66,6 +66,30 @@ extern "C" { #define array_assign(self, other) \ array__assign((VoidArray *)(self), (const VoidArray *)(other), array__elem_size(self)) +#define array_search_sorted_by(self, start, field, needle, out_index, out_exists) \ + do { \ + *(out_exists) = false; \ + for (*(out_index) = start; *(out_index) < (self)->size; (*(out_index))++) { \ + int _comparison = (int)((self)->contents[*(out_index)] field) - (int)(needle); \ + if (_comparison >= 0) { \ + if (_comparison == 0) *(out_exists) = true; \ + break; \ + } \ + } \ + } while (0); + +#define array_search_sorted_with(self, start, compare, needle, out_index, out_exists) \ + do { \ + *(out_exists) = false; \ + for (*(out_index) = start; *(out_index) < (self)->size; (*(out_index))++) { \ + int _comparison = compare(&(self)->contents[*(out_index)], (needle)); \ + if (_comparison >= 0) { \ + if (_comparison == 0) *(out_exists) = true; \ + break; \ + } \ + } \ + } while (0); + // Private typedef Array(void) VoidArray; diff --git a/lib/src/language.h b/lib/src/language.h index 2bb9a6f9..288c2a2b 100644 --- a/lib/src/language.h +++ b/lib/src/language.h @@ -12,6 +12,7 @@ extern "C" { #define TREE_SITTER_LANGUAGE_VERSION_WITH_FIELDS 10 #define TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING 11 #define TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES 11 +#define TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT 12 typedef struct { const TSParseAction *actions; diff --git a/lib/src/query.c b/lib/src/query.c index ff243494..10ab5371 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -14,6 +14,8 @@ #define MAX_STATE_COUNT 256 #define MAX_CAPTURE_LIST_COUNT 32 #define MAX_STEP_CAPTURE_COUNT 3 +#define MAX_STATE_PREDECESSOR_COUNT 100 +#define MAX_WALK_STATE_DEPTH 4 /* * Stream - A sequence of unicode characters derived from a UTF8 string. @@ -55,6 +57,7 @@ typedef struct { bool is_pass_through: 1; bool is_dead_end: 1; bool alternative_is_immediate: 1; + bool is_definite: 1; } QueryStep; /* @@ -89,6 +92,12 @@ typedef struct { uint16_t pattern_index; } PatternEntry; +typedef struct { + Slice predicate_steps; + uint32_t start_byte; + uint32_t start_step; +} QueryPattern; + /* * QueryState - The state of an in-progress match of a particular pattern * in a query. While executing, a `TSQueryCursor` must keep track of a number @@ -138,6 +147,31 @@ typedef struct { uint32_t usage_map; } CaptureListPool; +/* + * WalkState - The state needed for walking the parse table when analyzing + * a query pattern, to determine the steps where the pattern could fail + * to match. + */ +typedef struct { + TSStateId state; + TSSymbol parent_symbol; + uint16_t child_index; + TSFieldId field; +} WalkStateEntry; + +typedef struct { + WalkStateEntry stack[MAX_WALK_STATE_DEPTH]; + uint16_t depth; + uint16_t step_index; +} WalkState; + +/* + * StatePredecessorMap - A map that stores the predecessors of each parse state. + */ +typedef struct { + TSStateId *contents; +} StatePredecessorMap; + /* * TSQuery - A tree query, compiled from a string of S-expressions. The query * itself is immutable. The mutable state used in the process of executing the @@ -149,8 +183,7 @@ struct TSQuery { Array(QueryStep) steps; Array(PatternEntry) pattern_map; Array(TSQueryPredicateStep) predicate_steps; - Array(Slice) predicates_by_pattern; - Array(uint32_t) start_bytes_by_pattern; + Array(QueryPattern) patterns; const TSLanguage *language; uint16_t wildcard_root_pattern_count; TSSymbol *symbol_map; @@ -451,6 +484,7 @@ static QueryStep query_step__new( .is_pattern_start = false, .is_pass_through = false, .is_dead_end = false, + .is_definite = false, .is_immediate = is_immediate, .alternative_is_immediate = false, }; @@ -480,6 +514,67 @@ static void query_step__remove_capture(QueryStep *self, uint16_t capture_id) { } } +/********************** + * StatePredecessorMap + **********************/ + +static inline StatePredecessorMap state_predecessor_map_new(const TSLanguage *language) { + return (StatePredecessorMap) { + .contents = ts_calloc(language->state_count * (MAX_STATE_PREDECESSOR_COUNT + 1), sizeof(TSStateId)), + }; +} + +static inline void state_predecessor_map_delete(StatePredecessorMap *self) { + ts_free(self->contents); +} + +static inline void state_predecessor_map_add( + StatePredecessorMap *self, + TSStateId state, + TSStateId predecessor +) { + unsigned index = state * (MAX_STATE_PREDECESSOR_COUNT + 1); + TSStateId *count = &self->contents[index]; + if (*count == 0 || (*count < MAX_STATE_PREDECESSOR_COUNT && self->contents[index + *count] != predecessor)) { + (*count)++; + self->contents[index + *count] = predecessor; + } +} + +static inline const TSStateId *state_predecessor_map_get( + const StatePredecessorMap *self, + TSStateId state, + unsigned *count +) { + unsigned index = state * (MAX_STATE_PREDECESSOR_COUNT + 1); + *count = self->contents[index]; + return &self->contents[index + 1]; +} + +/************ + * WalkState + ************/ + +static inline int walk_state__compare(WalkState *self, WalkState *other) { + if (self->depth < other->depth) return -1; + if (self->depth > other->depth) return 1; + if (self->step_index < other->step_index) return -1; + if (self->step_index > other->step_index) return 1; + for (unsigned i = 0; i < self->depth; i++) { + if (self->stack[i].state < other->stack[i].state) return -1; + if (self->stack[i].state > other->stack[i].state) return 1; + if (self->stack[i].parent_symbol < other->stack[i].parent_symbol) return -1; + if (self->stack[i].parent_symbol > other->stack[i].parent_symbol) return 1; + if (self->stack[i].child_index < other->stack[i].child_index) return -1; + if (self->stack[i].child_index > other->stack[i].child_index) return 1; + } + return 0; +} + +static inline WalkStateEntry *walk_state__top(WalkState *self) { + return &self->stack[self->depth - 1]; +} + /********* * Query *********/ @@ -552,6 +647,466 @@ static inline void ts_query__pattern_map_insert( })); } +static void ts_query__analyze_patterns(TSQuery *self) { + typedef struct { + TSSymbol parent_symbol; + uint32_t parent_step_index; + Array(uint32_t) child_step_indices; + } ParentPattern; + + typedef struct { + TSStateId state; + uint8_t child_index; + uint8_t production_id; + bool done; + } SubgraphNode; + + typedef struct { + TSSymbol symbol; + Array(TSStateId) start_states; + Array(SubgraphNode) nodes; + } SymbolSubgraph; + + typedef Array(WalkState) WalkStateList; + + // Identify all of the patterns in the query that have child patterns. This + // includes both top-level patterns and patterns that are nested within some + // larger pattern. For each of these, record the parent symbol, the step index + // and all of the immediate child step indices in reverse order. + Array(ParentPattern) parent_patterns = array_new(); + Array(uint32_t) stack = array_new(); + for (unsigned i = 0; i < self->steps.size; i++) { + QueryStep *step = &self->steps.contents[i]; + if (step->depth == PATTERN_DONE_MARKER) { + array_clear(&stack); + } else { + uint32_t parent_pattern_index = 0; + while (stack.size > 0) { + parent_pattern_index = *array_back(&stack); + ParentPattern *parent_pattern = &parent_patterns.contents[parent_pattern_index]; + QueryStep *parent_step = &self->steps.contents[parent_pattern->parent_step_index]; + if (parent_step->depth >= step->depth) { + stack.size--; + } else { + break; + } + } + + if (stack.size > 0) { + ParentPattern *parent_pattern = &parent_patterns.contents[parent_pattern_index]; + step->is_definite = true; + array_push(&parent_pattern->child_step_indices, i); + } + + array_push(&stack, parent_patterns.size); + array_push(&parent_patterns, ((ParentPattern) { + .parent_symbol = step->symbol, + .parent_step_index = i, + })); + } + } + for (unsigned i = 0; i < parent_patterns.size; i++) { + ParentPattern *parent_pattern = &parent_patterns.contents[i]; + if (parent_pattern->child_step_indices.size == 0) { + array_erase(&parent_patterns, i); + i--; + } + } + + // Debug + // { + // printf("\nParent pattern entries\n"); + // for (unsigned i = 0; i < parent_patterns.size; i++) { + // ParentPattern *parent_pattern = &parent_patterns.contents[i]; + // printf(" %s ->", ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); + // for (unsigned j = 0; j < parent_pattern->child_step_indices.size; j++) { + // QueryStep *step = &self->steps.contents[parent_pattern->child_step_indices.contents[j]]; + // printf(" %s", ts_language_symbol_name(self->language, step->symbol)); + // } + // printf("\n"); + // } + // } + + // Initialize a set of subgraphs, with one subgraph for each parent symbol, + // in the query, and one subgraph for each hidden symbol. + unsigned subgraph_index = 0, exists; + Array(SymbolSubgraph) subgraphs = array_new(); + for (unsigned i = 0; i < parent_patterns.size; i++) { + TSSymbol parent_symbol = parent_patterns.contents[i].parent_symbol; + array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); + if (!exists) { + array_insert(&subgraphs, subgraph_index, ((SymbolSubgraph) { .symbol = parent_symbol, })); + } + } + subgraph_index = 0; + for (TSSymbol sym = 0; sym < self->language->symbol_count; sym++) { + if (!ts_language_symbol_metadata(self->language, sym).visible) { + array_search_sorted_by( + &subgraphs, subgraph_index, + .symbol, sym, + &subgraph_index, &exists + ); + if (!exists) { + array_insert(&subgraphs, subgraph_index, ((SymbolSubgraph) { .symbol = sym, })); + subgraph_index++; + } + } + } + + // Scan the parse table to find the data needed for these subgraphs. + // Collect three things during this scan: + // 1) All of the parse states where one of these symbols can start. + // 2) All of the parse states where one of these symbols can end, along + // with information about the node that would be created. + // 3) A list of predecessor states for each state. + StatePredecessorMap predecessor_map = state_predecessor_map_new(self->language); + for (TSStateId state = 1; state < self->language->state_count; state++) { + unsigned subgraph_index = 0, exists; + for (TSSymbol sym = 0; sym < self->language->token_count; sym++) { + unsigned count; + const TSParseAction *actions = ts_language_actions(self->language, state, sym, &count); + for (unsigned i = 0; i < count; i++) { + const TSParseAction *action = &actions[i]; + if (action->type == TSParseActionTypeReduce) { + unsigned exists; + array_search_sorted_by( + &subgraphs, + subgraph_index, + .symbol, + action->params.reduce.symbol, + &subgraph_index, + &exists + ); + if (exists) { + SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { + array_push(&subgraph->nodes, ((SubgraphNode) { + .state = state, + .production_id = action->params.reduce.production_id, + .child_index = action->params.reduce.child_count, + .done = true, + })); + } + } + } else if ( + action->type == TSParseActionTypeShift && + !action->params.shift.extra + ) { + TSStateId next_state = action->params.shift.state; + state_predecessor_map_add(&predecessor_map, next_state, state); + } + } + } + for (TSSymbol sym = self->language->token_count; sym < self->language->symbol_count; sym++) { + TSStateId next_state = ts_language_next_state(self->language, state, sym); + if (next_state != 0) { + state_predecessor_map_add(&predecessor_map, next_state, state); + array_search_sorted_by( + &subgraphs, + subgraph_index, + .symbol, + sym, + &subgraph_index, + &exists + ); + if (exists) { + SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + array_push(&subgraph->start_states, state); + } + } + } + } + + // For each subgraph, compute the remainder of the nodes by walking backward + // from the end states using the predecessor map. + Array(SubgraphNode) next_nodes = array_new(); + for (unsigned i = 0; i < subgraphs.size; i++) { + SymbolSubgraph *subgraph = &subgraphs.contents[i]; + if (subgraph->nodes.size == 0) { + array_delete(&subgraph->start_states); + array_erase(&subgraphs, i); + i--; + continue; + } + array_assign(&next_nodes, &subgraph->nodes); + while (next_nodes.size > 0) { + SubgraphNode node = array_pop(&next_nodes); + if (node.child_index > 1) { + unsigned predecessor_count; + const TSStateId *predecessors = state_predecessor_map_get( + &predecessor_map, + node.state, + &predecessor_count + ); + for (unsigned j = 0; j < predecessor_count; j++) { + SubgraphNode predecessor_node = { + .state = predecessors[j], + .child_index = node.child_index - 1, + .production_id = node.production_id, + .done = false, + }; + unsigned index, exists; + array_search_sorted_by(&subgraph->nodes, 0, .state, predecessor_node.state, &index, &exists); + if (!exists) { + array_insert(&subgraph->nodes, index, predecessor_node); + array_push(&next_nodes, predecessor_node); + } + } + } + } + } + + // Debug + // { + // printf("\nSubgraphs:\n"); + // for (unsigned i = 0; i < subgraphs.size; i++) { + // SymbolSubgraph *subgraph = &subgraphs.contents[i]; + // printf(" %u, %s:\n", subgraph->symbol, ts_language_symbol_name(self->language, subgraph->symbol)); + // for (unsigned j = 0; j < subgraph->nodes.size; j++) { + // SubgraphNode *node = &subgraph->nodes.contents[j]; + // printf(" {state: %u, child_index: %u}\n", node->state, node->child_index); + // } + // printf("\n"); + // } + // } + + // For each non-terminal pattern, determine if the pattern can successfully match, + // and all of the possible children within the pattern where matching could fail. + WalkStateList walk_states = array_new(); + WalkStateList next_walk_states = array_new(); + Array(uint16_t) finished_step_indices = array_new(); + for (unsigned i = 0; i < parent_patterns.size; i++) { + ParentPattern *parent_pattern = &parent_patterns.contents[i]; + unsigned subgraph_index, exists; + array_search_sorted_by(&subgraphs, 0, .symbol, parent_pattern->parent_symbol, &subgraph_index, &exists); + if (!exists) { + // TODO - what to do for ERROR patterns + continue; + } + SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + + // Initialize a walk at every possible parse state where this non-terminal + // symbol can start. + array_clear(&walk_states); + for (unsigned j = 0; j < subgraph->start_states.size; j++) { + TSStateId state = subgraph->start_states.contents[j]; + array_push(&walk_states, ((WalkState) { + .step_index = 0, + .stack = { + [0] = { + .state = state, + .child_index = 0, + .parent_symbol = subgraph->symbol, + .field = 0, + }, + }, + .depth = 1, + })); + } + + // Walk the subgraph for this non-terminal, tracking all of the possible + // sequences of progress within the pattern. + array_clear(&finished_step_indices); + while (walk_states.size > 0) { + // Debug + // { + // printf("Walk states for %u %s:\n", i, ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); + // for (unsigned j = 0; j < walk_states.size; j++) { + // WalkState *walk_state = &walk_states.contents[j]; + // printf( + // " %u: {depth: %u, step: %u, state: %u, child_index: %u, parent: %s}\n", + // j, + // walk_state->depth, + // walk_state->step_index, + // walk_state->stack[walk_state->depth - 1].state, + // walk_state->stack[walk_state->depth - 1].child_index, + // ts_language_symbol_name(self->language, walk_state->stack[walk_state->depth - 1].parent_symbol) + // ); + // } + + // printf("\nFinished step indices for %u %s:", i, ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); + // for (unsigned j = 0; j < finished_step_indices.size; j++) { + // printf(" %u", finished_step_indices.contents[j]); + // } + // printf("\n\n"); + // } + + array_clear(&next_walk_states); + for (unsigned j = 0; j < walk_states.size; j++) { + WalkState *walk_state = &walk_states.contents[j]; + TSStateId state = walk_state->stack[walk_state->depth - 1].state; + unsigned child_index = walk_state->stack[walk_state->depth - 1].child_index; + TSSymbol parent_symbol = walk_state->stack[walk_state->depth - 1].parent_symbol; + + unsigned subgraph_index, exists; + array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); + if (!exists) continue; + SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + + for (TSSymbol sym = 0; sym < self->language->symbol_count; sym++) { + TSStateId successor_state = ts_language_next_state(self->language, state, sym); + if (successor_state && successor_state != state) { + unsigned node_index; + array_search_sorted_by(&subgraph->nodes, 0, .state, successor_state, &node_index, &exists); + if (exists) { + SubgraphNode *node = &subgraph->nodes.contents[node_index]; + if (node->child_index != child_index + 1) continue; + + WalkState next_walk_state = *walk_state; + walk_state__top(&next_walk_state)->child_index++; + walk_state__top(&next_walk_state)->state = successor_state; + + bool does_match = true; + unsigned step_index = parent_pattern->child_step_indices.contents[walk_state->step_index]; + QueryStep *step = &self->steps.contents[step_index]; + TSSymbol alias = ts_language_alias_at(self->language, node->production_id, child_index); + TSSymbol visible_symbol = alias + ? alias + : self->language->symbol_metadata[sym].visible + ? self->language->public_symbol_map[sym] + : 0; + if (visible_symbol) { + if (step->symbol == NAMED_WILDCARD_SYMBOL) { + if (!ts_language_symbol_metadata(self->language, visible_symbol).named) does_match = false; + } else if (step->symbol != WILDCARD_SYMBOL) { + if (step->symbol != visible_symbol) does_match = false; + } + } else if (next_walk_state.depth < MAX_WALK_STATE_DEPTH) { + does_match = false; + next_walk_state.depth++; + walk_state__top(&next_walk_state)->state = state; + walk_state__top(&next_walk_state)->child_index = 0; + walk_state__top(&next_walk_state)->parent_symbol = sym; + } else { + continue; + } + + TSFieldId field_id = 0; + const TSFieldMapEntry *field_map, *field_map_end; + ts_language_field_map(self->language, node->production_id, &field_map, &field_map_end); + for (; field_map != field_map_end; field_map++) { + if (field_map->child_index == child_index) { + field_id = field_map->field_id; + break; + } + } + + if (does_match) { + next_walk_state.step_index++; + } + + if (node->done) { + next_walk_state.depth--; + } + + if ( + next_walk_state.depth == 0 || + next_walk_state.step_index == parent_pattern->child_step_indices.size + ) { + unsigned index, exists; + array_search_sorted_by(&finished_step_indices, 0, , next_walk_state.step_index, &index, &exists); + if (!exists) array_insert(&finished_step_indices, index, next_walk_state.step_index); + continue; + } + + unsigned index, exists; + array_search_sorted_with( + &next_walk_states, + 0, + walk_state__compare, + &next_walk_state, + &index, + &exists + ); + if (!exists) { + array_insert(&next_walk_states, index, next_walk_state); + } + } + } + } + } + + WalkStateList _walk_states = walk_states; + walk_states = next_walk_states; + next_walk_states = _walk_states; + } + + // Debug + // { + // printf("Finished step indices for %u %s:", i, ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); + // for (unsigned j = 0; j < finished_step_indices.size; j++) { + // printf(" %u", finished_step_indices.contents[j]); + // } + // printf("\n\n"); + // } + + // A query step is definite if the containing pattern will definitely match + // once the step is reached. In other words, a step is *not* definite if + // it's possible to create a syntax node that matches up to until that step, + // but does not match the entire pattern. + for (unsigned j = 0, n = parent_pattern->child_step_indices.size; j < n; j++) { + uint32_t step_index = parent_pattern->child_step_indices.contents[j]; + for (unsigned k = 0; k < finished_step_indices.size; k++) { + uint32_t finished_step_index = finished_step_indices.contents[k]; + if (finished_step_index >= j && finished_step_index < n) { + QueryStep *step = &self->steps.contents[step_index]; + step->is_definite = false; + break; + } + } + } + } + + // In order for a parent step to be definite, all of its child steps must + // be definite. Propagate the definiteness up the pattern trees by walking + // the query's steps in reverse. + for (unsigned i = self->steps.size - 1; i + 1 > 0; i--) { + QueryStep *step = &self->steps.contents[i]; + for (unsigned j = i + 1; j < self->steps.size; j++) { + QueryStep *child_step = &self->steps.contents[j]; + if (child_step->depth <= step->depth) break; + if (child_step->depth == step->depth + 1 && !child_step->is_definite) { + step->is_definite = false; + break; + } + } + } + + // Debug + // { + // printf("\nSteps:\n"); + // for (unsigned i = 0; i < self->steps.size; i++) { + // QueryStep *step = &self->steps.contents[i]; + // if (step->depth == PATTERN_DONE_MARKER) { + // printf("\n"); + // continue; + // } + // printf( + // " {symbol: %s, is_definite: %d}\n", + // (step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL) ? "ANY" : ts_language_symbol_name(self->language, step->symbol), + // step->is_definite + // ); + // } + // } + + // Cleanup + for (unsigned i = 0; i < parent_patterns.size; i++) { + array_delete(&parent_patterns.contents[i].child_step_indices); + } + for (unsigned i = 0; i < subgraphs.size; i++) { + array_delete(&subgraphs.contents[i].start_states); + array_delete(&subgraphs.contents[i].nodes); + } + array_delete(&stack); + array_delete(&subgraphs); + array_delete(&next_nodes); + array_delete(&walk_states); + array_delete(&parent_patterns); + array_delete(&next_walk_states); + array_delete(&finished_step_indices); + state_predecessor_map_delete(&predecessor_map); +} + static void ts_query__finalize_steps(TSQuery *self) { for (unsigned i = 0; i < self->steps.size; i++) { QueryStep *step = &self->steps.contents[i]; @@ -588,7 +1143,7 @@ static TSQueryError ts_query__parse_predicate( predicate_name, length ); - array_back(&self->predicates_by_pattern)->length++; + array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -599,7 +1154,7 @@ static TSQueryError ts_query__parse_predicate( if (stream->next == ')') { stream_advance(stream); stream_skip_whitespace(stream); - array_back(&self->predicates_by_pattern)->length++; + array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeDone, .value_id = 0, @@ -628,7 +1183,7 @@ static TSQueryError ts_query__parse_predicate( return TSQueryErrorCapture; } - array_back(&self->predicates_by_pattern)->length++; + array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeCapture, .value_id = capture_id, @@ -668,7 +1223,7 @@ static TSQueryError ts_query__parse_predicate( string_content, length ); - array_back(&self->predicates_by_pattern)->length++; + array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -688,7 +1243,7 @@ static TSQueryError ts_query__parse_predicate( symbol_start, length ); - array_back(&self->predicates_by_pattern)->length++; + array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -712,7 +1267,6 @@ static TSQueryError ts_query__parse_pattern( TSQuery *self, Stream *stream, uint32_t depth, - uint32_t *capture_count, bool is_immediate ) { const uint32_t starting_step_index = self->steps.size; @@ -737,7 +1291,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth, - capture_count, is_immediate ); @@ -790,7 +1343,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth, - capture_count, child_is_immediate ); if (e == PARENT_DONE && stream->next == ')') { @@ -871,7 +1423,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth + 1, - capture_count, child_is_immediate ); if (e == PARENT_DONE && stream->next == ')') { @@ -955,7 +1506,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth, - capture_count, is_immediate ); if (e == PARENT_DONE) return TSQueryErrorSyntax; @@ -1069,8 +1619,6 @@ static TSQueryError ts_query__parse_pattern( break; } } - - (*capture_count)++; } // No more suffix modifiers @@ -1123,7 +1671,7 @@ TSQuery *ts_query_new( .captures = symbol_table_new(), .predicate_values = symbol_table_new(), .predicate_steps = array_new(), - .predicates_by_pattern = array_new(), + .patterns = array_new(), .symbol_map = symbol_map, .wildcard_root_pattern_count = 0, .language = language, @@ -1133,15 +1681,14 @@ TSQuery *ts_query_new( Stream stream = stream_new(source, source_len); stream_skip_whitespace(&stream); while (stream.input < stream.end) { - uint32_t pattern_index = self->predicates_by_pattern.size; + uint32_t pattern_index = self->patterns.size; uint32_t start_step_index = self->steps.size; - uint32_t capture_count = 0; - array_push(&self->start_bytes_by_pattern, stream.input - source); - array_push(&self->predicates_by_pattern, ((Slice) { - .offset = self->predicate_steps.size, - .length = 0, + array_push(&self->patterns, ((QueryPattern) { + .predicate_steps = (Slice) {.offset = self->predicate_steps.size, .length = 0}, + .start_byte = stream.input - source, + .start_step = self->steps.size, })); - *error_type = ts_query__parse_pattern(self, &stream, 0, &capture_count, false); + *error_type = ts_query__parse_pattern(self, &stream, 0, false); array_push(&self->steps, query_step__new(0, PATTERN_DONE_MARKER, false)); // If any pattern could not be parsed, then report the error information @@ -1183,6 +1730,10 @@ TSQuery *ts_query_new( } } + if (self->language->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT) { + ts_query__analyze_patterns(self); + } + ts_query__finalize_steps(self); return self; } @@ -1192,8 +1743,7 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->steps); array_delete(&self->pattern_map); array_delete(&self->predicate_steps); - array_delete(&self->predicates_by_pattern); - array_delete(&self->start_bytes_by_pattern); + array_delete(&self->patterns); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); ts_free(self->symbol_map); @@ -1202,7 +1752,7 @@ void ts_query_delete(TSQuery *self) { } uint32_t ts_query_pattern_count(const TSQuery *self) { - return self->predicates_by_pattern.size; + return self->patterns.size; } uint32_t ts_query_capture_count(const TSQuery *self) { @@ -1234,7 +1784,7 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t pattern_index, uint32_t *step_count ) { - Slice slice = self->predicates_by_pattern.contents[pattern_index]; + Slice slice = self->patterns.contents[pattern_index].predicate_steps; *step_count = slice.length; return &self->predicate_steps.contents[slice.offset]; } @@ -1243,7 +1793,35 @@ uint32_t ts_query_start_byte_for_pattern( const TSQuery *self, uint32_t pattern_index ) { - return self->start_bytes_by_pattern.contents[pattern_index]; + return self->patterns.contents[pattern_index].start_byte; +} + +bool ts_query_pattern_is_definite( + const TSQuery *self, + uint32_t pattern_index, + uint32_t step_count +) { + uint32_t step_index = self->patterns.contents[pattern_index].start_step; + for (;;) { + QueryStep *start_step = &self->steps.contents[step_index]; + if (step_index + step_count < self->steps.size) { + QueryStep *step = start_step; + for (unsigned i = 0; i < step_count; i++) { + if (step->depth == PATTERN_DONE_MARKER) { + step = NULL; + break; + } + step++; + } + if (step && !step->is_definite) return false; + } + if (start_step->alternative_index != NONE && start_step->alternative_index > step_index) { + step_index = start_step->alternative_index; + } else { + break; + } + } + return true; } void ts_query_disable_capture( From 7f955419a88caada8455f8f73230a7b32712b30c Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 22 Jun 2020 16:20:49 -0700 Subject: [PATCH 03/26] Start work on recognizing impossible patterns --- cli/src/error.rs | 4 + cli/src/tests/query_test.rs | 46 +++++++- lib/binding_rust/bindings.rs | 1 + lib/binding_rust/lib.rs | 56 ++++++---- lib/binding_web/binding.js | 3 + lib/include/tree_sitter/api.h | 1 + lib/src/query.c | 190 ++++++++++++++++++++++------------ 7 files changed, 212 insertions(+), 89 deletions(-) diff --git a/cli/src/error.rs b/cli/src/error.rs index 824bd92f..4b493019 100644 --- a/cli/src/error.rs +++ b/cli/src/error.rs @@ -70,6 +70,10 @@ impl<'a> From for Error { "Query error on line {}. Invalid syntax:\n{}", row, l )), + QueryError::Pattern(row, l) => Error::new(format!( + "Query error on line {}. Impossible pattern:\n{}", + row, l + )), QueryError::Predicate(p) => Error::new(format!("Query error: {}", p)), } } diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 92aff5fb..cc42a70d 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -12,7 +12,11 @@ fn test_query_errors_on_invalid_syntax() { let language = get_language("javascript"); assert!(Query::new(language, "(if_statement)").is_ok()); - assert!(Query::new(language, "(if_statement condition:(identifier))").is_ok()); + assert!(Query::new( + language, + "(if_statement condition:(parenthesized_expression (identifier)))" + ) + .is_ok()); // Mismatched parens assert_eq!( @@ -180,6 +184,28 @@ fn test_query_errors_on_invalid_conditions() { }); } +#[test] +fn test_query_errors_on_impossible_patterns() { + allocations::record(|| { + let language = get_language("javascript"); + + assert_eq!( + Query::new( + language, + "(binary_expression left:(identifier) left:(identifier))" + ), + Err(QueryError::Pattern( + 1, + [ + "(binary_expression left:(identifier) left:(identifier))", // + "^" + ] + .join("\n") + )) + ); + }); +} + #[test] fn test_query_matches_with_simple_pattern() { allocations::record(|| { @@ -1946,10 +1972,10 @@ fn test_query_capture_names() { language, r#" (if_statement - condition: (binary_expression + condition: (parenthesized_expression (binary_expression left: _ @left-operand operator: "||" - right: _ @right-operand) + right: _ @right-operand)) consequence: (statement_block) @body) (while_statement @@ -2051,12 +2077,14 @@ fn test_query_disable_pattern() { #[test] fn test_query_is_definite() { struct Row { + language: Language, pattern: &'static str, results_by_step_index: &'static [(usize, bool)], } let rows = &[ Row { + language: get_language("javascript"), pattern: r#"(object "{" "}")"#, results_by_step_index: &[ (0, false), @@ -2065,6 +2093,7 @@ fn test_query_is_definite() { ], }, Row { + language: get_language("javascript"), pattern: r#"(pair (property_identifier) ":")"#, results_by_step_index: &[ (0, false), @@ -2073,6 +2102,7 @@ fn test_query_is_definite() { ], }, Row { + language: get_language("javascript"), pattern: r#"(object "{" (_) "}")"#, results_by_step_index: &[ (0, false), @@ -2083,6 +2113,7 @@ fn test_query_is_definite() { }, Row { // Named wildcards, fields + language: get_language("javascript"), pattern: r#"(binary_expression left: (identifier) right: (_))"#, results_by_step_index: &[ (0, false), @@ -2091,6 +2122,7 @@ fn test_query_is_definite() { ], }, Row { + language: get_language("javascript"), pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, results_by_step_index: &[ (0, false), @@ -2098,12 +2130,16 @@ fn test_query_is_definite() { (2, true), // statement_block ], }, + Row { + language: get_language("javascript"), + pattern: r#""#, + results_by_step_index: &[], + }, ]; allocations::record(|| { - let language = get_language("javascript"); for row in rows.iter() { - let query = Query::new(language, row.pattern).unwrap(); + let query = Query::new(row.language, row.pattern).unwrap(); for (step_index, is_definite) in row.results_by_step_index { assert_eq!( query.pattern_is_definite(0, *step_index), diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 7dc48660..167edebf 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -132,6 +132,7 @@ pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1; pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2; pub const TSQueryError_TSQueryErrorField: TSQueryError = 3; pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4; +pub const TSQueryError_TSQueryErrorPattern: TSQueryError = 5; pub type TSQueryError = u32; extern "C" { #[doc = " Create a new parser."] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 453cb8e7..d3284974 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -163,6 +163,7 @@ pub enum QueryError { Field(usize, String), Capture(usize, String), Predicate(String), + Pattern(usize, String), } #[derive(Debug)] @@ -1175,27 +1176,42 @@ impl Query { } }); - let message = if let Some(line) = line_containing_error { - line.to_string() + "\n" + &" ".repeat(offset - line_start) + "^" - } else { - "Unexpected EOF".to_string() - }; - - // if line_containing_error - return if error_type != ffi::TSQueryError_TSQueryErrorSyntax { - let suffix = source.split_at(offset).1; - let end_offset = suffix - .find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-') - .unwrap_or(source.len()); - let name = suffix.split_at(end_offset).0.to_string(); - match error_type { - ffi::TSQueryError_TSQueryErrorNodeType => Err(QueryError::NodeType(row, name)), - ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(row, name)), - ffi::TSQueryError_TSQueryErrorCapture => Err(QueryError::Capture(row, name)), - _ => Err(QueryError::Syntax(row, message)), + return match error_type { + // Error types that report names + ffi::TSQueryError_TSQueryErrorNodeType + | ffi::TSQueryError_TSQueryErrorField + | ffi::TSQueryError_TSQueryErrorCapture => { + let suffix = source.split_at(offset).1; + let end_offset = suffix + .find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-') + .unwrap_or(source.len()); + let name = suffix.split_at(end_offset).0.to_string(); + match error_type { + ffi::TSQueryError_TSQueryErrorNodeType => { + Err(QueryError::NodeType(row, name)) + } + ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(row, name)), + ffi::TSQueryError_TSQueryErrorCapture => { + Err(QueryError::Capture(row, name)) + } + _ => unreachable!(), + } + } + + // Error types that report positions + _ => { + let message = if let Some(line) = line_containing_error { + line.to_string() + "\n" + &" ".repeat(offset - line_start) + "^" + } else { + "Unexpected EOF".to_string() + }; + match error_type { + ffi::TSQueryError_TSQueryErrorPattern => { + Err(QueryError::Pattern(row, message)) + } + _ => Err(QueryError::Syntax(row, message)), + } } - } else { - Err(QueryError::Syntax(row, message)) }; } diff --git a/lib/binding_web/binding.js b/lib/binding_web/binding.js index 567b7eb3..cd8bec75 100644 --- a/lib/binding_web/binding.js +++ b/lib/binding_web/binding.js @@ -680,6 +680,9 @@ class Language { case 4: error = new RangeError(`Bad capture name @${word}`); break; + case 5: + error = new SyntaxError(`Impossible pattern at offset ${errorIndex}: '${suffix}'...`); + break; default: error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`); break; diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 1b2533fc..1abbf28c 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -130,6 +130,7 @@ typedef enum { TSQueryErrorNodeType, TSQueryErrorField, TSQueryErrorCapture, + TSQueryErrorPattern, } TSQueryError; /********************/ diff --git a/lib/src/query.c b/lib/src/query.c index 10ab5371..0b7530da 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -156,7 +156,8 @@ typedef struct { TSStateId state; TSSymbol parent_symbol; uint16_t child_index; - TSFieldId field; + TSFieldId field_id: 15; + bool done: 1; } WalkStateEntry; typedef struct { @@ -165,6 +166,19 @@ typedef struct { uint16_t step_index; } WalkState; +typedef struct { + TSStateId state; + uint8_t production_id; + uint8_t child_index: 7; + bool done: 1; +} SubgraphNode; + +typedef struct { + TSSymbol symbol; + Array(TSStateId) start_states; + Array(SubgraphNode) nodes; +} SymbolSubgraph; + /* * StatePredecessorMap - A map that stores the predecessors of each parse state. */ @@ -571,6 +585,16 @@ static inline int walk_state__compare(WalkState *self, WalkState *other) { return 0; } +static inline int subgraph_node__compare(SubgraphNode *self, SubgraphNode *other) { + if (self->state < other->state) return -1; + if (self->state > other->state) return 1; + if (self->child_index < other->child_index) return -1; + if (self->child_index > other->child_index) return 1; + if (self->production_id < other->production_id) return -1; + if (self->production_id > other->production_id) return 1; + return 0; +} + static inline WalkStateEntry *walk_state__top(WalkState *self) { return &self->stack[self->depth - 1]; } @@ -647,28 +671,17 @@ static inline void ts_query__pattern_map_insert( })); } -static void ts_query__analyze_patterns(TSQuery *self) { +static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index) { typedef struct { TSSymbol parent_symbol; uint32_t parent_step_index; Array(uint32_t) child_step_indices; } ParentPattern; - typedef struct { - TSStateId state; - uint8_t child_index; - uint8_t production_id; - bool done; - } SubgraphNode; - - typedef struct { - TSSymbol symbol; - Array(TSStateId) start_states; - Array(SubgraphNode) nodes; - } SymbolSubgraph; - typedef Array(WalkState) WalkStateList; + bool result = true; + // Identify all of the patterns in the query that have child patterns. This // includes both top-level patterns and patterns that are nested within some // larger pattern. For each of these, record the parent symbol, the step index @@ -846,7 +859,11 @@ static void ts_query__analyze_patterns(TSQuery *self) { .done = false, }; unsigned index, exists; - array_search_sorted_by(&subgraph->nodes, 0, .state, predecessor_node.state, &index, &exists); + array_search_sorted_with( + &subgraph->nodes, 0, + subgraph_node__compare, &predecessor_node, + &index, &exists + ); if (!exists) { array_insert(&subgraph->nodes, index, predecessor_node); array_push(&next_nodes, predecessor_node); @@ -897,7 +914,8 @@ static void ts_query__analyze_patterns(TSQuery *self) { .state = state, .child_index = 0, .parent_symbol = subgraph->symbol, - .field = 0, + .field_id = 0, + .done = false, }, }, .depth = 1, @@ -923,20 +941,14 @@ static void ts_query__analyze_patterns(TSQuery *self) { // ts_language_symbol_name(self->language, walk_state->stack[walk_state->depth - 1].parent_symbol) // ); // } - - // printf("\nFinished step indices for %u %s:", i, ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); - // for (unsigned j = 0; j < finished_step_indices.size; j++) { - // printf(" %u", finished_step_indices.contents[j]); - // } - // printf("\n\n"); // } array_clear(&next_walk_states); for (unsigned j = 0; j < walk_states.size; j++) { WalkState *walk_state = &walk_states.contents[j]; - TSStateId state = walk_state->stack[walk_state->depth - 1].state; - unsigned child_index = walk_state->stack[walk_state->depth - 1].child_index; - TSSymbol parent_symbol = walk_state->stack[walk_state->depth - 1].parent_symbol; + TSStateId state = walk_state__top(walk_state)->state; + unsigned child_index = walk_state__top(walk_state)->child_index; + TSSymbol parent_symbol = walk_state__top(walk_state)->parent_symbol; unsigned subgraph_index, exists; array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); @@ -948,15 +960,14 @@ static void ts_query__analyze_patterns(TSQuery *self) { if (successor_state && successor_state != state) { unsigned node_index; array_search_sorted_by(&subgraph->nodes, 0, .state, successor_state, &node_index, &exists); - if (exists) { - SubgraphNode *node = &subgraph->nodes.contents[node_index]; - if (node->child_index != child_index + 1) continue; + while (exists && node_index < subgraph->nodes.size) { + SubgraphNode *node = &subgraph->nodes.contents[node_index++]; + if (node->state != successor_state || node->child_index != child_index + 1) continue; WalkState next_walk_state = *walk_state; walk_state__top(&next_walk_state)->child_index++; walk_state__top(&next_walk_state)->state = successor_state; - bool does_match = true; unsigned step_index = parent_pattern->child_step_indices.contents[walk_state->step_index]; QueryStep *step = &self->steps.contents[step_index]; TSSymbol alias = ts_language_alias_at(self->language, node->production_id, child_index); @@ -965,21 +976,6 @@ static void ts_query__analyze_patterns(TSQuery *self) { : self->language->symbol_metadata[sym].visible ? self->language->public_symbol_map[sym] : 0; - if (visible_symbol) { - if (step->symbol == NAMED_WILDCARD_SYMBOL) { - if (!ts_language_symbol_metadata(self->language, visible_symbol).named) does_match = false; - } else if (step->symbol != WILDCARD_SYMBOL) { - if (step->symbol != visible_symbol) does_match = false; - } - } else if (next_walk_state.depth < MAX_WALK_STATE_DEPTH) { - does_match = false; - next_walk_state.depth++; - walk_state__top(&next_walk_state)->state = state; - walk_state__top(&next_walk_state)->child_index = 0; - walk_state__top(&next_walk_state)->parent_symbol = sym; - } else { - continue; - } TSFieldId field_id = 0; const TSFieldMapEntry *field_map, *field_map_end; @@ -991,18 +987,45 @@ static void ts_query__analyze_patterns(TSQuery *self) { } } + if (node->done) { + walk_state__top(&next_walk_state)->done = true; + } + + bool does_match = true; + if (visible_symbol) { + if (step->symbol == NAMED_WILDCARD_SYMBOL) { + if (!self->language->symbol_metadata[visible_symbol].named) does_match = false; + } else if (step->symbol != WILDCARD_SYMBOL) { + if (step->symbol != visible_symbol) does_match = false; + } + + if (step->field) { + bool does_match_field = step->field == field_id; + if (!does_match_field) { + for (unsigned i = 0; i < walk_state->depth; i++) { + if (walk_state->stack[i].field_id == step->field) { + does_match_field = true; + } + } + } + does_match &= does_match_field; + } + } else if (next_walk_state.depth < MAX_WALK_STATE_DEPTH) { + does_match = false; + next_walk_state.depth++; + walk_state__top(&next_walk_state)->state = state; + walk_state__top(&next_walk_state)->child_index = 0; + walk_state__top(&next_walk_state)->parent_symbol = sym; + walk_state__top(&next_walk_state)->field_id = field_id; + } else { + continue; + } + if (does_match) { next_walk_state.step_index++; } - if (node->done) { - next_walk_state.depth--; - } - - if ( - next_walk_state.depth == 0 || - next_walk_state.step_index == parent_pattern->child_step_indices.size - ) { + if (next_walk_state.step_index == parent_pattern->child_step_indices.size) { unsigned index, exists; array_search_sorted_by(&finished_step_indices, 0, , next_walk_state.step_index, &index, &exists); if (!exists) array_insert(&finished_step_indices, index, next_walk_state.step_index); @@ -1011,19 +1034,39 @@ static void ts_query__analyze_patterns(TSQuery *self) { unsigned index, exists; array_search_sorted_with( - &next_walk_states, - 0, - walk_state__compare, - &next_walk_state, - &index, - &exists + &next_walk_states, 0, + walk_state__compare, &next_walk_state, + &index, &exists ); - if (!exists) { - array_insert(&next_walk_states, index, next_walk_state); - } + if (!exists) array_insert(&next_walk_states, index, next_walk_state); } } } + + bool did_pop = false; + while (walk_state->depth > 0 && walk_state__top(walk_state)->done) { + walk_state->depth--; + did_pop = true; + } + + if (did_pop) { + if (walk_state->depth == 0) { + unsigned index, exists; + array_search_sorted_by(&finished_step_indices, 0, , walk_state->step_index, &index, &exists); + if (!exists) array_insert(&finished_step_indices, index, walk_state->step_index); + } else { + unsigned index, exists; + array_search_sorted_with( + &next_walk_states, + 0, + walk_state__compare, + walk_state, + &index, + &exists + ); + if (!exists) array_insert(&next_walk_states, index, *walk_state); + } + } } WalkStateList _walk_states = walk_states; @@ -1037,7 +1080,7 @@ static void ts_query__analyze_patterns(TSQuery *self) { // for (unsigned j = 0; j < finished_step_indices.size; j++) { // printf(" %u", finished_step_indices.contents[j]); // } - // printf("\n\n"); + // printf(". Length: %u\n\n", parent_pattern->child_step_indices.size); // } // A query step is definite if the containing pattern will definitely match @@ -1055,6 +1098,16 @@ static void ts_query__analyze_patterns(TSQuery *self) { } } } + + if (finished_step_indices.size == 0 || *array_back(&finished_step_indices) < parent_pattern->child_step_indices.size) { + unsigned exists; + array_search_sorted_by( + &self->patterns, 0, + .start_step, + parent_pattern->parent_step_index, impossible_index, &exists); + result = false; + goto cleanup; + } } // In order for a parent step to be definite, all of its child steps must @@ -1090,6 +1143,7 @@ static void ts_query__analyze_patterns(TSQuery *self) { // } // Cleanup +cleanup: for (unsigned i = 0; i < parent_patterns.size; i++) { array_delete(&parent_patterns.contents[i].child_step_indices); } @@ -1105,6 +1159,8 @@ static void ts_query__analyze_patterns(TSQuery *self) { array_delete(&next_walk_states); array_delete(&finished_step_indices); state_predecessor_map_delete(&predecessor_map); + + return result; } static void ts_query__finalize_steps(TSQuery *self) { @@ -1731,7 +1787,13 @@ TSQuery *ts_query_new( } if (self->language->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT) { - ts_query__analyze_patterns(self); + unsigned impossible_pattern_index = 0; + if (!ts_query__analyze_patterns(self, &impossible_pattern_index)) { + *error_type = TSQueryErrorPattern; + *error_offset = self->patterns.contents[impossible_pattern_index].start_byte; + ts_query_delete(self); + return NULL; + } } ts_query__finalize_steps(self); From e3cf5df039c599d3515b8d112fa1524df9734b5a Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 25 Jun 2020 13:09:38 -0700 Subject: [PATCH 04/26] Use actual step indices when walking subgraphs --- lib/src/array.h | 14 ++ lib/src/query.c | 384 ++++++++++++++++++++++-------------------------- 2 files changed, 193 insertions(+), 205 deletions(-) diff --git a/lib/src/array.h b/lib/src/array.h index c7e0ae4a..e95867cf 100644 --- a/lib/src/array.h +++ b/lib/src/array.h @@ -90,6 +90,20 @@ extern "C" { } \ } while (0); +#define array_insert_sorted_by(self, start, field, value) \ + do { \ + unsigned index, exists; \ + array_search_sorted_by(self, start, field, (value) field, &index, &exists); \ + if (!exists) array_insert(self, index, value); \ + } while (0); + +#define array_insert_sorted_with(self, start, compare, value) \ + do { \ + unsigned index, exists; \ + array_search_sorted_with(self, start, compare, &(value), &index, &exists); \ + if (!exists) array_insert(self, index, value); \ + } while (0); + // Private typedef Array(void) VoidArray; diff --git a/lib/src/query.c b/lib/src/query.c index 0b7530da..563ffe8d 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -569,25 +569,37 @@ static inline const TSStateId *state_predecessor_map_get( * WalkState ************/ -static inline int walk_state__compare(WalkState *self, WalkState *other) { - if (self->depth < other->depth) return -1; - if (self->depth > other->depth) return 1; - if (self->step_index < other->step_index) return -1; - if (self->step_index > other->step_index) return 1; +static inline int walk_state__compare_position(const WalkState *self, const WalkState *other) { for (unsigned i = 0; i < self->depth; i++) { - if (self->stack[i].state < other->stack[i].state) return -1; - if (self->stack[i].state > other->stack[i].state) return 1; - if (self->stack[i].parent_symbol < other->stack[i].parent_symbol) return -1; - if (self->stack[i].parent_symbol > other->stack[i].parent_symbol) return 1; + if (i >= other->depth) return -1; if (self->stack[i].child_index < other->stack[i].child_index) return -1; if (self->stack[i].child_index > other->stack[i].child_index) return 1; } + if (self->depth < other->depth) return 1; return 0; } -static inline int subgraph_node__compare(SubgraphNode *self, SubgraphNode *other) { +static inline int walk_state__compare(const WalkState *self, const WalkState *other) { + int result = walk_state__compare_position(self, other); + if (result != 0) return result; + for (unsigned i = 0; i < self->depth; i++) { + if (self->stack[i].parent_symbol < other->stack[i].parent_symbol) return -1; + if (self->stack[i].parent_symbol > other->stack[i].parent_symbol) return 1; + if (self->stack[i].state < other->stack[i].state) return -1; + if (self->stack[i].state > other->stack[i].state) return 1; + if (self->stack[i].field_id < other->stack[i].field_id) return -1; + if (self->stack[i].field_id > other->stack[i].field_id) return 1; + } + if (self->step_index < other->step_index) return -1; + if (self->step_index > other->step_index) return 1; + return 0; +} + +static inline int subgraph_node__compare(const SubgraphNode *self, const SubgraphNode *other) { if (self->state < other->state) return -1; if (self->state > other->state) return 1; + if (self->done && !other->done) return -1; + if (!self->done && other->done) return 1; if (self->child_index < other->child_index) return -1; if (self->child_index > other->child_index) return 1; if (self->production_id < other->production_id) return -1; @@ -672,97 +684,52 @@ static inline void ts_query__pattern_map_insert( } static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index) { - typedef struct { - TSSymbol parent_symbol; - uint32_t parent_step_index; - Array(uint32_t) child_step_indices; - } ParentPattern; - - typedef Array(WalkState) WalkStateList; - - bool result = true; - // Identify all of the patterns in the query that have child patterns. This // includes both top-level patterns and patterns that are nested within some // larger pattern. For each of these, record the parent symbol, the step index // and all of the immediate child step indices in reverse order. - Array(ParentPattern) parent_patterns = array_new(); - Array(uint32_t) stack = array_new(); + Array(uint32_t) parent_step_indices = array_new(); for (unsigned i = 0; i < self->steps.size; i++) { QueryStep *step = &self->steps.contents[i]; - if (step->depth == PATTERN_DONE_MARKER) { - array_clear(&stack); - } else { - uint32_t parent_pattern_index = 0; - while (stack.size > 0) { - parent_pattern_index = *array_back(&stack); - ParentPattern *parent_pattern = &parent_patterns.contents[parent_pattern_index]; - QueryStep *parent_step = &self->steps.contents[parent_pattern->parent_step_index]; - if (parent_step->depth >= step->depth) { - stack.size--; - } else { - break; - } + if (i + 1 < self->steps.size) { + QueryStep *next_step = &self->steps.contents[i + 1]; + if ( + step->symbol != WILDCARD_SYMBOL && + step->symbol != NAMED_WILDCARD_SYMBOL && + next_step->depth > step->depth && + next_step->depth != PATTERN_DONE_MARKER + ) { + array_push(&parent_step_indices, i); } - - if (stack.size > 0) { - ParentPattern *parent_pattern = &parent_patterns.contents[parent_pattern_index]; - step->is_definite = true; - array_push(&parent_pattern->child_step_indices, i); - } - - array_push(&stack, parent_patterns.size); - array_push(&parent_patterns, ((ParentPattern) { - .parent_symbol = step->symbol, - .parent_step_index = i, - })); } - } - for (unsigned i = 0; i < parent_patterns.size; i++) { - ParentPattern *parent_pattern = &parent_patterns.contents[i]; - if (parent_pattern->child_step_indices.size == 0) { - array_erase(&parent_patterns, i); - i--; + if (step->depth > 0) { + step->is_definite = true; } } // Debug // { - // printf("\nParent pattern entries\n"); - // for (unsigned i = 0; i < parent_patterns.size; i++) { - // ParentPattern *parent_pattern = &parent_patterns.contents[i]; - // printf(" %s ->", ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); - // for (unsigned j = 0; j < parent_pattern->child_step_indices.size; j++) { - // QueryStep *step = &self->steps.contents[parent_pattern->child_step_indices.contents[j]]; - // printf(" %s", ts_language_symbol_name(self->language, step->symbol)); - // } - // printf("\n"); + // printf("\nParent steps\n"); + // for (unsigned i = 0; i < parent_step_indices.size; i++) { + // uint32_t parent_step_index = parent_step_indices.contents[i]; + // TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; + // printf(" %s %u\n", ts_language_symbol_name(self->language, parent_symbol), parent_step_index); // } // } // Initialize a set of subgraphs, with one subgraph for each parent symbol, // in the query, and one subgraph for each hidden symbol. - unsigned subgraph_index = 0, exists; Array(SymbolSubgraph) subgraphs = array_new(); - for (unsigned i = 0; i < parent_patterns.size; i++) { - TSSymbol parent_symbol = parent_patterns.contents[i].parent_symbol; - array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); - if (!exists) { - array_insert(&subgraphs, subgraph_index, ((SymbolSubgraph) { .symbol = parent_symbol, })); - } + for (unsigned i = 0; i < parent_step_indices.size; i++) { + uint32_t parent_step_index = parent_step_indices.contents[i]; + TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; + SymbolSubgraph subgraph = { .symbol = parent_symbol }; + array_insert_sorted_by(&subgraphs, 0, .symbol, subgraph); } - subgraph_index = 0; - for (TSSymbol sym = 0; sym < self->language->symbol_count; sym++) { + for (TSSymbol sym = self->language->token_count; sym < self->language->symbol_count; sym++) { if (!ts_language_symbol_metadata(self->language, sym).visible) { - array_search_sorted_by( - &subgraphs, subgraph_index, - .symbol, sym, - &subgraph_index, &exists - ); - if (!exists) { - array_insert(&subgraphs, subgraph_index, ((SymbolSubgraph) { .symbol = sym, })); - subgraph_index++; - } + SymbolSubgraph subgraph = { .symbol = sym }; + array_insert_sorted_by(&subgraphs, 0, .symbol, subgraph); } } @@ -889,13 +856,18 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // For each non-terminal pattern, determine if the pattern can successfully match, // and all of the possible children within the pattern where matching could fail. + bool result = true; + typedef Array(WalkState) WalkStateList; WalkStateList walk_states = array_new(); WalkStateList next_walk_states = array_new(); - Array(uint16_t) finished_step_indices = array_new(); - for (unsigned i = 0; i < parent_patterns.size; i++) { - ParentPattern *parent_pattern = &parent_patterns.contents[i]; + Array(uint16_t) final_step_indices = array_new(); + for (unsigned i = 0; i < parent_step_indices.size; i++) { + bool can_finish_pattern = false; + uint16_t parent_step_index = parent_step_indices.contents[i]; + uint16_t parent_depth = self->steps.contents[parent_step_index].depth; + TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; unsigned subgraph_index, exists; - array_search_sorted_by(&subgraphs, 0, .symbol, parent_pattern->parent_symbol, &subgraph_index, &exists); + array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); if (!exists) { // TODO - what to do for ERROR patterns continue; @@ -908,7 +880,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index for (unsigned j = 0; j < subgraph->start_states.size; j++) { TSStateId state = subgraph->start_states.contents[j]; array_push(&walk_states, ((WalkState) { - .step_index = 0, + .step_index = parent_step_index + 1, .stack = { [0] = { .state = state, @@ -924,36 +896,57 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // Walk the subgraph for this non-terminal, tracking all of the possible // sequences of progress within the pattern. - array_clear(&finished_step_indices); - while (walk_states.size > 0) { + array_clear(&final_step_indices); + for (;;) { // Debug // { - // printf("Walk states for %u %s:\n", i, ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); + // printf("Final step indices:"); + // for (unsigned j = 0; j < final_step_indices.size; j++) { + // printf(" %u", final_step_indices.contents[j]); + // } + // printf("\nWalk states for %u %s:\n", i, ts_language_symbol_name(self->language, parent_symbol)); // for (unsigned j = 0; j < walk_states.size; j++) { // WalkState *walk_state = &walk_states.contents[j]; - // printf( - // " %u: {depth: %u, step: %u, state: %u, child_index: %u, parent: %s}\n", - // j, - // walk_state->depth, - // walk_state->step_index, - // walk_state->stack[walk_state->depth - 1].state, - // walk_state->stack[walk_state->depth - 1].child_index, - // ts_language_symbol_name(self->language, walk_state->stack[walk_state->depth - 1].parent_symbol) - // ); + // printf(" %3u: {step: %u, stack: [", j, walk_state->step_index); + // for (unsigned k = 0; k < walk_state->depth; k++) { + // printf( + // " {parent: %s, child_index: %u, field: %s, state: %3u, done:%d}", + // self->language->symbol_names[walk_state->stack[k].parent_symbol], + // walk_state->stack[k].child_index, + // self->language->field_names[walk_state->stack[k].field_id], + // walk_state->stack[k].state, + // walk_state->stack[k].done + // ); + // } + // printf(" ]}\n"); // } + // printf("\n"); // } + if (walk_states.size == 0) break; array_clear(&next_walk_states); - for (unsigned j = 0; j < walk_states.size; j++) { - WalkState *walk_state = &walk_states.contents[j]; - TSStateId state = walk_state__top(walk_state)->state; - unsigned child_index = walk_state__top(walk_state)->child_index; - TSSymbol parent_symbol = walk_state__top(walk_state)->parent_symbol; + + unsigned j = 0; + for (; j < walk_states.size; j++) { + WalkState * const walk_state = &walk_states.contents[j]; + if ( + next_walk_states.size > 0 && + walk_state__compare_position(walk_state, array_back(&next_walk_states)) >= 0 + ) { + array_insert_sorted_with(&next_walk_states, 0, walk_state__compare, *walk_state); + continue; + } + + const TSStateId state = walk_state__top(walk_state)->state; + const TSSymbol parent_symbol = walk_state__top(walk_state)->parent_symbol; + const TSFieldId parent_field_id = walk_state__top(walk_state)->field_id; + const unsigned child_index = walk_state__top(walk_state)->child_index; + const QueryStep * const step = &self->steps.contents[walk_state->step_index]; unsigned subgraph_index, exists; array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); if (!exists) continue; - SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + const SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; for (TSSymbol sym = 0; sym < self->language->symbol_count; sym++) { TSStateId successor_state = ts_language_next_state(self->language, state, sym); @@ -962,111 +955,93 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index array_search_sorted_by(&subgraph->nodes, 0, .state, successor_state, &node_index, &exists); while (exists && node_index < subgraph->nodes.size) { SubgraphNode *node = &subgraph->nodes.contents[node_index++]; - if (node->state != successor_state || node->child_index != child_index + 1) continue; - - WalkState next_walk_state = *walk_state; - walk_state__top(&next_walk_state)->child_index++; - walk_state__top(&next_walk_state)->state = successor_state; - - unsigned step_index = parent_pattern->child_step_indices.contents[walk_state->step_index]; - QueryStep *step = &self->steps.contents[step_index]; + if (node->state != successor_state) break; + if (node->child_index != child_index + 1) continue; TSSymbol alias = ts_language_alias_at(self->language, node->production_id, child_index); + TSSymbol visible_symbol = alias ? alias : self->language->symbol_metadata[sym].visible ? self->language->public_symbol_map[sym] : 0; - TSFieldId field_id = 0; - const TSFieldMapEntry *field_map, *field_map_end; - ts_language_field_map(self->language, node->production_id, &field_map, &field_map_end); - for (; field_map != field_map_end; field_map++) { - if (field_map->child_index == child_index) { - field_id = field_map->field_id; - break; + TSFieldId field_id = parent_field_id; + if (!field_id) { + const TSFieldMapEntry *field_map, *field_map_end; + ts_language_field_map(self->language, node->production_id, &field_map, &field_map_end); + for (; field_map != field_map_end; field_map++) { + if (field_map->child_index == child_index) { + field_id = field_map->field_id; + break; + } } } - if (node->done) { - walk_state__top(&next_walk_state)->done = true; - } + WalkState next_walk_state = *walk_state; + walk_state__top(&next_walk_state)->child_index++; + walk_state__top(&next_walk_state)->state = successor_state; + if (node->done) walk_state__top(&next_walk_state)->done = true; - bool does_match = true; + bool does_match = false; if (visible_symbol) { + does_match = true; if (step->symbol == NAMED_WILDCARD_SYMBOL) { if (!self->language->symbol_metadata[visible_symbol].named) does_match = false; } else if (step->symbol != WILDCARD_SYMBOL) { if (step->symbol != visible_symbol) does_match = false; } - - if (step->field) { - bool does_match_field = step->field == field_id; - if (!does_match_field) { - for (unsigned i = 0; i < walk_state->depth; i++) { - if (walk_state->stack[i].field_id == step->field) { - does_match_field = true; - } - } - } - does_match &= does_match_field; + if (step->field && step->field != field_id) { + does_match = false; } } else if (next_walk_state.depth < MAX_WALK_STATE_DEPTH) { - does_match = false; next_walk_state.depth++; walk_state__top(&next_walk_state)->state = state; walk_state__top(&next_walk_state)->child_index = 0; walk_state__top(&next_walk_state)->parent_symbol = sym; walk_state__top(&next_walk_state)->field_id = field_id; + walk_state__top(&next_walk_state)->done = false; } else { continue; } if (does_match) { - next_walk_state.step_index++; + for (;;) { + next_walk_state.step_index++; + const QueryStep *step = &self->steps.contents[next_walk_state.step_index]; + if ( + step->depth == PATTERN_DONE_MARKER || + step->depth == parent_depth + ) { + can_finish_pattern = true; + break; + } + if (step->depth == parent_depth + 1) { + break; + } + } } - if (next_walk_state.step_index == parent_pattern->child_step_indices.size) { - unsigned index, exists; - array_search_sorted_by(&finished_step_indices, 0, , next_walk_state.step_index, &index, &exists); - if (!exists) array_insert(&finished_step_indices, index, next_walk_state.step_index); - continue; + while (next_walk_state.depth > 0 && walk_state__top(&next_walk_state)->done) { + memset(walk_state__top(&next_walk_state), 0, sizeof(WalkStateEntry)); + next_walk_state.depth--; } - unsigned index, exists; - array_search_sorted_with( - &next_walk_states, 0, - walk_state__compare, &next_walk_state, - &index, &exists - ); - if (!exists) array_insert(&next_walk_states, index, next_walk_state); + if ( + next_walk_state.depth == 0 || + self->steps.contents[next_walk_state.step_index].depth != parent_depth + 1 + ) { + array_insert_sorted_by(&final_step_indices, 0, , next_walk_state.step_index); + } else { + array_insert_sorted_with(&next_walk_states, 0, walk_state__compare, next_walk_state); + } } } } + } - bool did_pop = false; - while (walk_state->depth > 0 && walk_state__top(walk_state)->done) { - walk_state->depth--; - did_pop = true; - } - - if (did_pop) { - if (walk_state->depth == 0) { - unsigned index, exists; - array_search_sorted_by(&finished_step_indices, 0, , walk_state->step_index, &index, &exists); - if (!exists) array_insert(&finished_step_indices, index, walk_state->step_index); - } else { - unsigned index, exists; - array_search_sorted_with( - &next_walk_states, - 0, - walk_state__compare, - walk_state, - &index, - &exists - ); - if (!exists) array_insert(&next_walk_states, index, *walk_state); - } - } + for (; j < walk_states.size; j++) { + WalkState *walk_state = &walk_states.contents[j]; + array_push(&next_walk_states, *walk_state); } WalkStateList _walk_states = walk_states; @@ -1074,39 +1049,40 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index next_walk_states = _walk_states; } - // Debug - // { - // printf("Finished step indices for %u %s:", i, ts_language_symbol_name(self->language, parent_pattern->parent_symbol)); - // for (unsigned j = 0; j < finished_step_indices.size; j++) { - // printf(" %u", finished_step_indices.contents[j]); - // } - // printf(". Length: %u\n\n", parent_pattern->child_step_indices.size); - // } - // A query step is definite if the containing pattern will definitely match // once the step is reached. In other words, a step is *not* definite if // it's possible to create a syntax node that matches up to until that step, // but does not match the entire pattern. - for (unsigned j = 0, n = parent_pattern->child_step_indices.size; j < n; j++) { - uint32_t step_index = parent_pattern->child_step_indices.contents[j]; - for (unsigned k = 0; k < finished_step_indices.size; k++) { - uint32_t finished_step_index = finished_step_indices.contents[k]; - if (finished_step_index >= j && finished_step_index < n) { - QueryStep *step = &self->steps.contents[step_index]; - step->is_definite = false; + uint32_t child_step_index = parent_step_index + 1; + QueryStep *child_step = &self->steps.contents[child_step_index]; + while (child_step->depth == parent_depth + 1) { + for (unsigned k = 0; k < final_step_indices.size; k++) { + uint32_t final_step_index = final_step_indices.contents[k]; + if ( + final_step_index >= child_step_index && + self->steps.contents[final_step_index].depth != PATTERN_DONE_MARKER + ) { + child_step->is_definite = false; break; } } + do { + child_step_index++; + child_step++; + } while ( + child_step->depth != PATTERN_DONE_MARKER && + child_step->depth > parent_depth + 1 + ); } - if (finished_step_indices.size == 0 || *array_back(&finished_step_indices) < parent_pattern->child_step_indices.size) { + if (result && !can_finish_pattern) { unsigned exists; array_search_sorted_by( &self->patterns, 0, - .start_step, - parent_pattern->parent_step_index, impossible_index, &exists); + .start_step, parent_step_index, + impossible_index, &exists + ); result = false; - goto cleanup; } } @@ -1131,33 +1107,31 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // for (unsigned i = 0; i < self->steps.size; i++) { // QueryStep *step = &self->steps.contents[i]; // if (step->depth == PATTERN_DONE_MARKER) { - // printf("\n"); - // continue; + // printf(" %u: DONE\n", i); + // } else { + // printf( + // " %u: {symbol: %s, is_definite: %d}\n", + // i, + // (step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL) + // ? "ANY" + // : ts_language_symbol_name(self->language, step->symbol), + // step->is_definite + // ); // } - // printf( - // " {symbol: %s, is_definite: %d}\n", - // (step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL) ? "ANY" : ts_language_symbol_name(self->language, step->symbol), - // step->is_definite - // ); // } // } // Cleanup -cleanup: - for (unsigned i = 0; i < parent_patterns.size; i++) { - array_delete(&parent_patterns.contents[i].child_step_indices); - } for (unsigned i = 0; i < subgraphs.size; i++) { array_delete(&subgraphs.contents[i].start_states); array_delete(&subgraphs.contents[i].nodes); } - array_delete(&stack); array_delete(&subgraphs); array_delete(&next_nodes); array_delete(&walk_states); - array_delete(&parent_patterns); array_delete(&next_walk_states); - array_delete(&finished_step_indices); + array_delete(&final_step_indices); + array_delete(&parent_step_indices); state_predecessor_map_delete(&predecessor_map); return result; From 9fb39b89545c5fa53650dbac522fa3709065f7e4 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 25 Jun 2020 13:49:07 -0700 Subject: [PATCH 05/26] Start work on handling alternatives when analyzing queries --- lib/src/query.c | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/lib/src/query.c b/lib/src/query.c index 563ffe8d..bc277bf9 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -1004,6 +1004,11 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index continue; } + while (next_walk_state.depth > 0 && walk_state__top(&next_walk_state)->done) { + memset(walk_state__top(&next_walk_state), 0, sizeof(WalkStateEntry)); + next_walk_state.depth--; + } + if (does_match) { for (;;) { next_walk_state.step_index++; @@ -1021,18 +1026,23 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index } } - while (next_walk_state.depth > 0 && walk_state__top(&next_walk_state)->done) { - memset(walk_state__top(&next_walk_state), 0, sizeof(WalkStateEntry)); - next_walk_state.depth--; - } - if ( next_walk_state.depth == 0 || self->steps.contents[next_walk_state.step_index].depth != parent_depth + 1 ) { array_insert_sorted_by(&final_step_indices, 0, , next_walk_state.step_index); } else { - array_insert_sorted_with(&next_walk_states, 0, walk_state__compare, next_walk_state); + for (;;) { + const QueryStep *step = &self->steps.contents[next_walk_state.step_index]; + if (!step->is_dead_end) { + array_insert_sorted_with(&next_walk_states, 0, walk_state__compare, next_walk_state); + } + if (step->alternative_index != NONE && step->alternative_index > next_walk_state.step_index) { + next_walk_state.step_index = step->alternative_index; + } else { + break; + } + } } } } From 891de051e2a33afd1f0b677e72965618348980f3 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 25 Jun 2020 15:05:44 -0700 Subject: [PATCH 06/26] Fix population of subgraph nodes when analyzing queries --- cli/src/tests/query_test.rs | 8 ++++++++ lib/src/query.c | 16 ++++++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index cc42a70d..4c2f65ab 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2083,6 +2083,14 @@ fn test_query_is_definite() { } let rows = &[ + Row { + language: get_language("python"), + pattern: r#"(expression_statement (string))"#, + results_by_step_index: &[ + (0, false), + (1, false), // string + ], + }, Row { language: get_language("javascript"), pattern: r#"(object "{" "}")"#, diff --git a/lib/src/query.c b/lib/src/query.c index bc277bf9..fa0edba1 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -748,10 +748,9 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index for (unsigned i = 0; i < count; i++) { const TSParseAction *action = &actions[i]; if (action->type == TSParseActionTypeReduce) { - unsigned exists; array_search_sorted_by( &subgraphs, - subgraph_index, + 0, .symbol, action->params.reduce.symbol, &subgraph_index, @@ -759,13 +758,14 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index ); if (exists) { SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + SubgraphNode node = { + .state = state, + .production_id = action->params.reduce.production_id, + .child_index = action->params.reduce.child_count, + .done = true, + }; if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { - array_push(&subgraph->nodes, ((SubgraphNode) { - .state = state, - .production_id = action->params.reduce.production_id, - .child_index = action->params.reduce.child_count, - .done = true, - })); + array_push(&subgraph->nodes, node); } } } else if ( From 19baa5fd5e34257c1b34bcd899812b1b30f2d455 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 25 Jun 2020 17:56:43 -0700 Subject: [PATCH 07/26] Clean up and document query analysis code --- cli/src/tests/query_test.rs | 8 + lib/src/query.c | 427 +++++++++++++++++++----------------- 2 files changed, 230 insertions(+), 205 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 4c2f65ab..c73931ce 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2091,6 +2091,14 @@ fn test_query_is_definite() { (1, false), // string ], }, + Row { + language: get_language("javascript"), + pattern: r#"(expression_statement (string))"#, + results_by_step_index: &[ + (0, false), + (1, false), // string + ], + }, Row { language: get_language("javascript"), pattern: r#"(object "{" "}")"#, diff --git a/lib/src/query.c b/lib/src/query.c index fa0edba1..cf84115f 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -15,7 +15,7 @@ #define MAX_CAPTURE_LIST_COUNT 32 #define MAX_STEP_CAPTURE_COUNT 3 #define MAX_STATE_PREDECESSOR_COUNT 100 -#define MAX_WALK_STATE_DEPTH 4 +#define MAX_ANALYSIS_STATE_DEPTH 4 /* * Stream - A sequence of unicode characters derived from a UTF8 string. @@ -148,36 +148,36 @@ typedef struct { } CaptureListPool; /* - * WalkState - The state needed for walking the parse table when analyzing + * AnalysisState - The state needed for walking the parse table when analyzing * a query pattern, to determine the steps where the pattern could fail * to match. */ typedef struct { - TSStateId state; + TSStateId parse_state; TSSymbol parent_symbol; uint16_t child_index; TSFieldId field_id: 15; bool done: 1; -} WalkStateEntry; +} AnalysisStateEntry; typedef struct { - WalkStateEntry stack[MAX_WALK_STATE_DEPTH]; + AnalysisStateEntry stack[MAX_ANALYSIS_STATE_DEPTH]; uint16_t depth; uint16_t step_index; -} WalkState; +} AnalysisState; typedef struct { TSStateId state; uint8_t production_id; uint8_t child_index: 7; bool done: 1; -} SubgraphNode; +} AnalysisSubgraphNode; typedef struct { TSSymbol symbol; Array(TSStateId) start_states; - Array(SubgraphNode) nodes; -} SymbolSubgraph; + Array(AnalysisSubgraphNode) nodes; +} AnalysisSubgraph; /* * StatePredecessorMap - A map that stores the predecessors of each parse state. @@ -565,11 +565,14 @@ static inline const TSStateId *state_predecessor_map_get( return &self->contents[index + 1]; } -/************ - * WalkState - ************/ +/**************** + * AnalysisState + ****************/ -static inline int walk_state__compare_position(const WalkState *self, const WalkState *other) { +static inline int analysis_state__compare_position( + const AnalysisState *self, + const AnalysisState *other +) { for (unsigned i = 0; i < self->depth; i++) { if (i >= other->depth) return -1; if (self->stack[i].child_index < other->stack[i].child_index) return -1; @@ -579,14 +582,17 @@ static inline int walk_state__compare_position(const WalkState *self, const Walk return 0; } -static inline int walk_state__compare(const WalkState *self, const WalkState *other) { - int result = walk_state__compare_position(self, other); +static inline int analysis_state__compare( + const AnalysisState *self, + const AnalysisState *other +) { + int result = analysis_state__compare_position(self, other); if (result != 0) return result; for (unsigned i = 0; i < self->depth; i++) { if (self->stack[i].parent_symbol < other->stack[i].parent_symbol) return -1; if (self->stack[i].parent_symbol > other->stack[i].parent_symbol) return 1; - if (self->stack[i].state < other->stack[i].state) return -1; - if (self->stack[i].state > other->stack[i].state) return 1; + if (self->stack[i].parse_state < other->stack[i].parse_state) return -1; + if (self->stack[i].parse_state > other->stack[i].parse_state) return 1; if (self->stack[i].field_id < other->stack[i].field_id) return -1; if (self->stack[i].field_id > other->stack[i].field_id) return 1; } @@ -595,7 +601,15 @@ static inline int walk_state__compare(const WalkState *self, const WalkState *ot return 0; } -static inline int subgraph_node__compare(const SubgraphNode *self, const SubgraphNode *other) { +static inline AnalysisStateEntry *analysis_state__top(AnalysisState *self) { + return &self->stack[self->depth - 1]; +} + +/*********************** + * AnalysisSubgraphNode + ***********************/ + +static inline int analysis_subgraph_node__compare(const AnalysisSubgraphNode *self, const AnalysisSubgraphNode *other) { if (self->state < other->state) return -1; if (self->state > other->state) return 1; if (self->done && !other->done) return -1; @@ -607,10 +621,6 @@ static inline int subgraph_node__compare(const SubgraphNode *self, const Subgrap return 0; } -static inline WalkStateEntry *walk_state__top(WalkState *self) { - return &self->stack[self->depth - 1]; -} - /********* * Query *********/ @@ -683,11 +693,12 @@ static inline void ts_query__pattern_map_insert( })); } +// #define DEBUG_ANALYZE_QUERY + static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index) { - // Identify all of the patterns in the query that have child patterns. This - // includes both top-level patterns and patterns that are nested within some - // larger pattern. For each of these, record the parent symbol, the step index - // and all of the immediate child step indices in reverse order. + // Identify all of the patterns in the query that have child patterns, both at the + // top level and nested within other larger patterns. Record the step index where + // each pattern starts. Array(uint32_t) parent_step_indices = array_new(); for (unsigned i = 0; i < self->steps.size; i++) { QueryStep *step = &self->steps.contents[i]; @@ -707,33 +718,29 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index } } - // Debug - // { - // printf("\nParent steps\n"); - // for (unsigned i = 0; i < parent_step_indices.size; i++) { - // uint32_t parent_step_index = parent_step_indices.contents[i]; - // TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; - // printf(" %s %u\n", ts_language_symbol_name(self->language, parent_symbol), parent_step_index); - // } - // } - - // Initialize a set of subgraphs, with one subgraph for each parent symbol, - // in the query, and one subgraph for each hidden symbol. - Array(SymbolSubgraph) subgraphs = array_new(); + // For every parent symbol in the query, initialize an 'analysis subgraph'. + // This subgraph lists all of the states in the parse table that are directly + // involved in building subtrees for this symbol. + // + // In addition to the parent symbols in the query, construct subgraphs for all + // of the hidden symbols in the grammar, because these might occur within + // one of the parent nodes, such that their children appear to belong to the + // parent. + Array(AnalysisSubgraph) subgraphs = array_new(); for (unsigned i = 0; i < parent_step_indices.size; i++) { uint32_t parent_step_index = parent_step_indices.contents[i]; TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; - SymbolSubgraph subgraph = { .symbol = parent_symbol }; + AnalysisSubgraph subgraph = { .symbol = parent_symbol }; array_insert_sorted_by(&subgraphs, 0, .symbol, subgraph); } for (TSSymbol sym = self->language->token_count; sym < self->language->symbol_count; sym++) { if (!ts_language_symbol_metadata(self->language, sym).visible) { - SymbolSubgraph subgraph = { .symbol = sym }; + AnalysisSubgraph subgraph = { .symbol = sym }; array_insert_sorted_by(&subgraphs, 0, .symbol, subgraph); } } - // Scan the parse table to find the data needed for these subgraphs. + // Scan the parse table to find the data needed to populate these subgraphs. // Collect three things during this scan: // 1) All of the parse states where one of these symbols can start. // 2) All of the parse states where one of these symbols can end, along @@ -757,21 +764,17 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index &exists ); if (exists) { - SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; - SubgraphNode node = { - .state = state, - .production_id = action->params.reduce.production_id, - .child_index = action->params.reduce.child_count, - .done = true, - }; + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { - array_push(&subgraph->nodes, node); + array_push(&subgraph->nodes, ((AnalysisSubgraphNode) { + .state = state, + .production_id = action->params.reduce.production_id, + .child_index = action->params.reduce.child_count, + .done = true, + })); } } - } else if ( - action->type == TSParseActionTypeShift && - !action->params.shift.extra - ) { + } else if (action->type == TSParseActionTypeShift && !action->params.shift.extra) { TSStateId next_state = action->params.shift.state; state_predecessor_map_add(&predecessor_map, next_state, state); } @@ -790,18 +793,18 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index &exists ); if (exists) { - SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; array_push(&subgraph->start_states, state); } } } } - // For each subgraph, compute the remainder of the nodes by walking backward + // For each subgraph, compute the preceding states by walking backward // from the end states using the predecessor map. - Array(SubgraphNode) next_nodes = array_new(); + Array(AnalysisSubgraphNode) next_nodes = array_new(); for (unsigned i = 0; i < subgraphs.size; i++) { - SymbolSubgraph *subgraph = &subgraphs.contents[i]; + AnalysisSubgraph *subgraph = &subgraphs.contents[i]; if (subgraph->nodes.size == 0) { array_delete(&subgraph->start_states); array_erase(&subgraphs, i); @@ -810,7 +813,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index } array_assign(&next_nodes, &subgraph->nodes); while (next_nodes.size > 0) { - SubgraphNode node = array_pop(&next_nodes); + AnalysisSubgraphNode node = array_pop(&next_nodes); if (node.child_index > 1) { unsigned predecessor_count; const TSStateId *predecessors = state_predecessor_map_get( @@ -819,7 +822,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index &predecessor_count ); for (unsigned j = 0; j < predecessor_count; j++) { - SubgraphNode predecessor_node = { + AnalysisSubgraphNode predecessor_node = { .state = predecessors[j], .child_index = node.child_index - 1, .production_id = node.production_id, @@ -828,7 +831,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index unsigned index, exists; array_search_sorted_with( &subgraph->nodes, 0, - subgraph_node__compare, &predecessor_node, + analysis_subgraph_node__compare, &predecessor_node, &index, &exists ); if (!exists) { @@ -840,52 +843,48 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index } } - // Debug - // { - // printf("\nSubgraphs:\n"); - // for (unsigned i = 0; i < subgraphs.size; i++) { - // SymbolSubgraph *subgraph = &subgraphs.contents[i]; - // printf(" %u, %s:\n", subgraph->symbol, ts_language_symbol_name(self->language, subgraph->symbol)); - // for (unsigned j = 0; j < subgraph->nodes.size; j++) { - // SubgraphNode *node = &subgraph->nodes.contents[j]; - // printf(" {state: %u, child_index: %u}\n", node->state, node->child_index); - // } - // printf("\n"); - // } - // } + #ifdef DEBUG_ANALYZE_QUERY + printf("\nSubgraphs:\n"); + for (unsigned i = 0; i < subgraphs.size; i++) { + AnalysisSubgraph *subgraph = &subgraphs.contents[i]; + printf(" %u, %s:\n", subgraph->symbol, ts_language_symbol_name(self->language, subgraph->symbol)); + for (unsigned j = 0; j < subgraph->nodes.size; j++) { + AnalysisSubgraphNode *node = &subgraph->nodes.contents[j]; + printf(" {state: %u, child_index: %u}\n", node->state, node->child_index); + } + printf("\n"); + } + #endif // For each non-terminal pattern, determine if the pattern can successfully match, - // and all of the possible children within the pattern where matching could fail. + // and identify all of the possible children within the pattern where matching could fail. bool result = true; - typedef Array(WalkState) WalkStateList; - WalkStateList walk_states = array_new(); - WalkStateList next_walk_states = array_new(); + typedef Array(AnalysisState) AnalysisStateList; + AnalysisStateList states = array_new(); + AnalysisStateList next_states = array_new(); Array(uint16_t) final_step_indices = array_new(); for (unsigned i = 0; i < parent_step_indices.size; i++) { - bool can_finish_pattern = false; + // Find the subgraph that corresponds to this pattern's root symbol. uint16_t parent_step_index = parent_step_indices.contents[i]; uint16_t parent_depth = self->steps.contents[parent_step_index].depth; TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; unsigned subgraph_index, exists; array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); - if (!exists) { - // TODO - what to do for ERROR patterns - continue; - } - SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if (!exists) continue; + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; - // Initialize a walk at every possible parse state where this non-terminal - // symbol can start. - array_clear(&walk_states); + // Initialize an analysis state at every parse state in the table where + // this parent symbol can occur. + array_clear(&states); for (unsigned j = 0; j < subgraph->start_states.size; j++) { - TSStateId state = subgraph->start_states.contents[j]; - array_push(&walk_states, ((WalkState) { + TSStateId parse_state = subgraph->start_states.contents[j]; + array_push(&states, ((AnalysisState) { .step_index = parent_step_index + 1, .stack = { [0] = { - .state = state, + .parse_state = parse_state, + .parent_symbol = parent_symbol, .child_index = 0, - .parent_symbol = subgraph->symbol, .field_id = 0, .done = false, }, @@ -896,75 +895,87 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // Walk the subgraph for this non-terminal, tracking all of the possible // sequences of progress within the pattern. + bool can_finish_pattern = false; array_clear(&final_step_indices); for (;;) { - // Debug - // { - // printf("Final step indices:"); - // for (unsigned j = 0; j < final_step_indices.size; j++) { - // printf(" %u", final_step_indices.contents[j]); - // } - // printf("\nWalk states for %u %s:\n", i, ts_language_symbol_name(self->language, parent_symbol)); - // for (unsigned j = 0; j < walk_states.size; j++) { - // WalkState *walk_state = &walk_states.contents[j]; - // printf(" %3u: {step: %u, stack: [", j, walk_state->step_index); - // for (unsigned k = 0; k < walk_state->depth; k++) { - // printf( - // " {parent: %s, child_index: %u, field: %s, state: %3u, done:%d}", - // self->language->symbol_names[walk_state->stack[k].parent_symbol], - // walk_state->stack[k].child_index, - // self->language->field_names[walk_state->stack[k].field_id], - // walk_state->stack[k].state, - // walk_state->stack[k].done - // ); - // } - // printf(" ]}\n"); - // } - // printf("\n"); - // } + #ifdef DEBUG_ANALYZE_QUERY + printf("Final step indices:"); + for (unsigned j = 0; j < final_step_indices.size; j++) { + printf(" %u", final_step_indices.contents[j]); + } + printf("\nWalk states for %u %s:\n", i, ts_language_symbol_name(self->language, parent_symbol)); + for (unsigned j = 0; j < states.size; j++) { + AnalysisState *state = &states.contents[j]; + printf(" %3u: {step: %u, stack: [", j, state->step_index); + for (unsigned k = 0; k < state->depth; k++) { + printf( + " {parent: %s, child_index: %u, field: %s, state: %3u, done:%d}", + self->language->symbol_names[state->stack[k].parent_symbol], + state->stack[k].child_index, + self->language->field_names[state->stack[k].field_id], + state->stack[k].parse_state, + state->stack[k].done + ); + } + printf(" ]}\n"); + } + #endif - if (walk_states.size == 0) break; - array_clear(&next_walk_states); + if (states.size == 0) break; + array_clear(&next_states); + for (unsigned j = 0; j < states.size; j++) { + AnalysisState * const state = &states.contents[j]; - unsigned j = 0; - for (; j < walk_states.size; j++) { - WalkState * const walk_state = &walk_states.contents[j]; - if ( - next_walk_states.size > 0 && - walk_state__compare_position(walk_state, array_back(&next_walk_states)) >= 0 - ) { - array_insert_sorted_with(&next_walk_states, 0, walk_state__compare, *walk_state); - continue; + // For efficiency, it's important to avoid processing the same analysis state more + // than once. To achieve this, keep the states in order of ascending position within + // their hypothetical syntax trees. In each iteration of this loop, start by advancing + // the states that have made the least progress. Avoid advancing states that have already + // made more progress. + if (next_states.size > 0) { + int comparison = analysis_state__compare_position(state, array_back(&next_states)); + if (comparison == 0) { + array_insert_sorted_with(&next_states, 0, analysis_state__compare, *state); + continue; + } else if (comparison > 0) { + while (j < states.size) { + array_push(&next_states, states.contents[j]); + j++; + } + break; + } } - const TSStateId state = walk_state__top(walk_state)->state; - const TSSymbol parent_symbol = walk_state__top(walk_state)->parent_symbol; - const TSFieldId parent_field_id = walk_state__top(walk_state)->field_id; - const unsigned child_index = walk_state__top(walk_state)->child_index; - const QueryStep * const step = &self->steps.contents[walk_state->step_index]; + const TSStateId parse_state = analysis_state__top(state)->parse_state; + const TSSymbol parent_symbol = analysis_state__top(state)->parent_symbol; + const TSFieldId parent_field_id = analysis_state__top(state)->field_id; + const unsigned child_index = analysis_state__top(state)->child_index; + const QueryStep * const step = &self->steps.contents[state->step_index]; unsigned subgraph_index, exists; array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); if (!exists) continue; - const SymbolSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + const AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + // Follow every possible path in the parse table, but only visit states that + // are part of the subgraph for the current symbol. for (TSSymbol sym = 0; sym < self->language->symbol_count; sym++) { - TSStateId successor_state = ts_language_next_state(self->language, state, sym); - if (successor_state && successor_state != state) { + TSStateId successor_state = ts_language_next_state(self->language, parse_state, sym); + if (successor_state && successor_state != parse_state) { unsigned node_index; array_search_sorted_by(&subgraph->nodes, 0, .state, successor_state, &node_index, &exists); while (exists && node_index < subgraph->nodes.size) { - SubgraphNode *node = &subgraph->nodes.contents[node_index++]; + AnalysisSubgraphNode *node = &subgraph->nodes.contents[node_index++]; if (node->state != successor_state) break; if (node->child_index != child_index + 1) continue; - TSSymbol alias = ts_language_alias_at(self->language, node->production_id, child_index); + // Use the subgraph to determine what alias and field will eventually be applied + // to this child node. + TSSymbol alias = ts_language_alias_at(self->language, node->production_id, child_index); TSSymbol visible_symbol = alias ? alias : self->language->symbol_metadata[sym].visible ? self->language->public_symbol_map[sym] : 0; - TSFieldId field_id = parent_field_id; if (!field_id) { const TSFieldMapEntry *field_map, *field_map_end; @@ -977,11 +988,13 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index } } - WalkState next_walk_state = *walk_state; - walk_state__top(&next_walk_state)->child_index++; - walk_state__top(&next_walk_state)->state = successor_state; - if (node->done) walk_state__top(&next_walk_state)->done = true; + AnalysisState next_state = *state; + analysis_state__top(&next_state)->child_index++; + analysis_state__top(&next_state)->parse_state = successor_state; + if (node->done) analysis_state__top(&next_state)->done = true; + // Determine if this hypothetical child node would match the current step + // of the query pattern. bool does_match = false; if (visible_symbol) { does_match = true; @@ -993,70 +1006,75 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index if (step->field && step->field != field_id) { does_match = false; } - } else if (next_walk_state.depth < MAX_WALK_STATE_DEPTH) { - next_walk_state.depth++; - walk_state__top(&next_walk_state)->state = state; - walk_state__top(&next_walk_state)->child_index = 0; - walk_state__top(&next_walk_state)->parent_symbol = sym; - walk_state__top(&next_walk_state)->field_id = field_id; - walk_state__top(&next_walk_state)->done = false; + } + + // If this is a hidden child, then push a new entry to the stack, in order to + // walk through the children of this child. + else if (next_state.depth < MAX_ANALYSIS_STATE_DEPTH) { + next_state.depth++; + analysis_state__top(&next_state)->parse_state = parse_state; + analysis_state__top(&next_state)->child_index = 0; + analysis_state__top(&next_state)->parent_symbol = sym; + analysis_state__top(&next_state)->field_id = field_id; + analysis_state__top(&next_state)->done = false; } else { continue; } - while (next_walk_state.depth > 0 && walk_state__top(&next_walk_state)->done) { - memset(walk_state__top(&next_walk_state), 0, sizeof(WalkStateEntry)); - next_walk_state.depth--; + // Pop from the stack when this state reached the end of its current syntax node. + while (next_state.depth > 0 && analysis_state__top(&next_state)->done) { + next_state.depth--; } + // If this hypothetical child did match the current step of the query pattern, + // then advance to the next step at the current depth. This involves skipping + // over any descendant steps of the current child. + const QueryStep *next_step = step; if (does_match) { for (;;) { - next_walk_state.step_index++; - const QueryStep *step = &self->steps.contents[next_walk_state.step_index]; + next_state.step_index++; + next_step = &self->steps.contents[next_state.step_index]; if ( - step->depth == PATTERN_DONE_MARKER || - step->depth == parent_depth - ) { - can_finish_pattern = true; - break; - } - if (step->depth == parent_depth + 1) { - break; - } + next_step->depth == PATTERN_DONE_MARKER || + next_step->depth <= parent_depth + 1 + ) break; } } - if ( - next_walk_state.depth == 0 || - self->steps.contents[next_walk_state.step_index].depth != parent_depth + 1 - ) { - array_insert_sorted_by(&final_step_indices, 0, , next_walk_state.step_index); - } else { - for (;;) { - const QueryStep *step = &self->steps.contents[next_walk_state.step_index]; - if (!step->is_dead_end) { - array_insert_sorted_with(&next_walk_states, 0, walk_state__compare, next_walk_state); - } - if (step->alternative_index != NONE && step->alternative_index > next_walk_state.step_index) { - next_walk_state.step_index = step->alternative_index; + for (;;) { + // If this state can make further progress, then add it to the states for the next iteration. + // Otherwise, record the fact that matching can fail at this step of the pattern. + if (!next_step->is_dead_end) { + bool did_finish_pattern = self->steps.contents[next_state.step_index].depth != parent_depth + 1; + if (did_finish_pattern) can_finish_pattern = true; + if (next_state.depth > 0 && !did_finish_pattern) { + array_insert_sorted_with(&next_states, 0, analysis_state__compare, next_state); } else { - break; + array_insert_sorted_by(&final_step_indices, 0, , next_state.step_index); } } + + // If the state has advanced to a step with an alternative step, then add another state at + // that alternative step to the next iteration. + if ( + does_match && + next_step->alternative_index != NONE && + next_step->alternative_index > next_state.step_index + ) { + next_state.step_index = next_step->alternative_index; + next_step = &self->steps.contents[next_state.step_index]; + } else { + break; + } } } } } } - for (; j < walk_states.size; j++) { - WalkState *walk_state = &walk_states.contents[j]; - array_push(&next_walk_states, *walk_state); - } - - WalkStateList _walk_states = walk_states; - walk_states = next_walk_states; - next_walk_states = _walk_states; + AnalysisStateList _states = states; + states = next_states; + next_states = _states; } // A query step is definite if the containing pattern will definitely match @@ -1111,25 +1129,24 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index } } - // Debug - // { - // printf("\nSteps:\n"); - // for (unsigned i = 0; i < self->steps.size; i++) { - // QueryStep *step = &self->steps.contents[i]; - // if (step->depth == PATTERN_DONE_MARKER) { - // printf(" %u: DONE\n", i); - // } else { - // printf( - // " %u: {symbol: %s, is_definite: %d}\n", - // i, - // (step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL) - // ? "ANY" - // : ts_language_symbol_name(self->language, step->symbol), - // step->is_definite - // ); - // } - // } - // } + #ifdef DEBUG_ANALYZE_QUERY + printf("Steps:\n"); + for (unsigned i = 0; i < self->steps.size; i++) { + QueryStep *step = &self->steps.contents[i]; + if (step->depth == PATTERN_DONE_MARKER) { + printf(" %u: DONE\n", i); + } else { + printf( + " %u: {symbol: %s, is_definite: %d}\n", + i, + (step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL) + ? "ANY" + : ts_language_symbol_name(self->language, step->symbol), + step->is_definite + ); + } + } + #endif // Cleanup for (unsigned i = 0; i < subgraphs.size; i++) { @@ -1138,8 +1155,8 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index } array_delete(&subgraphs); array_delete(&next_nodes); - array_delete(&walk_states); - array_delete(&next_walk_states); + array_delete(&states); + array_delete(&next_states); array_delete(&final_step_indices); array_delete(&parent_step_indices); state_predecessor_map_delete(&predecessor_map); From 997ef45992c2bdf33927fdff65c56fb11dc6ab6c Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 26 Jun 2020 15:05:10 -0700 Subject: [PATCH 08/26] Handle parent nodes with simple aliases in query analysis --- lib/src/query.c | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lib/src/query.c b/lib/src/query.c index cf84115f..bf781204 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -755,11 +755,12 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index for (unsigned i = 0; i < count; i++) { const TSParseAction *action = &actions[i]; if (action->type == TSParseActionTypeReduce) { + TSSymbol symbol = self->language->public_symbol_map[action->params.reduce.symbol]; array_search_sorted_by( &subgraphs, 0, .symbol, - action->params.reduce.symbol, + symbol, &subgraph_index, &exists ); @@ -784,11 +785,12 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index TSStateId next_state = ts_language_next_state(self->language, state, sym); if (next_state != 0) { state_predecessor_map_add(&predecessor_map, next_state, state); + TSSymbol symbol = self->language->public_symbol_map[sym]; array_search_sorted_by( &subgraphs, subgraph_index, .symbol, - sym, + symbol, &subgraph_index, &exists ); @@ -850,7 +852,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index printf(" %u, %s:\n", subgraph->symbol, ts_language_symbol_name(self->language, subgraph->symbol)); for (unsigned j = 0; j < subgraph->nodes.size; j++) { AnalysisSubgraphNode *node = &subgraph->nodes.contents[j]; - printf(" {state: %u, child_index: %u}\n", node->state, node->child_index); + printf(" {state: %u, child_index: %u, production_id: %u}\n", node->state, node->child_index, node->production_id); } printf("\n"); } From a317199215f1bf1848f12c57ee3713c721a4a392 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 26 Jun 2020 15:05:27 -0700 Subject: [PATCH 09/26] Add query construction to benchmark --- cli/benches/benchmark.rs | 124 ++++++++++++++++++++++++--------------- script/benchmark | 20 ++++++- 2 files changed, 95 insertions(+), 49 deletions(-) diff --git a/cli/benches/benchmark.rs b/cli/benches/benchmark.rs index 50ee5370..53ab3fea 100644 --- a/cli/benches/benchmark.rs +++ b/cli/benches/benchmark.rs @@ -2,8 +2,8 @@ use lazy_static::lazy_static; use std::collections::BTreeMap; use std::path::{Path, PathBuf}; use std::time::Instant; -use std::{env, fs, usize}; -use tree_sitter::{Language, Parser}; +use std::{env, fs, str, usize}; +use tree_sitter::{Language, Parser, Query}; use tree_sitter_cli::error::Error; use tree_sitter_cli::loader::Loader; @@ -18,26 +18,33 @@ lazy_static! { .map(|s| usize::from_str_radix(&s, 10).unwrap()) .unwrap_or(5); static ref TEST_LOADER: Loader = Loader::new(SCRATCH_DIR.clone()); - static ref EXAMPLE_PATHS_BY_LANGUAGE_DIR: BTreeMap> = { - fn process_dir(result: &mut BTreeMap>, dir: &Path) { + static ref EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR: BTreeMap, Vec)> = { + fn process_dir(result: &mut BTreeMap, Vec)>, dir: &Path) { if dir.join("grammar.js").exists() { let relative_path = dir.strip_prefix(GRAMMARS_DIR.as_path()).unwrap(); + let (example_paths, query_paths) = + result.entry(relative_path.to_owned()).or_default(); + if let Ok(example_files) = fs::read_dir(&dir.join("examples")) { - result.insert( - relative_path.to_owned(), - example_files - .filter_map(|p| { - let p = p.unwrap().path(); - if p.is_file() { - Some(p) - } else { - None - } - }) - .collect(), - ); - } else { - result.insert(relative_path.to_owned(), Vec::new()); + example_paths.extend(example_files.filter_map(|p| { + let p = p.unwrap().path(); + if p.is_file() { + Some(p.to_owned()) + } else { + None + } + })); + } + + if let Ok(query_files) = fs::read_dir(&dir.join("queries")) { + query_paths.extend(query_files.filter_map(|p| { + let p = p.unwrap().path(); + if p.is_file() { + Some(p.to_owned()) + } else { + None + } + })); } } else { for entry in fs::read_dir(&dir).unwrap() { @@ -56,20 +63,25 @@ lazy_static! { } fn main() { - let mut parser = Parser::new(); - let max_path_length = EXAMPLE_PATHS_BY_LANGUAGE_DIR - .iter() - .flat_map(|(_, paths)| paths.iter()) - .map(|p| p.file_name().unwrap().to_str().unwrap().chars().count()) + let max_path_length = EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR + .values() + .flat_map(|(e, q)| { + e.iter() + .chain(q.iter()) + .map(|s| s.file_name().unwrap().to_str().unwrap().len()) + }) .max() - .unwrap(); - - let mut all_normal_speeds = Vec::new(); - let mut all_error_speeds = Vec::new(); + .unwrap_or(0); eprintln!("Benchmarking with {} repetitions", *REPETITION_COUNT); - for (language_path, example_paths) in EXAMPLE_PATHS_BY_LANGUAGE_DIR.iter() { + let mut parser = Parser::new(); + let mut all_normal_speeds = Vec::new(); + let mut all_error_speeds = Vec::new(); + + for (language_path, (example_paths, query_paths)) in + EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR.iter() + { let language_name = language_path.file_name().unwrap().to_str().unwrap(); if let Some(filter) = LANGUAGE_FILTER.as_ref() { @@ -79,9 +91,24 @@ fn main() { } eprintln!("\nLanguage: {}", language_name); - parser.set_language(get_language(language_path)).unwrap(); + let language = get_language(language_path); + parser.set_language(language).unwrap(); - eprintln!(" Normal examples:"); + eprintln!(" Constructing Queries"); + for path in query_paths { + if let Some(filter) = EXAMPLE_FILTER.as_ref() { + if !path.to_str().unwrap().contains(filter.as_str()) { + continue; + } + } + + parse(&path, max_path_length, |source| { + Query::new(language, str::from_utf8(source).unwrap()) + .expect("Failed to parse query"); + }); + } + + eprintln!(" Parsing Valid Code:"); let mut normal_speeds = Vec::new(); for example_path in example_paths { if let Some(filter) = EXAMPLE_FILTER.as_ref() { @@ -90,12 +117,16 @@ fn main() { } } - normal_speeds.push(parse(&mut parser, example_path, max_path_length)); + normal_speeds.push(parse(example_path, max_path_length, |code| { + parser.parse(code, None).expect("Failed to parse"); + })); } - eprintln!(" Error examples (mismatched languages):"); + eprintln!(" Parsing Invalid Code (mismatched languages):"); let mut error_speeds = Vec::new(); - for (other_language_path, example_paths) in EXAMPLE_PATHS_BY_LANGUAGE_DIR.iter() { + for (other_language_path, (example_paths, _)) in + EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR.iter() + { if other_language_path != language_path { for example_path in example_paths { if let Some(filter) = EXAMPLE_FILTER.as_ref() { @@ -104,7 +135,9 @@ fn main() { } } - error_speeds.push(parse(&mut parser, example_path, max_path_length)); + error_speeds.push(parse(example_path, max_path_length, |code| { + parser.parse(code, None).expect("Failed to parse"); + })); } } } @@ -123,7 +156,7 @@ fn main() { all_error_speeds.extend(error_speeds); } - eprintln!("\nOverall"); + eprintln!("\n Overall"); if let Some((average_normal, worst_normal)) = aggregate(&all_normal_speeds) { eprintln!(" Average Speed (normal): {} bytes/ms", average_normal); eprintln!(" Worst Speed (normal): {} bytes/ms", worst_normal); @@ -151,28 +184,25 @@ fn aggregate(speeds: &Vec) -> Option<(usize, usize)> { Some((total / speeds.len(), max)) } -fn parse(parser: &mut Parser, example_path: &Path, max_path_length: usize) -> usize { +fn parse(path: &Path, max_path_length: usize, mut action: impl FnMut(&[u8])) -> usize { eprint!( " {:width$}\t", - example_path.file_name().unwrap().to_str().unwrap(), + path.file_name().unwrap().to_str().unwrap(), width = max_path_length ); - let source_code = fs::read(example_path) - .map_err(Error::wrap(|| format!("Failed to read {:?}", example_path))) + let source_code = fs::read(path) + .map_err(Error::wrap(|| format!("Failed to read {:?}", path))) .unwrap(); let time = Instant::now(); for _ in 0..*REPETITION_COUNT { - parser - .parse(&source_code, None) - .expect("Incompatible language version"); + action(&source_code); } let duration = time.elapsed() / (*REPETITION_COUNT as u32); - let duration_ms = - duration.as_secs() as f64 * 1000.0 + duration.subsec_nanos() as f64 / 1000000.0; - let speed = (source_code.len() as f64 / duration_ms) as usize; + let duration_ms = duration.as_millis(); + let speed = source_code.len() as u128 / (duration_ms + 1); eprintln!("time {} ms\tspeed {} bytes/ms", duration_ms as usize, speed); - speed + speed as usize } fn get_language(path: &Path) -> Language { diff --git a/script/benchmark b/script/benchmark index 61e57920..7599e989 100755 --- a/script/benchmark +++ b/script/benchmark @@ -18,15 +18,22 @@ OPTIONS -r parse each sample the given number of times (default 5) + -g debug + EOF } -while getopts "hl:e:r:" option; do +mode=normal + +while getopts "hgl:e:r:" option; do case ${option} in h) usage exit ;; + g) + mode=debug + ;; e) export TREE_SITTER_BENCHMARK_EXAMPLE_FILTER=${OPTARG} ;; @@ -39,4 +46,13 @@ while getopts "hl:e:r:" option; do esac done -cargo bench benchmark +if [[ "${mode}" == "debug" ]]; then + test_binary=$( + cargo bench benchmark --no-run --message-format=json 2> /dev/null |\ + jq -rs 'map(select(.target.name == "benchmark" and .executable))[0].executable' + ) + env | grep TREE_SITTER + echo $test_binary +else + exec cargo bench benchmark +fi From 645aacb1e7b8a02cf7badaf90e08d77350daa74f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 26 Jun 2020 15:40:34 -0700 Subject: [PATCH 10/26] Optimize query analysis using binary search --- lib/src/array.h | 47 +++++++++++++++++++++++++---------------------- lib/src/query.c | 24 +++++++++++++++--------- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/lib/src/array.h b/lib/src/array.h index e95867cf..7fae7a40 100644 --- a/lib/src/array.h +++ b/lib/src/array.h @@ -66,43 +66,46 @@ extern "C" { #define array_assign(self, other) \ array__assign((VoidArray *)(self), (const VoidArray *)(other), array__elem_size(self)) -#define array_search_sorted_by(self, start, field, needle, out_index, out_exists) \ +#define array__search_sorted(self, start, compare, suffix, needle, index, exists) \ do { \ - *(out_exists) = false; \ - for (*(out_index) = start; *(out_index) < (self)->size; (*(out_index))++) { \ - int _comparison = (int)((self)->contents[*(out_index)] field) - (int)(needle); \ - if (_comparison >= 0) { \ - if (_comparison == 0) *(out_exists) = true; \ - break; \ - } \ + *(index) = start; \ + *(exists) = false; \ + uint32_t size = (self)->size - *(index); \ + if (size == 0) break; \ + int comparison; \ + while (size > 1) { \ + uint32_t half_size = size / 2; \ + uint32_t mid_index = *(index) + half_size; \ + comparison = compare(&((self)->contents[mid_index] suffix), (needle)); \ + if (comparison <= 0) *(index) = mid_index; \ + size -= half_size; \ } \ - } while (0); + comparison = compare(&((self)->contents[*(index)] suffix), (needle)); \ + if (comparison == 0) *(exists) = true; \ + else if (comparison < 0) *(index) += 1; \ + } while (0) -#define array_search_sorted_with(self, start, compare, needle, out_index, out_exists) \ - do { \ - *(out_exists) = false; \ - for (*(out_index) = start; *(out_index) < (self)->size; (*(out_index))++) { \ - int _comparison = compare(&(self)->contents[*(out_index)], (needle)); \ - if (_comparison >= 0) { \ - if (_comparison == 0) *(out_exists) = true; \ - break; \ - } \ - } \ - } while (0); +#define _compare_int(a, b) ((int)*(a) - (int)(b)) + +#define array_search_sorted_by(self, start, field, needle, index, exists) \ + array__search_sorted(self, start, _compare_int, field, needle, index, exists) + +#define array_search_sorted_with(self, start, compare, needle, index, exists) \ + array__search_sorted(self, start, compare, , needle, index, exists) #define array_insert_sorted_by(self, start, field, value) \ do { \ unsigned index, exists; \ array_search_sorted_by(self, start, field, (value) field, &index, &exists); \ if (!exists) array_insert(self, index, value); \ - } while (0); + } while (0) #define array_insert_sorted_with(self, start, compare, value) \ do { \ unsigned index, exists; \ array_search_sorted_with(self, start, compare, &(value), &index, &exists); \ if (!exists) array_insert(self, index, value); \ - } while (0); + } while (0) // Private diff --git a/lib/src/query.c b/lib/src/query.c index bf781204..64a1b8a0 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -612,10 +612,10 @@ static inline AnalysisStateEntry *analysis_state__top(AnalysisState *self) { static inline int analysis_subgraph_node__compare(const AnalysisSubgraphNode *self, const AnalysisSubgraphNode *other) { if (self->state < other->state) return -1; if (self->state > other->state) return 1; - if (self->done && !other->done) return -1; - if (!self->done && other->done) return 1; if (self->child_index < other->child_index) return -1; if (self->child_index > other->child_index) return 1; + if (self->done < other->done) return -1; + if (self->done > other->done) return 1; if (self->production_id < other->production_id) return -1; if (self->production_id > other->production_id) return 1; return 0; @@ -961,14 +961,20 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // Follow every possible path in the parse table, but only visit states that // are part of the subgraph for the current symbol. for (TSSymbol sym = 0; sym < self->language->symbol_count; sym++) { - TSStateId successor_state = ts_language_next_state(self->language, parse_state, sym); - if (successor_state && successor_state != parse_state) { + AnalysisSubgraphNode successor = { + .state = ts_language_next_state(self->language, parse_state, sym), + .child_index = child_index + 1, + }; + if (successor.state && successor.state != parse_state) { unsigned node_index; - array_search_sorted_by(&subgraph->nodes, 0, .state, successor_state, &node_index, &exists); - while (exists && node_index < subgraph->nodes.size) { + array_search_sorted_with( + &subgraph->nodes, 0, + analysis_subgraph_node__compare, &successor, + &node_index, &exists + ); + while (node_index < subgraph->nodes.size) { AnalysisSubgraphNode *node = &subgraph->nodes.contents[node_index++]; - if (node->state != successor_state) break; - if (node->child_index != child_index + 1) continue; + if (node->state != successor.state || node->child_index != successor.child_index) break; // Use the subgraph to determine what alias and field will eventually be applied // to this child node. @@ -992,7 +998,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index AnalysisState next_state = *state; analysis_state__top(&next_state)->child_index++; - analysis_state__top(&next_state)->parse_state = successor_state; + analysis_state__top(&next_state)->parse_state = successor.state; if (node->done) analysis_state__top(&next_state)->done = true; // Determine if this hypothetical child node would match the current step From cc37da7457da79795e47a41878342758b443004b Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 26 Jun 2020 16:31:08 -0700 Subject: [PATCH 11/26] Query analysis: fix propagation of uncertainty from later siblings --- cli/src/tests/query_test.rs | 89 ++++++++++++++++++++++------------- lib/binding_rust/bindings.rs | 1 + lib/binding_rust/lib.rs | 4 +- lib/include/tree_sitter/api.h | 1 + lib/src/query.c | 71 ++++++++++++++++------------ 5 files changed, 99 insertions(+), 67 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index c73931ce..5f6979a2 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2079,90 +2079,111 @@ fn test_query_is_definite() { struct Row { language: Language, pattern: &'static str, - results_by_step_index: &'static [(usize, bool)], + results_by_symbol: &'static [(&'static str, bool)], } let rows = &[ Row { language: get_language("python"), pattern: r#"(expression_statement (string))"#, - results_by_step_index: &[ - (0, false), - (1, false), // string + results_by_symbol: &[ + ("expression_statement", false), + ("string", false), ], }, Row { language: get_language("javascript"), pattern: r#"(expression_statement (string))"#, - results_by_step_index: &[ - (0, false), - (1, false), // string + results_by_symbol: &[ + ("expression_statement", false), + ("string", false), // string ], }, Row { language: get_language("javascript"), pattern: r#"(object "{" "}")"#, - results_by_step_index: &[ - (0, false), - (1, true), // "{" - (2, true), // "}" + results_by_symbol: &[ + ("object", false), + ("{", true), + ("}", true), ], }, Row { language: get_language("javascript"), pattern: r#"(pair (property_identifier) ":")"#, - results_by_step_index: &[ - (0, false), - (1, false), // property_identifier - (2, true), // ":"" + results_by_symbol: &[ + ("pair", false), + ("property_identifier", false), + (":", true), ], }, Row { language: get_language("javascript"), pattern: r#"(object "{" (_) "}")"#, - results_by_step_index: &[ - (0, false), - (1, false), // "{"" - (2, false), // (_) - (3, true), // "}" + results_by_symbol: &[ + ("object", false), + ("{", false), + ("", false), + ("}", true), ], }, Row { - // Named wildcards, fields language: get_language("javascript"), pattern: r#"(binary_expression left: (identifier) right: (_))"#, - results_by_step_index: &[ - (0, false), - (1, false), // identifier - (2, true), // (_) + results_by_symbol: &[ + ("binary_expression", false), + ("identifier", false), + ("", true), ], }, Row { language: get_language("javascript"), pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, - results_by_step_index: &[ - (0, false), - (1, true), // identifier - (2, true), // statement_block + results_by_symbol: &[ + ("function_declaration", false), + ("identifier", true), + ("statement_block", true), + ], + }, + Row { + language: get_language("javascript"), + pattern: r#" + (function_declaration + name: (identifier) + body: (statement_block "{" (expression_statement) "}"))"#, + results_by_symbol: &[ + ("function_declaration", false), + ("identifier", false), + ("statement_block", false), + ("{", false), + ("expression_statement", false), + ("}", true), ], }, Row { language: get_language("javascript"), pattern: r#""#, - results_by_step_index: &[], + results_by_symbol: &[], }, ]; allocations::record(|| { for row in rows.iter() { let query = Query::new(row.language, row.pattern).unwrap(); - for (step_index, is_definite) in row.results_by_step_index { + for (symbol_name, is_definite) in row.results_by_symbol { + let mut symbol = 0; + if !symbol_name.is_empty() { + symbol = row.language.id_for_node_kind(symbol_name, true); + if symbol == 0 { + symbol = row.language.id_for_node_kind(symbol_name, false); + } + } assert_eq!( - query.pattern_is_definite(0, *step_index), + query.pattern_is_definite(0, symbol, 0), *is_definite, - "Pattern: {:?}, step: {}, expected is_definite to be {}", + "Pattern: {:?}, symbol: {}, expected is_definite to be {}", row.pattern, - step_index, + symbol_name, is_definite, ) } diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 167edebf..b5ff7a9e 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -654,6 +654,7 @@ extern "C" { pub fn ts_query_pattern_is_definite( self_: *const TSQuery, pattern_index: u32, + symbol: TSSymbol, step_index: u32, ) -> bool; } diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index d3284974..b4d6f8c5 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1467,9 +1467,9 @@ impl Query { /// Check if a pattern will definitely match after a certain number of steps /// have matched. - pub fn pattern_is_definite(&self, index: usize, step_index: usize) -> bool { + pub fn pattern_is_definite(&self, pattern_index: usize, symbol: u16, step_index: usize) -> bool { unsafe { - ffi::ts_query_pattern_is_definite(self.ptr.as_ptr(), index as u32, step_index as u32) + ffi::ts_query_pattern_is_definite(self.ptr.as_ptr(), pattern_index as u32, symbol, step_index as u32) } } diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 1abbf28c..850cd31e 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -722,6 +722,7 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( bool ts_query_pattern_is_definite( const TSQuery *self, uint32_t pattern_index, + TSSymbol symbol, uint32_t step_index ); diff --git a/lib/src/query.c b/lib/src/query.c index 64a1b8a0..dd6ad8c0 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -149,8 +149,7 @@ typedef struct { /* * AnalysisState - The state needed for walking the parse table when analyzing - * a query pattern, to determine the steps where the pattern could fail - * to match. + * a query pattern, to determine at which steps the pattern might fail to match. */ typedef struct { TSStateId parse_state; @@ -166,6 +165,12 @@ typedef struct { uint16_t step_index; } AnalysisState; +/* + * AnalysisSubgraph - A subset of the states in the parse table that are used + * in constructing nodes with a certain symbol. Each state is accompanied by + * some information about the possible node that could be produced in + * downstream states. + */ typedef struct { TSStateId state; uint8_t production_id; @@ -914,7 +919,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index " {parent: %s, child_index: %u, field: %s, state: %3u, done:%d}", self->language->symbol_names[state->stack[k].parent_symbol], state->stack[k].child_index, - self->language->field_names[state->stack[k].field_id], + state->stack[k].field_id ? self->language->field_names[state->stack[k].field_id] : "", state->stack[k].parse_state, state->stack[k].done ); @@ -1018,7 +1023,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // If this is a hidden child, then push a new entry to the stack, in order to // walk through the children of this child. - else if (next_state.depth < MAX_ANALYSIS_STATE_DEPTH) { + else if (sym >= self->language->token_count && next_state.depth < MAX_ANALYSIS_STATE_DEPTH) { next_state.depth++; analysis_state__top(&next_state)->parse_state = parse_state; analysis_state__top(&next_state)->child_index = 0; @@ -1122,17 +1127,29 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index } } - // In order for a parent step to be definite, all of its child steps must - // be definite. Propagate the definiteness up the pattern trees by walking - // the query's steps in reverse. + // In order for a step to be definite, all of its child steps must be definite, + // and all of its later sibling steps must be definite. Propagate any indefiniteness + // upward and backward through the pattern trees. for (unsigned i = self->steps.size - 1; i + 1 > 0; i--) { QueryStep *step = &self->steps.contents[i]; - for (unsigned j = i + 1; j < self->steps.size; j++) { + bool all_later_children_definite = true; + unsigned end_step_index = i + 1; + while (end_step_index < self->steps.size) { + QueryStep *child_step = &self->steps.contents[end_step_index]; + if (child_step->depth <= step->depth || child_step->depth == PATTERN_DONE_MARKER) break; + end_step_index++; + } + for (unsigned j = end_step_index - 1; j > i; j--) { QueryStep *child_step = &self->steps.contents[j]; - if (child_step->depth <= step->depth) break; - if (child_step->depth == step->depth + 1 && !child_step->is_definite) { - step->is_definite = false; - break; + if (child_step->depth == step->depth + 1) { + if (all_later_children_definite) { + if (!child_step->is_definite) { + all_later_children_definite = false; + step->is_definite = false; + } + } else { + child_step->is_definite = false; + } } } } @@ -1870,29 +1887,21 @@ uint32_t ts_query_start_byte_for_pattern( bool ts_query_pattern_is_definite( const TSQuery *self, uint32_t pattern_index, - uint32_t step_count + TSSymbol symbol, + uint32_t index ) { uint32_t step_index = self->patterns.contents[pattern_index].start_step; - for (;;) { - QueryStep *start_step = &self->steps.contents[step_index]; - if (step_index + step_count < self->steps.size) { - QueryStep *step = start_step; - for (unsigned i = 0; i < step_count; i++) { - if (step->depth == PATTERN_DONE_MARKER) { - step = NULL; - break; - } - step++; - } - if (step && !step->is_definite) return false; - } - if (start_step->alternative_index != NONE && start_step->alternative_index > step_index) { - step_index = start_step->alternative_index; - } else { - break; + QueryStep *step = &self->steps.contents[step_index]; + for (; step->depth != PATTERN_DONE_MARKER; step++) { + bool does_match = symbol ? + step->symbol == symbol : + step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL; + if (does_match) { + if (index == 0) return step->is_definite; + index--; } } - return true; + return false; } void ts_query_disable_capture( From c3f9b2b377a789160a1ddb1eed5b215e226f5d31 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 17 Aug 2020 09:57:06 -0700 Subject: [PATCH 12/26] Fix query analysis bugs found in ruby tags query --- cli/src/tests/query_test.rs | 14 +++++++++++--- lib/src/query.c | 24 ++++++++++++++---------- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index c3a18d71..aa5a6744 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2310,9 +2310,17 @@ fn test_query_is_definite() { ], }, Row { - language: get_language("javascript"), - pattern: r#""#, - results_by_symbol: &[], + language: get_language("ruby"), + pattern: r#" + (singleton_class + value: (constant) + "end") + "#, + results_by_symbol: &[ + ("singleton_class", false), + ("constant", false), + ("end", true), + ], }, ]; diff --git a/lib/src/query.c b/lib/src/query.c index 15aa2fd1..52f46918 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -14,7 +14,7 @@ #define MAX_CAPTURE_LIST_COUNT 32 #define MAX_STEP_CAPTURE_COUNT 3 #define MAX_STATE_PREDECESSOR_COUNT 100 -#define MAX_ANALYSIS_STATE_DEPTH 4 +#define MAX_ANALYSIS_STATE_DEPTH 8 /* * Stream - A sequence of unicode characters derived from a UTF8 string. @@ -804,7 +804,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index } for (TSSymbol sym = self->language->token_count; sym < self->language->symbol_count; sym++) { TSStateId next_state = ts_language_next_state(self->language, state, sym); - if (next_state != 0) { + if (next_state != 0 && next_state != state) { state_predecessor_map_add(&predecessor_map, next_state, state); TSSymbol symbol = self->language->public_symbol_map[sym]; array_search_sorted_by( @@ -873,7 +873,10 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index printf(" %u, %s:\n", subgraph->symbol, ts_language_symbol_name(self->language, subgraph->symbol)); for (unsigned j = 0; j < subgraph->nodes.size; j++) { AnalysisSubgraphNode *node = &subgraph->nodes.contents[j]; - printf(" {state: %u, child_index: %u, production_id: %u}\n", node->state, node->child_index, node->production_id); + printf( + " {state: %u, child_index: %u, production_id: %u, done: %d}\n", + node->state, node->child_index, node->production_id, node->done + ); } printf("\n"); } @@ -924,23 +927,24 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index #ifdef DEBUG_ANALYZE_QUERY printf("Final step indices:"); for (unsigned j = 0; j < final_step_indices.size; j++) { - printf(" %u", final_step_indices.contents[j]); + printf(" %4u", final_step_indices.contents[j]); } printf("\nWalk states for %u %s:\n", i, ts_language_symbol_name(self->language, parent_symbol)); for (unsigned j = 0; j < states.size; j++) { AnalysisState *state = &states.contents[j]; - printf(" %3u: {step: %u, stack: [", j, state->step_index); + printf(" %3u: step: %u, stack: [", j, state->step_index); for (unsigned k = 0; k < state->depth; k++) { printf( - " {parent: %s, child_index: %u, field: %s, state: %3u, done:%d}", + " {%s, child: %u, state: %4u", self->language->symbol_names[state->stack[k].parent_symbol], state->stack[k].child_index, - state->stack[k].field_id ? self->language->field_names[state->stack[k].field_id] : "", - state->stack[k].parse_state, - state->stack[k].done + state->stack[k].parse_state ); + if (state->stack[k].field_id) printf(", field: %s", self->language->field_names[state->stack[k].field_id]); + if (state->stack[k].done) printf(", DONE"); + printf("}"); } - printf(" ]}\n"); + printf(" ]\n"); } #endif From 228a9e28e1c19f12a6ca60ea85fab2b5c6c101ab Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 17 Aug 2020 13:27:17 -0700 Subject: [PATCH 13/26] Add tests for impossible queries --- cli/src/tests/query_test.rs | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index aa5a6744..15c64afa 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -186,21 +186,40 @@ fn test_query_errors_on_invalid_conditions() { #[test] fn test_query_errors_on_impossible_patterns() { - allocations::record(|| { - let language = get_language("javascript"); + let js_lang = get_language("javascript"); + let rb_lang = get_language("ruby"); + allocations::record(|| { assert_eq!( Query::new( - language, - "(binary_expression left:(identifier) left:(identifier))" + js_lang, + "(binary_expression left: (identifier) left: (identifier))" ), Err(QueryError::Pattern( 1, - [ - "(binary_expression left:(identifier) left:(identifier))", // - "^" - ] - .join("\n") + "(binary_expression left: (identifier) left: (identifier))\n^".to_string(), + )) + ); + + Query::new( + js_lang, + "(function_declaration name: (identifier) (statement_block))", + ) + .unwrap(); + assert_eq!( + Query::new(js_lang, "(function_declaration name: (statement_block))"), + Err(QueryError::Pattern( + 1, + "(function_declaration name: (statement_block))\n^".to_string(), + )) + ); + + Query::new(rb_lang, "(call receiver:(call))").unwrap(); + assert_eq!( + Query::new(rb_lang, "(call receiver:(binary))"), + Err(QueryError::Pattern( + 1, + "(call receiver:(binary))\n^".to_string(), )) ); }); From 91fc9f5399e4513efb87c1981ff31f9fd1e2e6ec Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 17 Aug 2020 16:50:59 -0700 Subject: [PATCH 14/26] Use is_definite flag in ts_query_cursor_next_capture --- cli/src/tests/query_test.rs | 48 +++++++++ lib/src/query.c | 197 ++++++++++++++++++++---------------- 2 files changed, 158 insertions(+), 87 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 15c64afa..1df87c74 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1952,6 +1952,54 @@ fn test_query_captures_with_too_many_nested_results() { }); } +#[test] +fn test_query_captures_with_definite_pattern_containing_many_nested_matches() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + (array + "[" @l-bracket + "]" @r-bracket) + + "." @dot + "#, + ) + .unwrap(); + + // The '[' node must be returned before all of the '.' nodes, + // even though its pattern does not finish until the ']' node + // at the end of the document. But because the '[' is definite, + // it can be returned before the pattern finishes matching. + let source = " + [ + a.b.c.d.e.f.g.h.i, + a.b.c.d.e.f.g.h.i, + a.b.c.d.e.f.g.h.i, + a.b.c.d.e.f.g.h.i, + a.b.c.d.e.f.g.h.i, + ] + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + + let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + assert_eq!( + collect_captures(captures, &query, source), + [("l-bracket", "[")] + .iter() + .chain([("dot", "."); 40].iter()) + .chain([("r-bracket", "]")].iter()) + .cloned() + .collect::>(), + ); + }); +} + #[test] fn test_query_captures_ordered_by_both_start_and_end_positions() { allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index 52f46918..a7bc9b81 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -2034,7 +2034,8 @@ static bool ts_query_cursor__first_in_progress_capture( TSQueryCursor *self, uint32_t *state_index, uint32_t *byte_offset, - uint32_t *pattern_index + uint32_t *pattern_index, + bool *is_definite ) { bool result = false; *state_index = UINT32_MAX; @@ -2047,13 +2048,20 @@ static bool ts_query_cursor__first_in_progress_capture( &self->capture_list_pool, state->capture_list_id ); - if (captures->size > 0) { - uint32_t capture_byte = ts_node_start_byte(captures->contents[0].node); + if (captures->size > state->consumed_capture_count) { + uint32_t capture_byte = ts_node_start_byte(captures->contents[state->consumed_capture_count].node); if ( !result || capture_byte < *byte_offset || (capture_byte == *byte_offset && state->pattern_index < *pattern_index) ) { + QueryStep *step = &self->query->steps.contents[state->step_index]; + if (is_definite) { + *is_definite = step->is_definite; + } else if (step->is_definite) { + continue; + } + result = true; *state_index = i; *byte_offset = capture_byte; @@ -2216,7 +2224,8 @@ static CaptureList *ts_query_cursor__prepare_to_capture( self, &state_index, &byte_offset, - &pattern_index + &pattern_index, + NULL ) && state_index != state_index_to_preserve ) { @@ -2275,7 +2284,10 @@ static QueryState *ts_query_cursor__copy_state( // If one or more patterns finish, return `true` and store their states in the // `finished_states` array. Multiple patterns can finish on the same node. If // there are no more matches, return `false`. -static inline bool ts_query_cursor__advance(TSQueryCursor *self) { +static inline bool ts_query_cursor__advance( + TSQueryCursor *self, + bool stop_on_definite_step +) { bool did_match = false; for (;;) { if (self->halted) { @@ -2290,6 +2302,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { if (did_match || self->halted) return did_match; + // Exit the current node. if (self->ascending) { LOG("leave node. type:%s\n", ts_node_type(ts_tree_cursor_current_node(&self->cursor))); @@ -2342,7 +2355,10 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { } } self->states.size -= deleted_count; - } else { + } + + // Enter a new node. + else { // If this node is before the selected range, then avoid descending into it. TSNode node = ts_tree_cursor_current_node(&self->cursor); if ( @@ -2516,6 +2532,9 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { state->step_index ); + QueryStep *next_step = &self->query->steps.contents[state->step_index]; + if (stop_on_definite_step && next_step->is_definite) did_match = true; + // If this state's next step has an alternative step, then copy the state in order // to pursue both alternatives. The alternative step itself may have an alternative, // so this is an interative process. @@ -2660,7 +2679,7 @@ bool ts_query_cursor_next_match( TSQueryMatch *match ) { if (self->finished_states.size == 0) { - if (!ts_query_cursor__advance(self)) { + if (!ts_query_cursor__advance(self, false)) { return false; } } @@ -2701,99 +2720,103 @@ bool ts_query_cursor_next_capture( TSQueryMatch *match, uint32_t *capture_index ) { + // The goal here is to return captures in order, even though they may not + // be discovered in order, because patterns can overlap. Search for matches + // until there is a finished capture that is before any unfinished capture. for (;;) { - // The goal here is to return captures in order, even though they may not - // be discovered in order, because patterns can overlap. If there are any - // finished patterns, then try to find one that contains a capture that - // is *definitely* before any capture in an *unfinished* pattern. - if (self->finished_states.size > 0) { - // First, identify the position of the earliest capture in an unfinished - // match. For a finished capture to be returned, it must be *before* - // this position. - uint32_t first_unfinished_capture_byte; - uint32_t first_unfinished_pattern_index; - uint32_t first_unfinished_state_index; - ts_query_cursor__first_in_progress_capture( - self, - &first_unfinished_state_index, - &first_unfinished_capture_byte, - &first_unfinished_pattern_index + // First, find the earliest capture in an unfinished match. + uint32_t first_unfinished_capture_byte; + uint32_t first_unfinished_pattern_index; + uint32_t first_unfinished_state_index; + bool first_unfinished_state_is_definite = false; + ts_query_cursor__first_in_progress_capture( + self, + &first_unfinished_state_index, + &first_unfinished_capture_byte, + &first_unfinished_pattern_index, + &first_unfinished_state_is_definite + ); + + // Then find the earliest capture in a finished match. It must occur + // before the first capture in an *unfinished* match. + QueryState *first_finished_state = NULL; + uint32_t first_finished_capture_byte = first_unfinished_capture_byte; + uint32_t first_finished_pattern_index = first_unfinished_pattern_index; + for (unsigned i = 0; i < self->finished_states.size; i++) { + QueryState *state = &self->finished_states.contents[i]; + const CaptureList *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id ); - - // Find the earliest capture in a finished match. - int first_finished_state_index = -1; - uint32_t first_finished_capture_byte = first_unfinished_capture_byte; - uint32_t first_finished_pattern_index = first_unfinished_pattern_index; - for (unsigned i = 0; i < self->finished_states.size; i++) { - const QueryState *state = &self->finished_states.contents[i]; - const CaptureList *captures = capture_list_pool_get( - &self->capture_list_pool, - state->capture_list_id + if (captures->size > state->consumed_capture_count) { + uint32_t capture_byte = ts_node_start_byte( + captures->contents[state->consumed_capture_count].node ); - if (captures->size > state->consumed_capture_count) { - uint32_t capture_byte = ts_node_start_byte( - captures->contents[state->consumed_capture_count].node - ); - if ( - capture_byte < first_finished_capture_byte || - ( - capture_byte == first_finished_capture_byte && - state->pattern_index < first_finished_pattern_index - ) - ) { - first_finished_state_index = i; - first_finished_capture_byte = capture_byte; - first_finished_pattern_index = state->pattern_index; - } - } else { - capture_list_pool_release( - &self->capture_list_pool, - state->capture_list_id - ); - array_erase(&self->finished_states, i); - i--; + if ( + capture_byte < first_finished_capture_byte || + ( + capture_byte == first_finished_capture_byte && + state->pattern_index < first_finished_pattern_index + ) + ) { + first_finished_state = state; + first_finished_capture_byte = capture_byte; + first_finished_pattern_index = state->pattern_index; } - } - - // If there is finished capture that is clearly before any unfinished - // capture, then return its match, and its capture index. Internally - // record the fact that the capture has been 'consumed'. - if (first_finished_state_index != -1) { - QueryState *state = &self->finished_states.contents[ - first_finished_state_index - ]; - match->id = state->id; - match->pattern_index = state->pattern_index; - const CaptureList *captures = capture_list_pool_get( - &self->capture_list_pool, - state->capture_list_id - ); - match->captures = captures->contents; - match->capture_count = captures->size; - *capture_index = state->consumed_capture_count; - state->consumed_capture_count++; - return true; - } - - if (capture_list_pool_is_empty(&self->capture_list_pool)) { - LOG( - " abandon state. index:%u, pattern:%u, offset:%u.\n", - first_unfinished_state_index, - first_unfinished_pattern_index, - first_unfinished_capture_byte - ); + } else { capture_list_pool_release( &self->capture_list_pool, - self->states.contents[first_unfinished_state_index].capture_list_id + state->capture_list_id ); - array_erase(&self->states, first_unfinished_state_index); + array_erase(&self->finished_states, i); + i--; } } + // If there is finished capture that is clearly before any unfinished + // capture, then return its match, and its capture index. Internally + // record the fact that the capture has been 'consumed'. + QueryState *state; + if (first_finished_state) { + state = first_finished_state; + } else if (first_unfinished_state_is_definite) { + state = &self->states.contents[first_unfinished_state_index]; + } else { + state = NULL; + } + + if (state) { + match->id = state->id; + match->pattern_index = state->pattern_index; + const CaptureList *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + match->captures = captures->contents; + match->capture_count = captures->size; + *capture_index = state->consumed_capture_count; + state->consumed_capture_count++; + return true; + } + + if (capture_list_pool_is_empty(&self->capture_list_pool)) { + LOG( + " abandon state. index:%u, pattern:%u, offset:%u.\n", + first_unfinished_state_index, + first_unfinished_pattern_index, + first_unfinished_capture_byte + ); + capture_list_pool_release( + &self->capture_list_pool, + self->states.contents[first_unfinished_state_index].capture_list_id + ); + array_erase(&self->states, first_unfinished_state_index); + } + // If there are no finished matches that are ready to be returned, then // continue finding more matches. if ( - !ts_query_cursor__advance(self) && + !ts_query_cursor__advance(self, true) && self->finished_states.size == 0 ) return false; } From 604f9e8148de6debdaf010978e994de93b18b0f0 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 18 Aug 2020 10:10:32 -0700 Subject: [PATCH 15/26] query: Assign is_definite correctly for steps within nested sub-patterns --- cli/src/tests/query_test.rs | 25 +++++++++++++++++++++++++ lib/src/query.c | 33 +++++++++++++-------------------- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 1df87c74..5a2fbdc9 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2389,6 +2389,31 @@ fn test_query_is_definite() { ("end", true), ], }, + Row { + language: get_language("javascript"), + pattern: r#" + (call_expression + function: (member_expression + property: (property_identifier) @template-tag) + arguments: (template_string)) @template-call + "#, + results_by_symbol: &[("property_identifier", false), ("template_string", false)], + }, + Row { + language: get_language("javascript"), + pattern: r#" + (subscript_expression + object: (member_expression + object: (identifier) @obj + property: (property_identifier) @prop) + "[") + "#, + results_by_symbol: &[ + ("identifier", false), + ("property_identifier", true), + ("[", true), + ], + }, ]; allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index a7bc9b81..416e9614 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -1117,16 +1117,20 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index uint32_t child_step_index = parent_step_index + 1; QueryStep *child_step = &self->steps.contents[child_step_index]; while (child_step->depth == parent_depth + 1) { + // Check if there is any way for the pattern to reach this step, but fail + // to reach the end of the sub-pattern. for (unsigned k = 0; k < final_step_indices.size; k++) { uint32_t final_step_index = final_step_indices.contents[k]; if ( final_step_index >= child_step_index && - self->steps.contents[final_step_index].depth != PATTERN_DONE_MARKER + self->steps.contents[final_step_index].depth == child_step->depth ) { child_step->is_definite = false; break; } } + + // Advance to the next child step in this sub-pattern. do { child_step_index++; child_step++; @@ -1136,6 +1140,8 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index ); } + // If this pattern cannot match, store the pattern index so that it can be + // returned to the caller. if (result && !can_finish_pattern) { unsigned exists; array_search_sorted_by( @@ -1150,27 +1156,14 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // In order for a step to be definite, all of its child steps must be definite, // and all of its later sibling steps must be definite. Propagate any indefiniteness // upward and backward through the pattern trees. + bool all_later_children_definite = true; for (unsigned i = self->steps.size - 1; i + 1 > 0; i--) { QueryStep *step = &self->steps.contents[i]; - bool all_later_children_definite = true; - unsigned end_step_index = i + 1; - while (end_step_index < self->steps.size) { - QueryStep *child_step = &self->steps.contents[end_step_index]; - if (child_step->depth <= step->depth || child_step->depth == PATTERN_DONE_MARKER) break; - end_step_index++; - } - for (unsigned j = end_step_index - 1; j > i; j--) { - QueryStep *child_step = &self->steps.contents[j]; - if (child_step->depth == step->depth + 1) { - if (all_later_children_definite) { - if (!child_step->is_definite) { - all_later_children_definite = false; - step->is_definite = false; - } - } else { - child_step->is_definite = false; - } - } + if (step->depth == PATTERN_DONE_MARKER) { + all_later_children_definite = true; + } else { + if (!all_later_children_definite) step->is_definite = false; + if (!step->is_definite) all_later_children_definite = false; } } From bd42729a41181a71690e0b99d35346b51fa5c6a8 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 18 Aug 2020 13:01:45 -0700 Subject: [PATCH 16/26] query: Avoid early-returning captures due to predicates --- cli/src/tests/query_test.rs | 21 +++++++++++- lib/src/query.c | 66 ++++++++++++++++++++++++++++--------- 2 files changed, 70 insertions(+), 17 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 5a2fbdc9..a18c3a8b 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2414,6 +2414,22 @@ fn test_query_is_definite() { ("[", true), ], }, + Row { + language: get_language("javascript"), + pattern: r#" + (subscript_expression + object: (member_expression + object: (identifier) @obj + property: (property_identifier) @prop) + "[" + (#match? @prop "foo")) + "#, + results_by_symbol: &[ + ("identifier", false), + ("property_identifier", false), + ("[", true), + ], + }, ]; allocations::record(|| { @@ -2431,7 +2447,10 @@ fn test_query_is_definite() { query.pattern_is_definite(0, symbol, 0), *is_definite, "Pattern: {:?}, symbol: {}, expected is_definite to be {}", - row.pattern, + row.pattern + .split_ascii_whitespace() + .collect::>() + .join(" "), symbol_name, is_definite, ) diff --git a/lib/src/query.c b/lib/src/query.c index 416e9614..b3bf0b48 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -91,9 +91,9 @@ typedef struct { } PatternEntry; typedef struct { + Slice steps; Slice predicate_steps; uint32_t start_byte; - uint32_t start_step; } QueryPattern; /* @@ -1146,7 +1146,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index unsigned exists; array_search_sorted_by( &self->patterns, 0, - .start_step, parent_step_index, + .steps.offset, parent_step_index, impossible_index, &exists ); result = false; @@ -1156,12 +1156,45 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // In order for a step to be definite, all of its child steps must be definite, // and all of its later sibling steps must be definite. Propagate any indefiniteness // upward and backward through the pattern trees. - bool all_later_children_definite = true; - for (unsigned i = self->steps.size - 1; i + 1 > 0; i--) { - QueryStep *step = &self->steps.contents[i]; - if (step->depth == PATTERN_DONE_MARKER) { - all_later_children_definite = true; - } else { + Array(uint16_t) predicate_capture_ids = array_new(); + for (unsigned i = 0; i < self->patterns.size; i++) { + QueryPattern *pattern = &self->patterns.contents[i]; + + // Gather all of the captures that are used in predicates for this pattern. + array_clear(&predicate_capture_ids); + for ( + unsigned start = pattern->predicate_steps.offset, + end = start + pattern->predicate_steps.length, + j = start; j < end; j++ + ) { + TSQueryPredicateStep *step = &self->predicate_steps.contents[j]; + if (step->type == TSQueryPredicateStepTypeCapture) { + array_insert_sorted_by(&predicate_capture_ids, 0, , step->value_id); + } + } + + bool all_later_children_definite = true; + for ( + unsigned start = pattern->steps.offset, + end = start + pattern->steps.length, + j = end - 1; j + 1 > start; j-- + ) { + QueryStep *step = &self->steps.contents[j]; + + // If this step has a capture that is used in a predicate, + // then it is not definite. + for (unsigned k = 0; k < MAX_STEP_CAPTURE_COUNT; k++) { + uint16_t capture_id = step->capture_ids[k]; + if (capture_id == NONE) break; + unsigned index, exists; + array_search_sorted_by(&predicate_capture_ids, 0, , capture_id, &index, &exists); + if (exists) { + step->is_definite = false; + break; + } + } + + // If a step is not definite, then none of its predecessors can be definite. if (!all_later_children_definite) step->is_definite = false; if (!step->is_definite) all_later_children_definite = false; } @@ -1197,6 +1230,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index array_delete(&next_states); array_delete(&final_step_indices); array_delete(&parent_step_indices); + array_delete(&predicate_capture_ids); state_predecessor_map_delete(&predecessor_map); return result; @@ -1238,7 +1272,6 @@ static TSQueryError ts_query__parse_predicate( predicate_name, length ); - array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -1249,7 +1282,6 @@ static TSQueryError ts_query__parse_predicate( if (stream->next == ')') { stream_advance(stream); stream_skip_whitespace(stream); - array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeDone, .value_id = 0, @@ -1278,7 +1310,6 @@ static TSQueryError ts_query__parse_predicate( return TSQueryErrorCapture; } - array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeCapture, .value_id = capture_id, @@ -1318,7 +1349,6 @@ static TSQueryError ts_query__parse_predicate( string_content, length ); - array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -1338,7 +1368,6 @@ static TSQueryError ts_query__parse_predicate( symbol_start, length ); - array_back(&self->patterns)->predicate_steps.length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -1778,14 +1807,19 @@ TSQuery *ts_query_new( while (stream.input < stream.end) { uint32_t pattern_index = self->patterns.size; uint32_t start_step_index = self->steps.size; + uint32_t start_predicate_step_index = self->predicate_steps.size; array_push(&self->patterns, ((QueryPattern) { - .predicate_steps = (Slice) {.offset = self->predicate_steps.size, .length = 0}, + .steps = (Slice) {.offset = start_step_index}, + .predicate_steps = (Slice) {.offset = start_predicate_step_index}, .start_byte = stream.input - source, - .start_step = self->steps.size, })); *error_type = ts_query__parse_pattern(self, &stream, 0, false); array_push(&self->steps, query_step__new(0, PATTERN_DONE_MARKER, false)); + QueryPattern *pattern = array_back(&self->patterns); + pattern->steps.length = self->steps.size - start_step_index; + pattern->predicate_steps.length = self->predicate_steps.size - start_predicate_step_index; + // If any pattern could not be parsed, then report the error information // and terminate. if (*error_type) { @@ -1903,7 +1937,7 @@ bool ts_query_pattern_is_definite( TSSymbol symbol, uint32_t index ) { - uint32_t step_index = self->patterns.contents[pattern_index].start_step; + uint32_t step_index = self->patterns.contents[pattern_index].steps.offset; QueryStep *step = &self->steps.contents[step_index]; for (; step->depth != PATTERN_DONE_MARKER; step++) { bool does_match = symbol ? From aac75e35b1e4c519158f26fe048699d127b1ed10 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 19 Aug 2020 13:15:45 -0700 Subject: [PATCH 17/26] Optimize iteration over state successors during query analysis --- lib/src/language.h | 106 +++++++++++++++- lib/src/query.c | 308 ++++++++++++++++++++++++--------------------- 2 files changed, 267 insertions(+), 147 deletions(-) diff --git a/lib/src/language.h b/lib/src/language.h index 288c2a2b..f8fd1ae5 100644 --- a/lib/src/language.h +++ b/lib/src/language.h @@ -20,6 +20,22 @@ typedef struct { bool is_reusable; } TableEntry; +typedef struct { + const TSLanguage *language; + const uint16_t *data; + const uint16_t *group_end; + TSStateId state; + uint16_t table_value; + uint16_t section_index; + uint16_t group_count; + bool is_small_state; + + const TSParseAction *actions; + TSSymbol symbol; + TSStateId next_state; + uint16_t action_count; +} LookaheadIterator; + void ts_language_table_entry(const TSLanguage *, TSStateId, TSSymbol, TableEntry *); TSSymbolMetadata ts_language_symbol_metadata(const TSLanguage *, TSSymbol); @@ -62,6 +78,13 @@ static inline bool ts_language_has_reduce_action( return entry.action_count > 0 && entry.actions[0].type == TSParseActionTypeReduce; } +// Lookup the table value for a given symbol and state. +// +// For non-terminal symbols, the table value represents a successor state. +// For terminal symbols, it represents an index in the actions table. +// For 'large' parse states, this is a direct lookup. For 'small' parse +// states, this requires searching through the symbol groups to find +// the given symbol. static inline uint16_t ts_language_lookup( const TSLanguage *self, TSStateId state, @@ -73,8 +96,8 @@ static inline uint16_t ts_language_lookup( ) { uint32_t index = self->small_parse_table_map[state - self->large_state_count]; const uint16_t *data = &self->small_parse_table[index]; - uint16_t section_count = *(data++); - for (unsigned i = 0; i < section_count; i++) { + uint16_t group_count = *(data++); + for (unsigned i = 0; i < group_count; i++) { uint16_t section_value = *(data++); uint16_t symbol_count = *(data++); for (unsigned i = 0; i < symbol_count; i++) { @@ -87,6 +110,85 @@ static inline uint16_t ts_language_lookup( } } +// Iterate over all of the symbols that are valid in the given state. +// +// For 'large' parse states, this just requires iterating through +// all possible symbols and checking the parse table for each one. +// For 'small' parse states, this exploits the structure of the +// table to only visit the valid symbols. +static inline LookaheadIterator ts_language_lookaheads( + const TSLanguage *self, + TSStateId state +) { + bool is_small_state = + self->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES && + state >= self->large_state_count; + const uint16_t *data; + const uint16_t *group_end = NULL; + uint16_t group_count = 0; + if (is_small_state) { + uint32_t index = self->small_parse_table_map[state - self->large_state_count]; + data = &self->small_parse_table[index]; + group_end = data + 1; + group_count = *data; + } else { + data = &self->parse_table[state * self->symbol_count] - 1; + } + return (LookaheadIterator) { + .language = self, + .data = data, + .group_end = group_end, + .group_count = group_count, + .is_small_state = is_small_state, + .symbol = UINT16_MAX, + .next_state = 0, + }; +} + +static inline bool ts_lookahead_iterator_next(LookaheadIterator *self) { + // For small parse states, valid symbols are listed explicitly, + // grouped by their value. There's no need to look up the actions + // again until moving to the next group. + if (self->is_small_state) { + self->data++; + if (self->data == self->group_end) { + if (self->group_count == 0) return false; + self->group_count--; + self->table_value = *(self->data++); + unsigned symbol_count = *(self->data++); + self->group_end = self->data + symbol_count; + self->symbol = *self->data; + } else { + self->symbol = *self->data; + return true; + } + } + + // For large parse states, iterate through every symbol until one + // is found that has valid actions. + else { + do { + self->data++; + self->symbol++; + if (self->symbol >= self->language->symbol_count) return false; + self->table_value = *self->data; + } while (!self->table_value); + } + + // Depending on if the symbols is terminal or non-terminal, the table value either + // represents a list of actions or a successor state. + if (self->symbol < self->language->token_count) { + const TSParseActionEntry *entry = &self->language->parse_actions[self->table_value]; + self->action_count = entry->entry.count; + self->actions = (const TSParseAction *)(entry + 1); + self->next_state = 0; + } else { + self->action_count = 0; + self->next_state = self->table_value; + } + return true; +} + static inline TSStateId ts_language_next_state( const TSLanguage *self, TSStateId state, diff --git a/lib/src/query.c b/lib/src/query.c index b3bf0b48..eba5955f 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -599,7 +599,7 @@ static inline int analysis_state__compare( if (self->stack[i].parse_state > other->stack[i].parse_state) return 1; if (self->stack[i].field_id < other->stack[i].field_id) return -1; if (self->stack[i].field_id > other->stack[i].field_id) return 1; - } + } if (self->step_index < other->step_index) return -1; if (self->step_index > other->step_index) return 1; return 0; @@ -769,47 +769,44 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // 3) A list of predecessor states for each state. StatePredecessorMap predecessor_map = state_predecessor_map_new(self->language); for (TSStateId state = 1; state < self->language->state_count; state++) { - unsigned subgraph_index = 0, exists; - for (TSSymbol sym = 0; sym < self->language->token_count; sym++) { - unsigned count; - const TSParseAction *actions = ts_language_actions(self->language, state, sym, &count); - for (unsigned i = 0; i < count; i++) { - const TSParseAction *action = &actions[i]; - if (action->type == TSParseActionTypeReduce) { - TSSymbol symbol = self->language->public_symbol_map[action->params.reduce.symbol]; - array_search_sorted_by( - &subgraphs, - 0, - .symbol, - symbol, - &subgraph_index, - &exists - ); - if (exists) { - AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; - if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { - array_push(&subgraph->nodes, ((AnalysisSubgraphNode) { - .state = state, - .production_id = action->params.reduce.production_id, - .child_index = action->params.reduce.child_count, - .done = true, - })); + unsigned subgraph_index, exists; + LookaheadIterator lookahead_iterator = ts_language_lookaheads(self->language, state); + while (ts_lookahead_iterator_next(&lookahead_iterator)) { + if (lookahead_iterator.action_count) { + for (unsigned i = 0; i < lookahead_iterator.action_count; i++) { + const TSParseAction *action = &lookahead_iterator.actions[i]; + if (action->type == TSParseActionTypeReduce) { + TSSymbol symbol = self->language->public_symbol_map[action->params.reduce.symbol]; + array_search_sorted_by( + &subgraphs, + 0, + .symbol, + symbol, + &subgraph_index, + &exists + ); + if (exists) { + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { + array_push(&subgraph->nodes, ((AnalysisSubgraphNode) { + .state = state, + .production_id = action->params.reduce.production_id, + .child_index = action->params.reduce.child_count, + .done = true, + })); + } } + } else if (action->type == TSParseActionTypeShift && !action->params.shift.extra) { + TSStateId next_state = action->params.shift.state; + state_predecessor_map_add(&predecessor_map, next_state, state); } - } else if (action->type == TSParseActionTypeShift && !action->params.shift.extra) { - TSStateId next_state = action->params.shift.state; - state_predecessor_map_add(&predecessor_map, next_state, state); } - } - } - for (TSSymbol sym = self->language->token_count; sym < self->language->symbol_count; sym++) { - TSStateId next_state = ts_language_next_state(self->language, state, sym); - if (next_state != 0 && next_state != state) { - state_predecessor_map_add(&predecessor_map, next_state, state); - TSSymbol symbol = self->language->public_symbol_map[sym]; + } else if (lookahead_iterator.next_state != 0 && lookahead_iterator.next_state != state) { + state_predecessor_map_add(&predecessor_map, lookahead_iterator.next_state, state); + TSSymbol symbol = self->language->public_symbol_map[lookahead_iterator.symbol]; array_search_sorted_by( &subgraphs, - subgraph_index, + 0, .symbol, symbol, &subgraph_index, @@ -871,6 +868,12 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index for (unsigned i = 0; i < subgraphs.size; i++) { AnalysisSubgraph *subgraph = &subgraphs.contents[i]; printf(" %u, %s:\n", subgraph->symbol, ts_language_symbol_name(self->language, subgraph->symbol)); + for (unsigned j = 0; j < subgraph->start_states.size; j++) { + printf( + " {state: %u}\n", + subgraph->start_states.contents[j] + ); + } for (unsigned j = 0; j < subgraph->nodes.size; j++) { AnalysisSubgraphNode *node = &subgraph->nodes.contents[j]; printf( @@ -985,122 +988,137 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // Follow every possible path in the parse table, but only visit states that // are part of the subgraph for the current symbol. - for (TSSymbol sym = 0; sym < self->language->symbol_count; sym++) { + LookaheadIterator lookahead_iterator = ts_language_lookaheads(self->language, parse_state); + while (ts_lookahead_iterator_next(&lookahead_iterator)) { + TSSymbol sym = lookahead_iterator.symbol; + + TSStateId next_parse_state; + if (lookahead_iterator.action_count) { + const TSParseAction *action = &lookahead_iterator.actions[lookahead_iterator.action_count - 1]; + if (action->type == TSParseActionTypeShift && !action->params.shift.extra) { + next_parse_state = action->params.shift.state; + } else { + continue; + } + } else if (lookahead_iterator.next_state != 0 && lookahead_iterator.next_state != parse_state) { + next_parse_state = lookahead_iterator.next_state; + } else { + continue; + } + AnalysisSubgraphNode successor = { - .state = ts_language_next_state(self->language, parse_state, sym), + .state = next_parse_state, .child_index = child_index + 1, }; - if (successor.state && successor.state != parse_state) { - unsigned node_index; - array_search_sorted_with( - &subgraph->nodes, 0, - analysis_subgraph_node__compare, &successor, - &node_index, &exists - ); - while (node_index < subgraph->nodes.size) { - AnalysisSubgraphNode *node = &subgraph->nodes.contents[node_index++]; - if (node->state != successor.state || node->child_index != successor.child_index) break; + unsigned node_index; + array_search_sorted_with( + &subgraph->nodes, 0, + analysis_subgraph_node__compare, &successor, + &node_index, &exists + ); + while (node_index < subgraph->nodes.size) { + AnalysisSubgraphNode *node = &subgraph->nodes.contents[node_index++]; + if (node->state != successor.state || node->child_index != successor.child_index) break; - // Use the subgraph to determine what alias and field will eventually be applied - // to this child node. - TSSymbol alias = ts_language_alias_at(self->language, node->production_id, child_index); - TSSymbol visible_symbol = alias - ? alias - : self->language->symbol_metadata[sym].visible - ? self->language->public_symbol_map[sym] - : 0; - TSFieldId field_id = parent_field_id; - if (!field_id) { - const TSFieldMapEntry *field_map, *field_map_end; - ts_language_field_map(self->language, node->production_id, &field_map, &field_map_end); - for (; field_map != field_map_end; field_map++) { - if (field_map->child_index == child_index) { - field_id = field_map->field_id; - break; - } - } - } - - AnalysisState next_state = *state; - analysis_state__top(&next_state)->child_index++; - analysis_state__top(&next_state)->parse_state = successor.state; - if (node->done) analysis_state__top(&next_state)->done = true; - - // Determine if this hypothetical child node would match the current step - // of the query pattern. - bool does_match = false; - if (visible_symbol) { - does_match = true; - if (step->symbol == NAMED_WILDCARD_SYMBOL) { - if (!self->language->symbol_metadata[visible_symbol].named) does_match = false; - } else if (step->symbol != WILDCARD_SYMBOL) { - if (step->symbol != visible_symbol) does_match = false; - } - if (step->field && step->field != field_id) { - does_match = false; - } - } - - // If this is a hidden child, then push a new entry to the stack, in order to - // walk through the children of this child. - else if (sym >= self->language->token_count && next_state.depth < MAX_ANALYSIS_STATE_DEPTH) { - next_state.depth++; - analysis_state__top(&next_state)->parse_state = parse_state; - analysis_state__top(&next_state)->child_index = 0; - analysis_state__top(&next_state)->parent_symbol = sym; - analysis_state__top(&next_state)->field_id = field_id; - analysis_state__top(&next_state)->done = false; - } else { - continue; - } - - // Pop from the stack when this state reached the end of its current syntax node. - while (next_state.depth > 0 && analysis_state__top(&next_state)->done) { - next_state.depth--; - } - - // If this hypothetical child did match the current step of the query pattern, - // then advance to the next step at the current depth. This involves skipping - // over any descendant steps of the current child. - const QueryStep *next_step = step; - if (does_match) { - for (;;) { - next_state.step_index++; - next_step = &self->steps.contents[next_state.step_index]; - if ( - next_step->depth == PATTERN_DONE_MARKER || - next_step->depth <= parent_depth + 1 - ) break; - } - } - - for (;;) { - // If this state can make further progress, then add it to the states for the next iteration. - // Otherwise, record the fact that matching can fail at this step of the pattern. - if (!next_step->is_dead_end) { - bool did_finish_pattern = self->steps.contents[next_state.step_index].depth != parent_depth + 1; - if (did_finish_pattern) can_finish_pattern = true; - if (next_state.depth > 0 && !did_finish_pattern) { - array_insert_sorted_with(&next_states, 0, analysis_state__compare, next_state); - } else { - array_insert_sorted_by(&final_step_indices, 0, , next_state.step_index); - } - } - - // If the state has advanced to a step with an alternative step, then add another state at - // that alternative step to the next iteration. - if ( - does_match && - next_step->alternative_index != NONE && - next_step->alternative_index > next_state.step_index - ) { - next_state.step_index = next_step->alternative_index; - next_step = &self->steps.contents[next_state.step_index]; - } else { + // Use the subgraph to determine what alias and field will eventually be applied + // to this child node. + TSSymbol alias = ts_language_alias_at(self->language, node->production_id, child_index); + TSSymbol visible_symbol = alias + ? alias + : self->language->symbol_metadata[sym].visible + ? self->language->public_symbol_map[sym] + : 0; + TSFieldId field_id = parent_field_id; + if (!field_id) { + const TSFieldMapEntry *field_map, *field_map_end; + ts_language_field_map(self->language, node->production_id, &field_map, &field_map_end); + for (; field_map != field_map_end; field_map++) { + if (field_map->child_index == child_index) { + field_id = field_map->field_id; break; } } } + + AnalysisState next_state = *state; + analysis_state__top(&next_state)->child_index++; + analysis_state__top(&next_state)->parse_state = successor.state; + if (node->done) analysis_state__top(&next_state)->done = true; + + // Determine if this hypothetical child node would match the current step + // of the query pattern. + bool does_match = false; + if (visible_symbol) { + does_match = true; + if (step->symbol == NAMED_WILDCARD_SYMBOL) { + if (!self->language->symbol_metadata[visible_symbol].named) does_match = false; + } else if (step->symbol != WILDCARD_SYMBOL) { + if (step->symbol != visible_symbol) does_match = false; + } + if (step->field && step->field != field_id) { + does_match = false; + } + } + + // If this is a hidden child, then push a new entry to the stack, in order to + // walk through the children of this child. + else if (sym >= self->language->token_count && next_state.depth < MAX_ANALYSIS_STATE_DEPTH) { + next_state.depth++; + analysis_state__top(&next_state)->parse_state = parse_state; + analysis_state__top(&next_state)->child_index = 0; + analysis_state__top(&next_state)->parent_symbol = sym; + analysis_state__top(&next_state)->field_id = field_id; + analysis_state__top(&next_state)->done = false; + } else { + continue; + } + + // Pop from the stack when this state reached the end of its current syntax node. + while (next_state.depth > 0 && analysis_state__top(&next_state)->done) { + next_state.depth--; + } + + // If this hypothetical child did match the current step of the query pattern, + // then advance to the next step at the current depth. This involves skipping + // over any descendant steps of the current child. + const QueryStep *next_step = step; + if (does_match) { + for (;;) { + next_state.step_index++; + next_step = &self->steps.contents[next_state.step_index]; + if ( + next_step->depth == PATTERN_DONE_MARKER || + next_step->depth <= parent_depth + 1 + ) break; + } + } + + for (;;) { + // If this state can make further progress, then add it to the states for the next iteration. + // Otherwise, record the fact that matching can fail at this step of the pattern. + if (!next_step->is_dead_end) { + bool did_finish_pattern = self->steps.contents[next_state.step_index].depth != parent_depth + 1; + if (did_finish_pattern) can_finish_pattern = true; + if (next_state.depth > 0 && !did_finish_pattern) { + array_insert_sorted_with(&next_states, 0, analysis_state__compare, next_state); + } else { + array_insert_sorted_by(&final_step_indices, 0, , next_state.step_index); + } + } + + // If the state has advanced to a step with an alternative step, then add another state at + // that alternative step to the next iteration. + if ( + does_match && + next_step->alternative_index != NONE && + next_step->alternative_index > next_state.step_index + ) { + next_state.step_index = next_step->alternative_index; + next_step = &self->steps.contents[next_state.step_index]; + } else { + break; + } + } } } } From d47346abc076410a531876eb5635d8230c69b72f Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 20 Aug 2020 10:07:22 -0700 Subject: [PATCH 18/26] Avoid pushing duplicate start states in query analysis --- lib/src/query.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/src/query.c b/lib/src/query.c index eba5955f..a156beb9 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -599,7 +599,7 @@ static inline int analysis_state__compare( if (self->stack[i].parse_state > other->stack[i].parse_state) return 1; if (self->stack[i].field_id < other->stack[i].field_id) return -1; if (self->stack[i].field_id > other->stack[i].field_id) return 1; - } + } if (self->step_index < other->step_index) return -1; if (self->step_index > other->step_index) return 1; return 0; @@ -814,6 +814,10 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index ); if (exists) { AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if ( + subgraph->start_states.size == 0 || + *array_back(&subgraph->start_states) != state + ) array_push(&subgraph->start_states, state); } } From 4301110c126b8fabe45a00b20ce965d4043910d8 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 20 Aug 2020 13:06:38 -0700 Subject: [PATCH 19/26] query: Indicate specific step that's impossible --- cli/src/tests/query_test.rs | 69 ++++++++++++----------- lib/binding_rust/bindings.rs | 7 +-- lib/binding_rust/lib.rs | 12 ++-- lib/include/tree_sitter/api.h | 6 +- lib/src/query.c | 100 ++++++++++++++++++++-------------- 5 files changed, 103 insertions(+), 91 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index a18c3a8b..1e4ea8cc 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -197,7 +197,11 @@ fn test_query_errors_on_impossible_patterns() { ), Err(QueryError::Pattern( 1, - "(binary_expression left: (identifier) left: (identifier))\n^".to_string(), + [ + "(binary_expression left: (identifier) left: (identifier))", + " ^" + ] + .join("\n"), )) ); @@ -210,7 +214,11 @@ fn test_query_errors_on_impossible_patterns() { Query::new(js_lang, "(function_declaration name: (statement_block))"), Err(QueryError::Pattern( 1, - "(function_declaration name: (statement_block))\n^".to_string(), + [ + "(function_declaration name: (statement_block))", + " ^", + ] + .join("\n") )) ); @@ -219,7 +227,11 @@ fn test_query_errors_on_impossible_patterns() { Query::new(rb_lang, "(call receiver:(binary))"), Err(QueryError::Pattern( 1, - "(call receiver:(binary))\n^".to_string(), + [ + "(call receiver:(binary))", // + " ^", + ] + .join("\n") )) ); }); @@ -2307,55 +2319,52 @@ fn test_query_alternative_predicate_prefix() { } #[test] -fn test_query_is_definite() { +fn test_query_step_is_definite() { struct Row { language: Language, pattern: &'static str, - results_by_symbol: &'static [(&'static str, bool)], + results_by_substring: &'static [(&'static str, bool)], } let rows = &[ Row { language: get_language("python"), pattern: r#"(expression_statement (string))"#, - results_by_symbol: &[("expression_statement", false), ("string", false)], + results_by_substring: &[("expression_statement", false), ("string", false)], }, Row { language: get_language("javascript"), pattern: r#"(expression_statement (string))"#, - results_by_symbol: &[ - ("expression_statement", false), - ("string", false), // string - ], + results_by_substring: &[("expression_statement", false), ("string", false)], }, Row { language: get_language("javascript"), pattern: r#"(object "{" "}")"#, - results_by_symbol: &[("object", false), ("{", true), ("}", true)], + results_by_substring: &[("object", false), ("{", true), ("}", true)], }, Row { language: get_language("javascript"), pattern: r#"(pair (property_identifier) ":")"#, - results_by_symbol: &[("pair", false), ("property_identifier", false), (":", true)], + results_by_substring: &[("pair", false), ("property_identifier", false), (":", true)], }, Row { language: get_language("javascript"), pattern: r#"(object "{" (_) "}")"#, - results_by_symbol: &[("object", false), ("{", false), ("", false), ("}", true)], + results_by_substring: &[("object", false), ("{", false), ("", false), ("}", true)], }, Row { language: get_language("javascript"), pattern: r#"(binary_expression left: (identifier) right: (_))"#, - results_by_symbol: &[ + results_by_substring: &[ ("binary_expression", false), - ("identifier", false), - ("", true), + ("(identifier)", false), + ("(_)", true), ], }, Row { language: get_language("javascript"), pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, - results_by_symbol: &[ + results_by_substring: &[ ("function_declaration", false), ("identifier", true), ("statement_block", true), @@ -2367,7 +2376,7 @@ fn test_query_is_definite() { (function_declaration name: (identifier) body: (statement_block "{" (expression_statement) "}"))"#, - results_by_symbol: &[ + results_by_substring: &[ ("function_declaration", false), ("identifier", false), ("statement_block", false), @@ -2383,7 +2392,7 @@ fn test_query_is_definite() { value: (constant) "end") "#, - results_by_symbol: &[ + results_by_substring: &[ ("singleton_class", false), ("constant", false), ("end", true), @@ -2397,7 +2406,7 @@ fn test_query_is_definite() { property: (property_identifier) @template-tag) arguments: (template_string)) @template-call "#, - results_by_symbol: &[("property_identifier", false), ("template_string", false)], + results_by_substring: &[("property_identifier", false), ("template_string", false)], }, Row { language: get_language("javascript"), @@ -2408,7 +2417,7 @@ fn test_query_is_definite() { property: (property_identifier) @prop) "[") "#, - results_by_symbol: &[ + results_by_substring: &[ ("identifier", false), ("property_identifier", true), ("[", true), @@ -2424,7 +2433,7 @@ fn test_query_is_definite() { "[" (#match? @prop "foo")) "#, - results_by_symbol: &[ + results_by_substring: &[ ("identifier", false), ("property_identifier", false), ("[", true), @@ -2435,23 +2444,17 @@ fn test_query_is_definite() { allocations::record(|| { for row in rows.iter() { let query = Query::new(row.language, row.pattern).unwrap(); - for (symbol_name, is_definite) in row.results_by_symbol { - let mut symbol = 0; - if !symbol_name.is_empty() { - symbol = row.language.id_for_node_kind(symbol_name, true); - if symbol == 0 { - symbol = row.language.id_for_node_kind(symbol_name, false); - } - } + for (substring, is_definite) in row.results_by_substring { + let offset = row.pattern.find(substring).unwrap(); assert_eq!( - query.pattern_is_definite(0, symbol, 0), + query.step_is_definite(offset), *is_definite, - "Pattern: {:?}, symbol: {}, expected is_definite to be {}", + "Pattern: {:?}, substring: {:?}, expected is_definite to be {}", row.pattern .split_ascii_whitespace() .collect::>() .join(" "), - symbol_name, + substring, is_definite, ) } diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index b5ff7a9e..81cc6f9a 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -651,12 +651,7 @@ extern "C" { ) -> *const TSQueryPredicateStep; } extern "C" { - pub fn ts_query_pattern_is_definite( - self_: *const TSQuery, - pattern_index: u32, - symbol: TSSymbol, - step_index: u32, - ) -> bool; + pub fn ts_query_step_is_definite(self_: *const TSQuery, byte_offset: u32) -> bool; } extern "C" { #[doc = " Get the name and length of one of the query\'s captures, or one of the"] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index c601aecc..10cd9fc2 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1467,12 +1467,12 @@ impl Query { unsafe { ffi::ts_query_disable_pattern(self.ptr.as_ptr(), index as u32) } } - /// Check if a pattern will definitely match after a certain number of steps - /// have matched. - pub fn pattern_is_definite(&self, pattern_index: usize, symbol: u16, step_index: usize) -> bool { - unsafe { - ffi::ts_query_pattern_is_definite(self.ptr.as_ptr(), pattern_index as u32, symbol, step_index as u32) - } + /// Check if a given step in a query is 'definite'. + /// + /// A query step is 'definite' if its parent pattern will be guaranteed to match + /// successfully once it reaches the step. + pub fn step_is_definite(&self, byte_offset: usize) -> bool { + unsafe { ffi::ts_query_step_is_definite(self.ptr.as_ptr(), byte_offset as u32) } } fn parse_property( diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 850cd31e..1e60e4b5 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -719,11 +719,9 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t *length ); -bool ts_query_pattern_is_definite( +bool ts_query_step_is_definite( const TSQuery *self, - uint32_t pattern_index, - TSSymbol symbol, - uint32_t step_index + uint32_t byte_offset ); /** diff --git a/lib/src/query.c b/lib/src/query.c index a156beb9..5a2bb2fb 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -22,6 +22,7 @@ */ typedef struct { const char *input; + const char *start; const char *end; int32_t next; uint8_t next_size; @@ -96,6 +97,11 @@ typedef struct { uint32_t start_byte; } QueryPattern; +typedef struct { + uint32_t byte_offset; + uint16_t step_index; +} StepOffset; + /* * QueryState - The state of an in-progress match of a particular pattern * in a query. While executing, a `TSQueryCursor` must keep track of a number @@ -202,6 +208,7 @@ struct TSQuery { Array(PatternEntry) pattern_map; Array(TSQueryPredicateStep) predicate_steps; Array(QueryPattern) patterns; + Array(StepOffset) step_offsets; const TSLanguage *language; uint16_t wildcard_root_pattern_count; TSSymbol *symbol_map; @@ -268,21 +275,22 @@ static Stream stream_new(const char *string, uint32_t length) { Stream self = { .next = 0, .input = string, + .start = string, .end = string + length, }; stream_advance(&self); return self; } -static void stream_skip_whitespace(Stream *stream) { +static void stream_skip_whitespace(Stream *self) { for (;;) { - if (iswspace(stream->next)) { - stream_advance(stream); - } else if (stream->next == ';') { + if (iswspace(self->next)) { + stream_advance(self); + } else if (self->next == ';') { // skip over comments - stream_advance(stream); - while (stream->next && stream->next != '\n') { - if (!stream_advance(stream)) break; + stream_advance(self); + while (self->next && self->next != '\n') { + if (!stream_advance(self)) break; } } else { break; @@ -290,8 +298,8 @@ static void stream_skip_whitespace(Stream *stream) { } } -static bool stream_is_ident_start(Stream *stream) { - return iswalnum(stream->next) || stream->next == '_' || stream->next == '-'; +static bool stream_is_ident_start(Stream *self) { + return iswalnum(self->next) || self->next == '_' || self->next == '-'; } static void stream_scan_identifier(Stream *stream) { @@ -307,6 +315,10 @@ static void stream_scan_identifier(Stream *stream) { ); } +static uint32_t stream_offset(Stream *self) { + return self->input - self->start; +} + /****************** * CaptureListPool ******************/ @@ -716,7 +728,7 @@ static inline void ts_query__pattern_map_insert( // #define DEBUG_ANALYZE_QUERY -static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index) { +static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { // Identify all of the patterns in the query that have child patterns, both at the // top level and nested within other larger patterns. Record the step index where // each pattern starts. @@ -1165,12 +1177,12 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *impossible_index // If this pattern cannot match, store the pattern index so that it can be // returned to the caller. if (result && !can_finish_pattern) { - unsigned exists; - array_search_sorted_by( - &self->patterns, 0, - .steps.offset, parent_step_index, - impossible_index, &exists - ); + assert(final_step_indices.size > 0); + uint16_t *impossible_step_index = array_back(&final_step_indices); + uint32_t i, exists; + array_search_sorted_by(&self->step_offsets, 0, .step_index, *impossible_step_index, &i, &exists); + assert(exists); + *error_offset = self->step_offsets.contents[i].byte_offset; result = false; } } @@ -1415,17 +1427,24 @@ static TSQueryError ts_query__parse_pattern( uint32_t depth, bool is_immediate ) { + if (stream->next == 0) return TSQueryErrorSyntax; + if (stream->next == ')' || stream->next == ']') return PARENT_DONE; + const uint32_t starting_step_index = self->steps.size; - if (stream->next == 0) return TSQueryErrorSyntax; - - // Finish the parent S-expression. - if (stream->next == ')' || stream->next == ']') { - return PARENT_DONE; + // Store the byte offset of each step in the query. + if ( + self->step_offsets.size == 0 || + array_back(&self->step_offsets)->step_index != starting_step_index + ) { + array_push(&self->step_offsets, ((StepOffset) { + .step_index = starting_step_index, + .byte_offset = stream_offset(stream), + })); } // An open bracket is the start of an alternation. - else if (stream->next == '[') { + if (stream->next == '[') { stream_advance(stream); stream_skip_whitespace(stream); @@ -1818,6 +1837,7 @@ TSQuery *ts_query_new( .predicate_values = symbol_table_new(), .predicate_steps = array_new(), .patterns = array_new(), + .step_offsets = array_new(), .symbol_map = symbol_map, .wildcard_root_pattern_count = 0, .language = language, @@ -1833,7 +1853,7 @@ TSQuery *ts_query_new( array_push(&self->patterns, ((QueryPattern) { .steps = (Slice) {.offset = start_step_index}, .predicate_steps = (Slice) {.offset = start_predicate_step_index}, - .start_byte = stream.input - source, + .start_byte = stream_offset(&stream), })); *error_type = ts_query__parse_pattern(self, &stream, 0, false); array_push(&self->steps, query_step__new(0, PATTERN_DONE_MARKER, false)); @@ -1846,7 +1866,7 @@ TSQuery *ts_query_new( // and terminate. if (*error_type) { if (*error_type == PARENT_DONE) *error_type = TSQueryErrorSyntax; - *error_offset = stream.input - source; + *error_offset = stream_offset(&stream); ts_query_delete(self); return NULL; } @@ -1882,10 +1902,8 @@ TSQuery *ts_query_new( } if (self->language->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT) { - unsigned impossible_pattern_index = 0; - if (!ts_query__analyze_patterns(self, &impossible_pattern_index)) { + if (!ts_query__analyze_patterns(self, error_offset)) { *error_type = TSQueryErrorPattern; - *error_offset = self->patterns.contents[impossible_pattern_index].start_byte; ts_query_delete(self); return NULL; } @@ -1901,6 +1919,7 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->pattern_map); array_delete(&self->predicate_steps); array_delete(&self->patterns); + array_delete(&self->step_offsets); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); ts_free(self->symbol_map); @@ -1953,24 +1972,21 @@ uint32_t ts_query_start_byte_for_pattern( return self->patterns.contents[pattern_index].start_byte; } -bool ts_query_pattern_is_definite( +bool ts_query_step_is_definite( const TSQuery *self, - uint32_t pattern_index, - TSSymbol symbol, - uint32_t index + uint32_t byte_offset ) { - uint32_t step_index = self->patterns.contents[pattern_index].steps.offset; - QueryStep *step = &self->steps.contents[step_index]; - for (; step->depth != PATTERN_DONE_MARKER; step++) { - bool does_match = symbol ? - step->symbol == symbol : - step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL; - if (does_match) { - if (index == 0) return step->is_definite; - index--; - } + uint32_t step_index = UINT32_MAX; + for (unsigned i = 0; i < self->step_offsets.size; i++) { + StepOffset *step_offset = &self->step_offsets.contents[i]; + if (step_offset->byte_offset >= byte_offset) break; + step_index = step_offset->step_index; + } + if (step_index < self->steps.size) { + return self->steps.contents[step_index].is_definite; + } else { + return false; } - return false; } void ts_query_disable_capture( From 9daec9cb22d6485acd776dd826a889e583eb74ad Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 20 Aug 2020 13:24:42 -0700 Subject: [PATCH 20/26] Tweak impossible pattern error messages --- cli/src/error.rs | 2 +- cli/src/tests/query_test.rs | 6 +++--- docs/assets/js/playground.js | 2 +- lib/binding_rust/bindings.rs | 2 +- lib/binding_rust/lib.rs | 6 +++--- lib/binding_web/binding.js | 8 +++++--- lib/binding_web/exports.json | 3 --- lib/binding_web/test/query-test.js | 3 +++ lib/include/tree_sitter/api.h | 2 +- lib/src/query.c | 2 +- 10 files changed, 19 insertions(+), 17 deletions(-) diff --git a/cli/src/error.rs b/cli/src/error.rs index c30e3647..075de3a6 100644 --- a/cli/src/error.rs +++ b/cli/src/error.rs @@ -70,7 +70,7 @@ impl<'a> From for Error { "Query error on line {}. Invalid syntax:\n{}", row, l )), - QueryError::Pattern(row, l) => Error::new(format!( + QueryError::Structure(row, l) => Error::new(format!( "Query error on line {}. Impossible pattern:\n{}", row, l )), diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 1e4ea8cc..e7231ef0 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -195,7 +195,7 @@ fn test_query_errors_on_impossible_patterns() { js_lang, "(binary_expression left: (identifier) left: (identifier))" ), - Err(QueryError::Pattern( + Err(QueryError::Structure( 1, [ "(binary_expression left: (identifier) left: (identifier))", @@ -212,7 +212,7 @@ fn test_query_errors_on_impossible_patterns() { .unwrap(); assert_eq!( Query::new(js_lang, "(function_declaration name: (statement_block))"), - Err(QueryError::Pattern( + Err(QueryError::Structure( 1, [ "(function_declaration name: (statement_block))", @@ -225,7 +225,7 @@ fn test_query_errors_on_impossible_patterns() { Query::new(rb_lang, "(call receiver:(call))").unwrap(); assert_eq!( Query::new(rb_lang, "(call receiver:(binary))"), - Err(QueryError::Pattern( + Err(QueryError::Structure( 1, [ "(call receiver:(binary))", // diff --git a/docs/assets/js/playground.js b/docs/assets/js/playground.js index 686be90d..137bb352 100644 --- a/docs/assets/js/playground.js +++ b/docs/assets/js/playground.js @@ -277,7 +277,7 @@ let tree; const startPosition = queryEditor.posFromIndex(error.index); const endPosition = { line: startPosition.line, - ch: startPosition.ch + (error.length || 1) + ch: startPosition.ch + (error.length || Infinity) }; if (error.index === queryText.length) { diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 81cc6f9a..f28d3461 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -132,7 +132,7 @@ pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1; pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2; pub const TSQueryError_TSQueryErrorField: TSQueryError = 3; pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4; -pub const TSQueryError_TSQueryErrorPattern: TSQueryError = 5; +pub const TSQueryError_TSQueryErrorStructure: TSQueryError = 5; pub type TSQueryError = u32; extern "C" { #[doc = " Create a new parser."] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 10cd9fc2..ea5893b4 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -163,7 +163,7 @@ pub enum QueryError { Field(usize, String), Capture(usize, String), Predicate(String), - Pattern(usize, String), + Structure(usize, String), } #[derive(Debug)] @@ -1206,8 +1206,8 @@ impl Query { "Unexpected EOF".to_string() }; match error_type { - ffi::TSQueryError_TSQueryErrorPattern => { - Err(QueryError::Pattern(row, message)) + ffi::TSQueryError_TSQueryErrorStructure => { + Err(QueryError::Structure(row, message)) } _ => Err(QueryError::Syntax(row, message)), } diff --git a/lib/binding_web/binding.js b/lib/binding_web/binding.js index 404beeb6..f731e8f8 100644 --- a/lib/binding_web/binding.js +++ b/lib/binding_web/binding.js @@ -667,8 +667,8 @@ class Language { const errorId = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); const errorByte = getValue(TRANSFER_BUFFER, 'i32'); const errorIndex = UTF8ToString(sourceAddress, errorByte).length; - const suffix = source.substr(errorIndex, 100); - const word = suffix.match(QUERY_WORD_REGEX)[0]; + const suffix = source.substr(errorIndex, 100).split('\n')[0]; + let word = suffix.match(QUERY_WORD_REGEX)[0]; let error; switch (errorId) { case 2: @@ -681,10 +681,12 @@ class Language { error = new RangeError(`Bad capture name @${word}`); break; case 5: - error = new SyntaxError(`Impossible pattern at offset ${errorIndex}: '${suffix}'...`); + error = new TypeError(`Bad pattern structure at offset ${errorIndex}: '${suffix}'...`); + word = ""; break; default: error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`); + word = ""; break; } error.index = errorIndex; diff --git a/lib/binding_web/exports.json b/lib/binding_web/exports.json index 2c638249..72105158 100644 --- a/lib/binding_web/exports.json +++ b/lib/binding_web/exports.json @@ -15,7 +15,6 @@ "__ZNSt3__212basic_stringIwNS_11char_traitsIwEENS_9allocatorIwEEED2Ev", "__ZdlPv", "__Znwm", - "___assert_fail", "_abort", "_iswalnum", "_iswalpha", @@ -73,8 +72,6 @@ "_ts_query_capture_count", "_ts_query_capture_name_for_id", "_ts_query_captures_wasm", - "_ts_query_context_delete", - "_ts_query_context_new", "_ts_query_delete", "_ts_query_matches_wasm", "_ts_query_new", diff --git a/lib/binding_web/test/query-test.js b/lib/binding_web/test/query-test.js index 9d1e24e1..23663e9a 100644 --- a/lib/binding_web/test/query-test.js +++ b/lib/binding_web/test/query-test.js @@ -30,6 +30,9 @@ describe("Query", () => { assert.throws(() => { JavaScript.query("(function_declaration non_existent:(identifier))"); }, "Bad field name 'non_existent'"); + assert.throws(() => { + JavaScript.query("(function_declaration name:(statement_block))"); + }, "Bad pattern structure at offset 22: 'name:(statement_block))'"); }); it("throws an error on invalid predicates", () => { diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 1e60e4b5..b85380d1 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -130,7 +130,7 @@ typedef enum { TSQueryErrorNodeType, TSQueryErrorField, TSQueryErrorCapture, - TSQueryErrorPattern, + TSQueryErrorStructure, } TSQueryError; /********************/ diff --git a/lib/src/query.c b/lib/src/query.c index 5a2bb2fb..60c892d3 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -1903,7 +1903,7 @@ TSQuery *ts_query_new( if (self->language->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT) { if (!ts_query__analyze_patterns(self, error_offset)) { - *error_type = TSQueryErrorPattern; + *error_type = TSQueryErrorStructure; ts_query_delete(self); return NULL; } From 456b1f6771de9ec689ea350eb4cbdfcf14baa283 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 20 Aug 2020 16:28:54 -0700 Subject: [PATCH 21/26] Fix handling of alternations and optional nodes in query analysis --- cli/src/tests/query_test.rs | 139 +++++++++++++++++++++++++++++++++--- lib/src/query.c | 91 ++++++++++++----------- script/test | 12 ++-- 3 files changed, 190 insertions(+), 52 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index e7231ef0..816c3aee 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1,11 +1,17 @@ use super::helpers::allocations; use super::helpers::fixtures::get_language; +use lazy_static::lazy_static; +use std::env; use std::fmt::Write; use tree_sitter::{ Language, Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch, QueryPredicate, QueryPredicateArg, QueryProperty, }; +lazy_static! { + static ref EXAMPLE_FILTER: Option = env::var("TREE_SITTER_TEST_EXAMPLE_FILTER").ok(); +} + #[test] fn test_query_errors_on_invalid_syntax() { allocations::record(|| { @@ -234,6 +240,34 @@ fn test_query_errors_on_impossible_patterns() { .join("\n") )) ); + + Query::new( + js_lang, + "[ + (function (identifier)) + (function_declaration (identifier)) + (generator_function_declaration (identifier)) + ]", + ) + .unwrap(); + assert_eq!( + Query::new( + js_lang, + "[ + (function (identifier)) + (function_declaration (object)) + (generator_function_declaration (identifier)) + ]", + ), + Err(QueryError::Structure( + 3, + [ + " (function_declaration (object))", // + " ^", + ] + .join("\n") + )) + ); }); } @@ -2322,37 +2356,92 @@ fn test_query_alternative_predicate_prefix() { fn test_query_step_is_definite() { struct Row { language: Language, + description: &'static str, pattern: &'static str, results_by_substring: &'static [(&'static str, bool)], } let rows = &[ Row { + description: "no definite steps", language: get_language("python"), pattern: r#"(expression_statement (string))"#, results_by_substring: &[("expression_statement", false), ("string", false)], }, Row { - language: get_language("javascript"), - pattern: r#"(expression_statement (string))"#, - results_by_substring: &[("expression_statement", false), ("string", false)], - }, - Row { + description: "all definite steps", language: get_language("javascript"), pattern: r#"(object "{" "}")"#, results_by_substring: &[("object", false), ("{", true), ("}", true)], }, Row { + description: "an indefinite step that is optional", + language: get_language("javascript"), + pattern: r#"(object "{" (identifier)? @foo "}")"#, + results_by_substring: &[ + ("object", false), + ("{", true), + ("(identifier)?", false), + ("}", true), + ], + }, + Row { + description: "multiple indefinite steps that are optional", + language: get_language("javascript"), + pattern: r#"(object "{" (identifier)? @id1 ("," (identifier) @id2)? "}")"#, + results_by_substring: &[ + ("object", false), + ("{", true), + ("(identifier)? @id1", false), + ("\",\"", false), + ("}", true), + ], + }, + Row { + description: "definite step after indefinite step", language: get_language("javascript"), pattern: r#"(pair (property_identifier) ":")"#, results_by_substring: &[("pair", false), ("property_identifier", false), (":", true)], }, Row { + description: "indefinite step in between two definite steps", language: get_language("javascript"), - pattern: r#"(object "{" (_) "}")"#, - results_by_substring: &[("object", false), ("{", false), ("", false), ("}", true)], + pattern: r#"(ternary_expression + condition: (_) + "?" + consequence: (call_expression) + ":" + alternative: (_))"#, + results_by_substring: &[ + ("condition:", false), + ("\"?\"", false), + ("consequence:", false), + ("\":\"", true), + ("alternative:", true), + ], }, Row { + description: "one definite step after a repetition", + language: get_language("javascript"), + pattern: r#"(object "{" (_) "}")"#, + results_by_substring: &[("object", false), ("{", false), ("(_)", false), ("}", true)], + }, + Row { + description: "definite steps after multiple repetitions", + language: get_language("json"), + pattern: r#"(object "{" (pair) "," (pair) "," (_) "}")"#, + results_by_substring: &[ + ("object", false), + ("{", false), + ("(pair) \",\" (pair)", false), + ("(pair) \",\" (_)", false), + ("\",\" (_)", false), + ("(_)", true), + ("}", true), + ], + }, + Row { + description: "a definite with a field", language: get_language("javascript"), pattern: r#"(binary_expression left: (identifier) right: (_))"#, results_by_substring: &[ @@ -2362,6 +2451,7 @@ fn test_query_step_is_definite() { ], }, Row { + description: "multiple definite steps with fields", language: get_language("javascript"), pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, results_by_substring: &[ @@ -2371,6 +2461,7 @@ fn test_query_step_is_definite() { ], }, Row { + description: "nesting, one definite step", language: get_language("javascript"), pattern: r#" (function_declaration @@ -2386,6 +2477,7 @@ fn test_query_step_is_definite() { ], }, Row { + description: "definite step after some deeply nested hidden nodes", language: get_language("ruby"), pattern: r#" (singleton_class @@ -2399,6 +2491,7 @@ fn test_query_step_is_definite() { ], }, Row { + description: "nesting, no definite steps", language: get_language("javascript"), pattern: r#" (call_expression @@ -2409,6 +2502,7 @@ fn test_query_step_is_definite() { results_by_substring: &[("property_identifier", false), ("template_string", false)], }, Row { + description: "a definite step after a nested node", language: get_language("javascript"), pattern: r#" (subscript_expression @@ -2424,6 +2518,7 @@ fn test_query_step_is_definite() { ], }, Row { + description: "a step that is indefinite due to a predicate", language: get_language("javascript"), pattern: r#" (subscript_expression @@ -2439,17 +2534,45 @@ fn test_query_step_is_definite() { ("[", true), ], }, + Row { + description: "alternation where one branch has definite steps", + language: get_language("javascript"), + pattern: r#" + [ + (unary_expression (identifier)) + (call_expression + function: (_) + arguments: (_)) + (binary_expression right:(call_expression)) + ] + "#, + results_by_substring: &[ + ("identifier", false), + ("right:", false), + ("function:", true), + ("arguments:", true), + ], + }, ]; allocations::record(|| { + eprintln!(""); + for row in rows.iter() { + if let Some(filter) = EXAMPLE_FILTER.as_ref() { + if !row.description.contains(filter.as_str()) { + continue; + } + } + eprintln!(" query example: {:?}", row.description); let query = Query::new(row.language, row.pattern).unwrap(); for (substring, is_definite) in row.results_by_substring { let offset = row.pattern.find(substring).unwrap(); assert_eq!( query.step_is_definite(offset), *is_definite, - "Pattern: {:?}, substring: {:?}, expected is_definite to be {}", + "Description: {}, Pattern: {:?}, substring: {:?}, expected is_definite to be {}", + row.description, row.pattern .split_ascii_whitespace() .collect::>() diff --git a/lib/src/query.c b/lib/src/query.c index 60c892d3..8464a691 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -1144,34 +1144,18 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { next_states = _states; } - // A query step is definite if the containing pattern will definitely match - // once the step is reached. In other words, a step is *not* definite if - // it's possible to create a syntax node that matches up to until that step, - // but does not match the entire pattern. - uint32_t child_step_index = parent_step_index + 1; - QueryStep *child_step = &self->steps.contents[child_step_index]; - while (child_step->depth == parent_depth + 1) { - // Check if there is any way for the pattern to reach this step, but fail - // to reach the end of the sub-pattern. - for (unsigned k = 0; k < final_step_indices.size; k++) { - uint32_t final_step_index = final_step_indices.contents[k]; - if ( - final_step_index >= child_step_index && - self->steps.contents[final_step_index].depth == child_step->depth - ) { - child_step->is_definite = false; - break; - } + // Mark as indefinite any step where a match terminated. + // Later, this property will be propagated to all of the step's predecessors. + for (unsigned j = 0; j < final_step_indices.size; j++) { + uint32_t final_step_index = final_step_indices.contents[j]; + QueryStep *step = &self->steps.contents[final_step_index]; + if ( + step->depth != PATTERN_DONE_MARKER && + step->depth > parent_depth && + !step->is_dead_end + ) { + step->is_definite = false; } - - // Advance to the next child step in this sub-pattern. - do { - child_step_index++; - child_step++; - } while ( - child_step->depth != PATTERN_DONE_MARKER && - child_step->depth > parent_depth + 1 - ); } // If this pattern cannot match, store the pattern index so that it can be @@ -1187,9 +1171,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { } } - // In order for a step to be definite, all of its child steps must be definite, - // and all of its later sibling steps must be definite. Propagate any indefiniteness - // upward and backward through the pattern trees. + // Mark as indefinite any step with captures that are used in predicates. Array(uint16_t) predicate_capture_ids = array_new(); for (unsigned i = 0; i < self->patterns.size; i++) { QueryPattern *pattern = &self->patterns.contents[i]; @@ -1207,16 +1189,13 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { } } - bool all_later_children_definite = true; + // Find all of the steps that have these captures. for ( unsigned start = pattern->steps.offset, end = start + pattern->steps.length, - j = end - 1; j + 1 > start; j-- + j = start; j < end; j++ ) { QueryStep *step = &self->steps.contents[j]; - - // If this step has a capture that is used in a predicate, - // then it is not definite. for (unsigned k = 0; k < MAX_STEP_CAPTURE_COUNT; k++) { uint16_t capture_id = step->capture_ids[k]; if (capture_id == NONE) break; @@ -1227,10 +1206,41 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { break; } } + } + } - // If a step is not definite, then none of its predecessors can be definite. - if (!all_later_children_definite) step->is_definite = false; - if (!step->is_definite) all_later_children_definite = false; + // Propagate indefiniteness backwards. + bool done = self->steps.size == 0; + while (!done) { + done = true; + for (unsigned i = self->steps.size - 1; i > 0; i--) { + QueryStep *step = &self->steps.contents[i]; + + // Determine if this step is definite or has definite alternatives. + bool is_definite = false; + for (;;) { + if (step->is_definite) { + is_definite = true; + break; + } + if (step->alternative_index == NONE || step->alternative_index < i) { + break; + } + step = &self->steps.contents[step->alternative_index]; + } + + // If not, mark its predecessor as indefinite. + if (!is_definite) { + QueryStep *prev_step = &self->steps.contents[i - 1]; + if ( + !prev_step->is_dead_end && + prev_step->depth != PATTERN_DONE_MARKER && + prev_step->is_definite + ) { + prev_step->is_definite = false; + done = false; + } + } } } @@ -1242,11 +1252,12 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { printf(" %u: DONE\n", i); } else { printf( - " %u: {symbol: %s, is_definite: %d}\n", + " %u: {symbol: %s, field: %s, is_definite: %d}\n", i, (step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL) ? "ANY" : ts_language_symbol_name(self->language, step->symbol), + (step->field ? ts_language_field_name_for_id(self->language, step->field) : "-"), step->is_definite ); } @@ -1979,7 +1990,7 @@ bool ts_query_step_is_definite( uint32_t step_index = UINT32_MAX; for (unsigned i = 0; i < self->step_offsets.size; i++) { StepOffset *step_offset = &self->step_offsets.contents[i]; - if (step_offset->byte_offset >= byte_offset) break; + if (step_offset->byte_offset > byte_offset) break; step_index = step_offset->step_index; } if (step_index < self->steps.size) { diff --git a/script/test b/script/test index bcc88e24..31e90226 100755 --- a/script/test +++ b/script/test @@ -83,10 +83,14 @@ done shift $(expr $OPTIND - 1) -if [[ -n $TREE_SITTER_TEST_LANGUAGE_FILTER || -n $TREE_SITTER_TEST_EXAMPLE_FILTER || -n $TREE_SITTER_TEST_TRIAL_FILTER ]]; then - top_level_filter=corpus -else - top_level_filter=$1 +top_level_filter=$1 + +if [[ \ + -n $TREE_SITTER_TEST_LANGUAGE_FILTER || \ + -n $TREE_SITTER_TEST_EXAMPLE_FILTER || \ + -n $TREE_SITTER_TEST_TRIAL_FILTER \ +]]; then + echo ${top_level_filter:=corpus} fi if [[ "${mode}" == "debug" ]]; then From 2eb04094f80048db6811e7238b8ed9b1f92c95ba Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 21 Aug 2020 14:12:04 -0700 Subject: [PATCH 22/26] Handle aliased parent nodes in query analysis --- cli/src/generate/render.rs | 221 ++++++++++++++++++++----------- cli/src/tests/query_test.rs | 8 ++ lib/include/tree_sitter/parser.h | 1 + lib/src/language.h | 27 ++++ lib/src/query.c | 80 ++++++----- 5 files changed, 225 insertions(+), 112 deletions(-) diff --git a/cli/src/generate/render.rs b/cli/src/generate/render.rs index 300ad383..5b016cb6 100644 --- a/cli/src/generate/render.rs +++ b/cli/src/generate/render.rs @@ -7,7 +7,7 @@ use super::tables::{ }; use core::ops::Range; use std::cmp; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{HashMap, HashSet}; use std::fmt::Write; use std::mem::swap; @@ -69,7 +69,8 @@ struct Generator { symbol_order: HashMap, symbol_ids: HashMap, alias_ids: HashMap, - alias_map: BTreeMap>, + unique_aliases: Vec, + symbol_map: HashMap, field_names: Vec, next_abi: bool, } @@ -108,6 +109,8 @@ impl Generator { self.add_alias_sequences(); } + self.add_non_terminal_alias_map(); + let mut main_lex_table = LexTable::default(); swap(&mut main_lex_table, &mut self.main_lex_table); self.add_lex_function("ts_lex", main_lex_table, true); @@ -159,13 +162,72 @@ impl Generator { format!("anon_alias_sym_{}", self.sanitize_identifier(&alias.value)) }; self.alias_ids.entry(alias.clone()).or_insert(alias_id); - self.alias_map - .entry(alias.clone()) - .or_insert(matching_symbol); } } } + self.unique_aliases = self + .alias_ids + .keys() + .filter(|alias| { + self.parse_table + .symbols + .iter() + .cloned() + .find(|symbol| { + let (name, kind) = self.metadata_for_symbol(*symbol); + name == alias.value && kind == alias.kind() + }) + .is_none() + }) + .cloned() + .collect(); + self.unique_aliases.sort_unstable(); + + self.symbol_map = self + .parse_table + .symbols + .iter() + .map(|symbol| { + let mut mapping = symbol; + + // There can be multiple symbols in the grammar that have the same name and kind, + // due to simple aliases. When that happens, ensure that they map to the same + // public-facing symbol. If one of the symbols is not aliased, choose that one + // to be the public-facing symbol. Otherwise, pick the symbol with the lowest + // numeric value. + if let Some(alias) = self.simple_aliases.get(symbol) { + let kind = alias.kind(); + for other_symbol in &self.parse_table.symbols { + if let Some(other_alias) = self.simple_aliases.get(other_symbol) { + if other_symbol < mapping && other_alias == alias { + mapping = other_symbol; + } + } else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) { + mapping = other_symbol; + break; + } + } + } + // Two anonymous tokens with different flags but the same string value + // should be represented with the same symbol in the public API. Examples: + // * "<" and token(prec(1, "<")) + // * "(" and token.immediate("(") + else if symbol.is_terminal() { + let metadata = self.metadata_for_symbol(*symbol); + for other_symbol in &self.parse_table.symbols { + let other_metadata = self.metadata_for_symbol(*other_symbol); + if other_metadata == metadata { + mapping = other_symbol; + break; + } + } + } + + (*symbol, *mapping) + }) + .collect(); + field_names.sort_unstable(); field_names.dedup(); self.field_names = field_names.into_iter().cloned().collect(); @@ -255,11 +317,7 @@ impl Generator { "#define SYMBOL_COUNT {}", self.parse_table.symbols.len() ); - add_line!( - self, - "#define ALIAS_COUNT {}", - self.alias_map.iter().filter(|e| e.1.is_none()).count() - ); + add_line!(self, "#define ALIAS_COUNT {}", self.unique_aliases.len(),); add_line!(self, "#define TOKEN_COUNT {}", token_count); add_line!( self, @@ -287,11 +345,9 @@ impl Generator { i += 1; } } - for (alias, symbol) in &self.alias_map { - if symbol.is_none() { - add_line!(self, "{} = {},", self.alias_ids[&alias], i); - i += 1; - } + for alias in &self.unique_aliases { + add_line!(self, "{} = {},", self.alias_ids[&alias], i); + i += 1; } dedent!(self); add_line!(self, "}};"); @@ -310,15 +366,13 @@ impl Generator { ); add_line!(self, "[{}] = \"{}\",", self.symbol_ids[&symbol], name); } - for (alias, symbol) in &self.alias_map { - if symbol.is_none() { - add_line!( - self, - "[{}] = \"{}\",", - self.alias_ids[&alias], - self.sanitize_string(&alias.value) - ); - } + for alias in &self.unique_aliases { + add_line!( + self, + "[{}] = \"{}\",", + self.alias_ids[&alias], + self.sanitize_string(&alias.value) + ); } dedent!(self); add_line!(self, "}};"); @@ -329,58 +383,21 @@ impl Generator { add_line!(self, "static TSSymbol ts_symbol_map[] = {{"); indent!(self); for symbol in &self.parse_table.symbols { - let mut mapping = symbol; - - // There can be multiple symbols in the grammar that have the same name and kind, - // due to simple aliases. When that happens, ensure that they map to the same - // public-facing symbol. If one of the symbols is not aliased, choose that one - // to be the public-facing symbol. Otherwise, pick the symbol with the lowest - // numeric value. - if let Some(alias) = self.simple_aliases.get(symbol) { - let kind = alias.kind(); - for other_symbol in &self.parse_table.symbols { - if let Some(other_alias) = self.simple_aliases.get(other_symbol) { - if other_symbol < mapping && other_alias == alias { - mapping = other_symbol; - } - } else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) { - mapping = other_symbol; - break; - } - } - } - // Two anonymous tokens with different flags but the same string value - // should be represented with the same symbol in the public API. Examples: - // * "<" and token(prec(1, "<")) - // * "(" and token.immediate("(") - else if symbol.is_terminal() { - let metadata = self.metadata_for_symbol(*symbol); - for other_symbol in &self.parse_table.symbols { - let other_metadata = self.metadata_for_symbol(*other_symbol); - if other_metadata == metadata { - mapping = other_symbol; - break; - } - } - } - add_line!( self, "[{}] = {},", - self.symbol_ids[&symbol], - self.symbol_ids[mapping], + self.symbol_ids[symbol], + self.symbol_ids[&self.symbol_map[symbol]], ); } - for (alias, symbol) in &self.alias_map { - if symbol.is_none() { - add_line!( - self, - "[{}] = {},", - self.alias_ids[&alias], - self.alias_ids[&alias], - ); - } + for alias in &self.unique_aliases { + add_line!( + self, + "[{}] = {},", + self.alias_ids[&alias], + self.alias_ids[&alias], + ); } dedent!(self); @@ -451,15 +468,13 @@ impl Generator { dedent!(self); add_line!(self, "}},"); } - for (alias, matching_symbol) in &self.alias_map { - if matching_symbol.is_none() { - add_line!(self, "[{}] = {{", self.alias_ids[&alias]); - indent!(self); - add_line!(self, ".visible = true,"); - add_line!(self, ".named = {},", alias.is_named); - dedent!(self); - add_line!(self, "}},"); - } + for alias in &self.unique_aliases { + add_line!(self, "[{}] = {{", self.alias_ids[&alias]); + indent!(self); + add_line!(self, ".visible = true,"); + add_line!(self, ".named = {},", alias.is_named); + dedent!(self); + add_line!(self, "}},"); } dedent!(self); add_line!(self, "}};"); @@ -498,6 +513,50 @@ impl Generator { add_line!(self, ""); } + fn add_non_terminal_alias_map(&mut self) { + let mut aliases_by_symbol = HashMap::new(); + for variable in &self.syntax_grammar.variables { + for production in &variable.productions { + for step in &production.steps { + if let Some(alias) = &step.alias { + if step.symbol.is_non_terminal() + && !self.simple_aliases.contains_key(&step.symbol) + { + if self.symbol_ids.contains_key(&step.symbol) { + let alias_ids = + aliases_by_symbol.entry(step.symbol).or_insert(Vec::new()); + if let Err(i) = alias_ids.binary_search(&alias) { + alias_ids.insert(i, alias); + } + } + } + } + } + } + } + + let mut aliases_by_symbol = aliases_by_symbol.iter().collect::>(); + aliases_by_symbol.sort_unstable_by_key(|e| e.0); + + add_line!(self, "static uint16_t ts_non_terminal_alias_map[] = {{"); + indent!(self); + for (symbol, aliases) in aliases_by_symbol { + let symbol_id = &self.symbol_ids[symbol]; + let public_symbol_id = &self.symbol_ids[&self.symbol_map[&symbol]]; + add_line!(self, "{}, {},", symbol_id, 1 + aliases.len()); + indent!(self); + add_line!(self, "{},", public_symbol_id); + for alias in aliases { + add_line!(self, "{},", &self.alias_ids[&alias]); + } + dedent!(self); + } + add_line!(self, "0,"); + dedent!(self); + add_line!(self, "}};"); + add_line!(self, ""); + } + fn add_field_sequences(&mut self) { let mut flat_field_maps = vec![]; let mut next_flat_field_map_index = 0; @@ -1207,6 +1266,7 @@ impl Generator { add_line!(self, ".large_state_count = LARGE_STATE_COUNT,"); if self.next_abi { + add_line!(self, ".alias_map = ts_non_terminal_alias_map,"); add_line!(self, ".state_count = STATE_COUNT,"); } @@ -1517,7 +1577,8 @@ pub(crate) fn render_c_code( symbol_ids: HashMap::new(), symbol_order: HashMap::new(), alias_ids: HashMap::new(), - alias_map: BTreeMap::new(), + symbol_map: HashMap::new(), + unique_aliases: Vec::new(), field_names: Vec::new(), next_abi, } diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 816c3aee..822fdd22 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2553,6 +2553,14 @@ fn test_query_step_is_definite() { ("arguments:", true), ], }, + Row { + description: "aliased parent node", + language: get_language("ruby"), + pattern: r#" + (method_parameters "(" (identifier) @id")") + "#, + results_by_substring: &[("\"(\"", false), ("(identifier)", false), ("\")\"", true)], + }, ]; allocations::record(|| { diff --git a/lib/include/tree_sitter/parser.h b/lib/include/tree_sitter/parser.h index 360e012f..84096132 100644 --- a/lib/include/tree_sitter/parser.h +++ b/lib/include/tree_sitter/parser.h @@ -119,6 +119,7 @@ struct TSLanguage { const uint16_t *small_parse_table; const uint32_t *small_parse_table_map; const TSSymbol *public_symbol_map; + const uint16_t *alias_map; uint32_t state_count; }; diff --git a/lib/src/language.h b/lib/src/language.h index f8fd1ae5..e5c07aa2 100644 --- a/lib/src/language.h +++ b/lib/src/language.h @@ -13,6 +13,7 @@ extern "C" { #define TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING 11 #define TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES 11 #define TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT 12 +#define TREE_SITTER_LANGUAGE_VERSION_WITH_ALIAS_MAP 12 typedef struct { const TSParseAction *actions; @@ -258,6 +259,32 @@ static inline void ts_language_field_map( *end = &self->field_map_entries[slice.index] + slice.length; } +static inline void ts_language_aliases_for_symbol( + const TSLanguage *self, + TSSymbol original_symbol, + const TSSymbol **start, + const TSSymbol **end +) { + *start = &self->public_symbol_map[original_symbol]; + *end = *start + 1; + + if (self->version < TREE_SITTER_LANGUAGE_VERSION_WITH_ALIAS_MAP) return; + + unsigned i = 0; + for (;;) { + TSSymbol symbol = self->alias_map[i++]; + if (symbol == 0 || symbol > original_symbol) break; + uint16_t count = self->alias_map[i++]; + if (symbol == original_symbol) { + *start = &self->alias_map[i]; + *end = &self->alias_map[i + count]; + break; + } + i += count; + } +} + + #ifdef __cplusplus } #endif diff --git a/lib/src/query.c b/lib/src/query.c index 8464a691..9f911438 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -788,24 +788,32 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { for (unsigned i = 0; i < lookahead_iterator.action_count; i++) { const TSParseAction *action = &lookahead_iterator.actions[i]; if (action->type == TSParseActionTypeReduce) { - TSSymbol symbol = self->language->public_symbol_map[action->params.reduce.symbol]; - array_search_sorted_by( - &subgraphs, - 0, - .symbol, - symbol, - &subgraph_index, - &exists + const TSSymbol *aliases, *aliases_end; + ts_language_aliases_for_symbol( + self->language, + action->params.reduce.symbol, + &aliases, + &aliases_end ); - if (exists) { - AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; - if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { - array_push(&subgraph->nodes, ((AnalysisSubgraphNode) { - .state = state, - .production_id = action->params.reduce.production_id, - .child_index = action->params.reduce.child_count, - .done = true, - })); + for (const TSSymbol *symbol = aliases; symbol < aliases_end; symbol++) { + array_search_sorted_by( + &subgraphs, + 0, + .symbol, + *symbol, + &subgraph_index, + &exists + ); + if (exists) { + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { + array_push(&subgraph->nodes, ((AnalysisSubgraphNode) { + .state = state, + .production_id = action->params.reduce.production_id, + .child_index = action->params.reduce.child_count, + .done = true, + })); + } } } } else if (action->type == TSParseActionTypeShift && !action->params.shift.extra) { @@ -815,22 +823,30 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { } } else if (lookahead_iterator.next_state != 0 && lookahead_iterator.next_state != state) { state_predecessor_map_add(&predecessor_map, lookahead_iterator.next_state, state); - TSSymbol symbol = self->language->public_symbol_map[lookahead_iterator.symbol]; - array_search_sorted_by( - &subgraphs, - 0, - .symbol, - symbol, - &subgraph_index, - &exists + const TSSymbol *aliases, *aliases_end; + ts_language_aliases_for_symbol( + self->language, + lookahead_iterator.symbol, + &aliases, + &aliases_end ); - if (exists) { - AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; - if ( - subgraph->start_states.size == 0 || - *array_back(&subgraph->start_states) != state - ) - array_push(&subgraph->start_states, state); + for (const TSSymbol *symbol = aliases; symbol < aliases_end; symbol++) { + array_search_sorted_by( + &subgraphs, + 0, + .symbol, + *symbol, + &subgraph_index, + &exists + ); + if (exists) { + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if ( + subgraph->start_states.size == 0 || + *array_back(&subgraph->start_states) != state + ) + array_push(&subgraph->start_states, state); + } } } } From 315f87bbff8b849734107cb6d0b8c66eea5d0276 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 24 Aug 2020 12:07:57 -0700 Subject: [PATCH 23/26] Remove unnecessary parameter from sorted array functions --- lib/src/array.h | 16 ++++++++-------- lib/src/query.c | 27 ++++++++++++--------------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/lib/src/array.h b/lib/src/array.h index 7fae7a40..7b2d42fe 100644 --- a/lib/src/array.h +++ b/lib/src/array.h @@ -87,23 +87,23 @@ extern "C" { #define _compare_int(a, b) ((int)*(a) - (int)(b)) -#define array_search_sorted_by(self, start, field, needle, index, exists) \ - array__search_sorted(self, start, _compare_int, field, needle, index, exists) +#define array_search_sorted_by(self, field, needle, index, exists) \ + array__search_sorted(self, 0, _compare_int, field, needle, index, exists) -#define array_search_sorted_with(self, start, compare, needle, index, exists) \ - array__search_sorted(self, start, compare, , needle, index, exists) +#define array_search_sorted_with(self, compare, needle, index, exists) \ + array__search_sorted(self, 0, compare, , needle, index, exists) -#define array_insert_sorted_by(self, start, field, value) \ +#define array_insert_sorted_by(self, field, value) \ do { \ unsigned index, exists; \ - array_search_sorted_by(self, start, field, (value) field, &index, &exists); \ + array_search_sorted_by(self, field, (value) field, &index, &exists); \ if (!exists) array_insert(self, index, value); \ } while (0) -#define array_insert_sorted_with(self, start, compare, value) \ +#define array_insert_sorted_with(self, compare, value) \ do { \ unsigned index, exists; \ - array_search_sorted_with(self, start, compare, &(value), &index, &exists); \ + array_search_sorted_with(self, compare, &(value), &index, &exists); \ if (!exists) array_insert(self, index, value); \ } while (0) diff --git a/lib/src/query.c b/lib/src/query.c index 9f911438..8a7e5ea2 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -764,12 +764,12 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { uint32_t parent_step_index = parent_step_indices.contents[i]; TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; AnalysisSubgraph subgraph = { .symbol = parent_symbol }; - array_insert_sorted_by(&subgraphs, 0, .symbol, subgraph); + array_insert_sorted_by(&subgraphs, .symbol, subgraph); } for (TSSymbol sym = self->language->token_count; sym < self->language->symbol_count; sym++) { if (!ts_language_symbol_metadata(self->language, sym).visible) { AnalysisSubgraph subgraph = { .symbol = sym }; - array_insert_sorted_by(&subgraphs, 0, .symbol, subgraph); + array_insert_sorted_by(&subgraphs, .symbol, subgraph); } } @@ -798,7 +798,6 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { for (const TSSymbol *symbol = aliases; symbol < aliases_end; symbol++) { array_search_sorted_by( &subgraphs, - 0, .symbol, *symbol, &subgraph_index, @@ -833,7 +832,6 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { for (const TSSymbol *symbol = aliases; symbol < aliases_end; symbol++) { array_search_sorted_by( &subgraphs, - 0, .symbol, *symbol, &subgraph_index, @@ -882,8 +880,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { }; unsigned index, exists; array_search_sorted_with( - &subgraph->nodes, 0, - analysis_subgraph_node__compare, &predecessor_node, + &subgraph->nodes, analysis_subgraph_node__compare, &predecessor_node, &index, &exists ); if (!exists) { @@ -930,7 +927,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { uint16_t parent_depth = self->steps.contents[parent_step_index].depth; TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; unsigned subgraph_index, exists; - array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); + array_search_sorted_by(&subgraphs, .symbol, parent_symbol, &subgraph_index, &exists); if (!exists) continue; AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; @@ -996,7 +993,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { if (next_states.size > 0) { int comparison = analysis_state__compare_position(state, array_back(&next_states)); if (comparison == 0) { - array_insert_sorted_with(&next_states, 0, analysis_state__compare, *state); + array_insert_sorted_with(&next_states, analysis_state__compare, *state); continue; } else if (comparison > 0) { while (j < states.size) { @@ -1014,7 +1011,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { const QueryStep * const step = &self->steps.contents[state->step_index]; unsigned subgraph_index, exists; - array_search_sorted_by(&subgraphs, 0, .symbol, parent_symbol, &subgraph_index, &exists); + array_search_sorted_by(&subgraphs, .symbol, parent_symbol, &subgraph_index, &exists); if (!exists) continue; const AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; @@ -1044,7 +1041,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { }; unsigned node_index; array_search_sorted_with( - &subgraph->nodes, 0, + &subgraph->nodes, analysis_subgraph_node__compare, &successor, &node_index, &exists ); @@ -1132,9 +1129,9 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { bool did_finish_pattern = self->steps.contents[next_state.step_index].depth != parent_depth + 1; if (did_finish_pattern) can_finish_pattern = true; if (next_state.depth > 0 && !did_finish_pattern) { - array_insert_sorted_with(&next_states, 0, analysis_state__compare, next_state); + array_insert_sorted_with(&next_states, analysis_state__compare, next_state); } else { - array_insert_sorted_by(&final_step_indices, 0, , next_state.step_index); + array_insert_sorted_by(&final_step_indices, , next_state.step_index); } } @@ -1180,7 +1177,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { assert(final_step_indices.size > 0); uint16_t *impossible_step_index = array_back(&final_step_indices); uint32_t i, exists; - array_search_sorted_by(&self->step_offsets, 0, .step_index, *impossible_step_index, &i, &exists); + array_search_sorted_by(&self->step_offsets, .step_index, *impossible_step_index, &i, &exists); assert(exists); *error_offset = self->step_offsets.contents[i].byte_offset; result = false; @@ -1201,7 +1198,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { ) { TSQueryPredicateStep *step = &self->predicate_steps.contents[j]; if (step->type == TSQueryPredicateStepTypeCapture) { - array_insert_sorted_by(&predicate_capture_ids, 0, , step->value_id); + array_insert_sorted_by(&predicate_capture_ids, , step->value_id); } } @@ -1216,7 +1213,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { uint16_t capture_id = step->capture_ids[k]; if (capture_id == NONE) break; unsigned index, exists; - array_search_sorted_by(&predicate_capture_ids, 0, , capture_id, &index, &exists); + array_search_sorted_by(&predicate_capture_ids, , capture_id, &index, &exists); if (exists) { step->is_definite = false; break; From 4aba684d6681278c82c3a472e0bead950da1ec9d Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 24 Aug 2020 15:53:05 -0700 Subject: [PATCH 24/26] Control recursion depth explicitly during query analysis --- cli/src/tests/query_test.rs | 49 +++++++++++++++++ lib/src/query.c | 107 ++++++++++++++++++++++++++++++------ 2 files changed, 138 insertions(+), 18 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 822fdd22..b857467b 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -268,6 +268,29 @@ fn test_query_errors_on_impossible_patterns() { .join("\n") )) ); + + assert_eq!( + Query::new(js_lang, "(identifier (identifier))",), + Err(QueryError::Structure( + 1, + [ + "(identifier (identifier))", // + " ^", + ] + .join("\n") + )) + ); + assert_eq!( + Query::new(js_lang, "(true (true))",), + Err(QueryError::Structure( + 1, + [ + "(true (true))", // + " ^", + ] + .join("\n") + )) + ); }); } @@ -2561,6 +2584,32 @@ fn test_query_step_is_definite() { "#, results_by_substring: &[("\"(\"", false), ("(identifier)", false), ("\")\"", true)], }, + Row { + description: "long, but not too long to analyze", + language: get_language("javascript"), + pattern: r#" + (object "{" (pair) (pair) (pair) (pair) "}") + "#, + results_by_substring: &[ + ("\"{\"", false), + ("(pair)", false), + ("(pair) \"}\"", false), + ("\"}\"", true), + ], + }, + Row { + description: "too long to analyze", + language: get_language("javascript"), + pattern: r#" + (object "{" (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) "}") + "#, + results_by_substring: &[ + ("\"{\"", false), + ("(pair)", false), + ("(pair) \"}\"", false), + ("\"}\"", false), + ], + }, ]; allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index 8a7e5ea2..85f71aa6 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -8,13 +8,14 @@ #include "./unicode.h" #include +// #define DEBUG_ANALYZE_QUERY // #define LOG(...) fprintf(stderr, __VA_ARGS__) #define LOG(...) #define MAX_CAPTURE_LIST_COUNT 32 #define MAX_STEP_CAPTURE_COUNT 3 #define MAX_STATE_PREDECESSOR_COUNT 100 -#define MAX_ANALYSIS_STATE_DEPTH 8 +#define MAX_ANALYSIS_STATE_DEPTH 12 /* * Stream - A sequence of unicode characters derived from a UTF8 string. @@ -170,6 +171,8 @@ typedef struct { uint16_t step_index; } AnalysisState; +typedef Array(AnalysisState) AnalysisStateSet; + /* * AnalysisSubgraph - A subset of the states in the parse table that are used * in constructing nodes with a certain symbol. Each state is accompanied by @@ -585,6 +588,20 @@ static inline const TSStateId *state_predecessor_map_get( * AnalysisState ****************/ +static unsigned analysis_state__recursion_depth(const AnalysisState *self) { + unsigned result = 0; + for (unsigned i = 0; i < self->depth; i++) { + TSSymbol symbol = self->stack[i].parent_symbol; + for (unsigned j = 0; j < i; j++) { + if (self->stack[j].parent_symbol == symbol) { + result++; + break; + } + } + } + return result; +} + static inline int analysis_state__compare_position( const AnalysisState *self, const AnalysisState *other @@ -726,8 +743,6 @@ static inline void ts_query__pattern_map_insert( })); } -// #define DEBUG_ANALYZE_QUERY - static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { // Identify all of the patterns in the query that have child patterns, both at the // top level and nested within other larger patterns. Record the step index where @@ -917,23 +932,35 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { // For each non-terminal pattern, determine if the pattern can successfully match, // and identify all of the possible children within the pattern where matching could fail. bool result = true; - typedef Array(AnalysisState) AnalysisStateList; - AnalysisStateList states = array_new(); - AnalysisStateList next_states = array_new(); + AnalysisStateSet states = array_new(); + AnalysisStateSet next_states = array_new(); + AnalysisStateSet deeper_states = array_new(); Array(uint16_t) final_step_indices = array_new(); for (unsigned i = 0; i < parent_step_indices.size; i++) { - // Find the subgraph that corresponds to this pattern's root symbol. uint16_t parent_step_index = parent_step_indices.contents[i]; uint16_t parent_depth = self->steps.contents[parent_step_index].depth; TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; + if (parent_symbol == ts_builtin_sym_error) continue; + + // Find the subgraph that corresponds to this pattern's root symbol. If the pattern's + // root symbols is not a non-terminal, then return an error. unsigned subgraph_index, exists; array_search_sorted_by(&subgraphs, .symbol, parent_symbol, &subgraph_index, &exists); - if (!exists) continue; - AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if (!exists) { + unsigned first_child_step_index = parent_step_index + 1; + uint32_t i, exists; + array_search_sorted_by(&self->step_offsets, .step_index, first_child_step_index, &i, &exists); + assert(exists); + *error_offset = self->step_offsets.contents[i].byte_offset; + result = false; + break; + } // Initialize an analysis state at every parse state in the table where // this parent symbol can occur. + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; array_clear(&states); + array_clear(&deeper_states); for (unsigned j = 0; j < subgraph->start_states.size; j++) { TSStateId parse_state = subgraph->start_states.contents[j]; array_push(&states, ((AnalysisState) { @@ -954,6 +981,9 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { // Walk the subgraph for this non-terminal, tracking all of the possible // sequences of progress within the pattern. bool can_finish_pattern = false; + bool did_exceed_max_depth = false; + unsigned recursion_depth_limit = 0; + unsigned prev_final_step_count = 0; array_clear(&final_step_indices); for (;;) { #ifdef DEBUG_ANALYZE_QUERY @@ -980,7 +1010,23 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { } #endif - if (states.size == 0) break; + if (states.size == 0) { + if (deeper_states.size > 0 && final_step_indices.size > prev_final_step_count) { + #ifdef DEBUG_ANALYZE_QUERY + printf("Increase recursion depth limit to %u\n", recursion_depth_limit + 1); + #endif + + prev_final_step_count = final_step_indices.size; + recursion_depth_limit++; + AnalysisStateSet _states = states; + states = deeper_states; + deeper_states = _states; + continue; + } + + break; + } + array_clear(&next_states); for (unsigned j = 0; j < states.size; j++) { AnalysisState * const state = &states.contents[j]; @@ -1091,13 +1137,23 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { // If this is a hidden child, then push a new entry to the stack, in order to // walk through the children of this child. - else if (sym >= self->language->token_count && next_state.depth < MAX_ANALYSIS_STATE_DEPTH) { + else if (sym >= self->language->token_count) { + if (next_state.depth + 1 >= MAX_ANALYSIS_STATE_DEPTH) { + did_exceed_max_depth = true; + continue; + } + next_state.depth++; analysis_state__top(&next_state)->parse_state = parse_state; analysis_state__top(&next_state)->child_index = 0; analysis_state__top(&next_state)->parent_symbol = sym; analysis_state__top(&next_state)->field_id = field_id; analysis_state__top(&next_state)->done = false; + + if (analysis_state__recursion_depth(&next_state) > recursion_depth_limit) { + array_insert_sorted_with(&deeper_states, analysis_state__compare, next_state); + continue; + } } else { continue; } @@ -1128,10 +1184,10 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { if (!next_step->is_dead_end) { bool did_finish_pattern = self->steps.contents[next_state.step_index].depth != parent_depth + 1; if (did_finish_pattern) can_finish_pattern = true; - if (next_state.depth > 0 && !did_finish_pattern) { - array_insert_sorted_with(&next_states, analysis_state__compare, next_state); - } else { + if (did_finish_pattern || next_state.depth == 0) { array_insert_sorted_by(&final_step_indices, , next_state.step_index); + } else { + array_insert_sorted_with(&next_states, analysis_state__compare, next_state); } } @@ -1152,7 +1208,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { } } - AnalysisStateList _states = states; + AnalysisStateSet _states = states; states = next_states; next_states = _states; } @@ -1171,16 +1227,30 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { } } + if (did_exceed_max_depth) { + for (unsigned j = parent_step_index + 1; j < self->steps.size; j++) { + QueryStep *step = &self->steps.contents[j]; + if ( + step->depth <= parent_depth || + step->depth == PATTERN_DONE_MARKER + ) break; + if (!step->is_dead_end) { + step->is_definite = false; + } + } + } + // If this pattern cannot match, store the pattern index so that it can be // returned to the caller. - if (result && !can_finish_pattern) { + if (result && !can_finish_pattern && !did_exceed_max_depth) { assert(final_step_indices.size > 0); - uint16_t *impossible_step_index = array_back(&final_step_indices); + uint16_t impossible_step_index = *array_back(&final_step_indices); uint32_t i, exists; - array_search_sorted_by(&self->step_offsets, .step_index, *impossible_step_index, &i, &exists); + array_search_sorted_by(&self->step_offsets, .step_index, impossible_step_index, &i, &exists); assert(exists); *error_offset = self->step_offsets.contents[i].byte_offset; result = false; + break; } } @@ -1286,6 +1356,7 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { array_delete(&next_nodes); array_delete(&states); array_delete(&next_states); + array_delete(&deeper_states); array_delete(&final_step_indices); array_delete(&parent_step_indices); array_delete(&predicate_capture_ids); From 4b9db41584069d4a6a0bc0776410212eb00820d5 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 2 Sep 2020 09:17:48 -0700 Subject: [PATCH 25/26] Remove unnecessary echo in test script --- script/test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/script/test b/script/test index 31e90226..9b578dcf 100755 --- a/script/test +++ b/script/test @@ -90,7 +90,7 @@ if [[ \ -n $TREE_SITTER_TEST_EXAMPLE_FILTER || \ -n $TREE_SITTER_TEST_TRIAL_FILTER \ ]]; then - echo ${top_level_filter:=corpus} + : ${top_level_filter:=corpus} fi if [[ "${mode}" == "debug" ]]; then From 31a22fc627e49003bf5410cebbda08808600b4ac Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 2 Sep 2020 09:59:26 -0700 Subject: [PATCH 26/26] In array.h, add comments and sort functions more logically --- lib/src/array.h | 101 +++++++++++++++++++++++++++++++----------------- 1 file changed, 65 insertions(+), 36 deletions(-) diff --git a/lib/src/array.h b/lib/src/array.h index 7b2d42fe..de8c8cb3 100644 --- a/lib/src/array.h +++ b/lib/src/array.h @@ -12,9 +12,9 @@ extern "C" { #include #include "./alloc.h" -#define Array(T) \ - struct { \ - T *contents; \ +#define Array(T) \ + struct { \ + T *contents; \ uint32_t size; \ uint32_t capacity; \ } @@ -37,15 +37,15 @@ extern "C" { #define array_reserve(self, new_capacity) \ array__reserve((VoidArray *)(self), array__elem_size(self), new_capacity) -#define array_erase(self, index) \ - array__erase((VoidArray *)(self), array__elem_size(self), index) - +// Free any memory allocated for this array. #define array_delete(self) array__delete((VoidArray *)self) #define array_push(self, element) \ (array__grow((VoidArray *)(self), 1, array__elem_size(self)), \ (self)->contents[(self)->size++] = (element)) +// Increase the array's size by a given number of elements, reallocating +// if necessary. New elements are zero-initialized. #define array_grow_by(self, count) \ (array__grow((VoidArray *)(self), count, array__elem_size(self)), \ memset((self)->contents + (self)->size, 0, (count) * array__elem_size(self)), \ @@ -54,52 +54,46 @@ extern "C" { #define array_push_all(self, other) \ array_splice((self), (self)->size, 0, (other)->size, (other)->contents) +// Remove `old_count` elements from the array starting at the given `index`. At +// the same index, insert `new_count` new elements, reading their values from the +// `new_contents` pointer. #define array_splice(self, index, old_count, new_count, new_contents) \ array__splice((VoidArray *)(self), array__elem_size(self), index, old_count, \ new_count, new_contents) +// Insert one `element` into the array at the given `index`. #define array_insert(self, index, element) \ array__splice((VoidArray *)(self), array__elem_size(self), index, 0, 1, &element) +// Remove one `element` from the array at the given `index`. +#define array_erase(self, index) \ + array__erase((VoidArray *)(self), array__elem_size(self), index) + #define array_pop(self) ((self)->contents[--(self)->size]) #define array_assign(self, other) \ array__assign((VoidArray *)(self), (const VoidArray *)(other), array__elem_size(self)) -#define array__search_sorted(self, start, compare, suffix, needle, index, exists) \ - do { \ - *(index) = start; \ - *(exists) = false; \ - uint32_t size = (self)->size - *(index); \ - if (size == 0) break; \ - int comparison; \ - while (size > 1) { \ - uint32_t half_size = size / 2; \ - uint32_t mid_index = *(index) + half_size; \ - comparison = compare(&((self)->contents[mid_index] suffix), (needle)); \ - if (comparison <= 0) *(index) = mid_index; \ - size -= half_size; \ - } \ - comparison = compare(&((self)->contents[*(index)] suffix), (needle)); \ - if (comparison == 0) *(exists) = true; \ - else if (comparison < 0) *(index) += 1; \ - } while (0) - -#define _compare_int(a, b) ((int)*(a) - (int)(b)) - -#define array_search_sorted_by(self, field, needle, index, exists) \ - array__search_sorted(self, 0, _compare_int, field, needle, index, exists) - +// Search a sorted array for a given `needle` value, using the given `compare` +// callback to determine the order. +// +// If an existing element is found to be equal to `needle`, then the `index` +// out-parameter is set to the existing value's index, and the `exists` +// out-parameter is set to true. Otherwise, `index` is set to an index where +// `needle` should be inserted in order to preserve the sorting, and `exists` +// is set to false. #define array_search_sorted_with(self, compare, needle, index, exists) \ array__search_sorted(self, 0, compare, , needle, index, exists) -#define array_insert_sorted_by(self, field, value) \ - do { \ - unsigned index, exists; \ - array_search_sorted_by(self, field, (value) field, &index, &exists); \ - if (!exists) array_insert(self, index, value); \ - } while (0) +// Search a sorted array for a given `needle` value, using integer comparisons +// of a given struct field (specified with a leading dot) to determine the order. +// +// See also `array_search_sorted_with`. +#define array_search_sorted_by(self, field, needle, index, exists) \ + array__search_sorted(self, 0, _compare_int, field, needle, index, exists) +// Insert a given `value` into a sorted array, using the given `compare` +// callback to determine the order. #define array_insert_sorted_with(self, compare, value) \ do { \ unsigned index, exists; \ @@ -107,6 +101,17 @@ extern "C" { if (!exists) array_insert(self, index, value); \ } while (0) +// Insert a given `value` into a sorted array, using integer comparisons of +// a given struct field (specified with a leading dot) to determine the order. +// +// See also `array_search_sorted_by`. +#define array_insert_sorted_by(self, field, value) \ + do { \ + unsigned index, exists; \ + array_search_sorted_by(self, field, (value) field, &index, &exists); \ + if (!exists) array_insert(self, index, value); \ + } while (0) + // Private typedef Array(void) VoidArray; @@ -192,6 +197,30 @@ static inline void array__splice(VoidArray *self, size_t element_size, self->size += new_count - old_count; } +// A binary search routine, based on Rust's `std::slice::binary_search_by`. +#define array__search_sorted(self, start, compare, suffix, needle, index, exists) \ + do { \ + *(index) = start; \ + *(exists) = false; \ + uint32_t size = (self)->size - *(index); \ + if (size == 0) break; \ + int comparison; \ + while (size > 1) { \ + uint32_t half_size = size / 2; \ + uint32_t mid_index = *(index) + half_size; \ + comparison = compare(&((self)->contents[mid_index] suffix), (needle)); \ + if (comparison <= 0) *(index) = mid_index; \ + size -= half_size; \ + } \ + comparison = compare(&((self)->contents[*(index)] suffix), (needle)); \ + if (comparison == 0) *(exists) = true; \ + else if (comparison < 0) *(index) += 1; \ + } while (0) + +// Helper macro for the `_sorted_by` routines below. This takes the left (existing) +// parameter by reference in order to work with the generic sorting function above. +#define _compare_int(a, b) ((int)*(a) - (int)(b)) + #ifdef __cplusplus } #endif