feat: add Supertype API

Introduces a new function that takes in a supertype symbol and returns
all associated subtypes. Can be used by query.c to give better errors
for invalid subtypes, as well as downstream applications like the query
LSP to give better diagnostics.
This commit is contained in:
Riley Bruins 2024-11-12 11:43:00 -08:00 committed by Amaan Qureshi
parent 07c08432ca
commit 19482834bd
11 changed files with 459 additions and 78 deletions

View file

@ -1,6 +1,6 @@
use std::{
cmp,
collections::{HashMap, HashSet},
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
fmt::Write,
mem::swap,
};
@ -9,6 +9,7 @@ use super::{
build_tables::Tables,
grammars::{ExternalToken, LexicalGrammar, SyntaxGrammar, VariableType},
nfa::CharacterSet,
node_types::ChildType,
rules::{Alias, AliasMap, Symbol, SymbolType, TokenSet},
tables::{
AdvanceAction, FieldLocation, GotoAction, LexState, LexTable, ParseAction, ParseTable,
@ -80,6 +81,8 @@ struct Generator {
reserved_word_sets: Vec<TokenSet>,
reserved_word_set_ids_by_parse_state: Vec<usize>,
field_names: Vec<String>,
supertype_symbol_map: BTreeMap<Symbol, Vec<ChildType>>,
supertype_map: BTreeMap<String, Vec<ChildType>>,
#[allow(unused)]
abi_version: usize,
@ -115,6 +118,10 @@ impl Generator {
self.add_non_terminal_alias_map();
self.add_primary_state_id_list();
if self.abi_version >= ABI_VERSION_WITH_RESERVED_WORDS && !self.supertype_map.is_empty() {
self.add_supertype_map();
}
let buffer_offset_before_lex_functions = self.buffer.len();
let mut main_lex_table = LexTable::default();
@ -224,33 +231,24 @@ impl Generator {
for alias in &production_info.alias_sequence {
// Generate a mapping from aliases to C identifiers.
if let Some(alias) = &alias {
let existing_symbol = self.parse_table.symbols.iter().copied().find(|symbol| {
self.default_aliases.get(symbol).map_or_else(
|| {
let (name, kind) = self.metadata_for_symbol(*symbol);
name == alias.value && kind == alias.kind()
},
|default_alias| default_alias == alias,
)
});
// Some aliases match an existing symbol in the grammar.
let alias_id = if let Some(existing_symbol) = existing_symbol {
self.symbol_ids[&self.symbol_map[&existing_symbol]].clone()
}
// Other aliases don't match any existing symbol, and need their own
// identifiers.
else {
if let Err(i) = self.unique_aliases.binary_search(alias) {
self.unique_aliases.insert(i, alias.clone());
let alias_id =
if let Some(existing_symbol) = self.symbols_for_alias(alias).first() {
self.symbol_ids[&self.symbol_map[existing_symbol]].clone()
}
// Other aliases don't match any existing symbol, and need their own
// identifiers.
else {
if let Err(i) = self.unique_aliases.binary_search(alias) {
self.unique_aliases.insert(i, alias.clone());
}
if alias.is_named {
format!("alias_sym_{}", self.sanitize_identifier(&alias.value))
} else {
format!("anon_alias_sym_{}", self.sanitize_identifier(&alias.value))
}
};
if alias.is_named {
format!("alias_sym_{}", self.sanitize_identifier(&alias.value))
} else {
format!("anon_alias_sym_{}", self.sanitize_identifier(&alias.value))
}
};
self.alias_ids.entry(alias.clone()).or_insert(alias_id);
}
@ -290,6 +288,18 @@ impl Generator {
self.reserved_word_set_ids_by_parse_state.push(id);
}
if self.abi_version >= ABI_VERSION_WITH_RESERVED_WORDS {
for (supertype, subtypes) in &self.supertype_symbol_map {
if let Some(supertype) = self.symbol_ids.get(supertype) {
self.supertype_map
.entry(supertype.clone())
.or_insert_with(|| subtypes.clone());
}
}
self.supertype_symbol_map.clear();
}
// Determine which states should use the "small state" representation, and which should
// use the normal array representation.
let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2);
@ -404,6 +414,7 @@ impl Generator {
"#define PRODUCTION_ID_COUNT {}",
self.parse_table.production_infos.len()
);
add_line!(self, "#define SUPERTYPE_COUNT {}", self.supertype_map.len());
add_line!(self, "");
}
@ -689,7 +700,7 @@ impl Generator {
add_line!(
self,
"static const TSFieldMapSlice ts_field_map_slices[PRODUCTION_ID_COUNT] = {{",
"static const TSMapSlice ts_field_map_slices[PRODUCTION_ID_COUNT] = {{",
);
indent!(self);
for (production_id, (row_id, length)) in field_map_ids.into_iter().enumerate() {
@ -728,6 +739,83 @@ impl Generator {
add_line!(self, "");
}
fn add_supertype_map(&mut self) {
add_line!(
self,
"static const TSSymbol ts_supertype_symbols[SUPERTYPE_COUNT] = {{"
);
indent!(self);
for supertype in self.supertype_map.keys() {
add_line!(self, "{supertype},");
}
dedent!(self);
add_line!(self, "}};\n");
add_line!(
self,
"static const TSMapSlice ts_supertype_map_slices[] = {{",
);
indent!(self);
let mut row_id = 0;
let mut supertype_ids = vec![0];
let mut supertype_string_map = BTreeMap::new();
for (supertype, subtypes) in &self.supertype_map {
supertype_string_map.insert(
supertype,
subtypes
.iter()
.flat_map(|s| match s {
ChildType::Normal(symbol) => vec![self.symbol_ids.get(symbol).cloned()],
ChildType::Aliased(alias) => {
self.alias_ids.get(alias).cloned().map_or_else(
|| {
self.symbols_for_alias(alias)
.into_iter()
.map(|s| self.symbol_ids.get(&s).cloned())
.collect()
},
|a| vec![Some(a)],
)
}
})
.flatten()
.collect::<BTreeSet<String>>(),
);
}
for (supertype, subtypes) in &supertype_string_map {
let length = subtypes.len();
add_line!(
self,
"[{supertype}] = {{.index = {row_id}, .length = {length}}},",
);
row_id += length;
supertype_ids.push(row_id);
}
dedent!(self);
add_line!(self, "}};");
add_line!(self, "");
add_line!(
self,
"static const TSSymbol ts_supertype_map_entries[] = {{",
);
indent!(self);
for (i, (_, subtypes)) in supertype_string_map.iter().enumerate() {
let row_index = supertype_ids[i];
add_line!(self, "[{row_index}] =");
indent!(self);
for subtype in subtypes {
add_whitespace!(self);
add!(self, "{subtype},\n");
}
dedent!(self);
}
dedent!(self);
add_line!(self, "}};");
add_line!(self, "");
}
fn add_lex_function(&mut self, name: &str, lex_table: LexTable) {
add_line!(
self,
@ -1462,6 +1550,9 @@ impl Generator {
add_line!(self, ".state_count = STATE_COUNT,");
add_line!(self, ".large_state_count = LARGE_STATE_COUNT,");
add_line!(self, ".production_id_count = PRODUCTION_ID_COUNT,");
if self.abi_version >= ABI_VERSION_WITH_RESERVED_WORDS {
add_line!(self, ".supertype_count = SUPERTYPE_COUNT,");
}
add_line!(self, ".field_count = FIELD_COUNT,");
add_line!(
self,
@ -1483,6 +1574,11 @@ impl Generator {
add_line!(self, ".field_map_slices = ts_field_map_slices,");
add_line!(self, ".field_map_entries = ts_field_map_entries,");
}
if !self.supertype_map.is_empty() && self.abi_version >= ABI_VERSION_WITH_RESERVED_WORDS {
add_line!(self, ".supertype_map_slices = ts_supertype_map_slices,");
add_line!(self, ".supertype_map_entries = ts_supertype_map_entries,");
add_line!(self, ".supertype_symbols = ts_supertype_symbols,");
}
add_line!(self, ".symbol_metadata = ts_symbol_metadata,");
add_line!(self, ".public_symbol_map = ts_symbol_map,");
add_line!(self, ".alias_map = ts_non_terminal_alias_map,");
@ -1635,6 +1731,23 @@ impl Generator {
}
}
fn symbols_for_alias(&self, alias: &Alias) -> Vec<Symbol> {
self.parse_table
.symbols
.iter()
.copied()
.filter(move |symbol| {
self.default_aliases.get(symbol).map_or_else(
|| {
let (name, kind) = self.metadata_for_symbol(*symbol);
name == alias.value && kind == alias.kind()
},
|default_alias| default_alias == alias,
)
})
.collect()
}
fn sanitize_identifier(&self, name: &str) -> String {
let mut result = String::with_capacity(name.len());
for c in name.chars() {
@ -1802,6 +1915,7 @@ pub fn render_c_code(
lexical_grammar: LexicalGrammar,
default_aliases: AliasMap,
abi_version: usize,
supertype_symbol_map: BTreeMap<Symbol, Vec<ChildType>>,
) -> String {
assert!(
(ABI_VERSION_MIN..=ABI_VERSION_MAX).contains(&abi_version),
@ -1819,6 +1933,7 @@ pub fn render_c_code(
lexical_grammar,
default_aliases,
abi_version,
supertype_symbol_map,
..Default::default()
}
.generate()