diff --git a/cli/benches/benchmark.rs b/cli/benches/benchmark.rs index 50ee5370..53ab3fea 100644 --- a/cli/benches/benchmark.rs +++ b/cli/benches/benchmark.rs @@ -2,8 +2,8 @@ use lazy_static::lazy_static; use std::collections::BTreeMap; use std::path::{Path, PathBuf}; use std::time::Instant; -use std::{env, fs, usize}; -use tree_sitter::{Language, Parser}; +use std::{env, fs, str, usize}; +use tree_sitter::{Language, Parser, Query}; use tree_sitter_cli::error::Error; use tree_sitter_cli::loader::Loader; @@ -18,26 +18,33 @@ lazy_static! { .map(|s| usize::from_str_radix(&s, 10).unwrap()) .unwrap_or(5); static ref TEST_LOADER: Loader = Loader::new(SCRATCH_DIR.clone()); - static ref EXAMPLE_PATHS_BY_LANGUAGE_DIR: BTreeMap> = { - fn process_dir(result: &mut BTreeMap>, dir: &Path) { + static ref EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR: BTreeMap, Vec)> = { + fn process_dir(result: &mut BTreeMap, Vec)>, dir: &Path) { if dir.join("grammar.js").exists() { let relative_path = dir.strip_prefix(GRAMMARS_DIR.as_path()).unwrap(); + let (example_paths, query_paths) = + result.entry(relative_path.to_owned()).or_default(); + if let Ok(example_files) = fs::read_dir(&dir.join("examples")) { - result.insert( - relative_path.to_owned(), - example_files - .filter_map(|p| { - let p = p.unwrap().path(); - if p.is_file() { - Some(p) - } else { - None - } - }) - .collect(), - ); - } else { - result.insert(relative_path.to_owned(), Vec::new()); + example_paths.extend(example_files.filter_map(|p| { + let p = p.unwrap().path(); + if p.is_file() { + Some(p.to_owned()) + } else { + None + } + })); + } + + if let Ok(query_files) = fs::read_dir(&dir.join("queries")) { + query_paths.extend(query_files.filter_map(|p| { + let p = p.unwrap().path(); + if p.is_file() { + Some(p.to_owned()) + } else { + None + } + })); } } else { for entry in fs::read_dir(&dir).unwrap() { @@ -56,20 +63,25 @@ lazy_static! { } fn main() { - let mut parser = Parser::new(); - let max_path_length = EXAMPLE_PATHS_BY_LANGUAGE_DIR - .iter() - .flat_map(|(_, paths)| paths.iter()) - .map(|p| p.file_name().unwrap().to_str().unwrap().chars().count()) + let max_path_length = EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR + .values() + .flat_map(|(e, q)| { + e.iter() + .chain(q.iter()) + .map(|s| s.file_name().unwrap().to_str().unwrap().len()) + }) .max() - .unwrap(); - - let mut all_normal_speeds = Vec::new(); - let mut all_error_speeds = Vec::new(); + .unwrap_or(0); eprintln!("Benchmarking with {} repetitions", *REPETITION_COUNT); - for (language_path, example_paths) in EXAMPLE_PATHS_BY_LANGUAGE_DIR.iter() { + let mut parser = Parser::new(); + let mut all_normal_speeds = Vec::new(); + let mut all_error_speeds = Vec::new(); + + for (language_path, (example_paths, query_paths)) in + EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR.iter() + { let language_name = language_path.file_name().unwrap().to_str().unwrap(); if let Some(filter) = LANGUAGE_FILTER.as_ref() { @@ -79,9 +91,24 @@ fn main() { } eprintln!("\nLanguage: {}", language_name); - parser.set_language(get_language(language_path)).unwrap(); + let language = get_language(language_path); + parser.set_language(language).unwrap(); - eprintln!(" Normal examples:"); + eprintln!(" Constructing Queries"); + for path in query_paths { + if let Some(filter) = EXAMPLE_FILTER.as_ref() { + if !path.to_str().unwrap().contains(filter.as_str()) { + continue; + } + } + + parse(&path, max_path_length, |source| { + Query::new(language, str::from_utf8(source).unwrap()) + .expect("Failed to parse query"); + }); + } + + eprintln!(" Parsing Valid Code:"); let mut normal_speeds = Vec::new(); for example_path in example_paths { if let Some(filter) = EXAMPLE_FILTER.as_ref() { @@ -90,12 +117,16 @@ fn main() { } } - normal_speeds.push(parse(&mut parser, example_path, max_path_length)); + normal_speeds.push(parse(example_path, max_path_length, |code| { + parser.parse(code, None).expect("Failed to parse"); + })); } - eprintln!(" Error examples (mismatched languages):"); + eprintln!(" Parsing Invalid Code (mismatched languages):"); let mut error_speeds = Vec::new(); - for (other_language_path, example_paths) in EXAMPLE_PATHS_BY_LANGUAGE_DIR.iter() { + for (other_language_path, (example_paths, _)) in + EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR.iter() + { if other_language_path != language_path { for example_path in example_paths { if let Some(filter) = EXAMPLE_FILTER.as_ref() { @@ -104,7 +135,9 @@ fn main() { } } - error_speeds.push(parse(&mut parser, example_path, max_path_length)); + error_speeds.push(parse(example_path, max_path_length, |code| { + parser.parse(code, None).expect("Failed to parse"); + })); } } } @@ -123,7 +156,7 @@ fn main() { all_error_speeds.extend(error_speeds); } - eprintln!("\nOverall"); + eprintln!("\n Overall"); if let Some((average_normal, worst_normal)) = aggregate(&all_normal_speeds) { eprintln!(" Average Speed (normal): {} bytes/ms", average_normal); eprintln!(" Worst Speed (normal): {} bytes/ms", worst_normal); @@ -151,28 +184,25 @@ fn aggregate(speeds: &Vec) -> Option<(usize, usize)> { Some((total / speeds.len(), max)) } -fn parse(parser: &mut Parser, example_path: &Path, max_path_length: usize) -> usize { +fn parse(path: &Path, max_path_length: usize, mut action: impl FnMut(&[u8])) -> usize { eprint!( " {:width$}\t", - example_path.file_name().unwrap().to_str().unwrap(), + path.file_name().unwrap().to_str().unwrap(), width = max_path_length ); - let source_code = fs::read(example_path) - .map_err(Error::wrap(|| format!("Failed to read {:?}", example_path))) + let source_code = fs::read(path) + .map_err(Error::wrap(|| format!("Failed to read {:?}", path))) .unwrap(); let time = Instant::now(); for _ in 0..*REPETITION_COUNT { - parser - .parse(&source_code, None) - .expect("Incompatible language version"); + action(&source_code); } let duration = time.elapsed() / (*REPETITION_COUNT as u32); - let duration_ms = - duration.as_secs() as f64 * 1000.0 + duration.subsec_nanos() as f64 / 1000000.0; - let speed = (source_code.len() as f64 / duration_ms) as usize; + let duration_ms = duration.as_millis(); + let speed = source_code.len() as u128 / (duration_ms + 1); eprintln!("time {} ms\tspeed {} bytes/ms", duration_ms as usize, speed); - speed + speed as usize } fn get_language(path: &Path) -> Language { diff --git a/cli/src/error.rs b/cli/src/error.rs index d583d1b9..075de3a6 100644 --- a/cli/src/error.rs +++ b/cli/src/error.rs @@ -70,6 +70,10 @@ impl<'a> From for Error { "Query error on line {}. Invalid syntax:\n{}", row, l )), + QueryError::Structure(row, l) => Error::new(format!( + "Query error on line {}. Impossible pattern:\n{}", + row, l + )), QueryError::Predicate(p) => Error::new(format!("Query error: {}", p)), } } diff --git a/cli/src/generate/render.rs b/cli/src/generate/render.rs index 270bd00d..5b016cb6 100644 --- a/cli/src/generate/render.rs +++ b/cli/src/generate/render.rs @@ -7,7 +7,7 @@ use super::tables::{ }; use core::ops::Range; use std::cmp; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{HashMap, HashSet}; use std::fmt::Write; use std::mem::swap; @@ -69,7 +69,8 @@ struct Generator { symbol_order: HashMap, symbol_ids: HashMap, alias_ids: HashMap, - alias_map: BTreeMap>, + unique_aliases: Vec, + symbol_map: HashMap, field_names: Vec, next_abi: bool, } @@ -95,11 +96,7 @@ impl Generator { self.add_stats(); self.add_symbol_enum(); self.add_symbol_names_list(); - - if self.next_abi { - self.add_unique_symbol_map(); - } - + self.add_unique_symbol_map(); self.add_symbol_metadata_list(); if !self.field_names.is_empty() { @@ -112,6 +109,8 @@ impl Generator { self.add_alias_sequences(); } + self.add_non_terminal_alias_map(); + let mut main_lex_table = LexTable::default(); swap(&mut main_lex_table, &mut self.main_lex_table); self.add_lex_function("ts_lex", main_lex_table, true); @@ -163,13 +162,72 @@ impl Generator { format!("anon_alias_sym_{}", self.sanitize_identifier(&alias.value)) }; self.alias_ids.entry(alias.clone()).or_insert(alias_id); - self.alias_map - .entry(alias.clone()) - .or_insert(matching_symbol); } } } + self.unique_aliases = self + .alias_ids + .keys() + .filter(|alias| { + self.parse_table + .symbols + .iter() + .cloned() + .find(|symbol| { + let (name, kind) = self.metadata_for_symbol(*symbol); + name == alias.value && kind == alias.kind() + }) + .is_none() + }) + .cloned() + .collect(); + self.unique_aliases.sort_unstable(); + + self.symbol_map = self + .parse_table + .symbols + .iter() + .map(|symbol| { + let mut mapping = symbol; + + // There can be multiple symbols in the grammar that have the same name and kind, + // due to simple aliases. When that happens, ensure that they map to the same + // public-facing symbol. If one of the symbols is not aliased, choose that one + // to be the public-facing symbol. Otherwise, pick the symbol with the lowest + // numeric value. + if let Some(alias) = self.simple_aliases.get(symbol) { + let kind = alias.kind(); + for other_symbol in &self.parse_table.symbols { + if let Some(other_alias) = self.simple_aliases.get(other_symbol) { + if other_symbol < mapping && other_alias == alias { + mapping = other_symbol; + } + } else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) { + mapping = other_symbol; + break; + } + } + } + // Two anonymous tokens with different flags but the same string value + // should be represented with the same symbol in the public API. Examples: + // * "<" and token(prec(1, "<")) + // * "(" and token.immediate("(") + else if symbol.is_terminal() { + let metadata = self.metadata_for_symbol(*symbol); + for other_symbol in &self.parse_table.symbols { + let other_metadata = self.metadata_for_symbol(*other_symbol); + if other_metadata == metadata { + mapping = other_symbol; + break; + } + } + } + + (*symbol, *mapping) + }) + .collect(); + field_names.sort_unstable(); field_names.dedup(); self.field_names = field_names.into_iter().cloned().collect(); @@ -177,20 +235,16 @@ impl Generator { // If we are opting in to the new unstable language ABI, then use the concept of // "small parse states". Otherwise, use the same representation for all parse // states. - if self.next_abi { - let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2); - self.large_state_count = self - .parse_table - .states - .iter() - .enumerate() - .take_while(|(i, s)| { - *i <= 1 || s.terminal_entries.len() + s.nonterminal_entries.len() > threshold - }) - .count(); - } else { - self.large_state_count = self.parse_table.states.len(); - } + let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2); + self.large_state_count = self + .parse_table + .states + .iter() + .enumerate() + .take_while(|(i, s)| { + *i <= 1 || s.terminal_entries.len() + s.nonterminal_entries.len() > threshold + }) + .count(); } fn add_includes(&mut self) { @@ -256,21 +310,14 @@ impl Generator { "#define STATE_COUNT {}", self.parse_table.states.len() ); - - if self.next_abi { - add_line!(self, "#define LARGE_STATE_COUNT {}", self.large_state_count); - } + add_line!(self, "#define LARGE_STATE_COUNT {}", self.large_state_count); add_line!( self, "#define SYMBOL_COUNT {}", self.parse_table.symbols.len() ); - add_line!( - self, - "#define ALIAS_COUNT {}", - self.alias_map.iter().filter(|e| e.1.is_none()).count() - ); + add_line!(self, "#define ALIAS_COUNT {}", self.unique_aliases.len(),); add_line!(self, "#define TOKEN_COUNT {}", token_count); add_line!( self, @@ -298,11 +345,9 @@ impl Generator { i += 1; } } - for (alias, symbol) in &self.alias_map { - if symbol.is_none() { - add_line!(self, "{} = {},", self.alias_ids[&alias], i); - i += 1; - } + for alias in &self.unique_aliases { + add_line!(self, "{} = {},", self.alias_ids[&alias], i); + i += 1; } dedent!(self); add_line!(self, "}};"); @@ -321,15 +366,13 @@ impl Generator { ); add_line!(self, "[{}] = \"{}\",", self.symbol_ids[&symbol], name); } - for (alias, symbol) in &self.alias_map { - if symbol.is_none() { - add_line!( - self, - "[{}] = \"{}\",", - self.alias_ids[&alias], - self.sanitize_string(&alias.value) - ); - } + for alias in &self.unique_aliases { + add_line!( + self, + "[{}] = \"{}\",", + self.alias_ids[&alias], + self.sanitize_string(&alias.value) + ); } dedent!(self); add_line!(self, "}};"); @@ -340,58 +383,21 @@ impl Generator { add_line!(self, "static TSSymbol ts_symbol_map[] = {{"); indent!(self); for symbol in &self.parse_table.symbols { - let mut mapping = symbol; - - // There can be multiple symbols in the grammar that have the same name and kind, - // due to simple aliases. When that happens, ensure that they map to the same - // public-facing symbol. If one of the symbols is not aliased, choose that one - // to be the public-facing symbol. Otherwise, pick the symbol with the lowest - // numeric value. - if let Some(alias) = self.simple_aliases.get(symbol) { - let kind = alias.kind(); - for other_symbol in &self.parse_table.symbols { - if let Some(other_alias) = self.simple_aliases.get(other_symbol) { - if other_symbol < mapping && other_alias == alias { - mapping = other_symbol; - } - } else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) { - mapping = other_symbol; - break; - } - } - } - // Two anonymous tokens with different flags but the same string value - // should be represented with the same symbol in the public API. Examples: - // * "<" and token(prec(1, "<")) - // * "(" and token.immediate("(") - else if symbol.is_terminal() { - let metadata = self.metadata_for_symbol(*symbol); - for other_symbol in &self.parse_table.symbols { - let other_metadata = self.metadata_for_symbol(*other_symbol); - if other_metadata == metadata { - mapping = other_symbol; - break; - } - } - } - add_line!( self, "[{}] = {},", - self.symbol_ids[&symbol], - self.symbol_ids[mapping], + self.symbol_ids[symbol], + self.symbol_ids[&self.symbol_map[symbol]], ); } - for (alias, symbol) in &self.alias_map { - if symbol.is_none() { - add_line!( - self, - "[{}] = {},", - self.alias_ids[&alias], - self.alias_ids[&alias], - ); - } + for alias in &self.unique_aliases { + add_line!( + self, + "[{}] = {},", + self.alias_ids[&alias], + self.alias_ids[&alias], + ); } dedent!(self); @@ -462,15 +468,13 @@ impl Generator { dedent!(self); add_line!(self, "}},"); } - for (alias, matching_symbol) in &self.alias_map { - if matching_symbol.is_none() { - add_line!(self, "[{}] = {{", self.alias_ids[&alias]); - indent!(self); - add_line!(self, ".visible = true,"); - add_line!(self, ".named = {},", alias.is_named); - dedent!(self); - add_line!(self, "}},"); - } + for alias in &self.unique_aliases { + add_line!(self, "[{}] = {{", self.alias_ids[&alias]); + indent!(self); + add_line!(self, ".visible = true,"); + add_line!(self, ".named = {},", alias.is_named); + dedent!(self); + add_line!(self, "}},"); } dedent!(self); add_line!(self, "}};"); @@ -509,6 +513,50 @@ impl Generator { add_line!(self, ""); } + fn add_non_terminal_alias_map(&mut self) { + let mut aliases_by_symbol = HashMap::new(); + for variable in &self.syntax_grammar.variables { + for production in &variable.productions { + for step in &production.steps { + if let Some(alias) = &step.alias { + if step.symbol.is_non_terminal() + && !self.simple_aliases.contains_key(&step.symbol) + { + if self.symbol_ids.contains_key(&step.symbol) { + let alias_ids = + aliases_by_symbol.entry(step.symbol).or_insert(Vec::new()); + if let Err(i) = alias_ids.binary_search(&alias) { + alias_ids.insert(i, alias); + } + } + } + } + } + } + } + + let mut aliases_by_symbol = aliases_by_symbol.iter().collect::>(); + aliases_by_symbol.sort_unstable_by_key(|e| e.0); + + add_line!(self, "static uint16_t ts_non_terminal_alias_map[] = {{"); + indent!(self); + for (symbol, aliases) in aliases_by_symbol { + let symbol_id = &self.symbol_ids[symbol]; + let public_symbol_id = &self.symbol_ids[&self.symbol_map[&symbol]]; + add_line!(self, "{}, {},", symbol_id, 1 + aliases.len()); + indent!(self); + add_line!(self, "{},", public_symbol_id); + for alias in aliases { + add_line!(self, "{},", &self.alias_ids[&alias]); + } + dedent!(self); + } + add_line!(self, "0,"); + dedent!(self); + add_line!(self, "}};"); + add_line!(self, ""); + } + fn add_field_sequences(&mut self) { let mut flat_field_maps = vec![]; let mut next_flat_field_map_index = 0; @@ -689,17 +737,12 @@ impl Generator { name ); indent!(self); + add_line!(self, "START_LEXER();"); - - if self.next_abi { - add_line!(self, "eof = lexer->eof(lexer);"); - } else { - add_line!(self, "eof = lookahead == 0;"); - } - + add_line!(self, "eof = lexer->eof(lexer);"); add_line!(self, "switch (state) {{"); - indent!(self); + indent!(self); for (i, state) in lex_table.states.into_iter().enumerate() { add_line!(self, "case {}:", i); indent!(self); @@ -714,6 +757,7 @@ impl Generator { dedent!(self); add_line!(self, "}}"); + dedent!(self); add_line!(self, "}}"); add_line!(self, ""); @@ -967,12 +1011,7 @@ impl Generator { add_line!( self, - "static uint16_t ts_parse_table[{}][SYMBOL_COUNT] = {{", - if self.next_abi { - "LARGE_STATE_COUNT" - } else { - "STATE_COUNT" - } + "static uint16_t ts_parse_table[LARGE_STATE_COUNT][SYMBOL_COUNT] = {{", ); indent!(self); @@ -1224,9 +1263,11 @@ impl Generator { add_line!(self, ".symbol_count = SYMBOL_COUNT,"); add_line!(self, ".alias_count = ALIAS_COUNT,"); add_line!(self, ".token_count = TOKEN_COUNT,"); + add_line!(self, ".large_state_count = LARGE_STATE_COUNT,"); if self.next_abi { - add_line!(self, ".large_state_count = LARGE_STATE_COUNT,"); + add_line!(self, ".alias_map = ts_non_terminal_alias_map,"); + add_line!(self, ".state_count = STATE_COUNT,"); } add_line!(self, ".symbol_metadata = ts_symbol_metadata,"); @@ -1249,10 +1290,7 @@ impl Generator { add_line!(self, ".parse_actions = ts_parse_actions,"); add_line!(self, ".lex_modes = ts_lex_modes,"); add_line!(self, ".symbol_names = ts_symbol_names,"); - - if self.next_abi { - add_line!(self, ".public_symbol_map = ts_symbol_map,"); - } + add_line!(self, ".public_symbol_map = ts_symbol_map,"); if !self.parse_table.production_infos.is_empty() { add_line!( @@ -1539,7 +1577,8 @@ pub(crate) fn render_c_code( symbol_ids: HashMap::new(), symbol_order: HashMap::new(), alias_ids: HashMap::new(), - alias_map: BTreeMap::new(), + symbol_map: HashMap::new(), + unique_aliases: Vec::new(), field_names: Vec::new(), next_abi, } diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index a377ca51..b857467b 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1,18 +1,28 @@ use super::helpers::allocations; use super::helpers::fixtures::get_language; +use lazy_static::lazy_static; +use std::env; use std::fmt::Write; use tree_sitter::{ Language, Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch, QueryPredicate, QueryPredicateArg, QueryProperty, }; +lazy_static! { + static ref EXAMPLE_FILTER: Option = env::var("TREE_SITTER_TEST_EXAMPLE_FILTER").ok(); +} + #[test] fn test_query_errors_on_invalid_syntax() { allocations::record(|| { let language = get_language("javascript"); assert!(Query::new(language, "(if_statement)").is_ok()); - assert!(Query::new(language, "(if_statement condition:(identifier))").is_ok()); + assert!(Query::new( + language, + "(if_statement condition:(parenthesized_expression (identifier)))" + ) + .is_ok()); // Mismatched parens assert_eq!( @@ -180,6 +190,110 @@ fn test_query_errors_on_invalid_conditions() { }); } +#[test] +fn test_query_errors_on_impossible_patterns() { + let js_lang = get_language("javascript"); + let rb_lang = get_language("ruby"); + + allocations::record(|| { + assert_eq!( + Query::new( + js_lang, + "(binary_expression left: (identifier) left: (identifier))" + ), + Err(QueryError::Structure( + 1, + [ + "(binary_expression left: (identifier) left: (identifier))", + " ^" + ] + .join("\n"), + )) + ); + + Query::new( + js_lang, + "(function_declaration name: (identifier) (statement_block))", + ) + .unwrap(); + assert_eq!( + Query::new(js_lang, "(function_declaration name: (statement_block))"), + Err(QueryError::Structure( + 1, + [ + "(function_declaration name: (statement_block))", + " ^", + ] + .join("\n") + )) + ); + + Query::new(rb_lang, "(call receiver:(call))").unwrap(); + assert_eq!( + Query::new(rb_lang, "(call receiver:(binary))"), + Err(QueryError::Structure( + 1, + [ + "(call receiver:(binary))", // + " ^", + ] + .join("\n") + )) + ); + + Query::new( + js_lang, + "[ + (function (identifier)) + (function_declaration (identifier)) + (generator_function_declaration (identifier)) + ]", + ) + .unwrap(); + assert_eq!( + Query::new( + js_lang, + "[ + (function (identifier)) + (function_declaration (object)) + (generator_function_declaration (identifier)) + ]", + ), + Err(QueryError::Structure( + 3, + [ + " (function_declaration (object))", // + " ^", + ] + .join("\n") + )) + ); + + assert_eq!( + Query::new(js_lang, "(identifier (identifier))",), + Err(QueryError::Structure( + 1, + [ + "(identifier (identifier))", // + " ^", + ] + .join("\n") + )) + ); + assert_eq!( + Query::new(js_lang, "(true (true))",), + Err(QueryError::Structure( + 1, + [ + "(true (true))", // + " ^", + ] + .join("\n") + )) + ); + }); +} + #[test] fn test_query_matches_with_simple_pattern() { allocations::record(|| { @@ -1907,6 +2021,54 @@ fn test_query_captures_with_too_many_nested_results() { }); } +#[test] +fn test_query_captures_with_definite_pattern_containing_many_nested_matches() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + (array + "[" @l-bracket + "]" @r-bracket) + + "." @dot + "#, + ) + .unwrap(); + + // The '[' node must be returned before all of the '.' nodes, + // even though its pattern does not finish until the ']' node + // at the end of the document. But because the '[' is definite, + // it can be returned before the pattern finishes matching. + let source = " + [ + a.b.c.d.e.f.g.h.i, + a.b.c.d.e.f.g.h.i, + a.b.c.d.e.f.g.h.i, + a.b.c.d.e.f.g.h.i, + a.b.c.d.e.f.g.h.i, + ] + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + + let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + assert_eq!( + collect_captures(captures, &query, source), + [("l-bracket", "[")] + .iter() + .chain([("dot", "."); 40].iter()) + .chain([("r-bracket", "]")].iter()) + .cloned() + .collect::>(), + ); + }); +} + #[test] fn test_query_captures_ordered_by_both_start_and_end_positions() { allocations::record(|| { @@ -2051,7 +2213,7 @@ fn test_query_start_byte_for_pattern() { let patterns_3 = " ((identifier) @b (#match? @b i)) (function_declaration name: (identifier) @c) - (method_definition name: (identifier) @d) + (method_definition name: (property_identifier) @d) " .trim_start(); @@ -2078,10 +2240,10 @@ fn test_query_capture_names() { language, r#" (if_statement - condition: (binary_expression + condition: (parenthesized_expression (binary_expression left: _ @left-operand operator: "||" - right: _ @right-operand) + right: _ @right-operand)) consequence: (statement_block) @body) (while_statement @@ -2213,6 +2375,273 @@ fn test_query_alternative_predicate_prefix() { }); } +#[test] +fn test_query_step_is_definite() { + struct Row { + language: Language, + description: &'static str, + pattern: &'static str, + results_by_substring: &'static [(&'static str, bool)], + } + + let rows = &[ + Row { + description: "no definite steps", + language: get_language("python"), + pattern: r#"(expression_statement (string))"#, + results_by_substring: &[("expression_statement", false), ("string", false)], + }, + Row { + description: "all definite steps", + language: get_language("javascript"), + pattern: r#"(object "{" "}")"#, + results_by_substring: &[("object", false), ("{", true), ("}", true)], + }, + Row { + description: "an indefinite step that is optional", + language: get_language("javascript"), + pattern: r#"(object "{" (identifier)? @foo "}")"#, + results_by_substring: &[ + ("object", false), + ("{", true), + ("(identifier)?", false), + ("}", true), + ], + }, + Row { + description: "multiple indefinite steps that are optional", + language: get_language("javascript"), + pattern: r#"(object "{" (identifier)? @id1 ("," (identifier) @id2)? "}")"#, + results_by_substring: &[ + ("object", false), + ("{", true), + ("(identifier)? @id1", false), + ("\",\"", false), + ("}", true), + ], + }, + Row { + description: "definite step after indefinite step", + language: get_language("javascript"), + pattern: r#"(pair (property_identifier) ":")"#, + results_by_substring: &[("pair", false), ("property_identifier", false), (":", true)], + }, + Row { + description: "indefinite step in between two definite steps", + language: get_language("javascript"), + pattern: r#"(ternary_expression + condition: (_) + "?" + consequence: (call_expression) + ":" + alternative: (_))"#, + results_by_substring: &[ + ("condition:", false), + ("\"?\"", false), + ("consequence:", false), + ("\":\"", true), + ("alternative:", true), + ], + }, + Row { + description: "one definite step after a repetition", + language: get_language("javascript"), + pattern: r#"(object "{" (_) "}")"#, + results_by_substring: &[("object", false), ("{", false), ("(_)", false), ("}", true)], + }, + Row { + description: "definite steps after multiple repetitions", + language: get_language("json"), + pattern: r#"(object "{" (pair) "," (pair) "," (_) "}")"#, + results_by_substring: &[ + ("object", false), + ("{", false), + ("(pair) \",\" (pair)", false), + ("(pair) \",\" (_)", false), + ("\",\" (_)", false), + ("(_)", true), + ("}", true), + ], + }, + Row { + description: "a definite with a field", + language: get_language("javascript"), + pattern: r#"(binary_expression left: (identifier) right: (_))"#, + results_by_substring: &[ + ("binary_expression", false), + ("(identifier)", false), + ("(_)", true), + ], + }, + Row { + description: "multiple definite steps with fields", + language: get_language("javascript"), + pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#, + results_by_substring: &[ + ("function_declaration", false), + ("identifier", true), + ("statement_block", true), + ], + }, + Row { + description: "nesting, one definite step", + language: get_language("javascript"), + pattern: r#" + (function_declaration + name: (identifier) + body: (statement_block "{" (expression_statement) "}"))"#, + results_by_substring: &[ + ("function_declaration", false), + ("identifier", false), + ("statement_block", false), + ("{", false), + ("expression_statement", false), + ("}", true), + ], + }, + Row { + description: "definite step after some deeply nested hidden nodes", + language: get_language("ruby"), + pattern: r#" + (singleton_class + value: (constant) + "end") + "#, + results_by_substring: &[ + ("singleton_class", false), + ("constant", false), + ("end", true), + ], + }, + Row { + description: "nesting, no definite steps", + language: get_language("javascript"), + pattern: r#" + (call_expression + function: (member_expression + property: (property_identifier) @template-tag) + arguments: (template_string)) @template-call + "#, + results_by_substring: &[("property_identifier", false), ("template_string", false)], + }, + Row { + description: "a definite step after a nested node", + language: get_language("javascript"), + pattern: r#" + (subscript_expression + object: (member_expression + object: (identifier) @obj + property: (property_identifier) @prop) + "[") + "#, + results_by_substring: &[ + ("identifier", false), + ("property_identifier", true), + ("[", true), + ], + }, + Row { + description: "a step that is indefinite due to a predicate", + language: get_language("javascript"), + pattern: r#" + (subscript_expression + object: (member_expression + object: (identifier) @obj + property: (property_identifier) @prop) + "[" + (#match? @prop "foo")) + "#, + results_by_substring: &[ + ("identifier", false), + ("property_identifier", false), + ("[", true), + ], + }, + Row { + description: "alternation where one branch has definite steps", + language: get_language("javascript"), + pattern: r#" + [ + (unary_expression (identifier)) + (call_expression + function: (_) + arguments: (_)) + (binary_expression right:(call_expression)) + ] + "#, + results_by_substring: &[ + ("identifier", false), + ("right:", false), + ("function:", true), + ("arguments:", true), + ], + }, + Row { + description: "aliased parent node", + language: get_language("ruby"), + pattern: r#" + (method_parameters "(" (identifier) @id")") + "#, + results_by_substring: &[("\"(\"", false), ("(identifier)", false), ("\")\"", true)], + }, + Row { + description: "long, but not too long to analyze", + language: get_language("javascript"), + pattern: r#" + (object "{" (pair) (pair) (pair) (pair) "}") + "#, + results_by_substring: &[ + ("\"{\"", false), + ("(pair)", false), + ("(pair) \"}\"", false), + ("\"}\"", true), + ], + }, + Row { + description: "too long to analyze", + language: get_language("javascript"), + pattern: r#" + (object "{" (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) "}") + "#, + results_by_substring: &[ + ("\"{\"", false), + ("(pair)", false), + ("(pair) \"}\"", false), + ("\"}\"", false), + ], + }, + ]; + + allocations::record(|| { + eprintln!(""); + + for row in rows.iter() { + if let Some(filter) = EXAMPLE_FILTER.as_ref() { + if !row.description.contains(filter.as_str()) { + continue; + } + } + eprintln!(" query example: {:?}", row.description); + let query = Query::new(row.language, row.pattern).unwrap(); + for (substring, is_definite) in row.results_by_substring { + let offset = row.pattern.find(substring).unwrap(); + assert_eq!( + query.step_is_definite(offset), + *is_definite, + "Description: {}, Pattern: {:?}, substring: {:?}, expected is_definite to be {}", + row.description, + row.pattern + .split_ascii_whitespace() + .collect::>() + .join(" "), + substring, + is_definite, + ) + } + } + }); +} + fn assert_query_matches( language: Language, query: &Query, diff --git a/docs/assets/js/playground.js b/docs/assets/js/playground.js index 686be90d..137bb352 100644 --- a/docs/assets/js/playground.js +++ b/docs/assets/js/playground.js @@ -277,7 +277,7 @@ let tree; const startPosition = queryEditor.posFromIndex(error.index); const endPosition = { line: startPosition.line, - ch: startPosition.ch + (error.length || 1) + ch: startPosition.ch + (error.length || Infinity) }; if (error.index === queryText.length) { diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index cba87fa3..f28d3461 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -132,6 +132,7 @@ pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1; pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2; pub const TSQueryError_TSQueryErrorField: TSQueryError = 3; pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4; +pub const TSQueryError_TSQueryErrorStructure: TSQueryError = 5; pub type TSQueryError = u32; extern "C" { #[doc = " Create a new parser."] @@ -172,9 +173,9 @@ extern "C" { #[doc = " the given ranges must be ordered from earliest to latest in the document,"] #[doc = " and they must not overlap. That is, the following must hold for all"] #[doc = " `i` < `length - 1`:"] - #[doc = " ```text"] + #[doc = ""] #[doc = " ranges[i].end_byte <= ranges[i + 1].start_byte"] - #[doc = " ```"] + #[doc = ""] #[doc = " If this requirement is not satisfied, the operation will fail, the ranges"] #[doc = " will not be assigned, and this function will return `false`. On success,"] #[doc = " this function returns `true`"] @@ -649,6 +650,9 @@ extern "C" { length: *mut u32, ) -> *const TSQueryPredicateStep; } +extern "C" { + pub fn ts_query_step_is_definite(self_: *const TSQuery, byte_offset: u32) -> bool; +} extern "C" { #[doc = " Get the name and length of one of the query\'s captures, or one of the"] #[doc = " query\'s string literals. Each capture and string is associated with a"] @@ -800,5 +804,5 @@ extern "C" { pub fn ts_language_version(arg1: *const TSLanguage) -> u32; } -pub const TREE_SITTER_LANGUAGE_VERSION: usize = 11; +pub const TREE_SITTER_LANGUAGE_VERSION: usize = 12; pub const TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION: usize = 9; diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index ec7cd791..ea5893b4 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -163,6 +163,7 @@ pub enum QueryError { Field(usize, String), Capture(usize, String), Predicate(String), + Structure(usize, String), } #[derive(Debug)] @@ -1175,27 +1176,42 @@ impl Query { } }); - let message = if let Some(line) = line_containing_error { - line.to_string() + "\n" + &" ".repeat(offset - line_start) + "^" - } else { - "Unexpected EOF".to_string() - }; - - // if line_containing_error - return if error_type != ffi::TSQueryError_TSQueryErrorSyntax { - let suffix = source.split_at(offset).1; - let end_offset = suffix - .find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-') - .unwrap_or(source.len()); - let name = suffix.split_at(end_offset).0.to_string(); - match error_type { - ffi::TSQueryError_TSQueryErrorNodeType => Err(QueryError::NodeType(row, name)), - ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(row, name)), - ffi::TSQueryError_TSQueryErrorCapture => Err(QueryError::Capture(row, name)), - _ => Err(QueryError::Syntax(row, message)), + return match error_type { + // Error types that report names + ffi::TSQueryError_TSQueryErrorNodeType + | ffi::TSQueryError_TSQueryErrorField + | ffi::TSQueryError_TSQueryErrorCapture => { + let suffix = source.split_at(offset).1; + let end_offset = suffix + .find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-') + .unwrap_or(source.len()); + let name = suffix.split_at(end_offset).0.to_string(); + match error_type { + ffi::TSQueryError_TSQueryErrorNodeType => { + Err(QueryError::NodeType(row, name)) + } + ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(row, name)), + ffi::TSQueryError_TSQueryErrorCapture => { + Err(QueryError::Capture(row, name)) + } + _ => unreachable!(), + } + } + + // Error types that report positions + _ => { + let message = if let Some(line) = line_containing_error { + line.to_string() + "\n" + &" ".repeat(offset - line_start) + "^" + } else { + "Unexpected EOF".to_string() + }; + match error_type { + ffi::TSQueryError_TSQueryErrorStructure => { + Err(QueryError::Structure(row, message)) + } + _ => Err(QueryError::Syntax(row, message)), + } } - } else { - Err(QueryError::Syntax(row, message)) }; } @@ -1451,6 +1467,14 @@ impl Query { unsafe { ffi::ts_query_disable_pattern(self.ptr.as_ptr(), index as u32) } } + /// Check if a given step in a query is 'definite'. + /// + /// A query step is 'definite' if its parent pattern will be guaranteed to match + /// successfully once it reaches the step. + pub fn step_is_definite(&self, byte_offset: usize) -> bool { + unsafe { ffi::ts_query_step_is_definite(self.ptr.as_ptr(), byte_offset as u32) } + } + fn parse_property( function_name: &str, capture_names: &[String], diff --git a/lib/binding_web/binding.js b/lib/binding_web/binding.js index 3a193ef9..f731e8f8 100644 --- a/lib/binding_web/binding.js +++ b/lib/binding_web/binding.js @@ -667,8 +667,8 @@ class Language { const errorId = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); const errorByte = getValue(TRANSFER_BUFFER, 'i32'); const errorIndex = UTF8ToString(sourceAddress, errorByte).length; - const suffix = source.substr(errorIndex, 100); - const word = suffix.match(QUERY_WORD_REGEX)[0]; + const suffix = source.substr(errorIndex, 100).split('\n')[0]; + let word = suffix.match(QUERY_WORD_REGEX)[0]; let error; switch (errorId) { case 2: @@ -680,8 +680,13 @@ class Language { case 4: error = new RangeError(`Bad capture name @${word}`); break; + case 5: + error = new TypeError(`Bad pattern structure at offset ${errorIndex}: '${suffix}'...`); + word = ""; + break; default: error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`); + word = ""; break; } error.index = errorIndex; diff --git a/lib/binding_web/exports.json b/lib/binding_web/exports.json index 2c638249..72105158 100644 --- a/lib/binding_web/exports.json +++ b/lib/binding_web/exports.json @@ -15,7 +15,6 @@ "__ZNSt3__212basic_stringIwNS_11char_traitsIwEENS_9allocatorIwEEED2Ev", "__ZdlPv", "__Znwm", - "___assert_fail", "_abort", "_iswalnum", "_iswalpha", @@ -73,8 +72,6 @@ "_ts_query_capture_count", "_ts_query_capture_name_for_id", "_ts_query_captures_wasm", - "_ts_query_context_delete", - "_ts_query_context_new", "_ts_query_delete", "_ts_query_matches_wasm", "_ts_query_new", diff --git a/lib/binding_web/test/query-test.js b/lib/binding_web/test/query-test.js index 9d1e24e1..23663e9a 100644 --- a/lib/binding_web/test/query-test.js +++ b/lib/binding_web/test/query-test.js @@ -30,6 +30,9 @@ describe("Query", () => { assert.throws(() => { JavaScript.query("(function_declaration non_existent:(identifier))"); }, "Bad field name 'non_existent'"); + assert.throws(() => { + JavaScript.query("(function_declaration name:(statement_block))"); + }, "Bad pattern structure at offset 22: 'name:(statement_block))'"); }); it("throws an error on invalid predicates", () => { diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 9d832e6e..b85380d1 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -21,7 +21,7 @@ extern "C" { * The Tree-sitter library is generally backwards-compatible with languages * generated using older CLI versions, but is not forwards-compatible. */ -#define TREE_SITTER_LANGUAGE_VERSION 11 +#define TREE_SITTER_LANGUAGE_VERSION 12 /** * The earliest ABI version that is supported by the current version of the @@ -130,6 +130,7 @@ typedef enum { TSQueryErrorNodeType, TSQueryErrorField, TSQueryErrorCapture, + TSQueryErrorStructure, } TSQueryError; /********************/ @@ -718,6 +719,11 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t *length ); +bool ts_query_step_is_definite( + const TSQuery *self, + uint32_t byte_offset +); + /** * Get the name and length of one of the query's captures, or one of the * query's string literals. Each capture and string is associated with a diff --git a/lib/include/tree_sitter/parser.h b/lib/include/tree_sitter/parser.h index 11bf4fc4..84096132 100644 --- a/lib/include/tree_sitter/parser.h +++ b/lib/include/tree_sitter/parser.h @@ -119,6 +119,8 @@ struct TSLanguage { const uint16_t *small_parse_table; const uint32_t *small_parse_table_map; const TSSymbol *public_symbol_map; + const uint16_t *alias_map; + uint32_t state_count; }; /* diff --git a/lib/src/array.h b/lib/src/array.h index 26cb8448..de8c8cb3 100644 --- a/lib/src/array.h +++ b/lib/src/array.h @@ -12,9 +12,9 @@ extern "C" { #include #include "./alloc.h" -#define Array(T) \ - struct { \ - T *contents; \ +#define Array(T) \ + struct { \ + T *contents; \ uint32_t size; \ uint32_t capacity; \ } @@ -37,15 +37,15 @@ extern "C" { #define array_reserve(self, new_capacity) \ array__reserve((VoidArray *)(self), array__elem_size(self), new_capacity) -#define array_erase(self, index) \ - array__erase((VoidArray *)(self), array__elem_size(self), index) - +// Free any memory allocated for this array. #define array_delete(self) array__delete((VoidArray *)self) #define array_push(self, element) \ (array__grow((VoidArray *)(self), 1, array__elem_size(self)), \ (self)->contents[(self)->size++] = (element)) +// Increase the array's size by a given number of elements, reallocating +// if necessary. New elements are zero-initialized. #define array_grow_by(self, count) \ (array__grow((VoidArray *)(self), count, array__elem_size(self)), \ memset((self)->contents + (self)->size, 0, (count) * array__elem_size(self)), \ @@ -54,18 +54,64 @@ extern "C" { #define array_push_all(self, other) \ array_splice((self), (self)->size, 0, (other)->size, (other)->contents) +// Remove `old_count` elements from the array starting at the given `index`. At +// the same index, insert `new_count` new elements, reading their values from the +// `new_contents` pointer. #define array_splice(self, index, old_count, new_count, new_contents) \ array__splice((VoidArray *)(self), array__elem_size(self), index, old_count, \ new_count, new_contents) +// Insert one `element` into the array at the given `index`. #define array_insert(self, index, element) \ array__splice((VoidArray *)(self), array__elem_size(self), index, 0, 1, &element) +// Remove one `element` from the array at the given `index`. +#define array_erase(self, index) \ + array__erase((VoidArray *)(self), array__elem_size(self), index) + #define array_pop(self) ((self)->contents[--(self)->size]) #define array_assign(self, other) \ array__assign((VoidArray *)(self), (const VoidArray *)(other), array__elem_size(self)) +// Search a sorted array for a given `needle` value, using the given `compare` +// callback to determine the order. +// +// If an existing element is found to be equal to `needle`, then the `index` +// out-parameter is set to the existing value's index, and the `exists` +// out-parameter is set to true. Otherwise, `index` is set to an index where +// `needle` should be inserted in order to preserve the sorting, and `exists` +// is set to false. +#define array_search_sorted_with(self, compare, needle, index, exists) \ + array__search_sorted(self, 0, compare, , needle, index, exists) + +// Search a sorted array for a given `needle` value, using integer comparisons +// of a given struct field (specified with a leading dot) to determine the order. +// +// See also `array_search_sorted_with`. +#define array_search_sorted_by(self, field, needle, index, exists) \ + array__search_sorted(self, 0, _compare_int, field, needle, index, exists) + +// Insert a given `value` into a sorted array, using the given `compare` +// callback to determine the order. +#define array_insert_sorted_with(self, compare, value) \ + do { \ + unsigned index, exists; \ + array_search_sorted_with(self, compare, &(value), &index, &exists); \ + if (!exists) array_insert(self, index, value); \ + } while (0) + +// Insert a given `value` into a sorted array, using integer comparisons of +// a given struct field (specified with a leading dot) to determine the order. +// +// See also `array_search_sorted_by`. +#define array_insert_sorted_by(self, field, value) \ + do { \ + unsigned index, exists; \ + array_search_sorted_by(self, field, (value) field, &index, &exists); \ + if (!exists) array_insert(self, index, value); \ + } while (0) + // Private typedef Array(void) VoidArray; @@ -151,6 +197,30 @@ static inline void array__splice(VoidArray *self, size_t element_size, self->size += new_count - old_count; } +// A binary search routine, based on Rust's `std::slice::binary_search_by`. +#define array__search_sorted(self, start, compare, suffix, needle, index, exists) \ + do { \ + *(index) = start; \ + *(exists) = false; \ + uint32_t size = (self)->size - *(index); \ + if (size == 0) break; \ + int comparison; \ + while (size > 1) { \ + uint32_t half_size = size / 2; \ + uint32_t mid_index = *(index) + half_size; \ + comparison = compare(&((self)->contents[mid_index] suffix), (needle)); \ + if (comparison <= 0) *(index) = mid_index; \ + size -= half_size; \ + } \ + comparison = compare(&((self)->contents[*(index)] suffix), (needle)); \ + if (comparison == 0) *(exists) = true; \ + else if (comparison < 0) *(index) += 1; \ + } while (0) + +// Helper macro for the `_sorted_by` routines below. This takes the left (existing) +// parameter by reference in order to work with the generic sorting function above. +#define _compare_int(a, b) ((int)*(a) - (int)(b)) + #ifdef __cplusplus } #endif diff --git a/lib/src/get_changed_ranges.c b/lib/src/get_changed_ranges.c index 5bd1d814..b24f3149 100644 --- a/lib/src/get_changed_ranges.c +++ b/lib/src/get_changed_ranges.c @@ -146,17 +146,21 @@ static bool iterator_tree_is_visible(const Iterator *self) { if (ts_subtree_visible(*entry.subtree)) return true; if (self->cursor.stack.size > 1) { Subtree parent = *self->cursor.stack.contents[self->cursor.stack.size - 2].subtree; - const TSSymbol *alias_sequence = ts_language_alias_sequence( + return ts_language_alias_at( self->language, - parent.ptr->production_id - ); - return alias_sequence && alias_sequence[entry.structural_child_index] != 0; + parent.ptr->production_id, + entry.structural_child_index + ) != 0; } return false; } -static void iterator_get_visible_state(const Iterator *self, Subtree *tree, - TSSymbol *alias_symbol, uint32_t *start_byte) { +static void iterator_get_visible_state( + const Iterator *self, + Subtree *tree, + TSSymbol *alias_symbol, + uint32_t *start_byte +) { uint32_t i = self->cursor.stack.size - 1; if (self->in_padding) { @@ -169,13 +173,11 @@ static void iterator_get_visible_state(const Iterator *self, Subtree *tree, if (i > 0) { const Subtree *parent = self->cursor.stack.contents[i - 1].subtree; - const TSSymbol *alias_sequence = ts_language_alias_sequence( + *alias_symbol = ts_language_alias_at( self->language, - parent->ptr->production_id + parent->ptr->production_id, + entry.structural_child_index ); - if (alias_sequence) { - *alias_symbol = alias_sequence[entry.structural_child_index]; - } } if (ts_subtree_visible(*entry.subtree) || *alias_symbol) { diff --git a/lib/src/language.h b/lib/src/language.h index 341f0f85..e5c07aa2 100644 --- a/lib/src/language.h +++ b/lib/src/language.h @@ -12,6 +12,8 @@ extern "C" { #define TREE_SITTER_LANGUAGE_VERSION_WITH_FIELDS 10 #define TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING 11 #define TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES 11 +#define TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT 12 +#define TREE_SITTER_LANGUAGE_VERSION_WITH_ALIAS_MAP 12 typedef struct { const TSParseAction *actions; @@ -19,6 +21,22 @@ typedef struct { bool is_reusable; } TableEntry; +typedef struct { + const TSLanguage *language; + const uint16_t *data; + const uint16_t *group_end; + TSStateId state; + uint16_t table_value; + uint16_t section_index; + uint16_t group_count; + bool is_small_state; + + const TSParseAction *actions; + TSSymbol symbol; + TSStateId next_state; + uint16_t action_count; +} LookaheadIterator; + void ts_language_table_entry(const TSLanguage *, TSStateId, TSSymbol, TableEntry *); TSSymbolMetadata ts_language_symbol_metadata(const TSLanguage *, TSSymbol); @@ -41,22 +59,33 @@ static inline const TSParseAction *ts_language_actions( return entry.actions; } -static inline bool ts_language_has_actions(const TSLanguage *self, - TSStateId state, - TSSymbol symbol) { +static inline bool ts_language_has_actions( + const TSLanguage *self, + TSStateId state, + TSSymbol symbol +) { TableEntry entry; ts_language_table_entry(self, state, symbol, &entry); return entry.action_count > 0; } -static inline bool ts_language_has_reduce_action(const TSLanguage *self, - TSStateId state, - TSSymbol symbol) { +static inline bool ts_language_has_reduce_action( + const TSLanguage *self, + TSStateId state, + TSSymbol symbol +) { TableEntry entry; ts_language_table_entry(self, state, symbol, &entry); return entry.action_count > 0 && entry.actions[0].type == TSParseActionTypeReduce; } +// Lookup the table value for a given symbol and state. +// +// For non-terminal symbols, the table value represents a successor state. +// For terminal symbols, it represents an index in the actions table. +// For 'large' parse states, this is a direct lookup. For 'small' parse +// states, this requires searching through the symbol groups to find +// the given symbol. static inline uint16_t ts_language_lookup( const TSLanguage *self, TSStateId state, @@ -68,8 +97,8 @@ static inline uint16_t ts_language_lookup( ) { uint32_t index = self->small_parse_table_map[state - self->large_state_count]; const uint16_t *data = &self->small_parse_table[index]; - uint16_t section_count = *(data++); - for (unsigned i = 0; i < section_count; i++) { + uint16_t group_count = *(data++); + for (unsigned i = 0; i < group_count; i++) { uint16_t section_value = *(data++); uint16_t symbol_count = *(data++); for (unsigned i = 0; i < symbol_count; i++) { @@ -82,9 +111,90 @@ static inline uint16_t ts_language_lookup( } } -static inline TSStateId ts_language_next_state(const TSLanguage *self, - TSStateId state, - TSSymbol symbol) { +// Iterate over all of the symbols that are valid in the given state. +// +// For 'large' parse states, this just requires iterating through +// all possible symbols and checking the parse table for each one. +// For 'small' parse states, this exploits the structure of the +// table to only visit the valid symbols. +static inline LookaheadIterator ts_language_lookaheads( + const TSLanguage *self, + TSStateId state +) { + bool is_small_state = + self->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES && + state >= self->large_state_count; + const uint16_t *data; + const uint16_t *group_end = NULL; + uint16_t group_count = 0; + if (is_small_state) { + uint32_t index = self->small_parse_table_map[state - self->large_state_count]; + data = &self->small_parse_table[index]; + group_end = data + 1; + group_count = *data; + } else { + data = &self->parse_table[state * self->symbol_count] - 1; + } + return (LookaheadIterator) { + .language = self, + .data = data, + .group_end = group_end, + .group_count = group_count, + .is_small_state = is_small_state, + .symbol = UINT16_MAX, + .next_state = 0, + }; +} + +static inline bool ts_lookahead_iterator_next(LookaheadIterator *self) { + // For small parse states, valid symbols are listed explicitly, + // grouped by their value. There's no need to look up the actions + // again until moving to the next group. + if (self->is_small_state) { + self->data++; + if (self->data == self->group_end) { + if (self->group_count == 0) return false; + self->group_count--; + self->table_value = *(self->data++); + unsigned symbol_count = *(self->data++); + self->group_end = self->data + symbol_count; + self->symbol = *self->data; + } else { + self->symbol = *self->data; + return true; + } + } + + // For large parse states, iterate through every symbol until one + // is found that has valid actions. + else { + do { + self->data++; + self->symbol++; + if (self->symbol >= self->language->symbol_count) return false; + self->table_value = *self->data; + } while (!self->table_value); + } + + // Depending on if the symbols is terminal or non-terminal, the table value either + // represents a list of actions or a successor state. + if (self->symbol < self->language->token_count) { + const TSParseActionEntry *entry = &self->language->parse_actions[self->table_value]; + self->action_count = entry->entry.count; + self->actions = (const TSParseAction *)(entry + 1); + self->next_state = 0; + } else { + self->action_count = 0; + self->next_state = self->table_value; + } + return true; +} + +static inline TSStateId ts_language_next_state( + const TSLanguage *self, + TSStateId state, + TSSymbol symbol +) { if (symbol == ts_builtin_sym_error || symbol == ts_builtin_sym_error_repeat) { return 0; } else if (symbol < self->token_count) { @@ -102,9 +212,10 @@ static inline TSStateId ts_language_next_state(const TSLanguage *self, } } -static inline const bool * -ts_language_enabled_external_tokens(const TSLanguage *self, - unsigned external_scanner_state) { +static inline const bool *ts_language_enabled_external_tokens( + const TSLanguage *self, + unsigned external_scanner_state +) { if (external_scanner_state == 0) { return NULL; } else { @@ -112,13 +223,25 @@ ts_language_enabled_external_tokens(const TSLanguage *self, } } -static inline const TSSymbol * -ts_language_alias_sequence(const TSLanguage *self, uint32_t production_id) { - return production_id > 0 ? - self->alias_sequences + production_id * self->max_alias_sequence_length : +static inline const TSSymbol *ts_language_alias_sequence( + const TSLanguage *self, + uint32_t production_id +) { + return production_id ? + &self->alias_sequences[production_id * self->max_alias_sequence_length] : NULL; } +static inline TSSymbol ts_language_alias_at( + const TSLanguage *self, + uint32_t production_id, + uint32_t child_index +) { + return production_id ? + self->alias_sequences[production_id * self->max_alias_sequence_length + child_index] : + 0; +} + static inline void ts_language_field_map( const TSLanguage *self, uint32_t production_id, @@ -136,6 +259,32 @@ static inline void ts_language_field_map( *end = &self->field_map_entries[slice.index] + slice.length; } +static inline void ts_language_aliases_for_symbol( + const TSLanguage *self, + TSSymbol original_symbol, + const TSSymbol **start, + const TSSymbol **end +) { + *start = &self->public_symbol_map[original_symbol]; + *end = *start + 1; + + if (self->version < TREE_SITTER_LANGUAGE_VERSION_WITH_ALIAS_MAP) return; + + unsigned i = 0; + for (;;) { + TSSymbol symbol = self->alias_map[i++]; + if (symbol == 0 || symbol > original_symbol) break; + uint16_t count = self->alias_map[i++]; + if (symbol == original_symbol) { + *start = &self->alias_map[i]; + *end = &self->alias_map[i + count]; + break; + } + i += count; + } +} + + #ifdef __cplusplus } #endif diff --git a/lib/src/query.c b/lib/src/query.c index b887b74f..b629af51 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -8,11 +8,14 @@ #include "./unicode.h" #include +// #define DEBUG_ANALYZE_QUERY // #define LOG(...) fprintf(stderr, __VA_ARGS__) #define LOG(...) #define MAX_CAPTURE_LIST_COUNT 32 #define MAX_STEP_CAPTURE_COUNT 3 +#define MAX_STATE_PREDECESSOR_COUNT 100 +#define MAX_ANALYSIS_STATE_DEPTH 12 /* * Stream - A sequence of unicode characters derived from a UTF8 string. @@ -20,6 +23,7 @@ */ typedef struct { const char *input; + const char *start; const char *end; int32_t next; uint8_t next_size; @@ -53,6 +57,7 @@ typedef struct { bool is_pass_through: 1; bool is_dead_end: 1; bool alternative_is_immediate: 1; + bool is_definite: 1; } QueryStep; /* @@ -87,6 +92,17 @@ typedef struct { uint16_t pattern_index; } PatternEntry; +typedef struct { + Slice steps; + Slice predicate_steps; + uint32_t start_byte; +} QueryPattern; + +typedef struct { + uint32_t byte_offset; + uint16_t step_index; +} StepOffset; + /* * QueryState - The state of an in-progress match of a particular pattern * in a query. While executing, a `TSQueryCursor` must keep track of a number @@ -137,6 +153,52 @@ typedef struct { uint32_t usage_map; } CaptureListPool; +/* + * AnalysisState - The state needed for walking the parse table when analyzing + * a query pattern, to determine at which steps the pattern might fail to match. + */ +typedef struct { + TSStateId parse_state; + TSSymbol parent_symbol; + uint16_t child_index; + TSFieldId field_id: 15; + bool done: 1; +} AnalysisStateEntry; + +typedef struct { + AnalysisStateEntry stack[MAX_ANALYSIS_STATE_DEPTH]; + uint16_t depth; + uint16_t step_index; +} AnalysisState; + +typedef Array(AnalysisState) AnalysisStateSet; + +/* + * AnalysisSubgraph - A subset of the states in the parse table that are used + * in constructing nodes with a certain symbol. Each state is accompanied by + * some information about the possible node that could be produced in + * downstream states. + */ +typedef struct { + TSStateId state; + uint8_t production_id; + uint8_t child_index: 7; + bool done: 1; +} AnalysisSubgraphNode; + +typedef struct { + TSSymbol symbol; + Array(TSStateId) start_states; + Array(AnalysisSubgraphNode) nodes; +} AnalysisSubgraph; + +/* + * StatePredecessorMap - A map that stores the predecessors of each parse state. + */ +typedef struct { + TSStateId *contents; +} StatePredecessorMap; + /* * TSQuery - A tree query, compiled from a string of S-expressions. The query * itself is immutable. The mutable state used in the process of executing the @@ -148,8 +210,8 @@ struct TSQuery { Array(QueryStep) steps; Array(PatternEntry) pattern_map; Array(TSQueryPredicateStep) predicate_steps; - Array(Slice) predicates_by_pattern; - Array(uint32_t) start_bytes_by_pattern; + Array(QueryPattern) patterns; + Array(StepOffset) step_offsets; const TSLanguage *language; uint16_t wildcard_root_pattern_count; TSSymbol *symbol_map; @@ -216,21 +278,22 @@ static Stream stream_new(const char *string, uint32_t length) { Stream self = { .next = 0, .input = string, + .start = string, .end = string + length, }; stream_advance(&self); return self; } -static void stream_skip_whitespace(Stream *stream) { +static void stream_skip_whitespace(Stream *self) { for (;;) { - if (iswspace(stream->next)) { - stream_advance(stream); - } else if (stream->next == ';') { + if (iswspace(self->next)) { + stream_advance(self); + } else if (self->next == ';') { // skip over comments - stream_advance(stream); - while (stream->next && stream->next != '\n') { - if (!stream_advance(stream)) break; + stream_advance(self); + while (self->next && self->next != '\n') { + if (!stream_advance(self)) break; } } else { break; @@ -238,8 +301,8 @@ static void stream_skip_whitespace(Stream *stream) { } } -static bool stream_is_ident_start(Stream *stream) { - return iswalnum(stream->next) || stream->next == '_' || stream->next == '-'; +static bool stream_is_ident_start(Stream *self) { + return iswalnum(self->next) || self->next == '_' || self->next == '-'; } static void stream_scan_identifier(Stream *stream) { @@ -255,6 +318,10 @@ static void stream_scan_identifier(Stream *stream) { ); } +static uint32_t stream_offset(Stream *self) { + return self->input - self->start; +} + /****************** * CaptureListPool ******************/ @@ -450,6 +517,7 @@ static QueryStep query_step__new( .is_last_child = false, .is_pass_through = false, .is_dead_end = false, + .is_definite = false, .is_immediate = is_immediate, .alternative_is_immediate = false, }; @@ -479,6 +547,113 @@ static void query_step__remove_capture(QueryStep *self, uint16_t capture_id) { } } +/********************** + * StatePredecessorMap + **********************/ + +static inline StatePredecessorMap state_predecessor_map_new(const TSLanguage *language) { + return (StatePredecessorMap) { + .contents = ts_calloc(language->state_count * (MAX_STATE_PREDECESSOR_COUNT + 1), sizeof(TSStateId)), + }; +} + +static inline void state_predecessor_map_delete(StatePredecessorMap *self) { + ts_free(self->contents); +} + +static inline void state_predecessor_map_add( + StatePredecessorMap *self, + TSStateId state, + TSStateId predecessor +) { + unsigned index = state * (MAX_STATE_PREDECESSOR_COUNT + 1); + TSStateId *count = &self->contents[index]; + if (*count == 0 || (*count < MAX_STATE_PREDECESSOR_COUNT && self->contents[index + *count] != predecessor)) { + (*count)++; + self->contents[index + *count] = predecessor; + } +} + +static inline const TSStateId *state_predecessor_map_get( + const StatePredecessorMap *self, + TSStateId state, + unsigned *count +) { + unsigned index = state * (MAX_STATE_PREDECESSOR_COUNT + 1); + *count = self->contents[index]; + return &self->contents[index + 1]; +} + +/**************** + * AnalysisState + ****************/ + +static unsigned analysis_state__recursion_depth(const AnalysisState *self) { + unsigned result = 0; + for (unsigned i = 0; i < self->depth; i++) { + TSSymbol symbol = self->stack[i].parent_symbol; + for (unsigned j = 0; j < i; j++) { + if (self->stack[j].parent_symbol == symbol) { + result++; + break; + } + } + } + return result; +} + +static inline int analysis_state__compare_position( + const AnalysisState *self, + const AnalysisState *other +) { + for (unsigned i = 0; i < self->depth; i++) { + if (i >= other->depth) return -1; + if (self->stack[i].child_index < other->stack[i].child_index) return -1; + if (self->stack[i].child_index > other->stack[i].child_index) return 1; + } + if (self->depth < other->depth) return 1; + return 0; +} + +static inline int analysis_state__compare( + const AnalysisState *self, + const AnalysisState *other +) { + int result = analysis_state__compare_position(self, other); + if (result != 0) return result; + for (unsigned i = 0; i < self->depth; i++) { + if (self->stack[i].parent_symbol < other->stack[i].parent_symbol) return -1; + if (self->stack[i].parent_symbol > other->stack[i].parent_symbol) return 1; + if (self->stack[i].parse_state < other->stack[i].parse_state) return -1; + if (self->stack[i].parse_state > other->stack[i].parse_state) return 1; + if (self->stack[i].field_id < other->stack[i].field_id) return -1; + if (self->stack[i].field_id > other->stack[i].field_id) return 1; + } + if (self->step_index < other->step_index) return -1; + if (self->step_index > other->step_index) return 1; + return 0; +} + +static inline AnalysisStateEntry *analysis_state__top(AnalysisState *self) { + return &self->stack[self->depth - 1]; +} + +/*********************** + * AnalysisSubgraphNode + ***********************/ + +static inline int analysis_subgraph_node__compare(const AnalysisSubgraphNode *self, const AnalysisSubgraphNode *other) { + if (self->state < other->state) return -1; + if (self->state > other->state) return 1; + if (self->child_index < other->child_index) return -1; + if (self->child_index > other->child_index) return 1; + if (self->done < other->done) return -1; + if (self->done > other->done) return 1; + if (self->production_id < other->production_id) return -1; + if (self->production_id > other->production_id) return 1; + return 0; +} + /********* * Query *********/ @@ -568,6 +743,628 @@ static inline void ts_query__pattern_map_insert( })); } +static bool ts_query__analyze_patterns(TSQuery *self, unsigned *error_offset) { + // Identify all of the patterns in the query that have child patterns, both at the + // top level and nested within other larger patterns. Record the step index where + // each pattern starts. + Array(uint32_t) parent_step_indices = array_new(); + for (unsigned i = 0; i < self->steps.size; i++) { + QueryStep *step = &self->steps.contents[i]; + if (i + 1 < self->steps.size) { + QueryStep *next_step = &self->steps.contents[i + 1]; + if ( + step->symbol != WILDCARD_SYMBOL && + step->symbol != NAMED_WILDCARD_SYMBOL && + next_step->depth > step->depth && + next_step->depth != PATTERN_DONE_MARKER + ) { + array_push(&parent_step_indices, i); + } + } + if (step->depth > 0) { + step->is_definite = true; + } + } + + // For every parent symbol in the query, initialize an 'analysis subgraph'. + // This subgraph lists all of the states in the parse table that are directly + // involved in building subtrees for this symbol. + // + // In addition to the parent symbols in the query, construct subgraphs for all + // of the hidden symbols in the grammar, because these might occur within + // one of the parent nodes, such that their children appear to belong to the + // parent. + Array(AnalysisSubgraph) subgraphs = array_new(); + for (unsigned i = 0; i < parent_step_indices.size; i++) { + uint32_t parent_step_index = parent_step_indices.contents[i]; + TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; + AnalysisSubgraph subgraph = { .symbol = parent_symbol }; + array_insert_sorted_by(&subgraphs, .symbol, subgraph); + } + for (TSSymbol sym = self->language->token_count; sym < self->language->symbol_count; sym++) { + if (!ts_language_symbol_metadata(self->language, sym).visible) { + AnalysisSubgraph subgraph = { .symbol = sym }; + array_insert_sorted_by(&subgraphs, .symbol, subgraph); + } + } + + // Scan the parse table to find the data needed to populate these subgraphs. + // Collect three things during this scan: + // 1) All of the parse states where one of these symbols can start. + // 2) All of the parse states where one of these symbols can end, along + // with information about the node that would be created. + // 3) A list of predecessor states for each state. + StatePredecessorMap predecessor_map = state_predecessor_map_new(self->language); + for (TSStateId state = 1; state < self->language->state_count; state++) { + unsigned subgraph_index, exists; + LookaheadIterator lookahead_iterator = ts_language_lookaheads(self->language, state); + while (ts_lookahead_iterator_next(&lookahead_iterator)) { + if (lookahead_iterator.action_count) { + for (unsigned i = 0; i < lookahead_iterator.action_count; i++) { + const TSParseAction *action = &lookahead_iterator.actions[i]; + if (action->type == TSParseActionTypeReduce) { + const TSSymbol *aliases, *aliases_end; + ts_language_aliases_for_symbol( + self->language, + action->params.reduce.symbol, + &aliases, + &aliases_end + ); + for (const TSSymbol *symbol = aliases; symbol < aliases_end; symbol++) { + array_search_sorted_by( + &subgraphs, + .symbol, + *symbol, + &subgraph_index, + &exists + ); + if (exists) { + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if (subgraph->nodes.size == 0 || array_back(&subgraph->nodes)->state != state) { + array_push(&subgraph->nodes, ((AnalysisSubgraphNode) { + .state = state, + .production_id = action->params.reduce.production_id, + .child_index = action->params.reduce.child_count, + .done = true, + })); + } + } + } + } else if (action->type == TSParseActionTypeShift && !action->params.shift.extra) { + TSStateId next_state = action->params.shift.state; + state_predecessor_map_add(&predecessor_map, next_state, state); + } + } + } else if (lookahead_iterator.next_state != 0 && lookahead_iterator.next_state != state) { + state_predecessor_map_add(&predecessor_map, lookahead_iterator.next_state, state); + const TSSymbol *aliases, *aliases_end; + ts_language_aliases_for_symbol( + self->language, + lookahead_iterator.symbol, + &aliases, + &aliases_end + ); + for (const TSSymbol *symbol = aliases; symbol < aliases_end; symbol++) { + array_search_sorted_by( + &subgraphs, + .symbol, + *symbol, + &subgraph_index, + &exists + ); + if (exists) { + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + if ( + subgraph->start_states.size == 0 || + *array_back(&subgraph->start_states) != state + ) + array_push(&subgraph->start_states, state); + } + } + } + } + } + + // For each subgraph, compute the preceding states by walking backward + // from the end states using the predecessor map. + Array(AnalysisSubgraphNode) next_nodes = array_new(); + for (unsigned i = 0; i < subgraphs.size; i++) { + AnalysisSubgraph *subgraph = &subgraphs.contents[i]; + if (subgraph->nodes.size == 0) { + array_delete(&subgraph->start_states); + array_erase(&subgraphs, i); + i--; + continue; + } + array_assign(&next_nodes, &subgraph->nodes); + while (next_nodes.size > 0) { + AnalysisSubgraphNode node = array_pop(&next_nodes); + if (node.child_index > 1) { + unsigned predecessor_count; + const TSStateId *predecessors = state_predecessor_map_get( + &predecessor_map, + node.state, + &predecessor_count + ); + for (unsigned j = 0; j < predecessor_count; j++) { + AnalysisSubgraphNode predecessor_node = { + .state = predecessors[j], + .child_index = node.child_index - 1, + .production_id = node.production_id, + .done = false, + }; + unsigned index, exists; + array_search_sorted_with( + &subgraph->nodes, analysis_subgraph_node__compare, &predecessor_node, + &index, &exists + ); + if (!exists) { + array_insert(&subgraph->nodes, index, predecessor_node); + array_push(&next_nodes, predecessor_node); + } + } + } + } + } + + #ifdef DEBUG_ANALYZE_QUERY + printf("\nSubgraphs:\n"); + for (unsigned i = 0; i < subgraphs.size; i++) { + AnalysisSubgraph *subgraph = &subgraphs.contents[i]; + printf(" %u, %s:\n", subgraph->symbol, ts_language_symbol_name(self->language, subgraph->symbol)); + for (unsigned j = 0; j < subgraph->start_states.size; j++) { + printf( + " {state: %u}\n", + subgraph->start_states.contents[j] + ); + } + for (unsigned j = 0; j < subgraph->nodes.size; j++) { + AnalysisSubgraphNode *node = &subgraph->nodes.contents[j]; + printf( + " {state: %u, child_index: %u, production_id: %u, done: %d}\n", + node->state, node->child_index, node->production_id, node->done + ); + } + printf("\n"); + } + #endif + + // For each non-terminal pattern, determine if the pattern can successfully match, + // and identify all of the possible children within the pattern where matching could fail. + bool result = true; + AnalysisStateSet states = array_new(); + AnalysisStateSet next_states = array_new(); + AnalysisStateSet deeper_states = array_new(); + Array(uint16_t) final_step_indices = array_new(); + for (unsigned i = 0; i < parent_step_indices.size; i++) { + uint16_t parent_step_index = parent_step_indices.contents[i]; + uint16_t parent_depth = self->steps.contents[parent_step_index].depth; + TSSymbol parent_symbol = self->steps.contents[parent_step_index].symbol; + if (parent_symbol == ts_builtin_sym_error) continue; + + // Find the subgraph that corresponds to this pattern's root symbol. If the pattern's + // root symbols is not a non-terminal, then return an error. + unsigned subgraph_index, exists; + array_search_sorted_by(&subgraphs, .symbol, parent_symbol, &subgraph_index, &exists); + if (!exists) { + unsigned first_child_step_index = parent_step_index + 1; + uint32_t i, exists; + array_search_sorted_by(&self->step_offsets, .step_index, first_child_step_index, &i, &exists); + assert(exists); + *error_offset = self->step_offsets.contents[i].byte_offset; + result = false; + break; + } + + // Initialize an analysis state at every parse state in the table where + // this parent symbol can occur. + AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + array_clear(&states); + array_clear(&deeper_states); + for (unsigned j = 0; j < subgraph->start_states.size; j++) { + TSStateId parse_state = subgraph->start_states.contents[j]; + array_push(&states, ((AnalysisState) { + .step_index = parent_step_index + 1, + .stack = { + [0] = { + .parse_state = parse_state, + .parent_symbol = parent_symbol, + .child_index = 0, + .field_id = 0, + .done = false, + }, + }, + .depth = 1, + })); + } + + // Walk the subgraph for this non-terminal, tracking all of the possible + // sequences of progress within the pattern. + bool can_finish_pattern = false; + bool did_exceed_max_depth = false; + unsigned recursion_depth_limit = 0; + unsigned prev_final_step_count = 0; + array_clear(&final_step_indices); + for (;;) { + #ifdef DEBUG_ANALYZE_QUERY + printf("Final step indices:"); + for (unsigned j = 0; j < final_step_indices.size; j++) { + printf(" %4u", final_step_indices.contents[j]); + } + printf("\nWalk states for %u %s:\n", i, ts_language_symbol_name(self->language, parent_symbol)); + for (unsigned j = 0; j < states.size; j++) { + AnalysisState *state = &states.contents[j]; + printf(" %3u: step: %u, stack: [", j, state->step_index); + for (unsigned k = 0; k < state->depth; k++) { + printf( + " {%s, child: %u, state: %4u", + self->language->symbol_names[state->stack[k].parent_symbol], + state->stack[k].child_index, + state->stack[k].parse_state + ); + if (state->stack[k].field_id) printf(", field: %s", self->language->field_names[state->stack[k].field_id]); + if (state->stack[k].done) printf(", DONE"); + printf("}"); + } + printf(" ]\n"); + } + #endif + + if (states.size == 0) { + if (deeper_states.size > 0 && final_step_indices.size > prev_final_step_count) { + #ifdef DEBUG_ANALYZE_QUERY + printf("Increase recursion depth limit to %u\n", recursion_depth_limit + 1); + #endif + + prev_final_step_count = final_step_indices.size; + recursion_depth_limit++; + AnalysisStateSet _states = states; + states = deeper_states; + deeper_states = _states; + continue; + } + + break; + } + + array_clear(&next_states); + for (unsigned j = 0; j < states.size; j++) { + AnalysisState * const state = &states.contents[j]; + + // For efficiency, it's important to avoid processing the same analysis state more + // than once. To achieve this, keep the states in order of ascending position within + // their hypothetical syntax trees. In each iteration of this loop, start by advancing + // the states that have made the least progress. Avoid advancing states that have already + // made more progress. + if (next_states.size > 0) { + int comparison = analysis_state__compare_position(state, array_back(&next_states)); + if (comparison == 0) { + array_insert_sorted_with(&next_states, analysis_state__compare, *state); + continue; + } else if (comparison > 0) { + while (j < states.size) { + array_push(&next_states, states.contents[j]); + j++; + } + break; + } + } + + const TSStateId parse_state = analysis_state__top(state)->parse_state; + const TSSymbol parent_symbol = analysis_state__top(state)->parent_symbol; + const TSFieldId parent_field_id = analysis_state__top(state)->field_id; + const unsigned child_index = analysis_state__top(state)->child_index; + const QueryStep * const step = &self->steps.contents[state->step_index]; + + unsigned subgraph_index, exists; + array_search_sorted_by(&subgraphs, .symbol, parent_symbol, &subgraph_index, &exists); + if (!exists) continue; + const AnalysisSubgraph *subgraph = &subgraphs.contents[subgraph_index]; + + // Follow every possible path in the parse table, but only visit states that + // are part of the subgraph for the current symbol. + LookaheadIterator lookahead_iterator = ts_language_lookaheads(self->language, parse_state); + while (ts_lookahead_iterator_next(&lookahead_iterator)) { + TSSymbol sym = lookahead_iterator.symbol; + + TSStateId next_parse_state; + if (lookahead_iterator.action_count) { + const TSParseAction *action = &lookahead_iterator.actions[lookahead_iterator.action_count - 1]; + if (action->type == TSParseActionTypeShift && !action->params.shift.extra) { + next_parse_state = action->params.shift.state; + } else { + continue; + } + } else if (lookahead_iterator.next_state != 0 && lookahead_iterator.next_state != parse_state) { + next_parse_state = lookahead_iterator.next_state; + } else { + continue; + } + + AnalysisSubgraphNode successor = { + .state = next_parse_state, + .child_index = child_index + 1, + }; + unsigned node_index; + array_search_sorted_with( + &subgraph->nodes, + analysis_subgraph_node__compare, &successor, + &node_index, &exists + ); + while (node_index < subgraph->nodes.size) { + AnalysisSubgraphNode *node = &subgraph->nodes.contents[node_index++]; + if (node->state != successor.state || node->child_index != successor.child_index) break; + + // Use the subgraph to determine what alias and field will eventually be applied + // to this child node. + TSSymbol alias = ts_language_alias_at(self->language, node->production_id, child_index); + TSSymbol visible_symbol = alias + ? alias + : self->language->symbol_metadata[sym].visible + ? self->language->public_symbol_map[sym] + : 0; + TSFieldId field_id = parent_field_id; + if (!field_id) { + const TSFieldMapEntry *field_map, *field_map_end; + ts_language_field_map(self->language, node->production_id, &field_map, &field_map_end); + for (; field_map != field_map_end; field_map++) { + if (field_map->child_index == child_index) { + field_id = field_map->field_id; + break; + } + } + } + + AnalysisState next_state = *state; + analysis_state__top(&next_state)->child_index++; + analysis_state__top(&next_state)->parse_state = successor.state; + if (node->done) analysis_state__top(&next_state)->done = true; + + // Determine if this hypothetical child node would match the current step + // of the query pattern. + bool does_match = false; + if (visible_symbol) { + does_match = true; + if (step->symbol == NAMED_WILDCARD_SYMBOL) { + if (!self->language->symbol_metadata[visible_symbol].named) does_match = false; + } else if (step->symbol != WILDCARD_SYMBOL) { + if (step->symbol != visible_symbol) does_match = false; + } + if (step->field && step->field != field_id) { + does_match = false; + } + } + + // If this is a hidden child, then push a new entry to the stack, in order to + // walk through the children of this child. + else if (sym >= self->language->token_count) { + if (next_state.depth + 1 >= MAX_ANALYSIS_STATE_DEPTH) { + did_exceed_max_depth = true; + continue; + } + + next_state.depth++; + analysis_state__top(&next_state)->parse_state = parse_state; + analysis_state__top(&next_state)->child_index = 0; + analysis_state__top(&next_state)->parent_symbol = sym; + analysis_state__top(&next_state)->field_id = field_id; + analysis_state__top(&next_state)->done = false; + + if (analysis_state__recursion_depth(&next_state) > recursion_depth_limit) { + array_insert_sorted_with(&deeper_states, analysis_state__compare, next_state); + continue; + } + } else { + continue; + } + + // Pop from the stack when this state reached the end of its current syntax node. + while (next_state.depth > 0 && analysis_state__top(&next_state)->done) { + next_state.depth--; + } + + // If this hypothetical child did match the current step of the query pattern, + // then advance to the next step at the current depth. This involves skipping + // over any descendant steps of the current child. + const QueryStep *next_step = step; + if (does_match) { + for (;;) { + next_state.step_index++; + next_step = &self->steps.contents[next_state.step_index]; + if ( + next_step->depth == PATTERN_DONE_MARKER || + next_step->depth <= parent_depth + 1 + ) break; + } + } + + for (;;) { + // If this state can make further progress, then add it to the states for the next iteration. + // Otherwise, record the fact that matching can fail at this step of the pattern. + if (!next_step->is_dead_end) { + bool did_finish_pattern = self->steps.contents[next_state.step_index].depth != parent_depth + 1; + if (did_finish_pattern) can_finish_pattern = true; + if (did_finish_pattern || next_state.depth == 0) { + array_insert_sorted_by(&final_step_indices, , next_state.step_index); + } else { + array_insert_sorted_with(&next_states, analysis_state__compare, next_state); + } + } + + // If the state has advanced to a step with an alternative step, then add another state at + // that alternative step to the next iteration. + if ( + does_match && + next_step->alternative_index != NONE && + next_step->alternative_index > next_state.step_index + ) { + next_state.step_index = next_step->alternative_index; + next_step = &self->steps.contents[next_state.step_index]; + } else { + break; + } + } + } + } + } + + AnalysisStateSet _states = states; + states = next_states; + next_states = _states; + } + + // Mark as indefinite any step where a match terminated. + // Later, this property will be propagated to all of the step's predecessors. + for (unsigned j = 0; j < final_step_indices.size; j++) { + uint32_t final_step_index = final_step_indices.contents[j]; + QueryStep *step = &self->steps.contents[final_step_index]; + if ( + step->depth != PATTERN_DONE_MARKER && + step->depth > parent_depth && + !step->is_dead_end + ) { + step->is_definite = false; + } + } + + if (did_exceed_max_depth) { + for (unsigned j = parent_step_index + 1; j < self->steps.size; j++) { + QueryStep *step = &self->steps.contents[j]; + if ( + step->depth <= parent_depth || + step->depth == PATTERN_DONE_MARKER + ) break; + if (!step->is_dead_end) { + step->is_definite = false; + } + } + } + + // If this pattern cannot match, store the pattern index so that it can be + // returned to the caller. + if (result && !can_finish_pattern && !did_exceed_max_depth) { + assert(final_step_indices.size > 0); + uint16_t impossible_step_index = *array_back(&final_step_indices); + uint32_t i, exists; + array_search_sorted_by(&self->step_offsets, .step_index, impossible_step_index, &i, &exists); + assert(exists); + *error_offset = self->step_offsets.contents[i].byte_offset; + result = false; + break; + } + } + + // Mark as indefinite any step with captures that are used in predicates. + Array(uint16_t) predicate_capture_ids = array_new(); + for (unsigned i = 0; i < self->patterns.size; i++) { + QueryPattern *pattern = &self->patterns.contents[i]; + + // Gather all of the captures that are used in predicates for this pattern. + array_clear(&predicate_capture_ids); + for ( + unsigned start = pattern->predicate_steps.offset, + end = start + pattern->predicate_steps.length, + j = start; j < end; j++ + ) { + TSQueryPredicateStep *step = &self->predicate_steps.contents[j]; + if (step->type == TSQueryPredicateStepTypeCapture) { + array_insert_sorted_by(&predicate_capture_ids, , step->value_id); + } + } + + // Find all of the steps that have these captures. + for ( + unsigned start = pattern->steps.offset, + end = start + pattern->steps.length, + j = start; j < end; j++ + ) { + QueryStep *step = &self->steps.contents[j]; + for (unsigned k = 0; k < MAX_STEP_CAPTURE_COUNT; k++) { + uint16_t capture_id = step->capture_ids[k]; + if (capture_id == NONE) break; + unsigned index, exists; + array_search_sorted_by(&predicate_capture_ids, , capture_id, &index, &exists); + if (exists) { + step->is_definite = false; + break; + } + } + } + } + + // Propagate indefiniteness backwards. + bool done = self->steps.size == 0; + while (!done) { + done = true; + for (unsigned i = self->steps.size - 1; i > 0; i--) { + QueryStep *step = &self->steps.contents[i]; + + // Determine if this step is definite or has definite alternatives. + bool is_definite = false; + for (;;) { + if (step->is_definite) { + is_definite = true; + break; + } + if (step->alternative_index == NONE || step->alternative_index < i) { + break; + } + step = &self->steps.contents[step->alternative_index]; + } + + // If not, mark its predecessor as indefinite. + if (!is_definite) { + QueryStep *prev_step = &self->steps.contents[i - 1]; + if ( + !prev_step->is_dead_end && + prev_step->depth != PATTERN_DONE_MARKER && + prev_step->is_definite + ) { + prev_step->is_definite = false; + done = false; + } + } + } + } + + #ifdef DEBUG_ANALYZE_QUERY + printf("Steps:\n"); + for (unsigned i = 0; i < self->steps.size; i++) { + QueryStep *step = &self->steps.contents[i]; + if (step->depth == PATTERN_DONE_MARKER) { + printf(" %u: DONE\n", i); + } else { + printf( + " %u: {symbol: %s, field: %s, is_definite: %d}\n", + i, + (step->symbol == WILDCARD_SYMBOL || step->symbol == NAMED_WILDCARD_SYMBOL) + ? "ANY" + : ts_language_symbol_name(self->language, step->symbol), + (step->field ? ts_language_field_name_for_id(self->language, step->field) : "-"), + step->is_definite + ); + } + } + #endif + + // Cleanup + for (unsigned i = 0; i < subgraphs.size; i++) { + array_delete(&subgraphs.contents[i].start_states); + array_delete(&subgraphs.contents[i].nodes); + } + array_delete(&subgraphs); + array_delete(&next_nodes); + array_delete(&states); + array_delete(&next_states); + array_delete(&deeper_states); + array_delete(&final_step_indices); + array_delete(&parent_step_indices); + array_delete(&predicate_capture_ids); + state_predecessor_map_delete(&predecessor_map); + + return result; +} + static void ts_query__finalize_steps(TSQuery *self) { for (unsigned i = 0; i < self->steps.size; i++) { QueryStep *step = &self->steps.contents[i]; @@ -604,7 +1401,6 @@ static TSQueryError ts_query__parse_predicate( predicate_name, length ); - array_back(&self->predicates_by_pattern)->length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -615,7 +1411,6 @@ static TSQueryError ts_query__parse_predicate( if (stream->next == ')') { stream_advance(stream); stream_skip_whitespace(stream); - array_back(&self->predicates_by_pattern)->length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeDone, .value_id = 0, @@ -644,7 +1439,6 @@ static TSQueryError ts_query__parse_predicate( return TSQueryErrorCapture; } - array_back(&self->predicates_by_pattern)->length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeCapture, .value_id = capture_id, @@ -684,7 +1478,6 @@ static TSQueryError ts_query__parse_predicate( string_content, length ); - array_back(&self->predicates_by_pattern)->length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -704,7 +1497,6 @@ static TSQueryError ts_query__parse_predicate( symbol_start, length ); - array_back(&self->predicates_by_pattern)->length++; array_push(&self->predicate_steps, ((TSQueryPredicateStep) { .type = TSQueryPredicateStepTypeString, .value_id = id, @@ -728,20 +1520,26 @@ static TSQueryError ts_query__parse_pattern( TSQuery *self, Stream *stream, uint32_t depth, - uint32_t *capture_count, bool is_immediate ) { + if (stream->next == 0) return TSQueryErrorSyntax; + if (stream->next == ')' || stream->next == ']') return PARENT_DONE; + const uint32_t starting_step_index = self->steps.size; - if (stream->next == 0) return TSQueryErrorSyntax; - - // Finish the parent S-expression. - if (stream->next == ')' || stream->next == ']') { - return PARENT_DONE; + // Store the byte offset of each step in the query. + if ( + self->step_offsets.size == 0 || + array_back(&self->step_offsets)->step_index != starting_step_index + ) { + array_push(&self->step_offsets, ((StepOffset) { + .step_index = starting_step_index, + .byte_offset = stream_offset(stream), + })); } // An open bracket is the start of an alternation. - else if (stream->next == '[') { + if (stream->next == '[') { stream_advance(stream); stream_skip_whitespace(stream); @@ -753,7 +1551,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth, - capture_count, is_immediate ); @@ -806,7 +1603,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth, - capture_count, child_is_immediate ); if (e == PARENT_DONE && stream->next == ')') { @@ -887,7 +1683,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth + 1, - capture_count, child_is_immediate ); if (e == PARENT_DONE && stream->next == ')') { @@ -971,7 +1766,6 @@ static TSQueryError ts_query__parse_pattern( self, stream, depth, - capture_count, is_immediate ); if (e == PARENT_DONE) return TSQueryErrorSyntax; @@ -1085,8 +1879,6 @@ static TSQueryError ts_query__parse_pattern( break; } } - - (*capture_count)++; } // No more suffix modifiers @@ -1139,7 +1931,8 @@ TSQuery *ts_query_new( .captures = symbol_table_new(), .predicate_values = symbol_table_new(), .predicate_steps = array_new(), - .predicates_by_pattern = array_new(), + .patterns = array_new(), + .step_offsets = array_new(), .symbol_map = symbol_map, .wildcard_root_pattern_count = 0, .language = language, @@ -1149,22 +1942,26 @@ TSQuery *ts_query_new( Stream stream = stream_new(source, source_len); stream_skip_whitespace(&stream); while (stream.input < stream.end) { - uint32_t pattern_index = self->predicates_by_pattern.size; + uint32_t pattern_index = self->patterns.size; uint32_t start_step_index = self->steps.size; - uint32_t capture_count = 0; - array_push(&self->start_bytes_by_pattern, stream.input - source); - array_push(&self->predicates_by_pattern, ((Slice) { - .offset = self->predicate_steps.size, - .length = 0, + uint32_t start_predicate_step_index = self->predicate_steps.size; + array_push(&self->patterns, ((QueryPattern) { + .steps = (Slice) {.offset = start_step_index}, + .predicate_steps = (Slice) {.offset = start_predicate_step_index}, + .start_byte = stream_offset(&stream), })); - *error_type = ts_query__parse_pattern(self, &stream, 0, &capture_count, false); + *error_type = ts_query__parse_pattern(self, &stream, 0, false); array_push(&self->steps, query_step__new(0, PATTERN_DONE_MARKER, false)); + QueryPattern *pattern = array_back(&self->patterns); + pattern->steps.length = self->steps.size - start_step_index; + pattern->predicate_steps.length = self->predicate_steps.size - start_predicate_step_index; + // If any pattern could not be parsed, then report the error information // and terminate. if (*error_type) { if (*error_type == PARENT_DONE) *error_type = TSQueryErrorSyntax; - *error_offset = stream.input - source; + *error_offset = stream_offset(&stream); ts_query_delete(self); return NULL; } @@ -1199,6 +1996,14 @@ TSQuery *ts_query_new( } } + if (self->language->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT) { + if (!ts_query__analyze_patterns(self, error_offset)) { + *error_type = TSQueryErrorStructure; + ts_query_delete(self); + return NULL; + } + } + ts_query__finalize_steps(self); return self; } @@ -1208,8 +2013,8 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->steps); array_delete(&self->pattern_map); array_delete(&self->predicate_steps); - array_delete(&self->predicates_by_pattern); - array_delete(&self->start_bytes_by_pattern); + array_delete(&self->patterns); + array_delete(&self->step_offsets); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); ts_free(self->symbol_map); @@ -1218,7 +2023,7 @@ void ts_query_delete(TSQuery *self) { } uint32_t ts_query_pattern_count(const TSQuery *self) { - return self->predicates_by_pattern.size; + return self->patterns.size; } uint32_t ts_query_capture_count(const TSQuery *self) { @@ -1250,7 +2055,7 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern( uint32_t pattern_index, uint32_t *step_count ) { - Slice slice = self->predicates_by_pattern.contents[pattern_index]; + Slice slice = self->patterns.contents[pattern_index].predicate_steps; *step_count = slice.length; if (self->predicate_steps.contents == NULL) { return NULL; @@ -1262,7 +2067,24 @@ uint32_t ts_query_start_byte_for_pattern( const TSQuery *self, uint32_t pattern_index ) { - return self->start_bytes_by_pattern.contents[pattern_index]; + return self->patterns.contents[pattern_index].start_byte; +} + +bool ts_query_step_is_definite( + const TSQuery *self, + uint32_t byte_offset +) { + uint32_t step_index = UINT32_MAX; + for (unsigned i = 0; i < self->step_offsets.size; i++) { + StepOffset *step_offset = &self->step_offsets.contents[i]; + if (step_offset->byte_offset > byte_offset) break; + step_index = step_offset->step_index; + } + if (step_index < self->steps.size) { + return self->steps.contents[step_index].is_definite; + } else { + return false; + } } void ts_query_disable_capture( @@ -1375,7 +2197,8 @@ static bool ts_query_cursor__first_in_progress_capture( TSQueryCursor *self, uint32_t *state_index, uint32_t *byte_offset, - uint32_t *pattern_index + uint32_t *pattern_index, + bool *is_definite ) { bool result = false; *state_index = UINT32_MAX; @@ -1388,13 +2211,20 @@ static bool ts_query_cursor__first_in_progress_capture( &self->capture_list_pool, state->capture_list_id ); - if (captures->size > 0) { - uint32_t capture_byte = ts_node_start_byte(captures->contents[0].node); + if (captures->size > state->consumed_capture_count) { + uint32_t capture_byte = ts_node_start_byte(captures->contents[state->consumed_capture_count].node); if ( !result || capture_byte < *byte_offset || (capture_byte == *byte_offset && state->pattern_index < *pattern_index) ) { + QueryStep *step = &self->query->steps.contents[state->step_index]; + if (is_definite) { + *is_definite = step->is_definite; + } else if (step->is_definite) { + continue; + } + result = true; *state_index = i; *byte_offset = capture_byte; @@ -1557,7 +2387,8 @@ static CaptureList *ts_query_cursor__prepare_to_capture( self, &state_index, &byte_offset, - &pattern_index + &pattern_index, + NULL ) && state_index != state_index_to_preserve ) { @@ -1616,7 +2447,10 @@ static QueryState *ts_query_cursor__copy_state( // If one or more patterns finish, return `true` and store their states in the // `finished_states` array. Multiple patterns can finish on the same node. If // there are no more matches, return `false`. -static inline bool ts_query_cursor__advance(TSQueryCursor *self) { +static inline bool ts_query_cursor__advance( + TSQueryCursor *self, + bool stop_on_definite_step +) { bool did_match = false; for (;;) { if (self->halted) { @@ -1631,6 +2465,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { if (did_match || self->halted) return did_match; + // Exit the current node. if (self->ascending) { LOG("leave node. type:%s\n", ts_node_type(ts_tree_cursor_current_node(&self->cursor))); @@ -1683,7 +2518,10 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { } } self->states.size -= deleted_count; - } else { + } + + // Enter a new node. + else { // If this node is before the selected range, then avoid descending into it. TSNode node = ts_tree_cursor_current_node(&self->cursor); if ( @@ -1857,6 +2695,9 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { state->step_index ); + QueryStep *next_step = &self->query->steps.contents[state->step_index]; + if (stop_on_definite_step && next_step->is_definite) did_match = true; + // If this state's next step has an alternative step, then copy the state in order // to pursue both alternatives. The alternative step itself may have an alternative, // so this is an interative process. @@ -2001,7 +2842,7 @@ bool ts_query_cursor_next_match( TSQueryMatch *match ) { if (self->finished_states.size == 0) { - if (!ts_query_cursor__advance(self)) { + if (!ts_query_cursor__advance(self, false)) { return false; } } @@ -2042,99 +2883,103 @@ bool ts_query_cursor_next_capture( TSQueryMatch *match, uint32_t *capture_index ) { + // The goal here is to return captures in order, even though they may not + // be discovered in order, because patterns can overlap. Search for matches + // until there is a finished capture that is before any unfinished capture. for (;;) { - // The goal here is to return captures in order, even though they may not - // be discovered in order, because patterns can overlap. If there are any - // finished patterns, then try to find one that contains a capture that - // is *definitely* before any capture in an *unfinished* pattern. - if (self->finished_states.size > 0) { - // First, identify the position of the earliest capture in an unfinished - // match. For a finished capture to be returned, it must be *before* - // this position. - uint32_t first_unfinished_capture_byte; - uint32_t first_unfinished_pattern_index; - uint32_t first_unfinished_state_index; - ts_query_cursor__first_in_progress_capture( - self, - &first_unfinished_state_index, - &first_unfinished_capture_byte, - &first_unfinished_pattern_index + // First, find the earliest capture in an unfinished match. + uint32_t first_unfinished_capture_byte; + uint32_t first_unfinished_pattern_index; + uint32_t first_unfinished_state_index; + bool first_unfinished_state_is_definite = false; + ts_query_cursor__first_in_progress_capture( + self, + &first_unfinished_state_index, + &first_unfinished_capture_byte, + &first_unfinished_pattern_index, + &first_unfinished_state_is_definite + ); + + // Then find the earliest capture in a finished match. It must occur + // before the first capture in an *unfinished* match. + QueryState *first_finished_state = NULL; + uint32_t first_finished_capture_byte = first_unfinished_capture_byte; + uint32_t first_finished_pattern_index = first_unfinished_pattern_index; + for (unsigned i = 0; i < self->finished_states.size; i++) { + QueryState *state = &self->finished_states.contents[i]; + const CaptureList *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id ); - - // Find the earliest capture in a finished match. - int first_finished_state_index = -1; - uint32_t first_finished_capture_byte = first_unfinished_capture_byte; - uint32_t first_finished_pattern_index = first_unfinished_pattern_index; - for (unsigned i = 0; i < self->finished_states.size; i++) { - const QueryState *state = &self->finished_states.contents[i]; - const CaptureList *captures = capture_list_pool_get( - &self->capture_list_pool, - state->capture_list_id + if (captures->size > state->consumed_capture_count) { + uint32_t capture_byte = ts_node_start_byte( + captures->contents[state->consumed_capture_count].node ); - if (captures->size > state->consumed_capture_count) { - uint32_t capture_byte = ts_node_start_byte( - captures->contents[state->consumed_capture_count].node - ); - if ( - capture_byte < first_finished_capture_byte || - ( - capture_byte == first_finished_capture_byte && - state->pattern_index < first_finished_pattern_index - ) - ) { - first_finished_state_index = i; - first_finished_capture_byte = capture_byte; - first_finished_pattern_index = state->pattern_index; - } - } else { - capture_list_pool_release( - &self->capture_list_pool, - state->capture_list_id - ); - array_erase(&self->finished_states, i); - i--; + if ( + capture_byte < first_finished_capture_byte || + ( + capture_byte == first_finished_capture_byte && + state->pattern_index < first_finished_pattern_index + ) + ) { + first_finished_state = state; + first_finished_capture_byte = capture_byte; + first_finished_pattern_index = state->pattern_index; } - } - - // If there is finished capture that is clearly before any unfinished - // capture, then return its match, and its capture index. Internally - // record the fact that the capture has been 'consumed'. - if (first_finished_state_index != -1) { - QueryState *state = &self->finished_states.contents[ - first_finished_state_index - ]; - match->id = state->id; - match->pattern_index = state->pattern_index; - const CaptureList *captures = capture_list_pool_get( - &self->capture_list_pool, - state->capture_list_id - ); - match->captures = captures->contents; - match->capture_count = captures->size; - *capture_index = state->consumed_capture_count; - state->consumed_capture_count++; - return true; - } - - if (capture_list_pool_is_empty(&self->capture_list_pool)) { - LOG( - " abandon state. index:%u, pattern:%u, offset:%u.\n", - first_unfinished_state_index, - first_unfinished_pattern_index, - first_unfinished_capture_byte - ); + } else { capture_list_pool_release( &self->capture_list_pool, - self->states.contents[first_unfinished_state_index].capture_list_id + state->capture_list_id ); - array_erase(&self->states, first_unfinished_state_index); + array_erase(&self->finished_states, i); + i--; } } + // If there is finished capture that is clearly before any unfinished + // capture, then return its match, and its capture index. Internally + // record the fact that the capture has been 'consumed'. + QueryState *state; + if (first_finished_state) { + state = first_finished_state; + } else if (first_unfinished_state_is_definite) { + state = &self->states.contents[first_unfinished_state_index]; + } else { + state = NULL; + } + + if (state) { + match->id = state->id; + match->pattern_index = state->pattern_index; + const CaptureList *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + match->captures = captures->contents; + match->capture_count = captures->size; + *capture_index = state->consumed_capture_count; + state->consumed_capture_count++; + return true; + } + + if (capture_list_pool_is_empty(&self->capture_list_pool)) { + LOG( + " abandon state. index:%u, pattern:%u, offset:%u.\n", + first_unfinished_state_index, + first_unfinished_pattern_index, + first_unfinished_capture_byte + ); + capture_list_pool_release( + &self->capture_list_pool, + self->states.contents[first_unfinished_state_index].capture_list_id + ); + array_erase(&self->states, first_unfinished_state_index); + } + // If there are no finished matches that are ready to be returned, then // continue finding more matches. if ( - !ts_query_cursor__advance(self) && + !ts_query_cursor__advance(self, true) && self->finished_states.size == 0 ) return false; } diff --git a/lib/src/subtree.c b/lib/src/subtree.c index ef92a32f..24dc06b2 100644 --- a/lib/src/subtree.c +++ b/lib/src/subtree.c @@ -360,7 +360,7 @@ void ts_subtree_set_children( self.ptr->has_external_tokens = false; self.ptr->dynamic_precedence = 0; - uint32_t non_extra_index = 0; + uint32_t structural_index = 0; const TSSymbol *alias_sequence = ts_language_alias_sequence(language, self.ptr->production_id); uint32_t lookahead_end_byte = 0; @@ -387,9 +387,9 @@ void ts_subtree_set_children( self.ptr->dynamic_precedence += ts_subtree_dynamic_precedence(child); self.ptr->node_count += ts_subtree_node_count(child); - if (alias_sequence && alias_sequence[non_extra_index] != 0 && !ts_subtree_extra(child)) { + if (alias_sequence && alias_sequence[structural_index] != 0 && !ts_subtree_extra(child)) { self.ptr->visible_child_count++; - if (ts_language_symbol_metadata(language, alias_sequence[non_extra_index]).named) { + if (ts_language_symbol_metadata(language, alias_sequence[structural_index]).named) { self.ptr->named_child_count++; } } else if (ts_subtree_visible(child)) { @@ -407,7 +407,7 @@ void ts_subtree_set_children( self.ptr->parse_state = TS_TREE_STATE_NONE; } - if (!ts_subtree_extra(child)) non_extra_index++; + if (!ts_subtree_extra(child)) structural_index++; } self.ptr->lookahead_bytes = lookahead_end_byte - self.ptr->size.bytes - self.ptr->padding.bytes; diff --git a/lib/src/tree_cursor.c b/lib/src/tree_cursor.c index 00b9679d..06c724d2 100644 --- a/lib/src/tree_cursor.c +++ b/lib/src/tree_cursor.c @@ -205,19 +205,21 @@ bool ts_tree_cursor_goto_parent(TSTreeCursor *_self) { TreeCursor *self = (TreeCursor *)_self; for (unsigned i = self->stack.size - 2; i + 1 > 0; i--) { TreeCursorEntry *entry = &self->stack.contents[i]; - bool is_aliased = false; - if (i > 0) { - TreeCursorEntry *parent_entry = &self->stack.contents[i - 1]; - const TSSymbol *alias_sequence = ts_language_alias_sequence( - self->tree->language, - parent_entry->subtree->ptr->production_id - ); - is_aliased = alias_sequence && alias_sequence[entry->structural_child_index]; - } - if (ts_subtree_visible(*entry->subtree) || is_aliased) { + if (ts_subtree_visible(*entry->subtree)) { self->stack.size = i + 1; return true; } + if (i > 0 && !ts_subtree_extra(*entry->subtree)) { + TreeCursorEntry *parent_entry = &self->stack.contents[i - 1]; + if (ts_language_alias_at( + self->tree->language, + parent_entry->subtree->ptr->production_id, + entry->structural_child_index + )) { + self->stack.size = i + 1; + return true; + } + } } return false; } @@ -226,15 +228,13 @@ TSNode ts_tree_cursor_current_node(const TSTreeCursor *_self) { const TreeCursor *self = (const TreeCursor *)_self; TreeCursorEntry *last_entry = array_back(&self->stack); TSSymbol alias_symbol = 0; - if (self->stack.size > 1) { + if (self->stack.size > 1 && !ts_subtree_extra(*last_entry->subtree)) { TreeCursorEntry *parent_entry = &self->stack.contents[self->stack.size - 2]; - const TSSymbol *alias_sequence = ts_language_alias_sequence( + alias_symbol = ts_language_alias_at( self->tree->language, - parent_entry->subtree->ptr->production_id + parent_entry->subtree->ptr->production_id, + last_entry->structural_child_index ); - if (alias_sequence && !ts_subtree_extra(*last_entry->subtree)) { - alias_symbol = alias_sequence[last_entry->structural_child_index]; - } } return ts_node_new( self->tree, @@ -263,13 +263,14 @@ TSFieldId ts_tree_cursor_current_status( // Stop walking up when a visible ancestor is found. if (i != self->stack.size - 1) { if (ts_subtree_visible(*entry->subtree)) break; - const TSSymbol *alias_sequence = ts_language_alias_sequence( - self->tree->language, - parent_entry->subtree->ptr->production_id - ); - if (alias_sequence && alias_sequence[entry->structural_child_index]) { - break; - } + if ( + !ts_subtree_extra(*entry->subtree) && + ts_language_alias_at( + self->tree->language, + parent_entry->subtree->ptr->production_id, + entry->structural_child_index + ) + ) break; } if (ts_subtree_child_count(*parent_entry->subtree) > entry->child_index + 1) { @@ -321,13 +322,14 @@ TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) { // Stop walking up when another visible node is found. if (i != self->stack.size - 1) { if (ts_subtree_visible(*entry->subtree)) break; - const TSSymbol *alias_sequence = ts_language_alias_sequence( - self->tree->language, - parent_entry->subtree->ptr->production_id - ); - if (alias_sequence && alias_sequence[entry->structural_child_index]) { - break; - } + if ( + !ts_subtree_extra(*entry->subtree) && + ts_language_alias_at( + self->tree->language, + parent_entry->subtree->ptr->production_id, + entry->structural_child_index + ) + ) break; } if (ts_subtree_extra(*entry->subtree)) break; diff --git a/script/benchmark b/script/benchmark index 61e57920..7599e989 100755 --- a/script/benchmark +++ b/script/benchmark @@ -18,15 +18,22 @@ OPTIONS -r parse each sample the given number of times (default 5) + -g debug + EOF } -while getopts "hl:e:r:" option; do +mode=normal + +while getopts "hgl:e:r:" option; do case ${option} in h) usage exit ;; + g) + mode=debug + ;; e) export TREE_SITTER_BENCHMARK_EXAMPLE_FILTER=${OPTARG} ;; @@ -39,4 +46,13 @@ while getopts "hl:e:r:" option; do esac done -cargo bench benchmark +if [[ "${mode}" == "debug" ]]; then + test_binary=$( + cargo bench benchmark --no-run --message-format=json 2> /dev/null |\ + jq -rs 'map(select(.target.name == "benchmark" and .executable))[0].executable' + ) + env | grep TREE_SITTER + echo $test_binary +else + exec cargo bench benchmark +fi diff --git a/script/test b/script/test index bcc88e24..9b578dcf 100755 --- a/script/test +++ b/script/test @@ -83,10 +83,14 @@ done shift $(expr $OPTIND - 1) -if [[ -n $TREE_SITTER_TEST_LANGUAGE_FILTER || -n $TREE_SITTER_TEST_EXAMPLE_FILTER || -n $TREE_SITTER_TEST_TRIAL_FILTER ]]; then - top_level_filter=corpus -else - top_level_filter=$1 +top_level_filter=$1 + +if [[ \ + -n $TREE_SITTER_TEST_LANGUAGE_FILTER || \ + -n $TREE_SITTER_TEST_EXAMPLE_FILTER || \ + -n $TREE_SITTER_TEST_TRIAL_FILTER \ +]]; then + : ${top_level_filter:=corpus} fi if [[ "${mode}" == "debug" ]]; then