Merge pull request #644 from tree-sitter/query-pattern-is-definite

Analyze queries on construction to identify impossible patterns, and patterns that will definitely match
This commit is contained in:
Max Brunsfeld 2020-09-02 10:28:21 -07:00 committed by GitHub
commit 18150a1573
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 2039 additions and 408 deletions

View file

@ -2,8 +2,8 @@ use lazy_static::lazy_static;
use std::collections::BTreeMap;
use std::path::{Path, PathBuf};
use std::time::Instant;
use std::{env, fs, usize};
use tree_sitter::{Language, Parser};
use std::{env, fs, str, usize};
use tree_sitter::{Language, Parser, Query};
use tree_sitter_cli::error::Error;
use tree_sitter_cli::loader::Loader;
@ -18,26 +18,33 @@ lazy_static! {
.map(|s| usize::from_str_radix(&s, 10).unwrap())
.unwrap_or(5);
static ref TEST_LOADER: Loader = Loader::new(SCRATCH_DIR.clone());
static ref EXAMPLE_PATHS_BY_LANGUAGE_DIR: BTreeMap<PathBuf, Vec<PathBuf>> = {
fn process_dir(result: &mut BTreeMap<PathBuf, Vec<PathBuf>>, dir: &Path) {
static ref EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR: BTreeMap<PathBuf, (Vec<PathBuf>, Vec<PathBuf>)> = {
fn process_dir(result: &mut BTreeMap<PathBuf, (Vec<PathBuf>, Vec<PathBuf>)>, dir: &Path) {
if dir.join("grammar.js").exists() {
let relative_path = dir.strip_prefix(GRAMMARS_DIR.as_path()).unwrap();
let (example_paths, query_paths) =
result.entry(relative_path.to_owned()).or_default();
if let Ok(example_files) = fs::read_dir(&dir.join("examples")) {
result.insert(
relative_path.to_owned(),
example_files
.filter_map(|p| {
let p = p.unwrap().path();
if p.is_file() {
Some(p)
} else {
None
}
})
.collect(),
);
} else {
result.insert(relative_path.to_owned(), Vec::new());
example_paths.extend(example_files.filter_map(|p| {
let p = p.unwrap().path();
if p.is_file() {
Some(p.to_owned())
} else {
None
}
}));
}
if let Ok(query_files) = fs::read_dir(&dir.join("queries")) {
query_paths.extend(query_files.filter_map(|p| {
let p = p.unwrap().path();
if p.is_file() {
Some(p.to_owned())
} else {
None
}
}));
}
} else {
for entry in fs::read_dir(&dir).unwrap() {
@ -56,20 +63,25 @@ lazy_static! {
}
fn main() {
let mut parser = Parser::new();
let max_path_length = EXAMPLE_PATHS_BY_LANGUAGE_DIR
.iter()
.flat_map(|(_, paths)| paths.iter())
.map(|p| p.file_name().unwrap().to_str().unwrap().chars().count())
let max_path_length = EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR
.values()
.flat_map(|(e, q)| {
e.iter()
.chain(q.iter())
.map(|s| s.file_name().unwrap().to_str().unwrap().len())
})
.max()
.unwrap();
let mut all_normal_speeds = Vec::new();
let mut all_error_speeds = Vec::new();
.unwrap_or(0);
eprintln!("Benchmarking with {} repetitions", *REPETITION_COUNT);
for (language_path, example_paths) in EXAMPLE_PATHS_BY_LANGUAGE_DIR.iter() {
let mut parser = Parser::new();
let mut all_normal_speeds = Vec::new();
let mut all_error_speeds = Vec::new();
for (language_path, (example_paths, query_paths)) in
EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR.iter()
{
let language_name = language_path.file_name().unwrap().to_str().unwrap();
if let Some(filter) = LANGUAGE_FILTER.as_ref() {
@ -79,9 +91,24 @@ fn main() {
}
eprintln!("\nLanguage: {}", language_name);
parser.set_language(get_language(language_path)).unwrap();
let language = get_language(language_path);
parser.set_language(language).unwrap();
eprintln!(" Normal examples:");
eprintln!(" Constructing Queries");
for path in query_paths {
if let Some(filter) = EXAMPLE_FILTER.as_ref() {
if !path.to_str().unwrap().contains(filter.as_str()) {
continue;
}
}
parse(&path, max_path_length, |source| {
Query::new(language, str::from_utf8(source).unwrap())
.expect("Failed to parse query");
});
}
eprintln!(" Parsing Valid Code:");
let mut normal_speeds = Vec::new();
for example_path in example_paths {
if let Some(filter) = EXAMPLE_FILTER.as_ref() {
@ -90,12 +117,16 @@ fn main() {
}
}
normal_speeds.push(parse(&mut parser, example_path, max_path_length));
normal_speeds.push(parse(example_path, max_path_length, |code| {
parser.parse(code, None).expect("Failed to parse");
}));
}
eprintln!(" Error examples (mismatched languages):");
eprintln!(" Parsing Invalid Code (mismatched languages):");
let mut error_speeds = Vec::new();
for (other_language_path, example_paths) in EXAMPLE_PATHS_BY_LANGUAGE_DIR.iter() {
for (other_language_path, (example_paths, _)) in
EXAMPLE_AND_QUERY_PATHS_BY_LANGUAGE_DIR.iter()
{
if other_language_path != language_path {
for example_path in example_paths {
if let Some(filter) = EXAMPLE_FILTER.as_ref() {
@ -104,7 +135,9 @@ fn main() {
}
}
error_speeds.push(parse(&mut parser, example_path, max_path_length));
error_speeds.push(parse(example_path, max_path_length, |code| {
parser.parse(code, None).expect("Failed to parse");
}));
}
}
}
@ -123,7 +156,7 @@ fn main() {
all_error_speeds.extend(error_speeds);
}
eprintln!("\nOverall");
eprintln!("\n Overall");
if let Some((average_normal, worst_normal)) = aggregate(&all_normal_speeds) {
eprintln!(" Average Speed (normal): {} bytes/ms", average_normal);
eprintln!(" Worst Speed (normal): {} bytes/ms", worst_normal);
@ -151,28 +184,25 @@ fn aggregate(speeds: &Vec<usize>) -> Option<(usize, usize)> {
Some((total / speeds.len(), max))
}
fn parse(parser: &mut Parser, example_path: &Path, max_path_length: usize) -> usize {
fn parse(path: &Path, max_path_length: usize, mut action: impl FnMut(&[u8])) -> usize {
eprint!(
" {:width$}\t",
example_path.file_name().unwrap().to_str().unwrap(),
path.file_name().unwrap().to_str().unwrap(),
width = max_path_length
);
let source_code = fs::read(example_path)
.map_err(Error::wrap(|| format!("Failed to read {:?}", example_path)))
let source_code = fs::read(path)
.map_err(Error::wrap(|| format!("Failed to read {:?}", path)))
.unwrap();
let time = Instant::now();
for _ in 0..*REPETITION_COUNT {
parser
.parse(&source_code, None)
.expect("Incompatible language version");
action(&source_code);
}
let duration = time.elapsed() / (*REPETITION_COUNT as u32);
let duration_ms =
duration.as_secs() as f64 * 1000.0 + duration.subsec_nanos() as f64 / 1000000.0;
let speed = (source_code.len() as f64 / duration_ms) as usize;
let duration_ms = duration.as_millis();
let speed = source_code.len() as u128 / (duration_ms + 1);
eprintln!("time {} ms\tspeed {} bytes/ms", duration_ms as usize, speed);
speed
speed as usize
}
fn get_language(path: &Path) -> Language {

View file

@ -70,6 +70,10 @@ impl<'a> From<QueryError> for Error {
"Query error on line {}. Invalid syntax:\n{}",
row, l
)),
QueryError::Structure(row, l) => Error::new(format!(
"Query error on line {}. Impossible pattern:\n{}",
row, l
)),
QueryError::Predicate(p) => Error::new(format!("Query error: {}", p)),
}
}

View file

@ -7,7 +7,7 @@ use super::tables::{
};
use core::ops::Range;
use std::cmp;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::{HashMap, HashSet};
use std::fmt::Write;
use std::mem::swap;
@ -69,7 +69,8 @@ struct Generator {
symbol_order: HashMap<Symbol, usize>,
symbol_ids: HashMap<Symbol, String>,
alias_ids: HashMap<Alias, String>,
alias_map: BTreeMap<Alias, Option<Symbol>>,
unique_aliases: Vec<Alias>,
symbol_map: HashMap<Symbol, Symbol>,
field_names: Vec<String>,
next_abi: bool,
}
@ -95,11 +96,7 @@ impl Generator {
self.add_stats();
self.add_symbol_enum();
self.add_symbol_names_list();
if self.next_abi {
self.add_unique_symbol_map();
}
self.add_unique_symbol_map();
self.add_symbol_metadata_list();
if !self.field_names.is_empty() {
@ -112,6 +109,8 @@ impl Generator {
self.add_alias_sequences();
}
self.add_non_terminal_alias_map();
let mut main_lex_table = LexTable::default();
swap(&mut main_lex_table, &mut self.main_lex_table);
self.add_lex_function("ts_lex", main_lex_table, true);
@ -163,13 +162,72 @@ impl Generator {
format!("anon_alias_sym_{}", self.sanitize_identifier(&alias.value))
};
self.alias_ids.entry(alias.clone()).or_insert(alias_id);
self.alias_map
.entry(alias.clone())
.or_insert(matching_symbol);
}
}
}
self.unique_aliases = self
.alias_ids
.keys()
.filter(|alias| {
self.parse_table
.symbols
.iter()
.cloned()
.find(|symbol| {
let (name, kind) = self.metadata_for_symbol(*symbol);
name == alias.value && kind == alias.kind()
})
.is_none()
})
.cloned()
.collect();
self.unique_aliases.sort_unstable();
self.symbol_map = self
.parse_table
.symbols
.iter()
.map(|symbol| {
let mut mapping = symbol;
// There can be multiple symbols in the grammar that have the same name and kind,
// due to simple aliases. When that happens, ensure that they map to the same
// public-facing symbol. If one of the symbols is not aliased, choose that one
// to be the public-facing symbol. Otherwise, pick the symbol with the lowest
// numeric value.
if let Some(alias) = self.simple_aliases.get(symbol) {
let kind = alias.kind();
for other_symbol in &self.parse_table.symbols {
if let Some(other_alias) = self.simple_aliases.get(other_symbol) {
if other_symbol < mapping && other_alias == alias {
mapping = other_symbol;
}
} else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) {
mapping = other_symbol;
break;
}
}
}
// Two anonymous tokens with different flags but the same string value
// should be represented with the same symbol in the public API. Examples:
// * "<" and token(prec(1, "<"))
// * "(" and token.immediate("(")
else if symbol.is_terminal() {
let metadata = self.metadata_for_symbol(*symbol);
for other_symbol in &self.parse_table.symbols {
let other_metadata = self.metadata_for_symbol(*other_symbol);
if other_metadata == metadata {
mapping = other_symbol;
break;
}
}
}
(*symbol, *mapping)
})
.collect();
field_names.sort_unstable();
field_names.dedup();
self.field_names = field_names.into_iter().cloned().collect();
@ -177,20 +235,16 @@ impl Generator {
// If we are opting in to the new unstable language ABI, then use the concept of
// "small parse states". Otherwise, use the same representation for all parse
// states.
if self.next_abi {
let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2);
self.large_state_count = self
.parse_table
.states
.iter()
.enumerate()
.take_while(|(i, s)| {
*i <= 1 || s.terminal_entries.len() + s.nonterminal_entries.len() > threshold
})
.count();
} else {
self.large_state_count = self.parse_table.states.len();
}
let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2);
self.large_state_count = self
.parse_table
.states
.iter()
.enumerate()
.take_while(|(i, s)| {
*i <= 1 || s.terminal_entries.len() + s.nonterminal_entries.len() > threshold
})
.count();
}
fn add_includes(&mut self) {
@ -256,21 +310,14 @@ impl Generator {
"#define STATE_COUNT {}",
self.parse_table.states.len()
);
if self.next_abi {
add_line!(self, "#define LARGE_STATE_COUNT {}", self.large_state_count);
}
add_line!(self, "#define LARGE_STATE_COUNT {}", self.large_state_count);
add_line!(
self,
"#define SYMBOL_COUNT {}",
self.parse_table.symbols.len()
);
add_line!(
self,
"#define ALIAS_COUNT {}",
self.alias_map.iter().filter(|e| e.1.is_none()).count()
);
add_line!(self, "#define ALIAS_COUNT {}", self.unique_aliases.len(),);
add_line!(self, "#define TOKEN_COUNT {}", token_count);
add_line!(
self,
@ -298,11 +345,9 @@ impl Generator {
i += 1;
}
}
for (alias, symbol) in &self.alias_map {
if symbol.is_none() {
add_line!(self, "{} = {},", self.alias_ids[&alias], i);
i += 1;
}
for alias in &self.unique_aliases {
add_line!(self, "{} = {},", self.alias_ids[&alias], i);
i += 1;
}
dedent!(self);
add_line!(self, "}};");
@ -321,15 +366,13 @@ impl Generator {
);
add_line!(self, "[{}] = \"{}\",", self.symbol_ids[&symbol], name);
}
for (alias, symbol) in &self.alias_map {
if symbol.is_none() {
add_line!(
self,
"[{}] = \"{}\",",
self.alias_ids[&alias],
self.sanitize_string(&alias.value)
);
}
for alias in &self.unique_aliases {
add_line!(
self,
"[{}] = \"{}\",",
self.alias_ids[&alias],
self.sanitize_string(&alias.value)
);
}
dedent!(self);
add_line!(self, "}};");
@ -340,58 +383,21 @@ impl Generator {
add_line!(self, "static TSSymbol ts_symbol_map[] = {{");
indent!(self);
for symbol in &self.parse_table.symbols {
let mut mapping = symbol;
// There can be multiple symbols in the grammar that have the same name and kind,
// due to simple aliases. When that happens, ensure that they map to the same
// public-facing symbol. If one of the symbols is not aliased, choose that one
// to be the public-facing symbol. Otherwise, pick the symbol with the lowest
// numeric value.
if let Some(alias) = self.simple_aliases.get(symbol) {
let kind = alias.kind();
for other_symbol in &self.parse_table.symbols {
if let Some(other_alias) = self.simple_aliases.get(other_symbol) {
if other_symbol < mapping && other_alias == alias {
mapping = other_symbol;
}
} else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) {
mapping = other_symbol;
break;
}
}
}
// Two anonymous tokens with different flags but the same string value
// should be represented with the same symbol in the public API. Examples:
// * "<" and token(prec(1, "<"))
// * "(" and token.immediate("(")
else if symbol.is_terminal() {
let metadata = self.metadata_for_symbol(*symbol);
for other_symbol in &self.parse_table.symbols {
let other_metadata = self.metadata_for_symbol(*other_symbol);
if other_metadata == metadata {
mapping = other_symbol;
break;
}
}
}
add_line!(
self,
"[{}] = {},",
self.symbol_ids[&symbol],
self.symbol_ids[mapping],
self.symbol_ids[symbol],
self.symbol_ids[&self.symbol_map[symbol]],
);
}
for (alias, symbol) in &self.alias_map {
if symbol.is_none() {
add_line!(
self,
"[{}] = {},",
self.alias_ids[&alias],
self.alias_ids[&alias],
);
}
for alias in &self.unique_aliases {
add_line!(
self,
"[{}] = {},",
self.alias_ids[&alias],
self.alias_ids[&alias],
);
}
dedent!(self);
@ -462,15 +468,13 @@ impl Generator {
dedent!(self);
add_line!(self, "}},");
}
for (alias, matching_symbol) in &self.alias_map {
if matching_symbol.is_none() {
add_line!(self, "[{}] = {{", self.alias_ids[&alias]);
indent!(self);
add_line!(self, ".visible = true,");
add_line!(self, ".named = {},", alias.is_named);
dedent!(self);
add_line!(self, "}},");
}
for alias in &self.unique_aliases {
add_line!(self, "[{}] = {{", self.alias_ids[&alias]);
indent!(self);
add_line!(self, ".visible = true,");
add_line!(self, ".named = {},", alias.is_named);
dedent!(self);
add_line!(self, "}},");
}
dedent!(self);
add_line!(self, "}};");
@ -509,6 +513,50 @@ impl Generator {
add_line!(self, "");
}
fn add_non_terminal_alias_map(&mut self) {
let mut aliases_by_symbol = HashMap::new();
for variable in &self.syntax_grammar.variables {
for production in &variable.productions {
for step in &production.steps {
if let Some(alias) = &step.alias {
if step.symbol.is_non_terminal()
&& !self.simple_aliases.contains_key(&step.symbol)
{
if self.symbol_ids.contains_key(&step.symbol) {
let alias_ids =
aliases_by_symbol.entry(step.symbol).or_insert(Vec::new());
if let Err(i) = alias_ids.binary_search(&alias) {
alias_ids.insert(i, alias);
}
}
}
}
}
}
}
let mut aliases_by_symbol = aliases_by_symbol.iter().collect::<Vec<_>>();
aliases_by_symbol.sort_unstable_by_key(|e| e.0);
add_line!(self, "static uint16_t ts_non_terminal_alias_map[] = {{");
indent!(self);
for (symbol, aliases) in aliases_by_symbol {
let symbol_id = &self.symbol_ids[symbol];
let public_symbol_id = &self.symbol_ids[&self.symbol_map[&symbol]];
add_line!(self, "{}, {},", symbol_id, 1 + aliases.len());
indent!(self);
add_line!(self, "{},", public_symbol_id);
for alias in aliases {
add_line!(self, "{},", &self.alias_ids[&alias]);
}
dedent!(self);
}
add_line!(self, "0,");
dedent!(self);
add_line!(self, "}};");
add_line!(self, "");
}
fn add_field_sequences(&mut self) {
let mut flat_field_maps = vec![];
let mut next_flat_field_map_index = 0;
@ -689,17 +737,12 @@ impl Generator {
name
);
indent!(self);
add_line!(self, "START_LEXER();");
if self.next_abi {
add_line!(self, "eof = lexer->eof(lexer);");
} else {
add_line!(self, "eof = lookahead == 0;");
}
add_line!(self, "eof = lexer->eof(lexer);");
add_line!(self, "switch (state) {{");
indent!(self);
indent!(self);
for (i, state) in lex_table.states.into_iter().enumerate() {
add_line!(self, "case {}:", i);
indent!(self);
@ -714,6 +757,7 @@ impl Generator {
dedent!(self);
add_line!(self, "}}");
dedent!(self);
add_line!(self, "}}");
add_line!(self, "");
@ -967,12 +1011,7 @@ impl Generator {
add_line!(
self,
"static uint16_t ts_parse_table[{}][SYMBOL_COUNT] = {{",
if self.next_abi {
"LARGE_STATE_COUNT"
} else {
"STATE_COUNT"
}
"static uint16_t ts_parse_table[LARGE_STATE_COUNT][SYMBOL_COUNT] = {{",
);
indent!(self);
@ -1224,9 +1263,11 @@ impl Generator {
add_line!(self, ".symbol_count = SYMBOL_COUNT,");
add_line!(self, ".alias_count = ALIAS_COUNT,");
add_line!(self, ".token_count = TOKEN_COUNT,");
add_line!(self, ".large_state_count = LARGE_STATE_COUNT,");
if self.next_abi {
add_line!(self, ".large_state_count = LARGE_STATE_COUNT,");
add_line!(self, ".alias_map = ts_non_terminal_alias_map,");
add_line!(self, ".state_count = STATE_COUNT,");
}
add_line!(self, ".symbol_metadata = ts_symbol_metadata,");
@ -1249,10 +1290,7 @@ impl Generator {
add_line!(self, ".parse_actions = ts_parse_actions,");
add_line!(self, ".lex_modes = ts_lex_modes,");
add_line!(self, ".symbol_names = ts_symbol_names,");
if self.next_abi {
add_line!(self, ".public_symbol_map = ts_symbol_map,");
}
add_line!(self, ".public_symbol_map = ts_symbol_map,");
if !self.parse_table.production_infos.is_empty() {
add_line!(
@ -1539,7 +1577,8 @@ pub(crate) fn render_c_code(
symbol_ids: HashMap::new(),
symbol_order: HashMap::new(),
alias_ids: HashMap::new(),
alias_map: BTreeMap::new(),
symbol_map: HashMap::new(),
unique_aliases: Vec::new(),
field_names: Vec::new(),
next_abi,
}

View file

@ -1,18 +1,28 @@
use super::helpers::allocations;
use super::helpers::fixtures::get_language;
use lazy_static::lazy_static;
use std::env;
use std::fmt::Write;
use tree_sitter::{
Language, Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch,
QueryPredicate, QueryPredicateArg, QueryProperty,
};
lazy_static! {
static ref EXAMPLE_FILTER: Option<String> = env::var("TREE_SITTER_TEST_EXAMPLE_FILTER").ok();
}
#[test]
fn test_query_errors_on_invalid_syntax() {
allocations::record(|| {
let language = get_language("javascript");
assert!(Query::new(language, "(if_statement)").is_ok());
assert!(Query::new(language, "(if_statement condition:(identifier))").is_ok());
assert!(Query::new(
language,
"(if_statement condition:(parenthesized_expression (identifier)))"
)
.is_ok());
// Mismatched parens
assert_eq!(
@ -180,6 +190,110 @@ fn test_query_errors_on_invalid_conditions() {
});
}
#[test]
fn test_query_errors_on_impossible_patterns() {
let js_lang = get_language("javascript");
let rb_lang = get_language("ruby");
allocations::record(|| {
assert_eq!(
Query::new(
js_lang,
"(binary_expression left: (identifier) left: (identifier))"
),
Err(QueryError::Structure(
1,
[
"(binary_expression left: (identifier) left: (identifier))",
" ^"
]
.join("\n"),
))
);
Query::new(
js_lang,
"(function_declaration name: (identifier) (statement_block))",
)
.unwrap();
assert_eq!(
Query::new(js_lang, "(function_declaration name: (statement_block))"),
Err(QueryError::Structure(
1,
[
"(function_declaration name: (statement_block))",
" ^",
]
.join("\n")
))
);
Query::new(rb_lang, "(call receiver:(call))").unwrap();
assert_eq!(
Query::new(rb_lang, "(call receiver:(binary))"),
Err(QueryError::Structure(
1,
[
"(call receiver:(binary))", //
" ^",
]
.join("\n")
))
);
Query::new(
js_lang,
"[
(function (identifier))
(function_declaration (identifier))
(generator_function_declaration (identifier))
]",
)
.unwrap();
assert_eq!(
Query::new(
js_lang,
"[
(function (identifier))
(function_declaration (object))
(generator_function_declaration (identifier))
]",
),
Err(QueryError::Structure(
3,
[
" (function_declaration (object))", //
" ^",
]
.join("\n")
))
);
assert_eq!(
Query::new(js_lang, "(identifier (identifier))",),
Err(QueryError::Structure(
1,
[
"(identifier (identifier))", //
" ^",
]
.join("\n")
))
);
assert_eq!(
Query::new(js_lang, "(true (true))",),
Err(QueryError::Structure(
1,
[
"(true (true))", //
" ^",
]
.join("\n")
))
);
});
}
#[test]
fn test_query_matches_with_simple_pattern() {
allocations::record(|| {
@ -1907,6 +2021,54 @@ fn test_query_captures_with_too_many_nested_results() {
});
}
#[test]
fn test_query_captures_with_definite_pattern_containing_many_nested_matches() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
(array
"[" @l-bracket
"]" @r-bracket)
"." @dot
"#,
)
.unwrap();
// The '[' node must be returned before all of the '.' nodes,
// even though its pattern does not finish until the ']' node
// at the end of the document. But because the '[' is definite,
// it can be returned before the pattern finishes matching.
let source = "
[
a.b.c.d.e.f.g.h.i,
a.b.c.d.e.f.g.h.i,
a.b.c.d.e.f.g.h.i,
a.b.c.d.e.f.g.h.i,
a.b.c.d.e.f.g.h.i,
]
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_captures(captures, &query, source),
[("l-bracket", "[")]
.iter()
.chain([("dot", "."); 40].iter())
.chain([("r-bracket", "]")].iter())
.cloned()
.collect::<Vec<_>>(),
);
});
}
#[test]
fn test_query_captures_ordered_by_both_start_and_end_positions() {
allocations::record(|| {
@ -2051,7 +2213,7 @@ fn test_query_start_byte_for_pattern() {
let patterns_3 = "
((identifier) @b (#match? @b i))
(function_declaration name: (identifier) @c)
(method_definition name: (identifier) @d)
(method_definition name: (property_identifier) @d)
"
.trim_start();
@ -2078,10 +2240,10 @@ fn test_query_capture_names() {
language,
r#"
(if_statement
condition: (binary_expression
condition: (parenthesized_expression (binary_expression
left: _ @left-operand
operator: "||"
right: _ @right-operand)
right: _ @right-operand))
consequence: (statement_block) @body)
(while_statement
@ -2213,6 +2375,273 @@ fn test_query_alternative_predicate_prefix() {
});
}
#[test]
fn test_query_step_is_definite() {
struct Row {
language: Language,
description: &'static str,
pattern: &'static str,
results_by_substring: &'static [(&'static str, bool)],
}
let rows = &[
Row {
description: "no definite steps",
language: get_language("python"),
pattern: r#"(expression_statement (string))"#,
results_by_substring: &[("expression_statement", false), ("string", false)],
},
Row {
description: "all definite steps",
language: get_language("javascript"),
pattern: r#"(object "{" "}")"#,
results_by_substring: &[("object", false), ("{", true), ("}", true)],
},
Row {
description: "an indefinite step that is optional",
language: get_language("javascript"),
pattern: r#"(object "{" (identifier)? @foo "}")"#,
results_by_substring: &[
("object", false),
("{", true),
("(identifier)?", false),
("}", true),
],
},
Row {
description: "multiple indefinite steps that are optional",
language: get_language("javascript"),
pattern: r#"(object "{" (identifier)? @id1 ("," (identifier) @id2)? "}")"#,
results_by_substring: &[
("object", false),
("{", true),
("(identifier)? @id1", false),
("\",\"", false),
("}", true),
],
},
Row {
description: "definite step after indefinite step",
language: get_language("javascript"),
pattern: r#"(pair (property_identifier) ":")"#,
results_by_substring: &[("pair", false), ("property_identifier", false), (":", true)],
},
Row {
description: "indefinite step in between two definite steps",
language: get_language("javascript"),
pattern: r#"(ternary_expression
condition: (_)
"?"
consequence: (call_expression)
":"
alternative: (_))"#,
results_by_substring: &[
("condition:", false),
("\"?\"", false),
("consequence:", false),
("\":\"", true),
("alternative:", true),
],
},
Row {
description: "one definite step after a repetition",
language: get_language("javascript"),
pattern: r#"(object "{" (_) "}")"#,
results_by_substring: &[("object", false), ("{", false), ("(_)", false), ("}", true)],
},
Row {
description: "definite steps after multiple repetitions",
language: get_language("json"),
pattern: r#"(object "{" (pair) "," (pair) "," (_) "}")"#,
results_by_substring: &[
("object", false),
("{", false),
("(pair) \",\" (pair)", false),
("(pair) \",\" (_)", false),
("\",\" (_)", false),
("(_)", true),
("}", true),
],
},
Row {
description: "a definite with a field",
language: get_language("javascript"),
pattern: r#"(binary_expression left: (identifier) right: (_))"#,
results_by_substring: &[
("binary_expression", false),
("(identifier)", false),
("(_)", true),
],
},
Row {
description: "multiple definite steps with fields",
language: get_language("javascript"),
pattern: r#"(function_declaration name: (identifier) body: (statement_block))"#,
results_by_substring: &[
("function_declaration", false),
("identifier", true),
("statement_block", true),
],
},
Row {
description: "nesting, one definite step",
language: get_language("javascript"),
pattern: r#"
(function_declaration
name: (identifier)
body: (statement_block "{" (expression_statement) "}"))"#,
results_by_substring: &[
("function_declaration", false),
("identifier", false),
("statement_block", false),
("{", false),
("expression_statement", false),
("}", true),
],
},
Row {
description: "definite step after some deeply nested hidden nodes",
language: get_language("ruby"),
pattern: r#"
(singleton_class
value: (constant)
"end")
"#,
results_by_substring: &[
("singleton_class", false),
("constant", false),
("end", true),
],
},
Row {
description: "nesting, no definite steps",
language: get_language("javascript"),
pattern: r#"
(call_expression
function: (member_expression
property: (property_identifier) @template-tag)
arguments: (template_string)) @template-call
"#,
results_by_substring: &[("property_identifier", false), ("template_string", false)],
},
Row {
description: "a definite step after a nested node",
language: get_language("javascript"),
pattern: r#"
(subscript_expression
object: (member_expression
object: (identifier) @obj
property: (property_identifier) @prop)
"[")
"#,
results_by_substring: &[
("identifier", false),
("property_identifier", true),
("[", true),
],
},
Row {
description: "a step that is indefinite due to a predicate",
language: get_language("javascript"),
pattern: r#"
(subscript_expression
object: (member_expression
object: (identifier) @obj
property: (property_identifier) @prop)
"["
(#match? @prop "foo"))
"#,
results_by_substring: &[
("identifier", false),
("property_identifier", false),
("[", true),
],
},
Row {
description: "alternation where one branch has definite steps",
language: get_language("javascript"),
pattern: r#"
[
(unary_expression (identifier))
(call_expression
function: (_)
arguments: (_))
(binary_expression right:(call_expression))
]
"#,
results_by_substring: &[
("identifier", false),
("right:", false),
("function:", true),
("arguments:", true),
],
},
Row {
description: "aliased parent node",
language: get_language("ruby"),
pattern: r#"
(method_parameters "(" (identifier) @id")")
"#,
results_by_substring: &[("\"(\"", false), ("(identifier)", false), ("\")\"", true)],
},
Row {
description: "long, but not too long to analyze",
language: get_language("javascript"),
pattern: r#"
(object "{" (pair) (pair) (pair) (pair) "}")
"#,
results_by_substring: &[
("\"{\"", false),
("(pair)", false),
("(pair) \"}\"", false),
("\"}\"", true),
],
},
Row {
description: "too long to analyze",
language: get_language("javascript"),
pattern: r#"
(object "{" (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) (pair) "}")
"#,
results_by_substring: &[
("\"{\"", false),
("(pair)", false),
("(pair) \"}\"", false),
("\"}\"", false),
],
},
];
allocations::record(|| {
eprintln!("");
for row in rows.iter() {
if let Some(filter) = EXAMPLE_FILTER.as_ref() {
if !row.description.contains(filter.as_str()) {
continue;
}
}
eprintln!(" query example: {:?}", row.description);
let query = Query::new(row.language, row.pattern).unwrap();
for (substring, is_definite) in row.results_by_substring {
let offset = row.pattern.find(substring).unwrap();
assert_eq!(
query.step_is_definite(offset),
*is_definite,
"Description: {}, Pattern: {:?}, substring: {:?}, expected is_definite to be {}",
row.description,
row.pattern
.split_ascii_whitespace()
.collect::<Vec<_>>()
.join(" "),
substring,
is_definite,
)
}
}
});
}
fn assert_query_matches(
language: Language,
query: &Query,

View file

@ -277,7 +277,7 @@ let tree;
const startPosition = queryEditor.posFromIndex(error.index);
const endPosition = {
line: startPosition.line,
ch: startPosition.ch + (error.length || 1)
ch: startPosition.ch + (error.length || Infinity)
};
if (error.index === queryText.length) {

View file

@ -132,6 +132,7 @@ pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1;
pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2;
pub const TSQueryError_TSQueryErrorField: TSQueryError = 3;
pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4;
pub const TSQueryError_TSQueryErrorStructure: TSQueryError = 5;
pub type TSQueryError = u32;
extern "C" {
#[doc = " Create a new parser."]
@ -172,9 +173,9 @@ extern "C" {
#[doc = " the given ranges must be ordered from earliest to latest in the document,"]
#[doc = " and they must not overlap. That is, the following must hold for all"]
#[doc = " `i` < `length - 1`:"]
#[doc = " ```text"]
#[doc = ""]
#[doc = " ranges[i].end_byte <= ranges[i + 1].start_byte"]
#[doc = " ```"]
#[doc = ""]
#[doc = " If this requirement is not satisfied, the operation will fail, the ranges"]
#[doc = " will not be assigned, and this function will return `false`. On success,"]
#[doc = " this function returns `true`"]
@ -649,6 +650,9 @@ extern "C" {
length: *mut u32,
) -> *const TSQueryPredicateStep;
}
extern "C" {
pub fn ts_query_step_is_definite(self_: *const TSQuery, byte_offset: u32) -> bool;
}
extern "C" {
#[doc = " Get the name and length of one of the query\'s captures, or one of the"]
#[doc = " query\'s string literals. Each capture and string is associated with a"]
@ -800,5 +804,5 @@ extern "C" {
pub fn ts_language_version(arg1: *const TSLanguage) -> u32;
}
pub const TREE_SITTER_LANGUAGE_VERSION: usize = 11;
pub const TREE_SITTER_LANGUAGE_VERSION: usize = 12;
pub const TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION: usize = 9;

View file

@ -163,6 +163,7 @@ pub enum QueryError {
Field(usize, String),
Capture(usize, String),
Predicate(String),
Structure(usize, String),
}
#[derive(Debug)]
@ -1175,27 +1176,42 @@ impl Query {
}
});
let message = if let Some(line) = line_containing_error {
line.to_string() + "\n" + &" ".repeat(offset - line_start) + "^"
} else {
"Unexpected EOF".to_string()
};
// if line_containing_error
return if error_type != ffi::TSQueryError_TSQueryErrorSyntax {
let suffix = source.split_at(offset).1;
let end_offset = suffix
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
.unwrap_or(source.len());
let name = suffix.split_at(end_offset).0.to_string();
match error_type {
ffi::TSQueryError_TSQueryErrorNodeType => Err(QueryError::NodeType(row, name)),
ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(row, name)),
ffi::TSQueryError_TSQueryErrorCapture => Err(QueryError::Capture(row, name)),
_ => Err(QueryError::Syntax(row, message)),
return match error_type {
// Error types that report names
ffi::TSQueryError_TSQueryErrorNodeType
| ffi::TSQueryError_TSQueryErrorField
| ffi::TSQueryError_TSQueryErrorCapture => {
let suffix = source.split_at(offset).1;
let end_offset = suffix
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
.unwrap_or(source.len());
let name = suffix.split_at(end_offset).0.to_string();
match error_type {
ffi::TSQueryError_TSQueryErrorNodeType => {
Err(QueryError::NodeType(row, name))
}
ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(row, name)),
ffi::TSQueryError_TSQueryErrorCapture => {
Err(QueryError::Capture(row, name))
}
_ => unreachable!(),
}
}
// Error types that report positions
_ => {
let message = if let Some(line) = line_containing_error {
line.to_string() + "\n" + &" ".repeat(offset - line_start) + "^"
} else {
"Unexpected EOF".to_string()
};
match error_type {
ffi::TSQueryError_TSQueryErrorStructure => {
Err(QueryError::Structure(row, message))
}
_ => Err(QueryError::Syntax(row, message)),
}
}
} else {
Err(QueryError::Syntax(row, message))
};
}
@ -1451,6 +1467,14 @@ impl Query {
unsafe { ffi::ts_query_disable_pattern(self.ptr.as_ptr(), index as u32) }
}
/// Check if a given step in a query is 'definite'.
///
/// A query step is 'definite' if its parent pattern will be guaranteed to match
/// successfully once it reaches the step.
pub fn step_is_definite(&self, byte_offset: usize) -> bool {
unsafe { ffi::ts_query_step_is_definite(self.ptr.as_ptr(), byte_offset as u32) }
}
fn parse_property(
function_name: &str,
capture_names: &[String],

View file

@ -667,8 +667,8 @@ class Language {
const errorId = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');
const errorByte = getValue(TRANSFER_BUFFER, 'i32');
const errorIndex = UTF8ToString(sourceAddress, errorByte).length;
const suffix = source.substr(errorIndex, 100);
const word = suffix.match(QUERY_WORD_REGEX)[0];
const suffix = source.substr(errorIndex, 100).split('\n')[0];
let word = suffix.match(QUERY_WORD_REGEX)[0];
let error;
switch (errorId) {
case 2:
@ -680,8 +680,13 @@ class Language {
case 4:
error = new RangeError(`Bad capture name @${word}`);
break;
case 5:
error = new TypeError(`Bad pattern structure at offset ${errorIndex}: '${suffix}'...`);
word = "";
break;
default:
error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`);
word = "";
break;
}
error.index = errorIndex;

View file

@ -15,7 +15,6 @@
"__ZNSt3__212basic_stringIwNS_11char_traitsIwEENS_9allocatorIwEEED2Ev",
"__ZdlPv",
"__Znwm",
"___assert_fail",
"_abort",
"_iswalnum",
"_iswalpha",
@ -73,8 +72,6 @@
"_ts_query_capture_count",
"_ts_query_capture_name_for_id",
"_ts_query_captures_wasm",
"_ts_query_context_delete",
"_ts_query_context_new",
"_ts_query_delete",
"_ts_query_matches_wasm",
"_ts_query_new",

View file

@ -30,6 +30,9 @@ describe("Query", () => {
assert.throws(() => {
JavaScript.query("(function_declaration non_existent:(identifier))");
}, "Bad field name 'non_existent'");
assert.throws(() => {
JavaScript.query("(function_declaration name:(statement_block))");
}, "Bad pattern structure at offset 22: 'name:(statement_block))'");
});
it("throws an error on invalid predicates", () => {

View file

@ -21,7 +21,7 @@ extern "C" {
* The Tree-sitter library is generally backwards-compatible with languages
* generated using older CLI versions, but is not forwards-compatible.
*/
#define TREE_SITTER_LANGUAGE_VERSION 11
#define TREE_SITTER_LANGUAGE_VERSION 12
/**
* The earliest ABI version that is supported by the current version of the
@ -130,6 +130,7 @@ typedef enum {
TSQueryErrorNodeType,
TSQueryErrorField,
TSQueryErrorCapture,
TSQueryErrorStructure,
} TSQueryError;
/********************/
@ -718,6 +719,11 @@ const TSQueryPredicateStep *ts_query_predicates_for_pattern(
uint32_t *length
);
bool ts_query_step_is_definite(
const TSQuery *self,
uint32_t byte_offset
);
/**
* Get the name and length of one of the query's captures, or one of the
* query's string literals. Each capture and string is associated with a

View file

@ -119,6 +119,8 @@ struct TSLanguage {
const uint16_t *small_parse_table;
const uint32_t *small_parse_table_map;
const TSSymbol *public_symbol_map;
const uint16_t *alias_map;
uint32_t state_count;
};
/*

View file

@ -12,9 +12,9 @@ extern "C" {
#include <stdbool.h>
#include "./alloc.h"
#define Array(T) \
struct { \
T *contents; \
#define Array(T) \
struct { \
T *contents; \
uint32_t size; \
uint32_t capacity; \
}
@ -37,15 +37,15 @@ extern "C" {
#define array_reserve(self, new_capacity) \
array__reserve((VoidArray *)(self), array__elem_size(self), new_capacity)
#define array_erase(self, index) \
array__erase((VoidArray *)(self), array__elem_size(self), index)
// Free any memory allocated for this array.
#define array_delete(self) array__delete((VoidArray *)self)
#define array_push(self, element) \
(array__grow((VoidArray *)(self), 1, array__elem_size(self)), \
(self)->contents[(self)->size++] = (element))
// Increase the array's size by a given number of elements, reallocating
// if necessary. New elements are zero-initialized.
#define array_grow_by(self, count) \
(array__grow((VoidArray *)(self), count, array__elem_size(self)), \
memset((self)->contents + (self)->size, 0, (count) * array__elem_size(self)), \
@ -54,18 +54,64 @@ extern "C" {
#define array_push_all(self, other) \
array_splice((self), (self)->size, 0, (other)->size, (other)->contents)
// Remove `old_count` elements from the array starting at the given `index`. At
// the same index, insert `new_count` new elements, reading their values from the
// `new_contents` pointer.
#define array_splice(self, index, old_count, new_count, new_contents) \
array__splice((VoidArray *)(self), array__elem_size(self), index, old_count, \
new_count, new_contents)
// Insert one `element` into the array at the given `index`.
#define array_insert(self, index, element) \
array__splice((VoidArray *)(self), array__elem_size(self), index, 0, 1, &element)
// Remove one `element` from the array at the given `index`.
#define array_erase(self, index) \
array__erase((VoidArray *)(self), array__elem_size(self), index)
#define array_pop(self) ((self)->contents[--(self)->size])
#define array_assign(self, other) \
array__assign((VoidArray *)(self), (const VoidArray *)(other), array__elem_size(self))
// Search a sorted array for a given `needle` value, using the given `compare`
// callback to determine the order.
//
// If an existing element is found to be equal to `needle`, then the `index`
// out-parameter is set to the existing value's index, and the `exists`
// out-parameter is set to true. Otherwise, `index` is set to an index where
// `needle` should be inserted in order to preserve the sorting, and `exists`
// is set to false.
#define array_search_sorted_with(self, compare, needle, index, exists) \
array__search_sorted(self, 0, compare, , needle, index, exists)
// Search a sorted array for a given `needle` value, using integer comparisons
// of a given struct field (specified with a leading dot) to determine the order.
//
// See also `array_search_sorted_with`.
#define array_search_sorted_by(self, field, needle, index, exists) \
array__search_sorted(self, 0, _compare_int, field, needle, index, exists)
// Insert a given `value` into a sorted array, using the given `compare`
// callback to determine the order.
#define array_insert_sorted_with(self, compare, value) \
do { \
unsigned index, exists; \
array_search_sorted_with(self, compare, &(value), &index, &exists); \
if (!exists) array_insert(self, index, value); \
} while (0)
// Insert a given `value` into a sorted array, using integer comparisons of
// a given struct field (specified with a leading dot) to determine the order.
//
// See also `array_search_sorted_by`.
#define array_insert_sorted_by(self, field, value) \
do { \
unsigned index, exists; \
array_search_sorted_by(self, field, (value) field, &index, &exists); \
if (!exists) array_insert(self, index, value); \
} while (0)
// Private
typedef Array(void) VoidArray;
@ -151,6 +197,30 @@ static inline void array__splice(VoidArray *self, size_t element_size,
self->size += new_count - old_count;
}
// A binary search routine, based on Rust's `std::slice::binary_search_by`.
#define array__search_sorted(self, start, compare, suffix, needle, index, exists) \
do { \
*(index) = start; \
*(exists) = false; \
uint32_t size = (self)->size - *(index); \
if (size == 0) break; \
int comparison; \
while (size > 1) { \
uint32_t half_size = size / 2; \
uint32_t mid_index = *(index) + half_size; \
comparison = compare(&((self)->contents[mid_index] suffix), (needle)); \
if (comparison <= 0) *(index) = mid_index; \
size -= half_size; \
} \
comparison = compare(&((self)->contents[*(index)] suffix), (needle)); \
if (comparison == 0) *(exists) = true; \
else if (comparison < 0) *(index) += 1; \
} while (0)
// Helper macro for the `_sorted_by` routines below. This takes the left (existing)
// parameter by reference in order to work with the generic sorting function above.
#define _compare_int(a, b) ((int)*(a) - (int)(b))
#ifdef __cplusplus
}
#endif

View file

@ -146,17 +146,21 @@ static bool iterator_tree_is_visible(const Iterator *self) {
if (ts_subtree_visible(*entry.subtree)) return true;
if (self->cursor.stack.size > 1) {
Subtree parent = *self->cursor.stack.contents[self->cursor.stack.size - 2].subtree;
const TSSymbol *alias_sequence = ts_language_alias_sequence(
return ts_language_alias_at(
self->language,
parent.ptr->production_id
);
return alias_sequence && alias_sequence[entry.structural_child_index] != 0;
parent.ptr->production_id,
entry.structural_child_index
) != 0;
}
return false;
}
static void iterator_get_visible_state(const Iterator *self, Subtree *tree,
TSSymbol *alias_symbol, uint32_t *start_byte) {
static void iterator_get_visible_state(
const Iterator *self,
Subtree *tree,
TSSymbol *alias_symbol,
uint32_t *start_byte
) {
uint32_t i = self->cursor.stack.size - 1;
if (self->in_padding) {
@ -169,13 +173,11 @@ static void iterator_get_visible_state(const Iterator *self, Subtree *tree,
if (i > 0) {
const Subtree *parent = self->cursor.stack.contents[i - 1].subtree;
const TSSymbol *alias_sequence = ts_language_alias_sequence(
*alias_symbol = ts_language_alias_at(
self->language,
parent->ptr->production_id
parent->ptr->production_id,
entry.structural_child_index
);
if (alias_sequence) {
*alias_symbol = alias_sequence[entry.structural_child_index];
}
}
if (ts_subtree_visible(*entry.subtree) || *alias_symbol) {

View file

@ -12,6 +12,8 @@ extern "C" {
#define TREE_SITTER_LANGUAGE_VERSION_WITH_FIELDS 10
#define TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING 11
#define TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES 11
#define TREE_SITTER_LANGUAGE_VERSION_WITH_STATE_COUNT 12
#define TREE_SITTER_LANGUAGE_VERSION_WITH_ALIAS_MAP 12
typedef struct {
const TSParseAction *actions;
@ -19,6 +21,22 @@ typedef struct {
bool is_reusable;
} TableEntry;
typedef struct {
const TSLanguage *language;
const uint16_t *data;
const uint16_t *group_end;
TSStateId state;
uint16_t table_value;
uint16_t section_index;
uint16_t group_count;
bool is_small_state;
const TSParseAction *actions;
TSSymbol symbol;
TSStateId next_state;
uint16_t action_count;
} LookaheadIterator;
void ts_language_table_entry(const TSLanguage *, TSStateId, TSSymbol, TableEntry *);
TSSymbolMetadata ts_language_symbol_metadata(const TSLanguage *, TSSymbol);
@ -41,22 +59,33 @@ static inline const TSParseAction *ts_language_actions(
return entry.actions;
}
static inline bool ts_language_has_actions(const TSLanguage *self,
TSStateId state,
TSSymbol symbol) {
static inline bool ts_language_has_actions(
const TSLanguage *self,
TSStateId state,
TSSymbol symbol
) {
TableEntry entry;
ts_language_table_entry(self, state, symbol, &entry);
return entry.action_count > 0;
}
static inline bool ts_language_has_reduce_action(const TSLanguage *self,
TSStateId state,
TSSymbol symbol) {
static inline bool ts_language_has_reduce_action(
const TSLanguage *self,
TSStateId state,
TSSymbol symbol
) {
TableEntry entry;
ts_language_table_entry(self, state, symbol, &entry);
return entry.action_count > 0 && entry.actions[0].type == TSParseActionTypeReduce;
}
// Lookup the table value for a given symbol and state.
//
// For non-terminal symbols, the table value represents a successor state.
// For terminal symbols, it represents an index in the actions table.
// For 'large' parse states, this is a direct lookup. For 'small' parse
// states, this requires searching through the symbol groups to find
// the given symbol.
static inline uint16_t ts_language_lookup(
const TSLanguage *self,
TSStateId state,
@ -68,8 +97,8 @@ static inline uint16_t ts_language_lookup(
) {
uint32_t index = self->small_parse_table_map[state - self->large_state_count];
const uint16_t *data = &self->small_parse_table[index];
uint16_t section_count = *(data++);
for (unsigned i = 0; i < section_count; i++) {
uint16_t group_count = *(data++);
for (unsigned i = 0; i < group_count; i++) {
uint16_t section_value = *(data++);
uint16_t symbol_count = *(data++);
for (unsigned i = 0; i < symbol_count; i++) {
@ -82,9 +111,90 @@ static inline uint16_t ts_language_lookup(
}
}
static inline TSStateId ts_language_next_state(const TSLanguage *self,
TSStateId state,
TSSymbol symbol) {
// Iterate over all of the symbols that are valid in the given state.
//
// For 'large' parse states, this just requires iterating through
// all possible symbols and checking the parse table for each one.
// For 'small' parse states, this exploits the structure of the
// table to only visit the valid symbols.
static inline LookaheadIterator ts_language_lookaheads(
const TSLanguage *self,
TSStateId state
) {
bool is_small_state =
self->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES &&
state >= self->large_state_count;
const uint16_t *data;
const uint16_t *group_end = NULL;
uint16_t group_count = 0;
if (is_small_state) {
uint32_t index = self->small_parse_table_map[state - self->large_state_count];
data = &self->small_parse_table[index];
group_end = data + 1;
group_count = *data;
} else {
data = &self->parse_table[state * self->symbol_count] - 1;
}
return (LookaheadIterator) {
.language = self,
.data = data,
.group_end = group_end,
.group_count = group_count,
.is_small_state = is_small_state,
.symbol = UINT16_MAX,
.next_state = 0,
};
}
static inline bool ts_lookahead_iterator_next(LookaheadIterator *self) {
// For small parse states, valid symbols are listed explicitly,
// grouped by their value. There's no need to look up the actions
// again until moving to the next group.
if (self->is_small_state) {
self->data++;
if (self->data == self->group_end) {
if (self->group_count == 0) return false;
self->group_count--;
self->table_value = *(self->data++);
unsigned symbol_count = *(self->data++);
self->group_end = self->data + symbol_count;
self->symbol = *self->data;
} else {
self->symbol = *self->data;
return true;
}
}
// For large parse states, iterate through every symbol until one
// is found that has valid actions.
else {
do {
self->data++;
self->symbol++;
if (self->symbol >= self->language->symbol_count) return false;
self->table_value = *self->data;
} while (!self->table_value);
}
// Depending on if the symbols is terminal or non-terminal, the table value either
// represents a list of actions or a successor state.
if (self->symbol < self->language->token_count) {
const TSParseActionEntry *entry = &self->language->parse_actions[self->table_value];
self->action_count = entry->entry.count;
self->actions = (const TSParseAction *)(entry + 1);
self->next_state = 0;
} else {
self->action_count = 0;
self->next_state = self->table_value;
}
return true;
}
static inline TSStateId ts_language_next_state(
const TSLanguage *self,
TSStateId state,
TSSymbol symbol
) {
if (symbol == ts_builtin_sym_error || symbol == ts_builtin_sym_error_repeat) {
return 0;
} else if (symbol < self->token_count) {
@ -102,9 +212,10 @@ static inline TSStateId ts_language_next_state(const TSLanguage *self,
}
}
static inline const bool *
ts_language_enabled_external_tokens(const TSLanguage *self,
unsigned external_scanner_state) {
static inline const bool *ts_language_enabled_external_tokens(
const TSLanguage *self,
unsigned external_scanner_state
) {
if (external_scanner_state == 0) {
return NULL;
} else {
@ -112,13 +223,25 @@ ts_language_enabled_external_tokens(const TSLanguage *self,
}
}
static inline const TSSymbol *
ts_language_alias_sequence(const TSLanguage *self, uint32_t production_id) {
return production_id > 0 ?
self->alias_sequences + production_id * self->max_alias_sequence_length :
static inline const TSSymbol *ts_language_alias_sequence(
const TSLanguage *self,
uint32_t production_id
) {
return production_id ?
&self->alias_sequences[production_id * self->max_alias_sequence_length] :
NULL;
}
static inline TSSymbol ts_language_alias_at(
const TSLanguage *self,
uint32_t production_id,
uint32_t child_index
) {
return production_id ?
self->alias_sequences[production_id * self->max_alias_sequence_length + child_index] :
0;
}
static inline void ts_language_field_map(
const TSLanguage *self,
uint32_t production_id,
@ -136,6 +259,32 @@ static inline void ts_language_field_map(
*end = &self->field_map_entries[slice.index] + slice.length;
}
static inline void ts_language_aliases_for_symbol(
const TSLanguage *self,
TSSymbol original_symbol,
const TSSymbol **start,
const TSSymbol **end
) {
*start = &self->public_symbol_map[original_symbol];
*end = *start + 1;
if (self->version < TREE_SITTER_LANGUAGE_VERSION_WITH_ALIAS_MAP) return;
unsigned i = 0;
for (;;) {
TSSymbol symbol = self->alias_map[i++];
if (symbol == 0 || symbol > original_symbol) break;
uint16_t count = self->alias_map[i++];
if (symbol == original_symbol) {
*start = &self->alias_map[i];
*end = &self->alias_map[i + count];
break;
}
i += count;
}
}
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load diff

View file

@ -360,7 +360,7 @@ void ts_subtree_set_children(
self.ptr->has_external_tokens = false;
self.ptr->dynamic_precedence = 0;
uint32_t non_extra_index = 0;
uint32_t structural_index = 0;
const TSSymbol *alias_sequence = ts_language_alias_sequence(language, self.ptr->production_id);
uint32_t lookahead_end_byte = 0;
@ -387,9 +387,9 @@ void ts_subtree_set_children(
self.ptr->dynamic_precedence += ts_subtree_dynamic_precedence(child);
self.ptr->node_count += ts_subtree_node_count(child);
if (alias_sequence && alias_sequence[non_extra_index] != 0 && !ts_subtree_extra(child)) {
if (alias_sequence && alias_sequence[structural_index] != 0 && !ts_subtree_extra(child)) {
self.ptr->visible_child_count++;
if (ts_language_symbol_metadata(language, alias_sequence[non_extra_index]).named) {
if (ts_language_symbol_metadata(language, alias_sequence[structural_index]).named) {
self.ptr->named_child_count++;
}
} else if (ts_subtree_visible(child)) {
@ -407,7 +407,7 @@ void ts_subtree_set_children(
self.ptr->parse_state = TS_TREE_STATE_NONE;
}
if (!ts_subtree_extra(child)) non_extra_index++;
if (!ts_subtree_extra(child)) structural_index++;
}
self.ptr->lookahead_bytes = lookahead_end_byte - self.ptr->size.bytes - self.ptr->padding.bytes;

View file

@ -205,19 +205,21 @@ bool ts_tree_cursor_goto_parent(TSTreeCursor *_self) {
TreeCursor *self = (TreeCursor *)_self;
for (unsigned i = self->stack.size - 2; i + 1 > 0; i--) {
TreeCursorEntry *entry = &self->stack.contents[i];
bool is_aliased = false;
if (i > 0) {
TreeCursorEntry *parent_entry = &self->stack.contents[i - 1];
const TSSymbol *alias_sequence = ts_language_alias_sequence(
self->tree->language,
parent_entry->subtree->ptr->production_id
);
is_aliased = alias_sequence && alias_sequence[entry->structural_child_index];
}
if (ts_subtree_visible(*entry->subtree) || is_aliased) {
if (ts_subtree_visible(*entry->subtree)) {
self->stack.size = i + 1;
return true;
}
if (i > 0 && !ts_subtree_extra(*entry->subtree)) {
TreeCursorEntry *parent_entry = &self->stack.contents[i - 1];
if (ts_language_alias_at(
self->tree->language,
parent_entry->subtree->ptr->production_id,
entry->structural_child_index
)) {
self->stack.size = i + 1;
return true;
}
}
}
return false;
}
@ -226,15 +228,13 @@ TSNode ts_tree_cursor_current_node(const TSTreeCursor *_self) {
const TreeCursor *self = (const TreeCursor *)_self;
TreeCursorEntry *last_entry = array_back(&self->stack);
TSSymbol alias_symbol = 0;
if (self->stack.size > 1) {
if (self->stack.size > 1 && !ts_subtree_extra(*last_entry->subtree)) {
TreeCursorEntry *parent_entry = &self->stack.contents[self->stack.size - 2];
const TSSymbol *alias_sequence = ts_language_alias_sequence(
alias_symbol = ts_language_alias_at(
self->tree->language,
parent_entry->subtree->ptr->production_id
parent_entry->subtree->ptr->production_id,
last_entry->structural_child_index
);
if (alias_sequence && !ts_subtree_extra(*last_entry->subtree)) {
alias_symbol = alias_sequence[last_entry->structural_child_index];
}
}
return ts_node_new(
self->tree,
@ -263,13 +263,14 @@ TSFieldId ts_tree_cursor_current_status(
// Stop walking up when a visible ancestor is found.
if (i != self->stack.size - 1) {
if (ts_subtree_visible(*entry->subtree)) break;
const TSSymbol *alias_sequence = ts_language_alias_sequence(
self->tree->language,
parent_entry->subtree->ptr->production_id
);
if (alias_sequence && alias_sequence[entry->structural_child_index]) {
break;
}
if (
!ts_subtree_extra(*entry->subtree) &&
ts_language_alias_at(
self->tree->language,
parent_entry->subtree->ptr->production_id,
entry->structural_child_index
)
) break;
}
if (ts_subtree_child_count(*parent_entry->subtree) > entry->child_index + 1) {
@ -321,13 +322,14 @@ TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) {
// Stop walking up when another visible node is found.
if (i != self->stack.size - 1) {
if (ts_subtree_visible(*entry->subtree)) break;
const TSSymbol *alias_sequence = ts_language_alias_sequence(
self->tree->language,
parent_entry->subtree->ptr->production_id
);
if (alias_sequence && alias_sequence[entry->structural_child_index]) {
break;
}
if (
!ts_subtree_extra(*entry->subtree) &&
ts_language_alias_at(
self->tree->language,
parent_entry->subtree->ptr->production_id,
entry->structural_child_index
)
) break;
}
if (ts_subtree_extra(*entry->subtree)) break;

View file

@ -18,15 +18,22 @@ OPTIONS
-r parse each sample the given number of times (default 5)
-g debug
EOF
}
while getopts "hl:e:r:" option; do
mode=normal
while getopts "hgl:e:r:" option; do
case ${option} in
h)
usage
exit
;;
g)
mode=debug
;;
e)
export TREE_SITTER_BENCHMARK_EXAMPLE_FILTER=${OPTARG}
;;
@ -39,4 +46,13 @@ while getopts "hl:e:r:" option; do
esac
done
cargo bench benchmark
if [[ "${mode}" == "debug" ]]; then
test_binary=$(
cargo bench benchmark --no-run --message-format=json 2> /dev/null |\
jq -rs 'map(select(.target.name == "benchmark" and .executable))[0].executable'
)
env | grep TREE_SITTER
echo $test_binary
else
exec cargo bench benchmark
fi

View file

@ -83,10 +83,14 @@ done
shift $(expr $OPTIND - 1)
if [[ -n $TREE_SITTER_TEST_LANGUAGE_FILTER || -n $TREE_SITTER_TEST_EXAMPLE_FILTER || -n $TREE_SITTER_TEST_TRIAL_FILTER ]]; then
top_level_filter=corpus
else
top_level_filter=$1
top_level_filter=$1
if [[ \
-n $TREE_SITTER_TEST_LANGUAGE_FILTER || \
-n $TREE_SITTER_TEST_EXAMPLE_FILTER || \
-n $TREE_SITTER_TEST_TRIAL_FILTER \
]]; then
: ${top_level_filter:=corpus}
fi
if [[ "${mode}" == "debug" ]]; then