Extract helper functions to reduce the code size of the lexer function (#626)

* Extract helper functions to reduce code size of ts_lex

* Name char set helper functions based on token name
This commit is contained in:
Max Brunsfeld 2020-05-26 13:39:11 -07:00 committed by GitHub
parent e8e80b1cf1
commit 911fb7f1b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 263 additions and 140 deletions

View file

@ -1,8 +1,10 @@
use std::char;
use std::cmp::max;
use std::cmp::Ordering;
use std::collections::HashSet;
use std::fmt;
use std::mem::swap;
use std::ops::Range;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum CharacterSet {
@ -178,6 +180,40 @@ impl CharacterSet {
}
}
pub fn ranges<'a>(
chars: &'a Vec<char>,
ruled_out_characters: &'a HashSet<u32>,
) -> impl Iterator<Item = Range<char>> + 'a {
let mut prev_range: Option<Range<char>> = None;
chars
.iter()
.map(|c| (*c, false))
.chain(Some(('\0', true)))
.filter_map(move |(c, done)| {
if done {
return prev_range.clone();
}
if ruled_out_characters.contains(&(c as u32)) {
return None;
}
if let Some(range) = prev_range.clone() {
let mut prev_range_successor = range.end as u32 + 1;
while prev_range_successor < c as u32 {
if !ruled_out_characters.contains(&prev_range_successor) {
prev_range = Some(c..c);
return Some(range);
}
prev_range_successor += 1;
}
prev_range = Some(range.start..c);
None
} else {
prev_range = Some(c..c);
None
}
})
}
#[cfg(test)]
pub fn contains(&self, c: char) -> bool {
match self {
@ -825,4 +861,45 @@ mod tests {
assert!(a.does_intersect(&b));
assert!(b.does_intersect(&a));
}
#[test]
fn test_character_set_get_ranges() {
struct Row {
chars: Vec<char>,
ruled_out_chars: Vec<char>,
expected_ranges: Vec<Range<char>>,
}
let table = [
Row {
chars: vec!['a'],
ruled_out_chars: vec![],
expected_ranges: vec!['a'..'a'],
},
Row {
chars: vec!['a', 'b', 'c', 'e', 'z'],
ruled_out_chars: vec![],
expected_ranges: vec!['a'..'c', 'e'..'e', 'z'..'z'],
},
Row {
chars: vec!['a', 'b', 'c', 'e', 'h', 'z'],
ruled_out_chars: vec!['d', 'f', 'g'],
expected_ranges: vec!['a'..'h', 'z'..'z'],
},
];
for Row {
chars,
ruled_out_chars,
expected_ranges,
} in table.iter()
{
let ruled_out_chars = ruled_out_chars
.into_iter()
.map(|c: &char| *c as u32)
.collect();
let ranges = CharacterSet::ranges(chars, &ruled_out_chars).collect::<Vec<_>>();
assert_eq!(ranges, *expected_ranges);
}
}
}

View file

@ -15,6 +15,8 @@ use std::mem::swap;
// stabilized, and the parser generation does not use it by default.
const STABLE_LANGUAGE_VERSION: usize = tree_sitter::LANGUAGE_VERSION - 1;
const LARGE_CHARACTER_RANGE_COUNT: usize = 8;
macro_rules! add {
($this: tt, $($arg: tt)*) => {{
$this.buffer.write_fmt(format_args!($($arg)*)).unwrap();
@ -72,6 +74,12 @@ struct Generator {
next_abi: bool,
}
struct TransitionSummary {
is_included: bool,
ranges: Vec<Range<char>>,
call_id: Option<usize>,
}
impl Generator {
fn generate(mut self) -> String {
self.init();
@ -99,12 +107,12 @@ impl Generator {
let mut main_lex_table = LexTable::default();
swap(&mut main_lex_table, &mut self.main_lex_table);
self.add_lex_function("ts_lex", main_lex_table);
self.add_lex_function("ts_lex", main_lex_table, true);
if self.keyword_capture_token.is_some() {
let mut keyword_lex_table = LexTable::default();
swap(&mut keyword_lex_table, &mut self.keyword_lex_table);
self.add_lex_function("ts_lex_keywords", keyword_lex_table);
self.add_lex_function("ts_lex_keywords", keyword_lex_table, false);
}
self.add_lex_modes_list();
@ -570,7 +578,100 @@ impl Generator {
add_line!(self, "");
}
fn add_lex_function(&mut self, name: &str, lex_table: LexTable) {
fn add_lex_function(
&mut self,
name: &str,
lex_table: LexTable,
extract_helper_functions: bool,
) {
let mut ruled_out_chars = HashSet::new();
let mut large_character_sets = Vec::<(Symbol, usize, Vec<Range<char>>)>::new();
// For each lex state, compute a summary of the code that needs to be
// generated.
let state_transition_summaries: Vec<Vec<TransitionSummary>> = lex_table
.states
.iter()
.map(|state| {
ruled_out_chars.clear();
// For each state transition, compute the set of character ranges
// that need to be checked.
state
.advance_actions
.iter()
.map(|(chars, action)| {
let (chars, is_included) = match chars {
CharacterSet::Include(c) => (c, true),
CharacterSet::Exclude(c) => (c, false),
};
let mut call_id = None;
let mut ranges =
CharacterSet::ranges(chars, &ruled_out_chars).collect::<Vec<_>>();
if is_included {
ruled_out_chars.extend(chars.iter().map(|c| *c as u32));
} else {
ranges.insert(0, '\0'..'\0')
}
// Record any large character sets so that they can be extracted
// into helper functions, reducing code duplication.
if extract_helper_functions && ranges.len() > LARGE_CHARACTER_RANGE_COUNT {
let char_set_symbol = self
.symbol_for_advance_action(action, &lex_table)
.expect("No symbol for lex state");
let mut count_for_symbol = 0;
for (i, (symbol, _, r)) in large_character_sets.iter().enumerate() {
if r == &ranges {
call_id = Some(i);
break;
}
if *symbol == char_set_symbol {
count_for_symbol += 1;
}
}
if call_id.is_none() {
call_id = Some(large_character_sets.len());
large_character_sets.push((
char_set_symbol,
count_for_symbol + 1,
ranges.clone(),
));
}
}
TransitionSummary {
is_included,
ranges,
call_id,
}
})
.collect()
})
.collect();
// Generate a helper function for each large character set.
let mut sorted_large_char_sets: Vec<_> = large_character_sets.iter().map(|e| e).collect();
sorted_large_char_sets.sort_unstable_by_key(|(sym, count, _)| (sym, count));
for (sym, count, ranges) in sorted_large_char_sets {
add_line!(
self,
"static inline bool {}_character_set_{}(int32_t lookahead) {{",
self.symbol_ids[sym],
count
);
indent!(self);
add_line!(self, "return");
indent!(self);
add_whitespace!(self);
self.add_character_range_conditions(ranges, true, 0);
add!(self, ";\n");
dedent!(self);
dedent!(self);
add_line!(self, "}}");
add_line!(self, "");
}
add_line!(
self,
"static bool {}(TSLexer *lexer, TSStateId state) {{",
@ -591,7 +692,7 @@ impl Generator {
for (i, state) in lex_table.states.into_iter().enumerate() {
add_line!(self, "case {}:", i);
indent!(self);
self.add_lex_state(state);
self.add_lex_state(state, &state_transition_summaries[i], &large_character_sets);
dedent!(self);
}
@ -607,7 +708,35 @@ impl Generator {
add_line!(self, "");
}
fn add_lex_state(&mut self, state: LexState) {
fn symbol_for_advance_action(
&self,
action: &AdvanceAction,
lex_table: &LexTable,
) -> Option<Symbol> {
let mut state_ids = vec![action.state];
let mut i = 0;
while i < state_ids.len() {
let id = state_ids[i];
let state = &lex_table.states[id];
if let Some(accept) = state.accept_action {
return Some(accept);
}
for (_, action) in &state.advance_actions {
if !state_ids.contains(&action.state) {
state_ids.push(action.state);
}
}
i += 1;
}
return None;
}
fn add_lex_state(
&mut self,
state: LexState,
transition_info: &Vec<TransitionSummary>,
large_character_sets: &Vec<(Symbol, usize, Vec<Range<char>>)>,
) {
if let Some(accept_action) = state.accept_action {
add_line!(self, "ACCEPT_TOKEN({});", self.symbol_ids[&accept_action]);
}
@ -616,75 +745,53 @@ impl Generator {
add_line!(self, "if (eof) ADVANCE({});", eof_action.state);
}
let mut ruled_out_characters = HashSet::new();
for (characters, action) in state.advance_actions {
let previous_length = self.buffer.len();
for (i, (_, action)) in state.advance_actions.into_iter().enumerate() {
let transition = &transition_info[i];
add_whitespace!(self);
add!(self, "if (");
if self.add_character_set_condition(&characters, &ruled_out_characters) {
add!(self, ") ");
self.add_advance_action(&action);
if let CharacterSet::Include(chars) = characters {
ruled_out_characters.extend(chars.iter().map(|c| *c as u32));
// If there is a helper function for this transition's character
// set, then generate a call to that helper function.
if let Some(call_id) = transition.call_id {
add!(self, "if (");
if !transition.is_included {
add!(self, "!");
}
} else {
self.buffer.truncate(previous_length);
self.add_advance_action(&action);
let (symbol, count, _) = &large_character_sets[call_id];
add!(
self,
"{}_character_set_{}(lookahead)) ",
self.symbol_ids[symbol],
count
);
}
// Otherwise, generate code to compare the lookahead character
// with all of the character ranges.
else if transition.ranges.len() > 0 {
add!(self, "if (");
self.add_character_range_conditions(&transition.ranges, transition.is_included, 2);
add!(self, ") ");
}
self.add_advance_action(&action);
add!(self, "\n");
}
add_line!(self, "END_STATE();");
}
fn add_character_set_condition(
&mut self,
characters: &CharacterSet,
ruled_out_characters: &HashSet<u32>,
) -> bool {
match characters {
CharacterSet::Include(chars) => {
let ranges = Self::get_ranges(chars, ruled_out_characters);
self.add_character_range_conditions(ranges, false)
}
CharacterSet::Exclude(chars) => {
let ranges = Some('\0'..'\0')
.into_iter()
.chain(Self::get_ranges(chars, ruled_out_characters));
self.add_character_range_conditions(ranges, true)
}
}
}
fn add_character_range_conditions(
&mut self,
ranges: impl Iterator<Item = Range<char>>,
is_negated: bool,
ranges: &[Range<char>],
is_included: bool,
indent_count: usize,
) -> bool {
let line_break = "\n ";
let mut line_break = "\n".to_string();
for _ in 0..self.indent_level + indent_count {
line_break.push_str(" ");
}
let mut did_add = false;
for range in ranges {
if is_negated {
if did_add {
add!(self, " &&{}", line_break);
}
if range.end == range.start {
add!(self, "lookahead != ");
self.add_character(range.start);
} else if range.end as u32 == range.start as u32 + 1 {
add!(self, "lookahead != ");
self.add_character(range.start);
add!(self, " &&{}lookahead != ", line_break);
self.add_character(range.end);
} else {
add!(self, "(lookahead < ");
self.add_character(range.start);
add!(self, " || ");
self.add_character(range.end);
add!(self, " < lookahead)");
}
} else {
if is_included {
if did_add {
add!(self, " ||{}", line_break);
}
@ -703,46 +810,31 @@ impl Generator {
self.add_character(range.end);
add!(self, ")");
}
} else {
if did_add {
add!(self, " &&{}", line_break);
}
if range.end == range.start {
add!(self, "lookahead != ");
self.add_character(range.start);
} else if range.end as u32 == range.start as u32 + 1 {
add!(self, "lookahead != ");
self.add_character(range.start);
add!(self, " &&{}lookahead != ", line_break);
self.add_character(range.end);
} else {
add!(self, "(lookahead < ");
self.add_character(range.start);
add!(self, " || ");
self.add_character(range.end);
add!(self, " < lookahead)");
}
}
did_add = true;
}
did_add
}
fn get_ranges<'a>(
chars: &'a Vec<char>,
ruled_out_characters: &'a HashSet<u32>,
) -> impl Iterator<Item = Range<char>> + 'a {
let mut prev_range: Option<Range<char>> = None;
chars
.iter()
.map(|c| (*c, false))
.chain(Some(('\0', true)))
.filter_map(move |(c, done)| {
if done {
return prev_range.clone();
}
if ruled_out_characters.contains(&(c as u32)) {
return None;
}
if let Some(range) = prev_range.clone() {
let mut prev_range_successor = range.end as u32 + 1;
while prev_range_successor < c as u32 {
if !ruled_out_characters.contains(&prev_range_successor) {
prev_range = Some(c..c);
return Some(range);
}
prev_range_successor += 1;
}
prev_range = Some(range.start..c);
None
} else {
prev_range = Some(c..c);
None
}
})
}
fn add_advance_action(&mut self, action: &AdvanceAction) {
if action.in_main_token {
add!(self, "ADVANCE({});", action.state);
@ -1436,49 +1528,3 @@ pub(crate) fn render_c_code(
}
.generate()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_char_ranges() {
struct Row {
chars: Vec<char>,
ruled_out_chars: Vec<char>,
expected_ranges: Vec<Range<char>>,
}
let table = [
Row {
chars: vec!['a'],
ruled_out_chars: vec![],
expected_ranges: vec!['a'..'a'],
},
Row {
chars: vec!['a', 'b', 'c', 'e', 'z'],
ruled_out_chars: vec![],
expected_ranges: vec!['a'..'c', 'e'..'e', 'z'..'z'],
},
Row {
chars: vec!['a', 'b', 'c', 'e', 'h', 'z'],
ruled_out_chars: vec!['d', 'f', 'g'],
expected_ranges: vec!['a'..'h', 'z'..'z'],
},
];
for Row {
chars,
ruled_out_chars,
expected_ranges,
} in table.iter()
{
let ruled_out_chars = ruled_out_chars
.into_iter()
.map(|c: &char| *c as u32)
.collect();
let ranges = Generator::get_ranges(chars, &ruled_out_chars).collect::<Vec<_>>();
assert_eq!(ranges, *expected_ranges);
}
}
}