Store parse states with few lookahead symbols in a more compact way

This commit is contained in:
Max Brunsfeld 2019-05-16 16:59:50 -07:00
parent 48a883c1d4
commit 09a2755399
7 changed files with 135 additions and 11 deletions

View file

@ -45,6 +45,8 @@ macro_rules! dedent {
};
}
const SMALL_STATE_THRESHOLD: usize = 48;
struct Generator {
buffer: String,
indent_level: usize,
@ -52,10 +54,12 @@ struct Generator {
parse_table: ParseTable,
main_lex_table: LexTable,
keyword_lex_table: LexTable,
large_state_count: usize,
keyword_capture_token: Option<Symbol>,
syntax_grammar: SyntaxGrammar,
lexical_grammar: LexicalGrammar,
simple_aliases: AliasMap,
symbol_order: HashMap<Symbol, usize>,
symbol_ids: HashMap<Symbol, String>,
alias_ids: HashMap<Alias, String>,
alias_map: BTreeMap<Alias, Option<Symbol>>,
@ -144,6 +148,15 @@ impl Generator {
}
}
self.large_state_count = self
.parse_table
.states
.iter()
.take_while(|s| {
s.terminal_entries.len() + s.nonterminal_entries.len() > SMALL_STATE_THRESHOLD
})
.count();
field_names.sort_unstable();
field_names.dedup();
self.field_names = field_names.into_iter().cloned().collect();
@ -203,6 +216,7 @@ impl Generator {
"#define STATE_COUNT {}",
self.parse_table.states.len()
);
add_line!(self, "#define LARGE_STATE_COUNT {}", self.large_state_count);
add_line!(
self,
"#define SYMBOL_COUNT {}",
@ -231,9 +245,11 @@ impl Generator {
fn add_symbol_enum(&mut self) {
add_line!(self, "enum {{");
indent!(self);
self.symbol_order.insert(Symbol::end(), 0);
let mut i = 1;
for symbol in self.parse_table.symbols.iter() {
if *symbol != Symbol::end() {
self.symbol_order.insert(*symbol, i);
add_line!(self, "{} = {},", self.symbol_ids[&symbol], i);
i += 1;
}
@ -733,25 +749,37 @@ impl Generator {
add_line!(
self,
"static uint16_t ts_parse_table[STATE_COUNT][SYMBOL_COUNT] = {{"
"static uint16_t ts_parse_table[LARGE_STATE_COUNT][SYMBOL_COUNT] = {{"
);
indent!(self);
let mut terminal_entries = Vec::new();
let mut nonterminal_entries = Vec::new();
for (i, state) in self.parse_table.states.iter().enumerate() {
for (i, state) in self
.parse_table
.states
.iter()
.enumerate()
.take(self.large_state_count)
{
add_line!(self, "[{}] = {{", i);
indent!(self);
terminal_entries.clear();
nonterminal_entries.clear();
terminal_entries.extend(state.terminal_entries.iter());
nonterminal_entries.extend(state.nonterminal_entries.iter());
terminal_entries.sort_unstable_by_key(|e| e.0);
nonterminal_entries.sort_unstable_by_key(|e| e.0);
terminal_entries.sort_unstable_by_key(|e| self.symbol_order.get(e.0));
nonterminal_entries.sort_unstable_by_key(|k| k.0);
add_line!(self, "[{}] = {{", i);
indent!(self);
for (symbol, state_id) in &nonterminal_entries {
add_line!(self, "[{}] = STATE({}),", self.symbol_ids[symbol], state_id);
add_line!(
self,
"[{}] = STATE({}),",
self.symbol_ids[symbol],
*state_id
);
}
for (symbol, entry) in &terminal_entries {
@ -774,6 +802,57 @@ impl Generator {
add_line!(self, "}};");
add_line!(self, "");
add_line!(self, "static uint32_t ts_small_parse_table_map[] = {{");
indent!(self);
let mut index = 0;
for (i, state) in self
.parse_table
.states
.iter()
.enumerate()
.skip(self.large_state_count)
{
add_line!(self, "[SMALL_STATE({})] = {},", i, index);
index += 1 + 2 * state.symbol_count();
}
dedent!(self);
add_line!(self, "}};");
add_line!(self, "");
index = 0;
add_line!(self, "static uint16_t ts_small_parse_table[] = {{");
indent!(self);
for state in self.parse_table.states.iter().skip(self.large_state_count) {
add_line!(self, "[{}] = {},", index, state.symbol_count());
indent!(self);
terminal_entries.clear();
nonterminal_entries.clear();
terminal_entries.extend(state.terminal_entries.iter());
nonterminal_entries.extend(state.nonterminal_entries.iter());
terminal_entries.sort_unstable_by_key(|e| self.symbol_order.get(e.0));
nonterminal_entries.sort_unstable_by_key(|k| k.0);
for (symbol, entry) in &terminal_entries {
let entry_id = self.get_parse_action_list_id(
entry,
&mut parse_table_entries,
&mut next_parse_action_list_index,
);
add_line!(self, "{}, ACTIONS({}),", self.symbol_ids[symbol], entry_id);
}
for (symbol, state_id) in &nonterminal_entries {
add_line!(self, "{}, STATE({}),", self.symbol_ids[symbol], *state_id);
}
dedent!(self);
index += 1 + 2 * state.symbol_count();
}
dedent!(self);
add_line!(self, "}};");
add_line!(self, "");
self.add_parse_action_list(parse_table_entries);
}
@ -872,11 +951,20 @@ impl Generator {
add_line!(self, ".symbol_count = SYMBOL_COUNT,");
add_line!(self, ".alias_count = ALIAS_COUNT,");
add_line!(self, ".token_count = TOKEN_COUNT,");
add_line!(self, ".large_state_count = LARGE_STATE_COUNT,");
add_line!(self, ".symbol_metadata = ts_symbol_metadata,");
add_line!(
self,
".parse_table = (const unsigned short *)ts_parse_table,"
);
add_line!(
self,
".small_parse_table = (const uint16_t *)ts_small_parse_table,"
);
add_line!(
self,
".small_parse_table_map = (const uint32_t *)ts_small_parse_table_map,"
);
add_line!(self, ".parse_actions = ts_parse_actions,");
add_line!(self, ".lex_modes = ts_lex_modes,");
add_line!(self, ".symbol_names = ts_symbol_names,");
@ -1131,6 +1219,7 @@ pub(crate) fn render_c_code(
buffer: String::new(),
indent_level: 0,
language_name: name.to_string(),
large_state_count: 0,
parse_table,
main_lex_table,
keyword_lex_table,
@ -1139,6 +1228,7 @@ pub(crate) fn render_c_code(
lexical_grammar,
simple_aliases,
symbol_ids: HashMap::new(),
symbol_order: HashMap::new(),
alias_ids: HashMap::new(),
alias_map: BTreeMap::new(),
field_names: Vec::new(),

View file

@ -94,6 +94,10 @@ impl Default for LexTable {
}
impl ParseState {
pub fn symbol_count(&self) -> usize {
self.terminal_entries.len() + self.nonterminal_entries.len()
}
pub fn referenced_states<'a>(&'a self) -> impl Iterator<Item = ParseStateId> + 'a {
self.terminal_entries
.iter()

View file

@ -591,5 +591,5 @@ extern "C" {
pub fn ts_language_version(arg1: *const TSLanguage) -> u32;
}
pub const TREE_SITTER_LANGUAGE_VERSION: usize = 10;
pub const TREE_SITTER_LANGUAGE_VERSION: usize = 11;
pub const TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION: usize = 9;

View file

@ -14,7 +14,7 @@ extern "C" {
/* Section - ABI Versioning */
/****************************/
#define TREE_SITTER_LANGUAGE_VERSION 10
#define TREE_SITTER_LANGUAGE_VERSION 11
#define TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION 9
/*******************/

View file

@ -114,6 +114,9 @@ struct TSLanguage {
const TSFieldMapSlice *field_map_slices;
const TSFieldMapEntry *field_map_entries;
const char **field_names;
uint32_t large_state_count;
const uint16_t *small_parse_table;
const uint32_t *small_parse_table_map;
};
/*
@ -155,6 +158,8 @@ struct TSLanguage {
* Parse Table Macros
*/
#define SMALL_STATE(id) id - LARGE_STATE_COUNT
#define STATE(id) id
#define ACTIONS(id) id

View file

@ -11,7 +11,7 @@ void ts_language_table_entry(const TSLanguage *self, TSStateId state,
result->actions = NULL;
} else {
assert(symbol < self->token_count);
uint32_t action_index = self->parse_table[state * self->symbol_count + symbol];
uint32_t action_index = ts_language_lookup(self, state, symbol);
const TSParseActionEntry *entry = &self->parse_actions[action_index];
result->action_count = entry->count;
result->is_reusable = entry->reusable;

View file

@ -10,6 +10,7 @@ extern "C" {
#define ts_builtin_sym_error_repeat (ts_builtin_sym_error - 1)
#define TREE_SITTER_LANGUAGE_VERSION_WITH_FIELDS 10
#define TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES 11
typedef struct {
const TSParseAction *actions;
@ -51,6 +52,30 @@ static inline bool ts_language_has_reduce_action(const TSLanguage *self,
return entry.action_count > 0 && entry.actions[0].type == TSParseActionTypeReduce;
}
static inline uint16_t ts_language_lookup(
const TSLanguage *self,
TSStateId state,
TSSymbol symbol
) {
if (
self->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES &&
state >= self->large_state_count
) {
uint32_t index = self->small_parse_table_map[state - self->large_state_count];
const uint16_t *state_data = &self->small_parse_table[index];
uint16_t symbol_count = *state_data;
state_data++;
for (unsigned i = 0; i < symbol_count; i++) {
if (state_data[0] == symbol) return state_data[1];
if (state_data[0] > symbol) break;
state_data += 2;
}
return 0;
} else {
return self->parse_table[state * self->symbol_count + symbol];
}
}
static inline TSStateId ts_language_next_state(const TSLanguage *self,
TSStateId state,
TSSymbol symbol) {
@ -67,7 +92,7 @@ static inline TSStateId ts_language_next_state(const TSLanguage *self,
}
return 0;
} else {
return self->parse_table[state * self->symbol_count + symbol];
return ts_language_lookup(self, state, symbol);
}
}