Handle aliased parent nodes in query analysis

This commit is contained in:
Max Brunsfeld 2020-08-21 14:12:04 -07:00
parent 456b1f6771
commit 2eb04094f8
5 changed files with 225 additions and 112 deletions

View file

@ -7,7 +7,7 @@ use super::tables::{
};
use core::ops::Range;
use std::cmp;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use std::mem::swap;
@ -69,7 +69,8 @@ struct Generator {
symbol_order: HashMap<Symbol, usize>,
symbol_ids: HashMap<Symbol, String>,
alias_ids: HashMap<Alias, String>,
alias_map: BTreeMap<Alias, Option<Symbol>>,
unique_aliases: Vec<Alias>,
symbol_map: HashMap<Symbol, Symbol>,
field_names: Vec<String>,
next_abi: bool,
}
@ -108,6 +109,8 @@ impl Generator {
self.add_alias_sequences();
}
self.add_non_terminal_alias_map();
let mut main_lex_table = LexTable::default();
swap(&mut main_lex_table, &mut self.main_lex_table);
self.add_lex_function("ts_lex", main_lex_table, true);
@ -159,13 +162,72 @@ impl Generator {
format!("anon_alias_sym_{}", self.sanitize_identifier(&alias.value))
};
self.alias_ids.entry(alias.clone()).or_insert(alias_id);
self.alias_map
.entry(alias.clone())
.or_insert(matching_symbol);
}
}
}
self.unique_aliases = self
.alias_ids
.keys()
.filter(|alias| {
self.parse_table
.symbols
.iter()
.cloned()
.find(|symbol| {
let (name, kind) = self.metadata_for_symbol(*symbol);
name == alias.value && kind == alias.kind()
})
.is_none()
})
.cloned()
.collect();
self.unique_aliases.sort_unstable();
self.symbol_map = self
.parse_table
.symbols
.iter()
.map(|symbol| {
let mut mapping = symbol;
// There can be multiple symbols in the grammar that have the same name and kind,
// due to simple aliases. When that happens, ensure that they map to the same
// public-facing symbol. If one of the symbols is not aliased, choose that one
// to be the public-facing symbol. Otherwise, pick the symbol with the lowest
// numeric value.
if let Some(alias) = self.simple_aliases.get(symbol) {
let kind = alias.kind();
for other_symbol in &self.parse_table.symbols {
if let Some(other_alias) = self.simple_aliases.get(other_symbol) {
if other_symbol < mapping && other_alias == alias {
mapping = other_symbol;
}
} else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) {
mapping = other_symbol;
break;
}
}
}
// Two anonymous tokens with different flags but the same string value
// should be represented with the same symbol in the public API. Examples:
// * "<" and token(prec(1, "<"))
// * "(" and token.immediate("(")
else if symbol.is_terminal() {
let metadata = self.metadata_for_symbol(*symbol);
for other_symbol in &self.parse_table.symbols {
let other_metadata = self.metadata_for_symbol(*other_symbol);
if other_metadata == metadata {
mapping = other_symbol;
break;
}
}
}
(*symbol, *mapping)
})
.collect();
field_names.sort_unstable();
field_names.dedup();
self.field_names = field_names.into_iter().cloned().collect();
@ -255,11 +317,7 @@ impl Generator {
"#define SYMBOL_COUNT {}",
self.parse_table.symbols.len()
);
add_line!(
self,
"#define ALIAS_COUNT {}",
self.alias_map.iter().filter(|e| e.1.is_none()).count()
);
add_line!(self, "#define ALIAS_COUNT {}", self.unique_aliases.len(),);
add_line!(self, "#define TOKEN_COUNT {}", token_count);
add_line!(
self,
@ -287,11 +345,9 @@ impl Generator {
i += 1;
}
}
for (alias, symbol) in &self.alias_map {
if symbol.is_none() {
add_line!(self, "{} = {},", self.alias_ids[&alias], i);
i += 1;
}
for alias in &self.unique_aliases {
add_line!(self, "{} = {},", self.alias_ids[&alias], i);
i += 1;
}
dedent!(self);
add_line!(self, "}};");
@ -310,15 +366,13 @@ impl Generator {
);
add_line!(self, "[{}] = \"{}\",", self.symbol_ids[&symbol], name);
}
for (alias, symbol) in &self.alias_map {
if symbol.is_none() {
add_line!(
self,
"[{}] = \"{}\",",
self.alias_ids[&alias],
self.sanitize_string(&alias.value)
);
}
for alias in &self.unique_aliases {
add_line!(
self,
"[{}] = \"{}\",",
self.alias_ids[&alias],
self.sanitize_string(&alias.value)
);
}
dedent!(self);
add_line!(self, "}};");
@ -329,58 +383,21 @@ impl Generator {
add_line!(self, "static TSSymbol ts_symbol_map[] = {{");
indent!(self);
for symbol in &self.parse_table.symbols {
let mut mapping = symbol;
// There can be multiple symbols in the grammar that have the same name and kind,
// due to simple aliases. When that happens, ensure that they map to the same
// public-facing symbol. If one of the symbols is not aliased, choose that one
// to be the public-facing symbol. Otherwise, pick the symbol with the lowest
// numeric value.
if let Some(alias) = self.simple_aliases.get(symbol) {
let kind = alias.kind();
for other_symbol in &self.parse_table.symbols {
if let Some(other_alias) = self.simple_aliases.get(other_symbol) {
if other_symbol < mapping && other_alias == alias {
mapping = other_symbol;
}
} else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) {
mapping = other_symbol;
break;
}
}
}
// Two anonymous tokens with different flags but the same string value
// should be represented with the same symbol in the public API. Examples:
// * "<" and token(prec(1, "<"))
// * "(" and token.immediate("(")
else if symbol.is_terminal() {
let metadata = self.metadata_for_symbol(*symbol);
for other_symbol in &self.parse_table.symbols {
let other_metadata = self.metadata_for_symbol(*other_symbol);
if other_metadata == metadata {
mapping = other_symbol;
break;
}
}
}
add_line!(
self,
"[{}] = {},",
self.symbol_ids[&symbol],
self.symbol_ids[mapping],
self.symbol_ids[symbol],
self.symbol_ids[&self.symbol_map[symbol]],
);
}
for (alias, symbol) in &self.alias_map {
if symbol.is_none() {
add_line!(
self,
"[{}] = {},",
self.alias_ids[&alias],
self.alias_ids[&alias],
);
}
for alias in &self.unique_aliases {
add_line!(
self,
"[{}] = {},",
self.alias_ids[&alias],
self.alias_ids[&alias],
);
}
dedent!(self);
@ -451,15 +468,13 @@ impl Generator {
dedent!(self);
add_line!(self, "}},");
}
for (alias, matching_symbol) in &self.alias_map {
if matching_symbol.is_none() {
add_line!(self, "[{}] = {{", self.alias_ids[&alias]);
indent!(self);
add_line!(self, ".visible = true,");
add_line!(self, ".named = {},", alias.is_named);
dedent!(self);
add_line!(self, "}},");
}
for alias in &self.unique_aliases {
add_line!(self, "[{}] = {{", self.alias_ids[&alias]);
indent!(self);
add_line!(self, ".visible = true,");
add_line!(self, ".named = {},", alias.is_named);
dedent!(self);
add_line!(self, "}},");
}
dedent!(self);
add_line!(self, "}};");
@ -498,6 +513,50 @@ impl Generator {
add_line!(self, "");
}
fn add_non_terminal_alias_map(&mut self) {
let mut aliases_by_symbol = HashMap::new();
for variable in &self.syntax_grammar.variables {
for production in &variable.productions {
for step in &production.steps {
if let Some(alias) = &step.alias {
if step.symbol.is_non_terminal()
&& !self.simple_aliases.contains_key(&step.symbol)
{
if self.symbol_ids.contains_key(&step.symbol) {
let alias_ids =
aliases_by_symbol.entry(step.symbol).or_insert(Vec::new());
if let Err(i) = alias_ids.binary_search(&alias) {
alias_ids.insert(i, alias);
}
}
}
}
}
}
}
let mut aliases_by_symbol = aliases_by_symbol.iter().collect::<Vec<_>>();
aliases_by_symbol.sort_unstable_by_key(|e| e.0);
add_line!(self, "static uint16_t ts_non_terminal_alias_map[] = {{");
indent!(self);
for (symbol, aliases) in aliases_by_symbol {
let symbol_id = &self.symbol_ids[symbol];
let public_symbol_id = &self.symbol_ids[&self.symbol_map[&symbol]];
add_line!(self, "{}, {},", symbol_id, 1 + aliases.len());
indent!(self);
add_line!(self, "{},", public_symbol_id);
for alias in aliases {
add_line!(self, "{},", &self.alias_ids[&alias]);
}
dedent!(self);
}
add_line!(self, "0,");
dedent!(self);
add_line!(self, "}};");
add_line!(self, "");
}
fn add_field_sequences(&mut self) {
let mut flat_field_maps = vec![];
let mut next_flat_field_map_index = 0;
@ -1207,6 +1266,7 @@ impl Generator {
add_line!(self, ".large_state_count = LARGE_STATE_COUNT,");
if self.next_abi {
add_line!(self, ".alias_map = ts_non_terminal_alias_map,");
add_line!(self, ".state_count = STATE_COUNT,");
}
@ -1517,7 +1577,8 @@ pub(crate) fn render_c_code(
symbol_ids: HashMap::new(),
symbol_order: HashMap::new(),
alias_ids: HashMap::new(),
alias_map: BTreeMap::new(),
symbol_map: HashMap::new(),
unique_aliases: Vec::new(),
field_names: Vec::new(),
next_abi,
}

View file

@ -2553,6 +2553,14 @@ fn test_query_step_is_definite() {
("arguments:", true),
],
},
Row {
description: "aliased parent node",
language: get_language("ruby"),
pattern: r#"
(method_parameters "(" (identifier) @id")")
"#,
results_by_substring: &[("\"(\"", false), ("(identifier)", false), ("\")\"", true)],
},
];
allocations::record(|| {

View file

@ -119,6 +119,7 @@ struct TSLanguage {
const uint16_t *small_parse_table;
const uint32_t *small_parse_table_map;
const TSSymbol *public_symbol_map;
const uint16_t *alias_map;
uint32_t state_count;
};

View file

@ -13,6 +13,7 @@ extern "C" {
#define TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING 11
#define TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES 11
#define TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT 12
#define TREE_SITTER_LANGUAGE_VERSION_WITH_ALIAS_MAP 12
typedef struct {
const TSParseAction *actions;
@ -258,6 +259,32 @@ static inline void ts_language_field_map(
*end = &self->field_map_entries[slice.index] + slice.length;
}
static inline void ts_language_aliases_for_symbol(
const TSLanguage *self,
TSSymbol original_symbol,
const TSSymbol **start,
const TSSymbol **end
) {
*start = &self->public_symbol_map[original_symbol];
*end = *start + 1;
if (self->version < TREE_SITTER_LANGUAGE_VERSION_WITH_ALIAS_MAP) return;
unsigned i = 0;
for (;;) {
TSSymbol symbol = self->alias_map[i++];
if (symbol == 0 || symbol > original_symbol) break;
uint16_t count = self->alias_map[i++];
if (symbol == original_symbol) {
*start = &self->alias_map[i];
*end = &self->alias_map[i + count];
break;
}
i += count;
}
}
#ifdef __cplusplus
}
#endif

View file

@ -788,24 +788,32 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) {
for (unsigned i = 0; i < lookahead_iterator.action_count; i++) {
const TSParseAction *action = &lookahead_iterator.actions[i];
if (action->type == TSParseActionTypeReduce) {
TSSymbol symbol = self->language->public_symbol_map[action->params.reduce.symbol];
array_search_sorted_by(
&subgraphs,
0,
.symbol,
symbol,
&subgraph_index,
&exists
const TSSymbol *aliases, *aliases_end;
ts_language_aliases_for_symbol(
self->language,
action->params.reduce.symbol,
&aliases,
&aliases_end
);
if (exists) {
AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index];
if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) {
array_push(&subgraph->nodes, ((AnalysisSubgraphNode) {
.state = state,
.production_id = action->params.reduce.production_id,
.child_index = action->params.reduce.child_count,
.done = true,
}));
for (const TSSymbol *symbol = aliases; symbol < aliases_end; symbol++) {
array_search_sorted_by(
&subgraphs,
0,
.symbol,
*symbol,
&subgraph_index,
&exists
);
if (exists) {
AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index];
if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) {
array_push(&subgraph->nodes, ((AnalysisSubgraphNode) {
.state = state,
.production_id = action->params.reduce.production_id,
.child_index = action->params.reduce.child_count,
.done = true,
}));
}
}
}
} else if (action->type == TSParseActionTypeShift && !action->params.shift.extra) {
@ -815,22 +823,30 @@ static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) {
}
} else if (lookahead_iterator.next_state != 0 && lookahead_iterator.next_state != state) {
state_predecessor_map_add(&predecessor_map, lookahead_iterator.next_state, state);
TSSymbol symbol = self->language->public_symbol_map[lookahead_iterator.symbol];
array_search_sorted_by(
&subgraphs,
0,
.symbol,
symbol,
&subgraph_index,
&exists
const TSSymbol *aliases, *aliases_end;
ts_language_aliases_for_symbol(
self->language,
lookahead_iterator.symbol,
&aliases,
&aliases_end
);
if (exists) {
AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index];
if (
subgraph->start_states.size == 0 ||
*array_back(&subgraph->start_states) != state
)
array_push(&subgraph->start_states, state);
for (const TSSymbol *symbol = aliases; symbol < aliases_end; symbol++) {
array_search_sorted_by(
&subgraphs,
0,
.symbol,
*symbol,
&subgraph_index,
&exists
);
if (exists) {
AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index];
if (
subgraph->start_states.size == 0 ||
*array_back(&subgraph->start_states) != state
)
array_push(&subgraph->start_states, state);
}
}
}
}