diff --git a/cli/src/generate/render.rs b/cli/src/generate/render.rs index 300ad383..5b016cb6 100644 --- a/cli/src/generate/render.rs +++ b/cli/src/generate/render.rs @@ -7,7 +7,7 @@ use super::tables::{ }; use core::ops::Range; use std::cmp; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{HashMap, HashSet}; use std::fmt::Write; use std::mem::swap; @@ -69,7 +69,8 @@ struct Generator { symbol_order: HashMap, symbol_ids: HashMap, alias_ids: HashMap, - alias_map: BTreeMap>, + unique_aliases: Vec, + symbol_map: HashMap, field_names: Vec, next_abi: bool, } @@ -108,6 +109,8 @@ impl Generator { self.add_alias_sequences(); } + self.add_non_terminal_alias_map(); + let mut main_lex_table = LexTable::default(); swap(&mut main_lex_table, &mut self.main_lex_table); self.add_lex_function("ts_lex", main_lex_table, true); @@ -159,13 +162,72 @@ impl Generator { format!("anon_alias_sym_{}", self.sanitize_identifier(&alias.value)) }; self.alias_ids.entry(alias.clone()).or_insert(alias_id); - self.alias_map - .entry(alias.clone()) - .or_insert(matching_symbol); } } } + self.unique_aliases = self + .alias_ids + .keys() + .filter(|alias| { + self.parse_table + .symbols + .iter() + .cloned() + .find(|symbol| { + let (name, kind) = self.metadata_for_symbol(*symbol); + name == alias.value && kind == alias.kind() + }) + .is_none() + }) + .cloned() + .collect(); + self.unique_aliases.sort_unstable(); + + self.symbol_map = self + .parse_table + .symbols + .iter() + .map(|symbol| { + let mut mapping = symbol; + + // There can be multiple symbols in the grammar that have the same name and kind, + // due to simple aliases. When that happens, ensure that they map to the same + // public-facing symbol. If one of the symbols is not aliased, choose that one + // to be the public-facing symbol. Otherwise, pick the symbol with the lowest + // numeric value. + if let Some(alias) = self.simple_aliases.get(symbol) { + let kind = alias.kind(); + for other_symbol in &self.parse_table.symbols { + if let Some(other_alias) = self.simple_aliases.get(other_symbol) { + if other_symbol < mapping && other_alias == alias { + mapping = other_symbol; + } + } else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) { + mapping = other_symbol; + break; + } + } + } + // Two anonymous tokens with different flags but the same string value + // should be represented with the same symbol in the public API. Examples: + // * "<" and token(prec(1, "<")) + // * "(" and token.immediate("(") + else if symbol.is_terminal() { + let metadata = self.metadata_for_symbol(*symbol); + for other_symbol in &self.parse_table.symbols { + let other_metadata = self.metadata_for_symbol(*other_symbol); + if other_metadata == metadata { + mapping = other_symbol; + break; + } + } + } + + (*symbol, *mapping) + }) + .collect(); + field_names.sort_unstable(); field_names.dedup(); self.field_names = field_names.into_iter().cloned().collect(); @@ -255,11 +317,7 @@ impl Generator { "#define SYMBOL_COUNT {}", self.parse_table.symbols.len() ); - add_line!( - self, - "#define ALIAS_COUNT {}", - self.alias_map.iter().filter(|e| e.1.is_none()).count() - ); + add_line!(self, "#define ALIAS_COUNT {}", self.unique_aliases.len(),); add_line!(self, "#define TOKEN_COUNT {}", token_count); add_line!( self, @@ -287,11 +345,9 @@ impl Generator { i += 1; } } - for (alias, symbol) in &self.alias_map { - if symbol.is_none() { - add_line!(self, "{} = {},", self.alias_ids[&alias], i); - i += 1; - } + for alias in &self.unique_aliases { + add_line!(self, "{} = {},", self.alias_ids[&alias], i); + i += 1; } dedent!(self); add_line!(self, "}};"); @@ -310,15 +366,13 @@ impl Generator { ); add_line!(self, "[{}] = \"{}\",", self.symbol_ids[&symbol], name); } - for (alias, symbol) in &self.alias_map { - if symbol.is_none() { - add_line!( - self, - "[{}] = \"{}\",", - self.alias_ids[&alias], - self.sanitize_string(&alias.value) - ); - } + for alias in &self.unique_aliases { + add_line!( + self, + "[{}] = \"{}\",", + self.alias_ids[&alias], + self.sanitize_string(&alias.value) + ); } dedent!(self); add_line!(self, "}};"); @@ -329,58 +383,21 @@ impl Generator { add_line!(self, "static TSSymbol ts_symbol_map[] = {{"); indent!(self); for symbol in &self.parse_table.symbols { - let mut mapping = symbol; - - // There can be multiple symbols in the grammar that have the same name and kind, - // due to simple aliases. When that happens, ensure that they map to the same - // public-facing symbol. If one of the symbols is not aliased, choose that one - // to be the public-facing symbol. Otherwise, pick the symbol with the lowest - // numeric value. - if let Some(alias) = self.simple_aliases.get(symbol) { - let kind = alias.kind(); - for other_symbol in &self.parse_table.symbols { - if let Some(other_alias) = self.simple_aliases.get(other_symbol) { - if other_symbol < mapping && other_alias == alias { - mapping = other_symbol; - } - } else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) { - mapping = other_symbol; - break; - } - } - } - // Two anonymous tokens with different flags but the same string value - // should be represented with the same symbol in the public API. Examples: - // * "<" and token(prec(1, "<")) - // * "(" and token.immediate("(") - else if symbol.is_terminal() { - let metadata = self.metadata_for_symbol(*symbol); - for other_symbol in &self.parse_table.symbols { - let other_metadata = self.metadata_for_symbol(*other_symbol); - if other_metadata == metadata { - mapping = other_symbol; - break; - } - } - } - add_line!( self, "[{}] = {},", - self.symbol_ids[&symbol], - self.symbol_ids[mapping], + self.symbol_ids[symbol], + self.symbol_ids[&self.symbol_map[symbol]], ); } - for (alias, symbol) in &self.alias_map { - if symbol.is_none() { - add_line!( - self, - "[{}] = {},", - self.alias_ids[&alias], - self.alias_ids[&alias], - ); - } + for alias in &self.unique_aliases { + add_line!( + self, + "[{}] = {},", + self.alias_ids[&alias], + self.alias_ids[&alias], + ); } dedent!(self); @@ -451,15 +468,13 @@ impl Generator { dedent!(self); add_line!(self, "}},"); } - for (alias, matching_symbol) in &self.alias_map { - if matching_symbol.is_none() { - add_line!(self, "[{}] = {{", self.alias_ids[&alias]); - indent!(self); - add_line!(self, ".visible = true,"); - add_line!(self, ".named = {},", alias.is_named); - dedent!(self); - add_line!(self, "}},"); - } + for alias in &self.unique_aliases { + add_line!(self, "[{}] = {{", self.alias_ids[&alias]); + indent!(self); + add_line!(self, ".visible = true,"); + add_line!(self, ".named = {},", alias.is_named); + dedent!(self); + add_line!(self, "}},"); } dedent!(self); add_line!(self, "}};"); @@ -498,6 +513,50 @@ impl Generator { add_line!(self, ""); } + fn add_non_terminal_alias_map(&mut self) { + let mut aliases_by_symbol = HashMap::new(); + for variable in &self.syntax_grammar.variables { + for production in &variable.productions { + for step in &production.steps { + if let Some(alias) = &step.alias { + if step.symbol.is_non_terminal() + && !self.simple_aliases.contains_key(&step.symbol) + { + if self.symbol_ids.contains_key(&step.symbol) { + let alias_ids = + aliases_by_symbol.entry(step.symbol).or_insert(Vec::new()); + if let Err(i) = alias_ids.binary_search(&alias) { + alias_ids.insert(i, alias); + } + } + } + } + } + } + } + + let mut aliases_by_symbol = aliases_by_symbol.iter().collect::>(); + aliases_by_symbol.sort_unstable_by_key(|e| e.0); + + add_line!(self, "static uint16_t ts_non_terminal_alias_map[] = {{"); + indent!(self); + for (symbol, aliases) in aliases_by_symbol { + let symbol_id = &self.symbol_ids[symbol]; + let public_symbol_id = &self.symbol_ids[&self.symbol_map[&symbol]]; + add_line!(self, "{}, {},", symbol_id, 1 + aliases.len()); + indent!(self); + add_line!(self, "{},", public_symbol_id); + for alias in aliases { + add_line!(self, "{},", &self.alias_ids[&alias]); + } + dedent!(self); + } + add_line!(self, "0,"); + dedent!(self); + add_line!(self, "}};"); + add_line!(self, ""); + } + fn add_field_sequences(&mut self) { let mut flat_field_maps = vec![]; let mut next_flat_field_map_index = 0; @@ -1207,6 +1266,7 @@ impl Generator { add_line!(self, ".large_state_count = LARGE_STATE_COUNT,"); if self.next_abi { + add_line!(self, ".alias_map = ts_non_terminal_alias_map,"); add_line!(self, ".state_count = STATE_COUNT,"); } @@ -1517,7 +1577,8 @@ pub(crate) fn render_c_code( symbol_ids: HashMap::new(), symbol_order: HashMap::new(), alias_ids: HashMap::new(), - alias_map: BTreeMap::new(), + symbol_map: HashMap::new(), + unique_aliases: Vec::new(), field_names: Vec::new(), next_abi, } diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 816c3aee..822fdd22 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2553,6 +2553,14 @@ fn test_query_step_is_definite() { ("arguments:", true), ], }, + Row { + description: "aliased parent node", + language: get_language("ruby"), + pattern: r#" + (method_parameters "(" (identifier) @id")") + "#, + results_by_substring: &[("\"(\"", false), ("(identifier)", false), ("\")\"", true)], + }, ]; allocations::record(|| { diff --git a/lib/include/tree_sitter/parser.h b/lib/include/tree_sitter/parser.h index 360e012f..84096132 100644 --- a/lib/include/tree_sitter/parser.h +++ b/lib/include/tree_sitter/parser.h @@ -119,6 +119,7 @@ struct TSLanguage { const uint16_t *small_parse_table; const uint32_t *small_parse_table_map; const TSSymbol *public_symbol_map; + const uint16_t *alias_map; uint32_t state_count; }; diff --git a/lib/src/language.h b/lib/src/language.h index f8fd1ae5..e5c07aa2 100644 --- a/lib/src/language.h +++ b/lib/src/language.h @@ -13,6 +13,7 @@ extern "C" { #define TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING 11 #define TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES 11 #define TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT 12 +#define TREE_SITTER_LANGUAGE_VERSION_WITH_ALIAS_MAP 12 typedef struct { const TSParseAction *actions; @@ -258,6 +259,32 @@ static inline void ts_language_field_map( *end = &self->field_map_entries[slice.index] + slice.length; } +static inline void ts_language_aliases_for_symbol( + const TSLanguage *self, + TSSymbol original_symbol, + const TSSymbol **start, + const TSSymbol **end +) { + *start = &self->public_symbol_map[original_symbol]; + *end = *start + 1; + + if (self->version < TREE_SITTER_LANGUAGE_VERSION_WITH_ALIAS_MAP) return; + + unsigned i = 0; + for (;;) { + TSSymbol symbol = self->alias_map[i++]; + if (symbol == 0 || symbol > original_symbol) break; + uint16_t count = self->alias_map[i++]; + if (symbol == original_symbol) { + *start = &self->alias_map[i]; + *end = &self->alias_map[i + count]; + break; + } + i += count; + } +} + + #ifdef __cplusplus } #endif diff --git a/lib/src/query.c b/lib/src/query.c index 8464a691..9f911438 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -788,24 +788,32 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { for (unsigned i = 0; i < lookahead_iterator.action_count; i++) { const TSParseAction *action = &lookahead_iterator.actions[i]; if (action->type == TSParseActionTypeReduce) { - TSSymbol symbol = self->language->public_symbol_map[action->params.reduce.symbol]; - array_search_sorted_by( - &subgraphs, - 0, - .symbol, - symbol, - &subgraph_index, - &exists + const TSSymbol *aliases, *aliases_end; + ts_language_aliases_for_symbol( + self->language, + action->params.reduce.symbol, + &aliases, + &aliases_end ); - if (exists) { - AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; - if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { - array_push(&subgraph->nodes, ((AnalysisSubgraphNode) { - .state = state, - .production_id = action->params.reduce.production_id, - .child_index = action->params.reduce.child_count, - .done = true, - })); + for (const TSSymbol *symbol = aliases; symbol < aliases_end; symbol++) { + array_search_sorted_by( + &subgraphs, + 0, + .symbol, + *symbol, + &subgraph_index, + &exists + ); + if (exists) { + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { + array_push(&subgraph->nodes, ((AnalysisSubgraphNode) { + .state = state, + .production_id = action->params.reduce.production_id, + .child_index = action->params.reduce.child_count, + .done = true, + })); + } } } } else if (action->type == TSParseActionTypeShift && !action->params.shift.extra) { @@ -815,22 +823,30 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { } } else if (lookahead_iterator.next_state != 0 && lookahead_iterator.next_state != state) { state_predecessor_map_add(&predecessor_map, lookahead_iterator.next_state, state); - TSSymbol symbol = self->language->public_symbol_map[lookahead_iterator.symbol]; - array_search_sorted_by( - &subgraphs, - 0, - .symbol, - symbol, - &subgraph_index, - &exists + const TSSymbol *aliases, *aliases_end; + ts_language_aliases_for_symbol( + self->language, + lookahead_iterator.symbol, + &aliases, + &aliases_end ); - if (exists) { - AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; - if ( - subgraph->start_states.size == 0 || - *array_back(&subgraph->start_states) != state - ) - array_push(&subgraph->start_states, state); + for (const TSSymbol *symbol = aliases; symbol < aliases_end; symbol++) { + array_search_sorted_by( + &subgraphs, + 0, + .symbol, + *symbol, + &subgraph_index, + &exists + ); + if (exists) { + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if ( + subgraph->start_states.size == 0 || + *array_back(&subgraph->start_states) != state + ) + array_push(&subgraph->start_states, state); + } } } }