From 911fb7f1b2e746f2b103973b89bec89841bb1216 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 26 May 2020 13:39:11 -0700 Subject: [PATCH] Extract helper functions to reduce the code size of the lexer function (#626) * Extract helper functions to reduce code size of ts_lex * Name char set helper functions based on token name --- cli/src/generate/nfa.rs | 77 +++++++++ cli/src/generate/render.rs | 326 +++++++++++++++++++++---------------- 2 files changed, 263 insertions(+), 140 deletions(-) diff --git a/cli/src/generate/nfa.rs b/cli/src/generate/nfa.rs index abab8739..bf9ca58d 100644 --- a/cli/src/generate/nfa.rs +++ b/cli/src/generate/nfa.rs @@ -1,8 +1,10 @@ use std::char; use std::cmp::max; use std::cmp::Ordering; +use std::collections::HashSet; use std::fmt; use std::mem::swap; +use std::ops::Range; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum CharacterSet { @@ -178,6 +180,40 @@ impl CharacterSet { } } + pub fn ranges<'a>( + chars: &'a Vec, + ruled_out_characters: &'a HashSet, + ) -> impl Iterator> + 'a { + let mut prev_range: Option> = None; + chars + .iter() + .map(|c| (*c, false)) + .chain(Some(('\0', true))) + .filter_map(move |(c, done)| { + if done { + return prev_range.clone(); + } + if ruled_out_characters.contains(&(c as u32)) { + return None; + } + if let Some(range) = prev_range.clone() { + let mut prev_range_successor = range.end as u32 + 1; + while prev_range_successor < c as u32 { + if !ruled_out_characters.contains(&prev_range_successor) { + prev_range = Some(c..c); + return Some(range); + } + prev_range_successor += 1; + } + prev_range = Some(range.start..c); + None + } else { + prev_range = Some(c..c); + None + } + }) + } + #[cfg(test)] pub fn contains(&self, c: char) -> bool { match self { @@ -825,4 +861,45 @@ mod tests { assert!(a.does_intersect(&b)); assert!(b.does_intersect(&a)); } + + #[test] + fn test_character_set_get_ranges() { + struct Row { + chars: Vec, + ruled_out_chars: Vec, + expected_ranges: Vec>, + } + + let table = [ + Row { + chars: vec!['a'], + ruled_out_chars: vec![], + expected_ranges: vec!['a'..'a'], + }, + Row { + chars: vec!['a', 'b', 'c', 'e', 'z'], + ruled_out_chars: vec![], + expected_ranges: vec!['a'..'c', 'e'..'e', 'z'..'z'], + }, + Row { + chars: vec!['a', 'b', 'c', 'e', 'h', 'z'], + ruled_out_chars: vec!['d', 'f', 'g'], + expected_ranges: vec!['a'..'h', 'z'..'z'], + }, + ]; + + for Row { + chars, + ruled_out_chars, + expected_ranges, + } in table.iter() + { + let ruled_out_chars = ruled_out_chars + .into_iter() + .map(|c: &char| *c as u32) + .collect(); + let ranges = CharacterSet::ranges(chars, &ruled_out_chars).collect::>(); + assert_eq!(ranges, *expected_ranges); + } + } } diff --git a/cli/src/generate/render.rs b/cli/src/generate/render.rs index d6191d9d..f4a4bfc4 100644 --- a/cli/src/generate/render.rs +++ b/cli/src/generate/render.rs @@ -15,6 +15,8 @@ use std::mem::swap; // stabilized, and the parser generation does not use it by default. const STABLE_LANGUAGE_VERSION: usize = tree_sitter::LANGUAGE_VERSION - 1; +const LARGE_CHARACTER_RANGE_COUNT: usize = 8; + macro_rules! add { ($this: tt, $($arg: tt)*) => {{ $this.buffer.write_fmt(format_args!($($arg)*)).unwrap(); @@ -72,6 +74,12 @@ struct Generator { next_abi: bool, } +struct TransitionSummary { + is_included: bool, + ranges: Vec>, + call_id: Option, +} + impl Generator { fn generate(mut self) -> String { self.init(); @@ -99,12 +107,12 @@ impl Generator { let mut main_lex_table = LexTable::default(); swap(&mut main_lex_table, &mut self.main_lex_table); - self.add_lex_function("ts_lex", main_lex_table); + self.add_lex_function("ts_lex", main_lex_table, true); if self.keyword_capture_token.is_some() { let mut keyword_lex_table = LexTable::default(); swap(&mut keyword_lex_table, &mut self.keyword_lex_table); - self.add_lex_function("ts_lex_keywords", keyword_lex_table); + self.add_lex_function("ts_lex_keywords", keyword_lex_table, false); } self.add_lex_modes_list(); @@ -570,7 +578,100 @@ impl Generator { add_line!(self, ""); } - fn add_lex_function(&mut self, name: &str, lex_table: LexTable) { + fn add_lex_function( + &mut self, + name: &str, + lex_table: LexTable, + extract_helper_functions: bool, + ) { + let mut ruled_out_chars = HashSet::new(); + let mut large_character_sets = Vec::<(Symbol, usize, Vec>)>::new(); + + // For each lex state, compute a summary of the code that needs to be + // generated. + let state_transition_summaries: Vec> = lex_table + .states + .iter() + .map(|state| { + ruled_out_chars.clear(); + + // For each state transition, compute the set of character ranges + // that need to be checked. + state + .advance_actions + .iter() + .map(|(chars, action)| { + let (chars, is_included) = match chars { + CharacterSet::Include(c) => (c, true), + CharacterSet::Exclude(c) => (c, false), + }; + let mut call_id = None; + let mut ranges = + CharacterSet::ranges(chars, &ruled_out_chars).collect::>(); + if is_included { + ruled_out_chars.extend(chars.iter().map(|c| *c as u32)); + } else { + ranges.insert(0, '\0'..'\0') + } + + // Record any large character sets so that they can be extracted + // into helper functions, reducing code duplication. + if extract_helper_functions && ranges.len() > LARGE_CHARACTER_RANGE_COUNT { + let char_set_symbol = self + .symbol_for_advance_action(action, &lex_table) + .expect("No symbol for lex state"); + let mut count_for_symbol = 0; + for (i, (symbol, _, r)) in large_character_sets.iter().enumerate() { + if r == &ranges { + call_id = Some(i); + break; + } + if *symbol == char_set_symbol { + count_for_symbol += 1; + } + } + if call_id.is_none() { + call_id = Some(large_character_sets.len()); + large_character_sets.push(( + char_set_symbol, + count_for_symbol + 1, + ranges.clone(), + )); + } + } + + TransitionSummary { + is_included, + ranges, + call_id, + } + }) + .collect() + }) + .collect(); + + // Generate a helper function for each large character set. + let mut sorted_large_char_sets: Vec<_> = large_character_sets.iter().map(|e| e).collect(); + sorted_large_char_sets.sort_unstable_by_key(|(sym, count, _)| (sym, count)); + for (sym, count, ranges) in sorted_large_char_sets { + add_line!( + self, + "static inline bool {}_character_set_{}(int32_t lookahead) {{", + self.symbol_ids[sym], + count + ); + indent!(self); + add_line!(self, "return"); + indent!(self); + add_whitespace!(self); + self.add_character_range_conditions(ranges, true, 0); + add!(self, ";\n"); + dedent!(self); + dedent!(self); + add_line!(self, "}}"); + add_line!(self, ""); + } + add_line!( self, "static bool {}(TSLexer *lexer, TSStateId state) {{", @@ -591,7 +692,7 @@ impl Generator { for (i, state) in lex_table.states.into_iter().enumerate() { add_line!(self, "case {}:", i); indent!(self); - self.add_lex_state(state); + self.add_lex_state(state, &state_transition_summaries[i], &large_character_sets); dedent!(self); } @@ -607,7 +708,35 @@ impl Generator { add_line!(self, ""); } - fn add_lex_state(&mut self, state: LexState) { + fn symbol_for_advance_action( + &self, + action: &AdvanceAction, + lex_table: &LexTable, + ) -> Option { + let mut state_ids = vec![action.state]; + let mut i = 0; + while i < state_ids.len() { + let id = state_ids[i]; + let state = &lex_table.states[id]; + if let Some(accept) = state.accept_action { + return Some(accept); + } + for (_, action) in &state.advance_actions { + if !state_ids.contains(&action.state) { + state_ids.push(action.state); + } + } + i += 1; + } + return None; + } + + fn add_lex_state( + &mut self, + state: LexState, + transition_info: &Vec, + large_character_sets: &Vec<(Symbol, usize, Vec>)>, + ) { if let Some(accept_action) = state.accept_action { add_line!(self, "ACCEPT_TOKEN({});", self.symbol_ids[&accept_action]); } @@ -616,75 +745,53 @@ impl Generator { add_line!(self, "if (eof) ADVANCE({});", eof_action.state); } - let mut ruled_out_characters = HashSet::new(); - for (characters, action) in state.advance_actions { - let previous_length = self.buffer.len(); - + for (i, (_, action)) in state.advance_actions.into_iter().enumerate() { + let transition = &transition_info[i]; add_whitespace!(self); - add!(self, "if ("); - if self.add_character_set_condition(&characters, &ruled_out_characters) { - add!(self, ") "); - self.add_advance_action(&action); - if let CharacterSet::Include(chars) = characters { - ruled_out_characters.extend(chars.iter().map(|c| *c as u32)); + + // If there is a helper function for this transition's character + // set, then generate a call to that helper function. + if let Some(call_id) = transition.call_id { + add!(self, "if ("); + if !transition.is_included { + add!(self, "!"); } - } else { - self.buffer.truncate(previous_length); - self.add_advance_action(&action); + let (symbol, count, _) = &large_character_sets[call_id]; + add!( + self, + "{}_character_set_{}(lookahead)) ", + self.symbol_ids[symbol], + count + ); } + // Otherwise, generate code to compare the lookahead character + // with all of the character ranges. + else if transition.ranges.len() > 0 { + add!(self, "if ("); + self.add_character_range_conditions(&transition.ranges, transition.is_included, 2); + add!(self, ") "); + } + self.add_advance_action(&action); add!(self, "\n"); } add_line!(self, "END_STATE();"); } - fn add_character_set_condition( - &mut self, - characters: &CharacterSet, - ruled_out_characters: &HashSet, - ) -> bool { - match characters { - CharacterSet::Include(chars) => { - let ranges = Self::get_ranges(chars, ruled_out_characters); - self.add_character_range_conditions(ranges, false) - } - CharacterSet::Exclude(chars) => { - let ranges = Some('\0'..'\0') - .into_iter() - .chain(Self::get_ranges(chars, ruled_out_characters)); - self.add_character_range_conditions(ranges, true) - } - } - } - fn add_character_range_conditions( &mut self, - ranges: impl Iterator>, - is_negated: bool, + ranges: &[Range], + is_included: bool, + indent_count: usize, ) -> bool { - let line_break = "\n "; + let mut line_break = "\n".to_string(); + for _ in 0..self.indent_level + indent_count { + line_break.push_str(" "); + } + let mut did_add = false; for range in ranges { - if is_negated { - if did_add { - add!(self, " &&{}", line_break); - } - if range.end == range.start { - add!(self, "lookahead != "); - self.add_character(range.start); - } else if range.end as u32 == range.start as u32 + 1 { - add!(self, "lookahead != "); - self.add_character(range.start); - add!(self, " &&{}lookahead != ", line_break); - self.add_character(range.end); - } else { - add!(self, "(lookahead < "); - self.add_character(range.start); - add!(self, " || "); - self.add_character(range.end); - add!(self, " < lookahead)"); - } - } else { + if is_included { if did_add { add!(self, " ||{}", line_break); } @@ -703,46 +810,31 @@ impl Generator { self.add_character(range.end); add!(self, ")"); } + } else { + if did_add { + add!(self, " &&{}", line_break); + } + if range.end == range.start { + add!(self, "lookahead != "); + self.add_character(range.start); + } else if range.end as u32 == range.start as u32 + 1 { + add!(self, "lookahead != "); + self.add_character(range.start); + add!(self, " &&{}lookahead != ", line_break); + self.add_character(range.end); + } else { + add!(self, "(lookahead < "); + self.add_character(range.start); + add!(self, " || "); + self.add_character(range.end); + add!(self, " < lookahead)"); + } } did_add = true; } did_add } - fn get_ranges<'a>( - chars: &'a Vec, - ruled_out_characters: &'a HashSet, - ) -> impl Iterator> + 'a { - let mut prev_range: Option> = None; - chars - .iter() - .map(|c| (*c, false)) - .chain(Some(('\0', true))) - .filter_map(move |(c, done)| { - if done { - return prev_range.clone(); - } - if ruled_out_characters.contains(&(c as u32)) { - return None; - } - if let Some(range) = prev_range.clone() { - let mut prev_range_successor = range.end as u32 + 1; - while prev_range_successor < c as u32 { - if !ruled_out_characters.contains(&prev_range_successor) { - prev_range = Some(c..c); - return Some(range); - } - prev_range_successor += 1; - } - prev_range = Some(range.start..c); - None - } else { - prev_range = Some(c..c); - None - } - }) - } - fn add_advance_action(&mut self, action: &AdvanceAction) { if action.in_main_token { add!(self, "ADVANCE({});", action.state); @@ -1436,49 +1528,3 @@ pub(crate) fn render_c_code( } .generate() } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_get_char_ranges() { - struct Row { - chars: Vec, - ruled_out_chars: Vec, - expected_ranges: Vec>, - } - - let table = [ - Row { - chars: vec!['a'], - ruled_out_chars: vec![], - expected_ranges: vec!['a'..'a'], - }, - Row { - chars: vec!['a', 'b', 'c', 'e', 'z'], - ruled_out_chars: vec![], - expected_ranges: vec!['a'..'c', 'e'..'e', 'z'..'z'], - }, - Row { - chars: vec!['a', 'b', 'c', 'e', 'h', 'z'], - ruled_out_chars: vec!['d', 'f', 'g'], - expected_ranges: vec!['a'..'h', 'z'..'z'], - }, - ]; - - for Row { - chars, - ruled_out_chars, - expected_ranges, - } in table.iter() - { - let ruled_out_chars = ruled_out_chars - .into_iter() - .map(|c: &char| *c as u32) - .collect(); - let ranges = Generator::get_ranges(chars, &ruled_out_chars).collect::>(); - assert_eq!(ranges, *expected_ranges); - } - } -}