diff --git a/cli/src/generate/render.rs b/cli/src/generate/render.rs index c5f1545b..34d8f391 100644 --- a/cli/src/generate/render.rs +++ b/cli/src/generate/render.rs @@ -49,7 +49,7 @@ macro_rules! dedent { }; } -const SMALL_STATE_THRESHOLD: usize = 48; +const SMALL_STATE_THRESHOLD: usize = 64; struct Generator { buffer: String, @@ -161,10 +161,7 @@ impl Generator { // "small parse states". Otherwise, use the same representation for all parse // states. if self.next_abi { - let threshold = cmp::min( - SMALL_STATE_THRESHOLD, - self.parse_table.symbols.len() / 2 - 1, - ); + let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2); self.large_state_count = self .parse_table .states @@ -801,6 +798,8 @@ impl Generator { add_line!(self, "[{}] = {{", i); indent!(self); + // Ensure the entries are in a deterministic order, since they are + // internally represented as a hash map. terminal_entries.clear(); nonterminal_entries.clear(); terminal_entries.extend(state.terminal_entries.iter()); @@ -838,52 +837,85 @@ impl Generator { add_line!(self, ""); if self.large_state_count < self.parse_table.states.len() { - add_line!(self, "static uint32_t ts_small_parse_table_map[] = {{"); - indent!(self); - let mut index = 0; - for (i, state) in self - .parse_table - .states - .iter() - .enumerate() - .skip(self.large_state_count) - { - add_line!(self, "[SMALL_STATE({})] = {},", i, index); - index += 1 + 2 * state.symbol_count(); - } - dedent!(self); - add_line!(self, "}};"); - add_line!(self, ""); - - index = 0; add_line!(self, "static uint16_t ts_small_parse_table[] = {{"); indent!(self); + + let mut index = 0; + let mut small_state_indices = Vec::new(); + let mut symbols_by_value: HashMap<(usize, SymbolType), Vec> = HashMap::new(); for state in self.parse_table.states.iter().skip(self.large_state_count) { - add_line!(self, "[{}] = {},", index, state.symbol_count()); - indent!(self); + small_state_indices.push(index); + symbols_by_value.clear(); terminal_entries.clear(); - nonterminal_entries.clear(); terminal_entries.extend(state.terminal_entries.iter()); - nonterminal_entries.extend(state.nonterminal_entries.iter()); terminal_entries.sort_unstable_by_key(|e| self.symbol_order.get(e.0)); - nonterminal_entries.sort_unstable_by_key(|k| k.0); + // In a given parse state, many lookahead symbols have the same actions. + // So in the "small state" representation, group symbols by their action + // in order to avoid repeating the action. for (symbol, entry) in &terminal_entries { let entry_id = self.get_parse_action_list_id( entry, &mut parse_table_entries, &mut next_parse_action_list_index, ); - add_line!(self, "{}, ACTIONS({}),", self.symbol_ids[symbol], entry_id); + symbols_by_value + .entry((entry_id, SymbolType::Terminal)) + .or_default() + .push(**symbol); + } + for (symbol, state_id) in &state.nonterminal_entries { + symbols_by_value + .entry((*state_id, SymbolType::NonTerminal)) + .or_default() + .push(*symbol); } - for (symbol, state_id) in &nonterminal_entries { - add_line!(self, "{}, STATE({}),", self.symbol_ids[symbol], *state_id); + let mut values_with_symbols = symbols_by_value.drain().collect::>(); + values_with_symbols.sort_unstable_by_key(|((value, kind), symbols)| { + (symbols.len(), *kind, *value, symbols[0]) + }); + + add_line!(self, "[{}] = {},", index, values_with_symbols.len()); + indent!(self); + + for ((value, kind), symbols) in values_with_symbols.iter_mut() { + if *kind == SymbolType::NonTerminal { + add_line!(self, "STATE({}), {},", value, symbols.len()); + } else { + add_line!(self, "ACTIONS({}), {},", value, symbols.len()); + } + + symbols.sort_unstable(); + indent!(self); + for symbol in symbols { + add_line!(self, "{},", self.symbol_ids[symbol]); + } + dedent!(self); } + dedent!(self); - index += 1 + 2 * state.symbol_count(); + index += 1 + values_with_symbols + .iter() + .map(|(_, symbols)| 2 + symbols.len()) + .sum::(); + } + + dedent!(self); + add_line!(self, "}};"); + add_line!(self, ""); + + add_line!(self, "static uint32_t ts_small_parse_table_map[] = {{"); + indent!(self); + for i in self.large_state_count..self.parse_table.states.len() { + add_line!( + self, + "[SMALL_STATE({})] = {},", + i, + small_state_indices[i - self.large_state_count] + ); } dedent!(self); add_line!(self, "}};"); diff --git a/cli/src/generate/tables.rs b/cli/src/generate/tables.rs index 8a8cc089..fb593953 100644 --- a/cli/src/generate/tables.rs +++ b/cli/src/generate/tables.rs @@ -94,10 +94,6 @@ impl Default for LexTable { } impl ParseState { - pub fn symbol_count(&self) -> usize { - self.terminal_entries.len() + self.nonterminal_entries.len() - } - pub fn referenced_states<'a>(&'a self) -> impl Iterator + 'a { self.terminal_entries .iter() diff --git a/lib/src/language.h b/lib/src/language.h index de33d2d7..0741486a 100644 --- a/lib/src/language.h +++ b/lib/src/language.h @@ -62,13 +62,14 @@ static inline uint16_t ts_language_lookup( state >= self->large_state_count ) { uint32_t index = self->small_parse_table_map[state - self->large_state_count]; - const uint16_t *state_data = &self->small_parse_table[index]; - uint16_t symbol_count = *state_data; - state_data++; - for (unsigned i = 0; i < symbol_count; i++) { - if (state_data[0] == symbol) return state_data[1]; - if (state_data[0] > symbol) break; - state_data += 2; + const uint16_t *data = &self->small_parse_table[index]; + uint16_t section_count = *(data++); + for (unsigned i = 0; i < section_count; i++) { + uint16_t section_value = *(data++); + uint16_t symbol_count = *(data++); + for (unsigned i = 0; i < symbol_count; i++) { + if (*(data++) == symbol) return section_value; + } } return 0; } else {