Merge pull request #444 from tree-sitter/tree-queries

Introduce the 'Tree query' - an API for pattern-matching on syntax trees
This commit is contained in:
Max Brunsfeld 2019-09-18 17:36:21 -07:00 committed by GitHub
commit 07afce0686
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 4054 additions and 123 deletions

View file

@ -5,6 +5,7 @@ pub mod highlight;
pub mod loader;
pub mod logger;
pub mod parse;
pub mod query;
pub mod test;
pub mod util;
pub mod wasm;

View file

@ -3,8 +3,9 @@ use error::Error;
use std::path::Path;
use std::process::exit;
use std::{env, fs, u64};
use tree_sitter::Language;
use tree_sitter_cli::{
config, error, generate, highlight, loader, logger, parse, test, wasm, web_ui,
config, error, generate, highlight, loader, logger, parse, query, test, wasm, web_ui,
};
const BUILD_VERSION: &'static str = env!("CARGO_PKG_VERSION");
@ -50,7 +51,7 @@ fn run() -> error::Result<()> {
)
.subcommand(
SubCommand::with_name("parse")
.about("Parse a file")
.about("Parse files")
.arg(
Arg::with_name("path")
.index(1)
@ -73,6 +74,19 @@ fn run() -> error::Result<()> {
.number_of_values(1),
),
)
.subcommand(
SubCommand::with_name("query")
.about("Search files using a syntax tree query")
.arg(Arg::with_name("query-path").index(1).required(true))
.arg(
Arg::with_name("path")
.index(2)
.multiple(true)
.required(true),
)
.arg(Arg::with_name("scope").long("scope").takes_value(true))
.arg(Arg::with_name("captures").long("captures").short("c")),
)
.subcommand(
SubCommand::with_name("test")
.about("Run a parser's tests")
@ -168,7 +182,6 @@ fn run() -> error::Result<()> {
let timeout = matches
.value_of("timeout")
.map_or(0, |t| u64::from_str_radix(t, 10).unwrap());
loader.find_all_languages(&config.parser_directories)?;
let paths = matches
.values_of("path")
.unwrap()
@ -176,43 +189,11 @@ fn run() -> error::Result<()> {
.collect::<Vec<_>>();
let max_path_length = paths.iter().map(|p| p.chars().count()).max().unwrap();
let mut has_error = false;
loader.find_all_languages(&config.parser_directories)?;
for path in paths {
let path = Path::new(path);
let language = if let Some(scope) = matches.value_of("scope") {
if let Some(config) =
loader
.language_configuration_for_scope(scope)
.map_err(Error::wrap(|| {
format!("Failed to load language for scope '{}'", scope)
}))?
{
config.0
} else {
return Error::err(format!("Unknown scope '{}'", scope));
}
} else if let Some((lang, _)) = loader
.language_configuration_for_file_name(path)
.map_err(Error::wrap(|| {
format!(
"Failed to load language for file name {:?}",
path.file_name().unwrap()
)
}))?
{
lang
} else if let Some(lang) = loader
.languages_at_path(&current_dir)
.map_err(Error::wrap(|| {
"Failed to load language in current directory"
}))?
.first()
.cloned()
{
lang
} else {
eprintln!("No language found");
return Ok(());
};
let language =
select_language(&mut loader, path, &current_dir, matches.value_of("scope"))?;
has_error |= parse::parse_file_at_path(
language,
path,
@ -226,10 +207,26 @@ fn run() -> error::Result<()> {
allow_cancellation,
)?;
}
if has_error {
return Error::err(String::new());
}
} else if let Some(matches) = matches.subcommand_matches("query") {
let ordered_captures = matches.values_of("captures").is_some();
let paths = matches
.values_of("path")
.unwrap()
.into_iter()
.map(Path::new)
.collect::<Vec<&Path>>();
loader.find_all_languages(&config.parser_directories)?;
let language = select_language(
&mut loader,
paths[0],
&current_dir,
matches.value_of("scope"),
)?;
let query_path = Path::new(matches.value_of("query-path").unwrap());
query::query_files_at_paths(language, paths, query_path, ordered_captures)?;
} else if let Some(matches) = matches.subcommand_matches("highlight") {
let paths = matches.values_of("path").unwrap().into_iter();
let html_mode = matches.is_present("html");
@ -296,3 +293,47 @@ fn run() -> error::Result<()> {
Ok(())
}
fn select_language(
loader: &mut loader::Loader,
path: &Path,
current_dir: &Path,
scope: Option<&str>,
) -> Result<Language, Error> {
if let Some(scope) = scope {
if let Some(config) =
loader
.language_configuration_for_scope(scope)
.map_err(Error::wrap(|| {
format!("Failed to load language for scope '{}'", scope)
}))?
{
Ok(config.0)
} else {
return Error::err(format!("Unknown scope '{}'", scope));
}
} else if let Some((lang, _)) =
loader
.language_configuration_for_file_name(path)
.map_err(Error::wrap(|| {
format!(
"Failed to load language for file name {:?}",
path.file_name().unwrap()
)
}))?
{
Ok(lang)
} else if let Some(lang) = loader
.languages_at_path(&current_dir)
.map_err(Error::wrap(|| {
"Failed to load language in current directory"
}))?
.first()
.cloned()
{
Ok(lang)
} else {
eprintln!("No language found");
Error::err("No language found".to_string())
}
}

64
cli/src/query.rs Normal file
View file

@ -0,0 +1,64 @@
use super::error::{Error, Result};
use std::fs;
use std::io::{self, Write};
use std::path::Path;
use tree_sitter::{Language, Node, Parser, Query, QueryCursor};
pub fn query_files_at_paths(
language: Language,
paths: Vec<&Path>,
query_path: &Path,
ordered_captures: bool,
) -> Result<()> {
let stdout = io::stdout();
let mut stdout = stdout.lock();
let query_source = fs::read_to_string(query_path).map_err(Error::wrap(|| {
format!("Error reading query file {:?}", query_path)
}))?;
let query = Query::new(language, &query_source)
.map_err(|e| Error::new(format!("Query compilation failed: {:?}", e)))?;
let mut query_cursor = QueryCursor::new();
let mut parser = Parser::new();
parser.set_language(language).map_err(|e| e.to_string())?;
for path in paths {
writeln!(&mut stdout, "{}", path.to_str().unwrap())?;
let source_code = fs::read(path).map_err(Error::wrap(|| {
format!("Error reading source file {:?}", path)
}))?;
let text_callback = |n: Node| &source_code[n.byte_range()];
let tree = parser.parse(&source_code, None).unwrap();
if ordered_captures {
for (pattern_index, capture) in query_cursor.captures(&query, tree.root_node(), text_callback) {
writeln!(
&mut stdout,
" pattern: {}, capture: {}, row: {}, text: {:?}",
pattern_index,
&query.capture_names()[capture.index],
capture.node.start_position().row,
capture.node.utf8_text(&source_code).unwrap_or("")
)?;
}
} else {
for m in query_cursor.matches(&query, tree.root_node(), text_callback) {
writeln!(&mut stdout, " pattern: {}", m.pattern_index)?;
for capture in m.captures() {
writeln!(
&mut stdout,
" capture: {}, row: {}, text: {:?}",
&query.capture_names()[capture.index],
capture.node.start_position().row,
capture.node.utf8_text(&source_code).unwrap_or("")
)?;
}
}
}
}
Ok(())
}

View file

@ -51,6 +51,12 @@ pub fn stop_recording() {
}
}
pub fn record(f: impl FnOnce()) {
start_recording();
f();
stop_recording();
}
fn record_alloc(ptr: *mut c_void) {
let mut recorder = RECORDER.lock();
if recorder.enabled {

View file

@ -4,4 +4,5 @@ mod highlight_test;
mod node_test;
mod parser_test;
mod properties_test;
mod query_test;
mod tree_test;

897
cli/src/tests/query_test.rs Normal file
View file

@ -0,0 +1,897 @@
use super::helpers::allocations;
use super::helpers::fixtures::get_language;
use std::fmt::Write;
use tree_sitter::{Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch};
#[test]
fn test_query_errors_on_invalid_syntax() {
allocations::record(|| {
let language = get_language("javascript");
assert!(Query::new(language, "(if_statement)").is_ok());
assert!(Query::new(language, "(if_statement condition:(identifier))").is_ok());
// Mismatched parens
assert_eq!(
Query::new(language, "(if_statement"),
Err(QueryError::Syntax(13))
);
assert_eq!(
Query::new(language, "(if_statement))"),
Err(QueryError::Syntax(14))
);
// Return an error at the *beginning* of a bare identifier not followed a colon.
// If there's a colon but no pattern, return an error at the end of the colon.
assert_eq!(
Query::new(language, "(if_statement identifier)"),
Err(QueryError::Syntax(14))
);
assert_eq!(
Query::new(language, "(if_statement condition:)"),
Err(QueryError::Syntax(24))
);
// Return an error at the beginning of an unterminated string.
assert_eq!(
Query::new(language, r#"(identifier) "h "#),
Err(QueryError::Syntax(13))
);
assert_eq!(
Query::new(language, r#"((identifier) ()"#),
Err(QueryError::Syntax(16))
);
assert_eq!(
Query::new(language, r#"((identifier) @x (eq? @x a"#),
Err(QueryError::Syntax(26))
);
});
}
#[test]
fn test_query_errors_on_invalid_symbols() {
allocations::record(|| {
let language = get_language("javascript");
assert_eq!(
Query::new(language, "(clas)"),
Err(QueryError::NodeType("clas"))
);
assert_eq!(
Query::new(language, "(if_statement (arrayyyyy))"),
Err(QueryError::NodeType("arrayyyyy"))
);
assert_eq!(
Query::new(language, "(if_statement condition: (non_existent3))"),
Err(QueryError::NodeType("non_existent3"))
);
assert_eq!(
Query::new(language, "(if_statement condit: (identifier))"),
Err(QueryError::Field("condit"))
);
assert_eq!(
Query::new(language, "(if_statement conditioning: (identifier))"),
Err(QueryError::Field("conditioning"))
);
});
}
#[test]
fn test_query_errors_on_invalid_conditions() {
allocations::record(|| {
let language = get_language("javascript");
assert_eq!(
Query::new(language, "((identifier) @id (@id))"),
Err(QueryError::Predicate(
"Expected predicate to start with a function name. Got @id.".to_string()
))
);
assert_eq!(
Query::new(language, "((identifier) @id (eq? @id))"),
Err(QueryError::Predicate(
"Wrong number of arguments to eq? predicate. Expected 2, got 1.".to_string()
))
);
assert_eq!(
Query::new(language, "((identifier) @id (eq? @id @ok))"),
Err(QueryError::Capture("ok"))
);
});
}
#[test]
fn test_query_matches_with_simple_pattern() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
"(function_declaration name: (identifier) @fn-name)",
)
.unwrap();
let source = "function one() { two(); function three() {} }";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(0, vec![("fn-name", "one")]),
(0, vec![("fn-name", "three")])
],
);
});
}
#[test]
fn test_query_matches_with_multiple_on_same_root() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
"(class_declaration
name: (identifier) @the-class-name
(class_body
(method_definition
name: (property_identifier) @the-method-name)))",
)
.unwrap();
let source = "
class Person {
// the constructor
constructor(name) { this.name = name; }
// the getter
getFullName() { return this.name; }
}
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(
0,
vec![
("the-class-name", "Person"),
("the-method-name", "constructor")
]
),
(
0,
vec![
("the-class-name", "Person"),
("the-method-name", "getFullName")
]
),
],
);
});
}
#[test]
fn test_query_matches_with_multiple_patterns_different_roots() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
"
(function_declaration name:(identifier) @fn-def)
(call_expression function:(identifier) @fn-ref)
",
)
.unwrap();
let source = "
function f1() {
f2(f3());
}
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(0, vec![("fn-def", "f1")]),
(1, vec![("fn-ref", "f2")]),
(1, vec![("fn-ref", "f3")]),
],
);
});
}
#[test]
fn test_query_matches_with_multiple_patterns_same_root() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
"
(pair
key: (property_identifier) @method-def
value: (function))
(pair
key: (property_identifier) @method-def
value: (arrow_function))
",
)
.unwrap();
let source = "
a = {
b: () => { return c; },
d: function() { return d; }
};
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(1, vec![("method-def", "b")]),
(0, vec![("method-def", "d")]),
],
);
});
}
#[test]
fn test_query_matches_with_nesting_and_no_fields() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
"
(array
(array
(identifier) @x1
(identifier) @x2))
",
)
.unwrap();
let source = "
[[a]];
[[c, d], [e, f, g, h]];
[[h], [i]];
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(0, vec![("x1", "c"), ("x2", "d")]),
(0, vec![("x1", "e"), ("x2", "f")]),
(0, vec![("x1", "e"), ("x2", "g")]),
(0, vec![("x1", "f"), ("x2", "g")]),
(0, vec![("x1", "e"), ("x2", "h")]),
(0, vec![("x1", "f"), ("x2", "h")]),
(0, vec![("x1", "g"), ("x2", "h")]),
],
);
});
}
#[test]
fn test_query_matches_with_many() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(language, "(array (identifier) @element)").unwrap();
let source = "[hello];\n".repeat(50);
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(&source));
assert_eq!(
collect_matches(matches, &query, source.as_str()),
vec![(0, vec![("element", "hello")]); 50],
);
});
}
#[test]
fn test_query_matches_with_too_many_permutations_to_track() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
"
(array (identifier) @pre (identifier) @post)
",
)
.unwrap();
let mut source = "hello, ".repeat(50);
source.insert(0, '[');
source.push_str("];");
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(&source));
// For this pathological query, some match permutations will be dropped.
// Just check that a subset of the results are returned, and crash or
// leak occurs.
assert_eq!(
collect_matches(matches, &query, source.as_str())[0],
(0, vec![("pre", "hello"), ("post", "hello")]),
);
});
}
#[test]
fn test_query_matches_with_anonymous_tokens() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
";" @punctuation
"&&" @operator
"#,
)
.unwrap();
let source = "foo(a && b);";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(1, vec![("operator", "&&")]),
(0, vec![("punctuation", ";")]),
]
);
});
}
#[test]
fn test_query_matches_within_byte_range() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(language, "(identifier) @element").unwrap();
let source = "[a, b, c, d, e, f, g]";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches =
cursor
.set_byte_range(5, 15)
.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(0, vec![("element", "c")]),
(0, vec![("element", "d")]),
(0, vec![("element", "e")]),
]
);
});
}
#[test]
fn test_query_matches_different_queries_same_cursor() {
allocations::record(|| {
let language = get_language("javascript");
let query1 = Query::new(
language,
"
(array (identifier) @id1)
",
)
.unwrap();
let query2 = Query::new(
language,
"
(array (identifier) @id1)
(pair (identifier) @id2)
",
)
.unwrap();
let query3 = Query::new(
language,
"
(array (identifier) @id1)
(pair (identifier) @id2)
(parenthesized_expression (identifier) @id3)
",
)
.unwrap();
let source = "[a, {b: b}, (c)];";
let mut parser = Parser::new();
let mut cursor = QueryCursor::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let matches = cursor.matches(&query1, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query1, source),
&[(0, vec![("id1", "a")]),]
);
let matches = cursor.matches(&query3, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query3, source),
&[
(0, vec![("id1", "a")]),
(1, vec![("id2", "b")]),
(2, vec![("id3", "c")]),
]
);
let matches = cursor.matches(&query2, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query2, source),
&[(0, vec![("id1", "a")]), (1, vec![("id2", "b")]),]
);
});
}
#[test]
fn test_query_captures() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
(pair
key: * @method.def
(function
name: (identifier) @method.alias))
(variable_declarator
name: * @function.def
value: (function
name: (identifier) @function.alias))
":" @delimiter
"=" @operator
"#,
)
.unwrap();
let source = "
a({
bc: function de() {
const fg = function hi() {}
},
jk: function lm() {
const no = function pq() {}
},
});
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(2, vec![("delimiter", ":")]),
(0, vec![("method.def", "bc"), ("method.alias", "de")]),
(3, vec![("operator", "=")]),
(1, vec![("function.def", "fg"), ("function.alias", "hi")]),
(2, vec![("delimiter", ":")]),
(0, vec![("method.def", "jk"), ("method.alias", "lm")]),
(3, vec![("operator", "=")]),
(1, vec![("function.def", "no"), ("function.alias", "pq")]),
],
);
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_captures(captures, &query, source),
&[
("method.def", "bc"),
("delimiter", ":"),
("method.alias", "de"),
("function.def", "fg"),
("operator", "="),
("function.alias", "hi"),
("method.def", "jk"),
("delimiter", ":"),
("method.alias", "lm"),
("function.def", "no"),
("operator", "="),
("function.alias", "pq"),
]
);
});
}
#[test]
fn test_query_captures_with_text_conditions() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
((identifier) @constant
(match? @constant "^[A-Z]{2,}$"))
((identifier) @constructor
(match? @constructor "^[A-Z]"))
((identifier) @function.builtin
(eq? @function.builtin "require"))
(identifier) @variable
"#,
)
.unwrap();
let source = "
const ab = require('./ab');
new Cd(EF);
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_captures(captures, &query, source),
&[
("variable", "ab"),
("function.builtin", "require"),
("variable", "require"),
("constructor", "Cd"),
("variable", "Cd"),
("constant", "EF"),
("constructor", "EF"),
("variable", "EF"),
],
);
});
}
#[test]
fn test_query_captures_with_duplicates() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
(variable_declarator
name: (identifier) @function
value: (function))
(identifier) @variable
"#,
)
.unwrap();
let source = "
var x = function() {};
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_captures(captures, &query, source),
&[("function", "x"), ("variable", "x"),],
);
});
}
#[test]
fn test_query_captures_with_many_nested_results_without_fields() {
allocations::record(|| {
let language = get_language("javascript");
// Search for key-value pairs whose values are anonymous functions.
let query = Query::new(
language,
r#"
(pair
key: * @method-def
(arrow_function))
":" @colon
"," @comma
"#,
)
.unwrap();
// The `pair` node for key `y` does not match any pattern, but inside of
// its value, it contains many other `pair` nodes that do match the pattern.
// The match for the *outer* pair should be terminated *before* descending into
// the object value, so that we can avoid needing to buffer all of the inner
// matches.
let method_count = 50;
let mut source = "x = { y: {\n".to_owned();
for i in 0..method_count {
writeln!(&mut source, " method{}: $ => null,", i).unwrap();
}
source.push_str("}};\n");
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(&source));
let captures = collect_captures(captures, &query, &source);
assert_eq!(
&captures[0..13],
&[
("colon", ":"),
("method-def", "method0"),
("colon", ":"),
("comma", ","),
("method-def", "method1"),
("colon", ":"),
("comma", ","),
("method-def", "method2"),
("colon", ":"),
("comma", ","),
("method-def", "method3"),
("colon", ":"),
("comma", ","),
]
);
// Ensure that we don't drop matches because of needing to buffer too many.
assert_eq!(captures.len(), 1 + 3 * method_count);
});
}
#[test]
fn test_query_captures_with_many_nested_results_with_fields() {
allocations::record(|| {
let language = get_language("javascript");
// Search expressions like `a ? a.b : null`
let query = Query::new(
language,
r#"
((ternary_expression
condition: (identifier) @left
consequence: (member_expression
object: (identifier) @right)
alternative: (null))
(eq? @left @right))
"#,
)
.unwrap();
// The outer expression does not match the pattern, but the consequence of the ternary
// is an object that *does* contain many occurences of the pattern.
let count = 50;
let mut source = "a ? {".to_owned();
for i in 0..count {
writeln!(&mut source, " x: y{} ? y{}.z : null,", i, i).unwrap();
}
source.push_str("} : null;\n");
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(&source));
let captures = collect_captures(captures, &query, &source);
assert_eq!(
&captures[0..20],
&[
("left", "y0"),
("right", "y0"),
("left", "y1"),
("right", "y1"),
("left", "y2"),
("right", "y2"),
("left", "y3"),
("right", "y3"),
("left", "y4"),
("right", "y4"),
("left", "y5"),
("right", "y5"),
("left", "y6"),
("right", "y6"),
("left", "y7"),
("right", "y7"),
("left", "y8"),
("right", "y8"),
("left", "y9"),
("right", "y9"),
]
);
// Ensure that we don't drop matches because of needing to buffer too many.
assert_eq!(captures.len(), 2 * count);
});
}
#[test]
fn test_query_start_byte_for_pattern() {
let language = get_language("javascript");
let patterns_1 = r#"
"+" @operator
"-" @operator
"*" @operator
"=" @operator
"=>" @operator
"#
.trim_start();
let patterns_2 = "
(identifier) @a
(string) @b
"
.trim_start();
let patterns_3 = "
((identifier) @b (match? @b i))
(function_declaration name: (identifier) @c)
(method_definition name: (identifier) @d)
"
.trim_start();
let mut source = String::new();
source += patterns_1;
source += patterns_2;
source += patterns_3;
let query = Query::new(language, &source).unwrap();
assert_eq!(query.start_byte_for_pattern(0), 0);
assert_eq!(query.start_byte_for_pattern(5), patterns_1.len());
assert_eq!(
query.start_byte_for_pattern(7),
patterns_1.len() + patterns_2.len()
);
}
#[test]
fn test_query_capture_names() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
(if_statement
condition: (binary_expression
left: * @left-operand
operator: "||"
right: * @right-operand)
consequence: (statement_block) @body)
(while_statement
condition:* @loop-condition)
"#,
)
.unwrap();
assert_eq!(
query.capture_names(),
&[
"left-operand".to_string(),
"right-operand".to_string(),
"body".to_string(),
"loop-condition".to_string(),
]
);
});
}
#[test]
fn test_query_comments() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
"
; this is my first comment
; i have two comments here
(function_declaration
; there is also a comment here
; and here
name: (identifier) @fn-name)",
)
.unwrap();
let source = "function one() { }";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[(0, vec![("fn-name", "one")]),],
);
});
}
fn collect_matches<'a>(
matches: impl Iterator<Item = QueryMatch<'a>>,
query: &'a Query,
source: &'a str,
) -> Vec<(usize, Vec<(&'a str, &'a str)>)> {
matches
.map(|m| {
(
m.pattern_index,
collect_captures(m.captures().map(|c| (m.pattern_index, c)), query, source),
)
})
.collect()
}
fn collect_captures<'a, 'b>(
captures: impl Iterator<Item = (usize, QueryCapture<'a>)>,
query: &'b Query,
source: &'b str,
) -> Vec<(&'b str, &'b str)> {
captures
.map(|(_, QueryCapture { index, node })| {
(
query.capture_names()[index].as_str(),
node.utf8_text(source.as_bytes()).unwrap(),
)
})
.collect()
}
fn to_callback<'a>(source: &'a str) -> impl Fn(Node) -> &'a [u8] {
move |n| &source.as_bytes()[n.byte_range()]
}

View file

@ -7,7 +7,7 @@
</head>
<body>
<div id="playground-container">
<div id="playground-container" style="visibility: hidden;">
<header>
<div class=header-item>
<bold>THE_LANGUAGE_NAME</bold>
@ -18,18 +18,31 @@
<input id="logging-checkbox" type="checkbox"></input>
</div>
<div class=header-item>
<label for="query-checkbox">query</label>
<input id="query-checkbox" type="checkbox"></input>
</div>
<div class=header-item>
<label for="update-time">parse time: </label>
<span id="update-time"></span>
</div>
</header>
<main>
<select id="language-select" style="display: none;">
<option value="parser">Parser</option>
</select>
</header>
<textarea id="code-input"></textarea>
<main>
<div id="input-pane">
<div id="code-container">
<textarea id="code-input"></textarea>
</div>
<div id="query-container" style="visibility: hidden; position: absolute;">
<textarea id="query-input"></textarea>
</div>
</div>
<div id="output-container-scroll">
<pre id="output-container" class="highlight"></pre>
@ -51,15 +64,13 @@
<style>
body {
font: Sans Serif;
margin: 0;
padding: 0;
}
#playground-container {
position: absolute;
top: 0;
bottom: 0;
left: 0;
right: 0;
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
}
@ -73,24 +84,51 @@
}
main {
flex: 1;
position: relative;
}
#input-pane {
position: absolute;
top: 0;
left: 0;
bottom: 0;
right: 50%;
display: flex;
height: 100%;
flex-direction: row;
flex-direction: column;
}
#code-container, #query-container {
flex: 1;
position: relative;
overflow: hidden;
border-right: 1px solid #aaa;
border-bottom: 1px solid #aaa;
}
#output-container-scroll {
position: absolute;
top: 0;
left: 50%;
bottom: 0;
right: 0;
}
.header-item {
margin-right: 30px;
}
.CodeMirror {
width: 50%;
#playground-container .CodeMirror {
position: absolute;
top: 0;
bottom: 0;
left: 0;
right: 0;
height: 100%;
border-right: 1px solid #aaa;
}
#output-container-scroll {
width: 50%;
height: 100%;
flex: 1;
padding: 0;
overflow: auto;
}
@ -124,5 +162,9 @@
border-radius: 3px;
text-decoration: underline;
}
.query-error {
text-decoration: underline red dashed;
}
</style>
</body>

View file

@ -118,7 +118,7 @@ body {
}
#playground-container {
> .CodeMirror {
.CodeMirror {
height: auto;
max-height: 350px;
border: 1px solid #aaa;
@ -129,7 +129,7 @@ body {
max-height: 350px;
}
h4, select, .field {
h4, select, .field, label {
display: inline-block;
margin-right: 20px;
}
@ -161,3 +161,7 @@ a.highlighted {
background-color: #ddd;
text-decoration: underline;
}
.query-error {
text-decoration: underline red dashed;
}

View file

@ -1,16 +1,31 @@
let tree;
(async () => {
const CAPTURE_REGEX = /@\s*([\w\._-]+)/g;
const COLORS_BY_INDEX = [
'red',
'green',
'blue',
'orange',
'violet',
];
const scriptURL = document.currentScript.getAttribute('src');
const codeInput = document.getElementById('code-input');
const languageSelect = document.getElementById('language-select');
const loggingCheckbox = document.getElementById('logging-checkbox');
const outputContainer = document.getElementById('output-container');
const outputContainerScroll = document.getElementById('output-container-scroll');
const playgroundContainer = document.getElementById('playground-container');
const queryCheckbox = document.getElementById('query-checkbox');
const queryContainer = document.getElementById('query-container');
const queryInput = document.getElementById('query-input');
const updateTimeSpan = document.getElementById('update-time');
const demoContainer = document.getElementById('playground-container');
const languagesByName = {};
loadState();
await TreeSitter.init();
const parser = new TreeSitter();
@ -18,6 +33,12 @@ let tree;
lineNumbers: true,
showCursorWhenSelecting: true
});
const queryEditor = CodeMirror.fromTextArea(queryInput, {
lineNumbers: true,
showCursorWhenSelecting: true
});
const cluster = new Clusterize({
rows: [],
noDataText: null,
@ -25,22 +46,30 @@ let tree;
scrollElem: outputContainerScroll
});
const renderTreeOnCodeChange = debounce(renderTree, 50);
const saveStateOnChange = debounce(saveState, 2000);
const runTreeQueryOnChange = debounce(runTreeQuery, 50);
let languageName = languageSelect.value;
let treeRows = null;
let treeRowHighlightedIndex = -1;
let parseCount = 0;
let isRendering = 0;
let query;
codeEditor.on('changes', handleCodeChange);
codeEditor.on('viewportChange', runTreeQueryOnChange);
codeEditor.on('cursorActivity', debounce(handleCursorMovement, 150));
queryEditor.on('changes', debounce(handleQueryChange, 150));
loggingCheckbox.addEventListener('change', handleLoggingChange);
queryCheckbox.addEventListener('change', handleQueryEnableChange);
languageSelect.addEventListener('change', handleLanguageChange);
outputContainer.addEventListener('click', handleTreeClick);
handleQueryEnableChange();
await handleLanguageChange()
demoContainer.style.visibility = 'visible';
playgroundContainer.style.visibility = 'visible';
async function handleLanguageChange() {
const newLanguageName = languageSelect.value;
@ -62,6 +91,7 @@ let tree;
languageName = newLanguageName;
parser.setLanguage(languagesByName[newLanguageName]);
handleCodeChange();
handleQueryChange();
}
async function handleCodeChange(editor, changes) {
@ -81,6 +111,8 @@ let tree;
tree = newTree;
parseCount++;
renderTreeOnCodeChange();
runTreeQueryOnChange();
saveStateOnChange();
}
async function renderTree() {
@ -164,6 +196,104 @@ let tree;
handleCursorMovement();
}
function runTreeQuery(_, startRow, endRow) {
if (endRow == null) {
const viewport = codeEditor.getViewport();
startRow = viewport.from;
endRow = viewport.to;
}
codeEditor.operation(() => {
const marks = codeEditor.getAllMarks();
marks.forEach(m => m.clear());
if (tree && query) {
const captures = query.captures(
tree.rootNode,
{row: startRow, column: 0},
{row: endRow, column: 0},
);
for (const {name, node} of captures) {
const {startPosition, endPosition} = node;
codeEditor.markText(
{line: startPosition.row, ch: startPosition.column},
{line: endPosition.row, ch: endPosition.column},
{
inclusiveLeft: true,
inclusiveRight: true,
css: `color: ${colorForCaptureName(name)}`
}
);
}
}
});
}
function handleQueryChange() {
if (query) {
query.delete();
query.deleted = true;
query = null;
}
queryEditor.operation(() => {
queryEditor.getAllMarks().forEach(m => m.clear());
if (!queryCheckbox.checked) return;
const queryText = queryEditor.getValue();
try {
query = parser.getLanguage().query(queryText);
let match;
let row = 0;
queryEditor.eachLine((line) => {
while (match = CAPTURE_REGEX.exec(line.text)) {
queryEditor.markText(
{line: row, ch: match.index},
{line: row, ch: match.index + match[0].length},
{
inclusiveLeft: true,
inclusiveRight: true,
css: `color: ${colorForCaptureName(match[1])}`
}
);
}
row++;
});
} catch (error) {
const startPosition = queryEditor.posFromIndex(error.index);
const endPosition = {
line: startPosition.line,
ch: startPosition.ch + (error.length || 1)
};
if (error.index === queryText.length) {
if (startPosition.ch > 0) {
startPosition.ch--;
} else if (startPosition.row > 0) {
startPosition.row--;
startPosition.column = Infinity;
}
}
queryEditor.markText(
startPosition,
endPosition,
{
className: 'query-error',
inclusiveLeft: true,
inclusiveRight: true,
attributes: {title: error.message}
}
);
}
});
runTreeQuery();
saveQueryState();
}
function handleCursorMovement() {
if (isRendering) return;
@ -236,6 +366,17 @@ let tree;
}
}
function handleQueryEnableChange() {
if (queryCheckbox.checked) {
queryContainer.style.visibility = '';
queryContainer.style.position = '';
} else {
queryContainer.style.visibility = 'hidden';
queryContainer.style.position = 'absolute';
}
handleQueryChange();
}
function treeEditForEditorChange(change) {
const oldLineCount = change.removed.length;
const newLineCount = change.text.length;
@ -262,6 +403,35 @@ let tree;
};
}
function colorForCaptureName(capture) {
const id = query.captureNames.indexOf(capture);
return COLORS_BY_INDEX[id % COLORS_BY_INDEX.length];
}
function loadState() {
const language = localStorage.getItem("language");
const sourceCode = localStorage.getItem("sourceCode");
const query = localStorage.getItem("query");
const queryEnabled = localStorage.getItem("queryEnabled");
if (language != null && sourceCode != null && query != null) {
queryInput.value = query;
codeInput.value = sourceCode;
languageSelect.value = language;
queryCheckbox.checked = (queryEnabled === 'true');
}
}
function saveState() {
localStorage.setItem("language", languageSelect.value);
localStorage.setItem("sourceCode", codeEditor.getValue());
saveQueryState();
}
function saveQueryState() {
localStorage.setItem("queryEnabled", queryCheckbox.checked);
localStorage.setItem("query", queryEditor.getValue());
}
function debounce(func, wait, immediate) {
var timeout;
return function() {

View file

@ -31,6 +31,9 @@ permalink: playground
<input id="logging-checkbox" type="checkbox"></input>
<label for="logging-checkbox">Log</label>
<input id="query-checkbox" type="checkbox"></input>
<label for="query-checkbox">Query</label>
<textarea id="code-input">
</textarea>

View file

@ -19,6 +19,16 @@ pub struct TSParser {
pub struct TSTree {
_unused: [u8; 0],
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQuery {
_unused: [u8; 0],
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryCursor {
_unused: [u8; 0],
}
pub const TSInputEncoding_TSInputEncodingUTF8: TSInputEncoding = 0;
pub const TSInputEncoding_TSInputEncodingUTF16: TSInputEncoding = 1;
pub type TSInputEncoding = u32;
@ -93,6 +103,36 @@ pub struct TSTreeCursor {
pub id: *const ::std::os::raw::c_void,
pub context: [u32; 2usize],
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryCapture {
pub node: TSNode,
pub index: u32,
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryMatch {
pub id: u32,
pub pattern_index: u16,
pub capture_count: u16,
pub captures: *const TSQueryCapture,
}
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeDone: TSQueryPredicateStepType = 0;
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture: TSQueryPredicateStepType = 1;
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeString: TSQueryPredicateStepType = 2;
pub type TSQueryPredicateStepType = u32;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryPredicateStep {
pub type_: TSQueryPredicateStepType,
pub value_id: u32,
}
pub const TSQueryError_TSQueryErrorNone: TSQueryError = 0;
pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1;
pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2;
pub const TSQueryError_TSQueryErrorField: TSQueryError = 3;
pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4;
pub type TSQueryError = u32;
extern "C" {
#[doc = " Create a new parser."]
pub fn ts_parser_new() -> *mut TSParser;
@ -538,6 +578,140 @@ extern "C" {
extern "C" {
pub fn ts_tree_cursor_copy(arg1: *const TSTreeCursor) -> TSTreeCursor;
}
extern "C" {
#[doc = " Create a new query from a string containing one or more S-expression"]
#[doc = " patterns. The query is associated with a particular language, and can"]
#[doc = " only be run on syntax nodes parsed with that language."]
#[doc = ""]
#[doc = " If all of the given patterns are valid, this returns a `TSQuery`."]
#[doc = " If a pattern is invalid, this returns `NULL`, and provides two pieces"]
#[doc = " of information about the problem:"]
#[doc = " 1. The byte offset of the error is written to the `error_offset` parameter."]
#[doc = " 2. The type of error is written to the `error_type` parameter."]
pub fn ts_query_new(
language: *const TSLanguage,
source: *const ::std::os::raw::c_char,
source_len: u32,
error_offset: *mut u32,
error_type: *mut TSQueryError,
) -> *mut TSQuery;
}
extern "C" {
#[doc = " Delete a query, freeing all of the memory that it used."]
pub fn ts_query_delete(arg1: *mut TSQuery);
}
extern "C" {
#[doc = " Get the number of patterns, captures, or string literals in the query."]
pub fn ts_query_pattern_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
pub fn ts_query_string_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
#[doc = " Get the byte offset where the given pattern starts in the query\'s source."]
#[doc = ""]
#[doc = " This can be useful when combining queries by concatenating their source"]
#[doc = " code strings."]
pub fn ts_query_start_byte_for_pattern(arg1: *const TSQuery, arg2: u32) -> u32;
}
extern "C" {
#[doc = " Get all of the predicates for the given pattern in the query."]
#[doc = ""]
#[doc = " The predicates are represented as a single array of steps. There are three"]
#[doc = " types of steps in this array, which correspond to the three legal values for"]
#[doc = " the `type` field:"]
#[doc = " - `TSQueryPredicateStepTypeCapture` - Steps with this type represent names"]
#[doc = " of captures. Their `value_id` can be used with the"]
#[doc = " `ts_query_capture_name_for_id` function to obtain the name of the capture."]
#[doc = " - `TSQueryPredicateStepTypeString` - Steps with this type represent literal"]
#[doc = " strings. Their `value_id` can be used with the"]
#[doc = " `ts_query_string_value_for_id` function to obtain their string value."]
#[doc = " - `TSQueryPredicateStepTypeDone` - Steps with this type are *sentinels*"]
#[doc = " that represent the end of an individual predicate. If a pattern has two"]
#[doc = " predicates, then there will be two steps with this `type` in the array."]
pub fn ts_query_predicates_for_pattern(
self_: *const TSQuery,
pattern_index: u32,
length: *mut u32,
) -> *const TSQueryPredicateStep;
}
extern "C" {
#[doc = " Get the name and length of one of the query\'s captures, or one of the"]
#[doc = " query\'s string literals. Each capture and string is associated with a"]
#[doc = " numeric id based on the order that it appeared in the query\'s source."]
pub fn ts_query_capture_name_for_id(
arg1: *const TSQuery,
id: u32,
length: *mut u32,
) -> *const ::std::os::raw::c_char;
}
extern "C" {
pub fn ts_query_string_value_for_id(
arg1: *const TSQuery,
id: u32,
length: *mut u32,
) -> *const ::std::os::raw::c_char;
}
extern "C" {
#[doc = " Create a new cursor for executing a given query."]
#[doc = ""]
#[doc = " The cursor stores the state that is needed to iteratively search"]
#[doc = " for matches. To use the query cursor, first call `ts_query_cursor_exec`"]
#[doc = " to start running a given query on a given syntax node. Then, there are"]
#[doc = " two options for consuming the results of the query:"]
#[doc = " 1. Repeatedly call `ts_query_cursor_next_match` to iterate over all of the"]
#[doc = " the *matches* in the order that they were found. Each match contains the"]
#[doc = " index of the pattern that matched, and an array of captures. Because"]
#[doc = " multiple patterns can match the same set of nodes, one match may contain"]
#[doc = " captures that appear *before* some of the captures from a previous match."]
#[doc = " 2. Repeatedly call `ts_query_cursor_next_capture` to iterate over all of the"]
#[doc = " individual *captures* in the order that they appear. This is useful if"]
#[doc = " don\'t care about which pattern matched, and just want a single ordered"]
#[doc = " sequence of captures."]
#[doc = ""]
#[doc = " If you don\'t care about consuming all of the results, you can stop calling"]
#[doc = " `ts_query_cursor_next_match` or `ts_query_cursor_next_capture` at any point."]
#[doc = " You can then start executing another query on another node by calling"]
#[doc = " `ts_query_cursor_exec` again."]
pub fn ts_query_cursor_new() -> *mut TSQueryCursor;
}
extern "C" {
#[doc = " Delete a query cursor, freeing all of the memory that it used."]
pub fn ts_query_cursor_delete(arg1: *mut TSQueryCursor);
}
extern "C" {
#[doc = " Start running a given query on a given node."]
pub fn ts_query_cursor_exec(arg1: *mut TSQueryCursor, arg2: *const TSQuery, arg3: TSNode);
}
extern "C" {
#[doc = " Set the range of bytes or (row, column) positions in which the query"]
#[doc = " will be executed."]
pub fn ts_query_cursor_set_byte_range(arg1: *mut TSQueryCursor, arg2: u32, arg3: u32);
}
extern "C" {
pub fn ts_query_cursor_set_point_range(arg1: *mut TSQueryCursor, arg2: TSPoint, arg3: TSPoint);
}
extern "C" {
#[doc = " Advance to the next match of the currently running query."]
#[doc = ""]
#[doc = " If there is a match, write it to `*match` and return `true`."]
#[doc = " Otherwise, return `false`."]
pub fn ts_query_cursor_next_match(arg1: *mut TSQueryCursor, match_: *mut TSQueryMatch) -> bool;
}
extern "C" {
#[doc = " Advance to the next capture of the currently running query."]
#[doc = ""]
#[doc = " If there is a capture, write its match to `*match` and its index within"]
#[doc = " the matche\'s capture list to `*capture_index`. Otherwise, return `false`."]
pub fn ts_query_cursor_next_capture(
arg1: *mut TSQueryCursor,
match_: *mut TSQueryMatch,
capture_index: *mut u32,
) -> bool;
}
extern "C" {
#[doc = " Get the number of distinct node types in the language."]
pub fn ts_language_symbol_count(arg1: *const TSLanguage) -> u32;

View file

@ -15,9 +15,10 @@ use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::ffi::CStr;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::os::raw::{c_char, c_void};
use std::sync::atomic::AtomicUsize;
use std::{fmt, ptr, str, u16};
use std::{char, fmt, ptr, slice, str, u16};
pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION;
pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h");
@ -136,6 +137,41 @@ pub struct TreePropertyCursor<'a, P> {
source: &'a [u8],
}
#[derive(Debug)]
enum QueryPredicate {
CaptureEqString(u32, String),
CaptureEqCapture(u32, u32),
CaptureMatchString(u32, regex::bytes::Regex),
}
#[derive(Debug)]
pub struct Query {
ptr: *mut ffi::TSQuery,
capture_names: Vec<String>,
predicates: Vec<Vec<QueryPredicate>>,
}
pub struct QueryCursor(*mut ffi::TSQueryCursor);
pub struct QueryMatch<'a> {
pub pattern_index: usize,
captures: &'a [ffi::TSQueryCapture],
}
pub struct QueryCapture<'a> {
pub index: usize,
pub node: Node<'a>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum QueryError<'a> {
Syntax(usize),
NodeType(&'a str),
Field(&'a str),
Capture(&'a str),
Predicate(String),
}
impl Language {
pub fn version(&self) -> usize {
unsafe { ffi::ts_language_version(self.0) as usize }
@ -237,7 +273,7 @@ impl Parser {
pub fn set_logger(&mut self, logger: Option<Logger>) {
let prev_logger = unsafe { ffi::ts_parser_logger(self.0) };
if !prev_logger.payload.is_null() {
unsafe { Box::from_raw(prev_logger.payload as *mut Logger) };
drop(unsafe { Box::from_raw(prev_logger.payload as *mut Logger) });
}
let c_logger;
@ -308,7 +344,7 @@ impl Parser {
)
}
/// Parse a slice UTF16 text.
/// Parse a slice of UTF16 text.
///
/// # Arguments:
/// * `text` The UTF16-encoded text to parse.
@ -592,6 +628,10 @@ impl<'tree> Node<'tree> {
unsafe { ffi::ts_node_end_byte(self.0) as usize }
}
pub fn byte_range(&self) -> std::ops::Range<usize> {
self.start_byte()..self.end_byte()
}
pub fn range(&self) -> Range {
Range {
start_byte: self.start_byte(),
@ -921,6 +961,353 @@ impl<'a, P> TreePropertyCursor<'a, P> {
}
}
impl Query {
pub fn new<'a>(language: Language, source: &'a str) -> Result<Self, QueryError<'a>> {
let mut error_offset = 0u32;
let mut error_type: ffi::TSQueryError = 0;
let bytes = source.as_bytes();
// Compile the query.
let ptr = unsafe {
ffi::ts_query_new(
language.0,
bytes.as_ptr() as *const c_char,
bytes.len() as u32,
&mut error_offset as *mut u32,
&mut error_type as *mut ffi::TSQueryError,
)
};
// On failure, build an error based on the error code and offset.
if ptr.is_null() {
let offset = error_offset as usize;
return if error_type != ffi::TSQueryError_TSQueryErrorSyntax {
let suffix = source.split_at(offset).1;
let end_offset = suffix
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
.unwrap_or(source.len());
let name = suffix.split_at(end_offset).0;
match error_type {
ffi::TSQueryError_TSQueryErrorNodeType => Err(QueryError::NodeType(name)),
ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(name)),
ffi::TSQueryError_TSQueryErrorCapture => Err(QueryError::Capture(name)),
_ => Err(QueryError::Syntax(offset)),
}
} else {
Err(QueryError::Syntax(offset))
};
}
let string_count = unsafe { ffi::ts_query_string_count(ptr) };
let capture_count = unsafe { ffi::ts_query_capture_count(ptr) };
let pattern_count = unsafe { ffi::ts_query_pattern_count(ptr) as usize };
let mut result = Query {
ptr,
capture_names: Vec::with_capacity(capture_count as usize),
predicates: Vec::with_capacity(pattern_count),
};
// Build a vector of strings to store the capture names.
for i in 0..capture_count {
unsafe {
let mut length = 0u32;
let name =
ffi::ts_query_capture_name_for_id(ptr, i, &mut length as *mut u32) as *const u8;
let name = slice::from_raw_parts(name, length as usize);
let name = str::from_utf8_unchecked(name);
result.capture_names.push(name.to_string());
}
}
// Build a vector of strings to represent literal values used in predicates.
let string_values = (0..string_count)
.map(|i| unsafe {
let mut length = 0u32;
let value =
ffi::ts_query_string_value_for_id(ptr, i as u32, &mut length as *mut u32)
as *const u8;
let value = slice::from_raw_parts(value, length as usize);
let value = str::from_utf8_unchecked(value);
value.to_string()
})
.collect::<Vec<_>>();
// Build a vector of predicates for each pattern.
for i in 0..pattern_count {
let predicate_steps = unsafe {
let mut length = 0u32;
let raw_predicates =
ffi::ts_query_predicates_for_pattern(ptr, i as u32, &mut length as *mut u32);
slice::from_raw_parts(raw_predicates, length as usize)
};
let type_done = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeDone;
let type_capture = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture;
let type_string = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeString;
let mut pattern_predicates = Vec::new();
for p in predicate_steps.split(|s| s.type_ == type_done) {
if p.is_empty() {
continue;
}
if p[0].type_ != type_string {
return Err(QueryError::Predicate(format!(
"Expected predicate to start with a function name. Got @{}.",
result.capture_names[p[0].value_id as usize],
)));
}
// Build a predicate for each of the known predicate function names.
let operator_name = &string_values[p[0].value_id as usize];
pattern_predicates.push(match operator_name.as_str() {
"eq?" => {
if p.len() != 3 {
return Err(QueryError::Predicate(format!(
"Wrong number of arguments to eq? predicate. Expected 2, got {}.",
p.len() - 1
)));
}
if p[1].type_ != type_capture {
return Err(QueryError::Predicate(format!(
"First argument to eq? predicate must be a capture name. Got literal \"{}\".",
string_values[p[1].value_id as usize],
)));
}
if p[2].type_ == type_capture {
Ok(QueryPredicate::CaptureEqCapture(
p[1].value_id,
p[2].value_id,
))
} else {
Ok(QueryPredicate::CaptureEqString(
p[1].value_id,
string_values[p[2].value_id as usize].clone(),
))
}
}
"match?" => {
if p.len() != 3 {
return Err(QueryError::Predicate(format!(
"Wrong number of arguments to match? predicate. Expected 2, got {}.",
p.len() - 1
)));
}
if p[1].type_ != type_capture {
return Err(QueryError::Predicate(format!(
"First argument to match? predicate must be a capture name. Got literal \"{}\".",
string_values[p[1].value_id as usize],
)));
}
if p[2].type_ == type_capture {
return Err(QueryError::Predicate(format!(
"Second argument to match? predicate must be a literal. Got capture @{}.",
result.capture_names[p[2].value_id as usize],
)));
}
let regex = &string_values[p[2].value_id as usize];
Ok(QueryPredicate::CaptureMatchString(
p[1].value_id,
regex::bytes::Regex::new(regex)
.map_err(|_| QueryError::Predicate(format!("Invalid regex '{}'", regex)))?,
))
}
_ => Err(QueryError::Predicate(format!(
"Unknown query predicate function {}",
operator_name,
))),
}?);
}
result.predicates.push(pattern_predicates);
}
Ok(result)
}
pub fn start_byte_for_pattern(&self, pattern_index: usize) -> usize {
if pattern_index >= self.predicates.len() {
panic!(
"Pattern index is {} but the pattern count is {}",
pattern_index,
self.predicates.len(),
);
}
unsafe { ffi::ts_query_start_byte_for_pattern(self.ptr, pattern_index as u32) as usize }
}
pub fn pattern_count(&self) -> usize {
unsafe { ffi::ts_query_pattern_count(self.ptr) as usize }
}
pub fn capture_names(&self) -> &[String] {
&self.capture_names
}
}
impl QueryCursor {
pub fn new() -> Self {
QueryCursor(unsafe { ffi::ts_query_cursor_new() })
}
pub fn matches<'a>(
&'a mut self,
query: &'a Query,
node: Node<'a>,
mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a,
) -> impl Iterator<Item = QueryMatch<'a>> + 'a {
unsafe {
ffi::ts_query_cursor_exec(self.0, query.ptr, node.0);
}
std::iter::from_fn(move || -> Option<QueryMatch<'a>> {
loop {
unsafe {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_match(self.0, m.as_mut_ptr()) {
let m = m.assume_init();
let captures = slice::from_raw_parts(m.captures, m.capture_count as usize);
if self.captures_match_condition(
query,
captures,
m.pattern_index as usize,
&mut text_callback,
) {
return Some(QueryMatch {
pattern_index: m.pattern_index as usize,
captures,
});
}
} else {
return None;
}
}
}
})
}
pub fn captures<'a>(
&'a mut self,
query: &'a Query,
node: Node<'a>,
mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a,
) -> impl Iterator<Item = (usize, QueryCapture)> + 'a {
unsafe {
ffi::ts_query_cursor_exec(self.0, query.ptr, node.0);
}
std::iter::from_fn(move || loop {
unsafe {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
let mut capture_index = 0u32;
if ffi::ts_query_cursor_next_capture(
self.0,
m.as_mut_ptr(),
&mut capture_index as *mut u32,
) {
let m = m.assume_init();
let captures = slice::from_raw_parts(m.captures, m.capture_count as usize);
if self.captures_match_condition(
query,
captures,
m.pattern_index as usize,
&mut text_callback,
) {
let capture = captures[capture_index as usize];
return Some((
m.pattern_index as usize,
QueryCapture {
index: capture.index as usize,
node: Node::new(capture.node).unwrap(),
},
));
}
} else {
return None;
}
}
})
}
fn captures_match_condition<'a>(
&self,
query: &'a Query,
captures: &'a [ffi::TSQueryCapture],
pattern_index: usize,
text_callback: &mut impl FnMut(Node<'a>) -> &'a [u8],
) -> bool {
query.predicates[pattern_index]
.iter()
.all(|predicate| match predicate {
QueryPredicate::CaptureEqCapture(i, j) => {
let node1 = Self::capture_for_id(captures, *i).unwrap();
let node2 = Self::capture_for_id(captures, *j).unwrap();
text_callback(node1) == text_callback(node2)
}
QueryPredicate::CaptureEqString(i, s) => {
let node = Self::capture_for_id(captures, *i).unwrap();
text_callback(node) == s.as_bytes()
}
QueryPredicate::CaptureMatchString(i, r) => {
let node = Self::capture_for_id(captures, *i).unwrap();
r.is_match(text_callback(node))
}
})
}
fn capture_for_id(captures: &[ffi::TSQueryCapture], capture_id: u32) -> Option<Node> {
for c in captures {
if c.index == capture_id {
return Node::new(c.node);
}
}
None
}
pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self {
unsafe {
ffi::ts_query_cursor_set_byte_range(self.0, start as u32, end as u32);
}
self
}
pub fn set_point_range(&mut self, start: Point, end: Point) -> &mut Self {
unsafe {
ffi::ts_query_cursor_set_point_range(self.0, start.into(), end.into());
}
self
}
}
impl<'a> QueryMatch<'a> {
pub fn captures(&self) -> impl ExactSizeIterator<Item = QueryCapture> {
self.captures.iter().map(|capture| QueryCapture {
index: capture.index as usize,
node: Node::new(capture.node).unwrap(),
})
}
}
impl PartialEq for Query {
fn eq(&self, other: &Self) -> bool {
self.ptr == other.ptr
}
}
impl Drop for Query {
fn drop(&mut self) {
unsafe { ffi::ts_query_delete(self.ptr) }
}
}
impl Drop for QueryCursor {
fn drop(&mut self) {
unsafe { ffi::ts_query_cursor_delete(self.0) }
}
}
impl Point {
pub fn new(row: usize, column: usize) -> Self {
Point { row, column }

View file

@ -2,6 +2,7 @@
#include <tree_sitter/api.h>
#include <stdio.h>
#include "array.h"
#include "point.h"
/*****************************/
/* Section - Data marshaling */
@ -31,18 +32,18 @@ static uint32_t byte_to_code_unit(uint32_t byte) {
static inline void marshal_node(const void **buffer, TSNode node) {
buffer[0] = (const void *)node.id;
buffer[1] = (const void *)node.context[0];
buffer[1] = (const void *)byte_to_code_unit(node.context[0]);
buffer[2] = (const void *)node.context[1];
buffer[3] = (const void *)node.context[2];
buffer[3] = (const void *)byte_to_code_unit(node.context[2]);
buffer[4] = (const void *)node.context[3];
}
static inline TSNode unmarshal_node(const TSTree *tree) {
TSNode node;
node.id = TRANSFER_BUFFER[0];
node.context[0] = (uint32_t)TRANSFER_BUFFER[1];
node.context[0] = code_unit_to_byte((uint32_t)TRANSFER_BUFFER[1]);
node.context[1] = (uint32_t)TRANSFER_BUFFER[2];
node.context[2] = (uint32_t)TRANSFER_BUFFER[3];
node.context[2] = code_unit_to_byte((uint32_t)TRANSFER_BUFFER[3]);
node.context[3] = (uint32_t)TRANSFER_BUFFER[4];
node.tree = tree;
return node;
@ -305,6 +306,7 @@ void ts_tree_cursor_current_node_wasm(const TSTree *tree) {
/******************/
static TSTreeCursor scratch_cursor = {0};
static TSQueryCursor *scratch_query_cursor = NULL;
uint16_t ts_node_symbol_wasm(const TSTree *tree) {
TSNode node = unmarshal_node(tree);
@ -464,12 +466,6 @@ void ts_node_named_children_wasm(const TSTree *tree) {
TRANSFER_BUFFER[1] = result;
}
bool point_lte(TSPoint a, TSPoint b) {
if (a.row < b.row) return true;
if (a.row > b.row) return false;
return a.column <= b.column;
}
bool symbols_contain(const uint32_t *set, uint32_t length, uint32_t value) {
for (unsigned i = 0; i < length; i++) {
if (set[i] == value) return true;
@ -566,3 +562,90 @@ int ts_node_is_missing_wasm(const TSTree *tree) {
TSNode node = unmarshal_node(tree);
return ts_node_is_missing(node);
}
/******************/
/* Section - Query */
/******************/
void ts_query_matches_wasm(
const TSQuery *self,
const TSTree *tree,
uint32_t start_row,
uint32_t start_column,
uint32_t end_row,
uint32_t end_column
) {
if (!scratch_query_cursor) scratch_query_cursor = ts_query_cursor_new();
TSNode node = unmarshal_node(tree);
TSPoint start_point = {start_row, code_unit_to_byte(start_column)};
TSPoint end_point = {end_row, code_unit_to_byte(end_column)};
ts_query_cursor_set_point_range(scratch_query_cursor, start_point, end_point);
ts_query_cursor_exec(scratch_query_cursor, self, node);
uint32_t index = 0;
uint32_t match_count = 0;
Array(const void *) result = array_new();
TSQueryMatch match;
while (ts_query_cursor_next_match(scratch_query_cursor, &match)) {
match_count++;
array_grow_by(&result, 2 + 6 * match.capture_count);
result.contents[index++] = (const void *)(uint32_t)match.pattern_index;
result.contents[index++] = (const void *)(uint32_t)match.capture_count;
for (unsigned i = 0; i < match.capture_count; i++) {
const TSQueryCapture *capture = &match.captures[i];
result.contents[index++] = (const void *)capture->index;
marshal_node(result.contents + index, capture->node);
index += 5;
}
}
TRANSFER_BUFFER[0] = (const void *)(match_count);
TRANSFER_BUFFER[1] = result.contents;
}
void ts_query_captures_wasm(
const TSQuery *self,
const TSTree *tree,
uint32_t start_row,
uint32_t start_column,
uint32_t end_row,
uint32_t end_column
) {
if (!scratch_query_cursor) scratch_query_cursor = ts_query_cursor_new();
TSNode node = unmarshal_node(tree);
TSPoint start_point = {start_row, code_unit_to_byte(start_column)};
TSPoint end_point = {end_row, code_unit_to_byte(end_column)};
ts_query_cursor_set_point_range(scratch_query_cursor, start_point, end_point);
ts_query_cursor_exec(scratch_query_cursor, self, node);
unsigned index = 0;
unsigned capture_count = 0;
Array(const void *) result = array_new();
TSQueryMatch match;
uint32_t capture_index;
while (ts_query_cursor_next_capture(
scratch_query_cursor,
&match,
&capture_index
)) {
capture_count++;
array_grow_by(&result, 3 + 6 * match.capture_count);
result.contents[index++] = (const void *)(uint32_t)match.pattern_index;
result.contents[index++] = (const void *)(uint32_t)match.capture_count;
result.contents[index++] = (const void *)(uint32_t)capture_index;
for (unsigned i = 0; i < match.capture_count; i++) {
const TSQueryCapture *capture = &match.captures[i];
result.contents[index++] = (const void *)capture->index;
marshal_node(result.contents + index, capture->node);
index += 5;
}
}
TRANSFER_BUFFER[0] = (const void *)(capture_count);
TRANSFER_BUFFER[1] = result.contents;
}

View file

@ -5,6 +5,11 @@ const SIZE_OF_NODE = 5 * SIZE_OF_INT;
const SIZE_OF_POINT = 2 * SIZE_OF_INT;
const SIZE_OF_RANGE = 2 * SIZE_OF_INT + 2 * SIZE_OF_POINT;
const ZERO_POINT = {row: 0, column: 0};
const QUERY_WORD_REGEX = /[\w-.]*/g;
const PREDICATE_STEP_TYPE_DONE = 0;
const PREDICATE_STEP_TYPE_CAPTURE = 1;
const PREDICATE_STEP_TYPE_STRING = 2;
var VERSION;
var MIN_COMPATIBLE_VERSION;
@ -143,9 +148,7 @@ class Parser {
class Tree {
constructor(internal, address, language, textCallback) {
if (internal !== INTERNAL) {
throw new Error('Illegal constructor')
}
assertInternal(internal);
this[0] = address;
this.language = language;
this.textCallback = textCallback;
@ -201,16 +204,10 @@ class Tree {
class Node {
constructor(internal, tree) {
if (internal !== INTERNAL) {
throw new Error('Illegal constructor')
}
assertInternal(internal);
this.tree = tree;
}
get id() {
return this[0];
}
get typeId() {
marshalNode(this);
return C._ts_node_symbol_wasm(this.tree);
@ -220,23 +217,12 @@ class Node {
return this.tree.language.types[this.typeId] || 'ERROR';
}
get startPosition() {
marshalNode(this);
C._ts_node_start_point_wasm(this.tree[0]);
return unmarshalPoint(TRANSFER_BUFFER);
}
get endPosition() {
marshalNode(this);
C._ts_node_end_point_wasm(this.tree[0]);
return unmarshalPoint(TRANSFER_BUFFER);
}
get startIndex() {
marshalNode(this);
return C._ts_node_start_index_wasm(this.tree[0]);
}
get endIndex() {
marshalNode(this);
return C._ts_node_end_index_wasm(this.tree[0]);
@ -526,9 +512,7 @@ class Node {
class TreeCursor {
constructor(internal, tree) {
if (internal !== INTERNAL) {
throw new Error('Illegal constructor')
}
assertInternal(internal);
this.tree = tree;
unmarshalTreeCursor(this);
}
@ -630,9 +614,7 @@ class TreeCursor {
class Language {
constructor(internal, address) {
if (internal !== INTERNAL) {
throw new Error('Illegal constructor')
}
assertInternal(internal);
this[0] = address;
this.types = new Array(C._ts_language_symbol_count(this[0]));
for (let i = 0, n = this.types.length; i < n; i++) {
@ -672,6 +654,103 @@ class Language {
return this.fields[fieldName] || null;
}
query(source) {
const sourceLength = lengthBytesUTF8(source);
const sourceAddress = C._malloc(sourceLength + 1);
stringToUTF8(source, sourceAddress, sourceLength + 1);
const address = C._ts_query_new(
this[0],
sourceAddress,
sourceLength,
TRANSFER_BUFFER,
TRANSFER_BUFFER + SIZE_OF_INT
);
if (!address) {
const errorId = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');
const errorByte = getValue(TRANSFER_BUFFER, 'i32');
const errorIndex = UTF8ToString(sourceAddress, errorByte).length;
const suffix = source.substr(errorIndex, 100);
const word = suffix.match(QUERY_WORD_REGEX)[0];
let error;
switch (errorId) {
case 2:
error = new RangeError(`Bad node name '${word}'`);
break;
case 3:
error = new RangeError(`Bad field name '${word}'`);
break;
case 4:
error = new RangeError(`Bad capture name @${word}`);
break;
default:
error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`);
break;
}
error.index = errorIndex;
error.length = word.length;
C._free(sourceAddress);
throw error;
}
const stringCount = C._ts_query_string_count(address);
const captureCount = C._ts_query_capture_count(address);
const patternCount = C._ts_query_pattern_count(address);
const captureNames = new Array(captureCount);
const stringValues = new Array(stringCount);
for (let i = 0; i < captureCount; i++) {
const nameAddress = C._ts_query_capture_name_for_id(
address,
i,
TRANSFER_BUFFER
);
const nameLength = getValue(TRANSFER_BUFFER, 'i32');
captureNames[i] = UTF8ToString(nameAddress, nameLength);
}
for (let i = 0; i < stringCount; i++) {
const valueAddress = C._ts_query_string_value_for_id(
address,
i,
TRANSFER_BUFFER
);
const nameLength = getValue(TRANSFER_BUFFER, 'i32');
stringValues[i] = UTF8ToString(valueAddress, nameLength);
}
const predicates = new Array(patternCount);
for (let i = 0; i < patternCount; i++) {
const predicatesAddress = C._ts_query_predicates_for_pattern(
address,
i,
TRANSFER_BUFFER
);
const stepCount = getValue(TRANSFER_BUFFER, 'i32');
predicates[i] = [];
const steps = [];
let stepAddress = predicatesAddress;
for (let j = 0; j < stepCount; j++) {
const stepType = getValue(stepAddress, 'i32');
stepAddress += SIZE_OF_INT;
const stepValueId = getValue(stepAddress, 'i32');
stepAddress += SIZE_OF_INT;
if (stepType === PREDICATE_STEP_TYPE_CAPTURE) {
steps.push({type: 'capture', name: captureNames[stepValueId]});
} else if (stepType === PREDICATE_STEP_TYPE_STRING) {
steps.push({type: 'string', value: stringValues[stepValueId]});
} else if (steps.length > 0) {
predicates[i].push(buildQueryPredicate(steps));
steps.length = 0;
}
}
}
C._free(sourceAddress);
return new Query(INTERNAL, address, captureNames, predicates);
}
static load(url) {
let bytes;
if (
@ -704,6 +783,170 @@ class Language {
}
}
class Query {
constructor(internal, address, captureNames, predicates) {
assertInternal(internal);
this[0] = address;
this.captureNames = captureNames;
this.predicates = predicates;
}
delete() {
C._ts_query_delete(this[0]);
}
matches(node, startPosition, endPosition) {
if (!startPosition) startPosition = ZERO_POINT;
if (!endPosition) endPosition = ZERO_POINT;
marshalNode(node);
C._ts_query_matches_wasm(
this[0],
node.tree[0],
startPosition.row,
startPosition.column,
endPosition.row,
endPosition.column
);
const count = getValue(TRANSFER_BUFFER, 'i32');
const startAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');
const result = new Array(count);
let address = startAddress;
for (let i = 0; i < count; i++) {
const pattern = getValue(address, 'i32');
address += SIZE_OF_INT;
const captureCount = getValue(address, 'i32');
address += SIZE_OF_INT;
const captures = new Array(captureCount);
address = unmarshalCaptures(this, node.tree, address, captures);
if (this.predicates[pattern].every(p => p(captures))) {
result[i] = {pattern, captures};
}
}
C._free(startAddress);
return result;
}
captures(node, startPosition, endPosition) {
if (!startPosition) startPosition = ZERO_POINT;
if (!endPosition) endPosition = ZERO_POINT;
marshalNode(node);
C._ts_query_captures_wasm(
this[0],
node.tree[0],
startPosition.row,
startPosition.column,
endPosition.row,
endPosition.column
);
const count = getValue(TRANSFER_BUFFER, 'i32');
const startAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');
const result = [];
let address = startAddress;
for (let i = 0; i < count; i++) {
const pattern = getValue(address, 'i32');
address += SIZE_OF_INT;
const captureCount = getValue(address, 'i32');
address += SIZE_OF_INT;
const captureIndex = getValue(address, 'i32');
address += SIZE_OF_INT;
const captures = new Array(captureCount);
address = unmarshalCaptures(this, node.tree, address, captures);
if (this.predicates[pattern].every(p => p(captures))) {
result.push(captures[captureIndex]);
}
}
C._free(startAddress);
return result;
}
}
function buildQueryPredicate(steps) {
if (steps[0].type !== 'string') {
throw new Error('Predicates must begin with a literal value');
}
switch (steps[0].value) {
case 'eq?':
if (steps.length !== 3) throw new Error(
`Wrong number of arguments to \`eq?\` predicate. Expected 2, got ${steps.length - 1}`
);
if (steps[1].type !== 'capture') throw new Error(
`First argument of \`eq?\` predicate must be a capture. Got "${steps[1].value}"`
);
if (steps[2].type === 'capture') {
const captureName1 = steps[1].name;
const captureName2 = steps[2].name;
return function(captures) {
let node1, node2
for (const c of captures) {
if (c.name === captureName1) node1 = c.node;
if (c.name === captureName2) node2 = c.node;
}
return node1.text === node2.text
}
} else {
const captureName = steps[1].name;
const stringValue = steps[2].value;
return function(captures) {
for (const c of captures) {
if (c.name === captureName) return c.node.text === stringValue;
}
return false;
}
}
case 'match?':
if (steps.length !== 3) throw new Error(
`Wrong number of arguments to \`match?\` predicate. Expected 2, got ${steps.length - 1}.`
);
if (steps[1].type !== 'capture') throw new Error(
`First argument of \`match?\` predicate must be a capture. Got "${steps[1].value}".`
);
if (steps[2].type !== 'string') throw new Error(
`Second argument of \`match?\` predicate must be a string. Got @${steps[2].value}.`
);
const captureName = steps[1].name;
const regex = new RegExp(steps[2].value);
return function(captures) {
for (const c of captures) {
if (c.name === captureName) return regex.test(c.node.text);
}
return false;
}
default:
throw new Error(`Unknown query predicate \`${steps[0].value}\``);
}
}
function unmarshalCaptures(query, tree, address, result) {
for (let i = 0, n = result.length; i < n; i++) {
const captureIndex = getValue(address, 'i32');
address += SIZE_OF_INT;
const node = unmarshalNode(tree, address);
address += SIZE_OF_NODE;
result[i] = {name: query.captureNames[captureIndex], node};
}
return address;
}
function assertInternal(x) {
if (x !== INTERNAL) throw new Error('Illegal constructor')
}
function isPoint(point) {
return (
point &&
@ -714,22 +957,36 @@ function isPoint(point) {
function marshalNode(node) {
let address = TRANSFER_BUFFER;
for (let i = 0; i < 5; i++) {
setValue(address, node[i], 'i32');
address += SIZE_OF_INT;
}
setValue(address, node.id, 'i32');
address += SIZE_OF_INT;
setValue(address, node.startIndex, 'i32');
address += SIZE_OF_INT;
setValue(address, node.startPosition.row, 'i32');
address += SIZE_OF_INT;
setValue(address, node.startPosition.column, 'i32');
address += SIZE_OF_INT;
setValue(address, node[0], 'i32');
}
function unmarshalNode(tree, address = TRANSFER_BUFFER) {
const id = getValue(address, 'i32');
if (id === 0) return null;
const result = new Node(INTERNAL, tree);
result[0] = id;
address += SIZE_OF_INT;
for (let i = 1; i < 5; i++) {
result[i] = getValue(address, 'i32');
address += SIZE_OF_INT;
}
if (id === 0) return null;
const index = getValue(address, 'i32');
address += SIZE_OF_INT;
const row = getValue(address, 'i32');
address += SIZE_OF_INT;
const column = getValue(address, 'i32');
address += SIZE_OF_INT;
const other = getValue(address, 'i32');
const result = new Node(INTERNAL, tree);
result.id = id;
result.startIndex = index;
result.startPosition = {row, column};
result[0] = other;
return result;
}

View file

@ -68,6 +68,18 @@
"_ts_parser_new_wasm",
"_ts_parser_parse_wasm",
"_ts_parser_set_language",
"_ts_query_capture_count",
"_ts_query_capture_name_for_id",
"_ts_query_captures_wasm",
"_ts_query_context_delete",
"_ts_query_context_new",
"_ts_query_delete",
"_ts_query_matches_wasm",
"_ts_query_new",
"_ts_query_pattern_count",
"_ts_query_predicates_for_pattern",
"_ts_query_string_count",
"_ts_query_string_value_for_id",
"_ts_tree_cursor_current_field_id_wasm",
"_ts_tree_cursor_current_node_id_wasm",
"_ts_tree_cursor_current_node_is_missing_wasm",

View file

@ -0,0 +1,212 @@
const {assert} = require('chai');
let Parser, JavaScript;
describe("Query", () => {
let parser, tree, query;
before(async () =>
({Parser, JavaScript} = await require('./helper'))
);
beforeEach(() => {
parser = new Parser().setLanguage(JavaScript);
});
afterEach(() => {
parser.delete();
if (tree) tree.delete();
if (query) query.delete();
});
describe('construction', () => {
it('throws an error on invalid patterns', () => {
assert.throws(() => {
JavaScript.query("(function_declaration wat)")
}, "Bad syntax at offset 22: \'wat)\'...");
assert.throws(() => {
JavaScript.query("(non_existent)")
}, "Bad node name 'non_existent'");
assert.throws(() => {
JavaScript.query("(a)")
}, "Bad node name 'a'");
assert.throws(() => {
JavaScript.query("(function_declaration non_existent:(identifier))")
}, "Bad field name 'non_existent'");
});
it('throws an error on invalid predicates', () => {
assert.throws(() => {
JavaScript.query("((identifier) @abc (eq? @ab hi))")
}, "Bad capture name @ab");
assert.throws(() => {
JavaScript.query("((identifier) @abc (eq? @ab hi))")
}, "Bad capture name @ab");
assert.throws(() => {
JavaScript.query("((identifier) @abc (eq?))")
}, "Wrong number of arguments to `eq?` predicate. Expected 2, got 0");
assert.throws(() => {
JavaScript.query("((identifier) @a (eq? @a @a @a))")
}, "Wrong number of arguments to `eq?` predicate. Expected 2, got 3");
assert.throws(() => {
JavaScript.query("((identifier) @a (something-else? @a))")
}, "Unknown query predicate `something-else?`");
});
});
describe('.matches', () => {
it('returns all of the matches for the given query', () => {
tree = parser.parse("function one() { two(); function three() {} }");
query = JavaScript.query(`
(function_declaration name:(identifier) @fn-def)
(call_expression function:(identifier) @fn-ref)
`);
const matches = query.matches(tree.rootNode);
assert.deepEqual(
formatMatches(matches),
[
{pattern: 0, captures: [{name: 'fn-def', text: 'one'}]},
{pattern: 1, captures: [{name: 'fn-ref', text: 'two'}]},
{pattern: 0, captures: [{name: 'fn-def', text: 'three'}]},
]
);
});
it('can search in a specified ranges', () => {
tree = parser.parse("[a, b,\nc, d,\ne, f,\ng, h]");
query = JavaScript.query('(identifier) @element');
const matches = query.matches(
tree.rootNode,
{row: 1, column: 1},
{row: 3, column: 1}
);
assert.deepEqual(
formatMatches(matches),
[
{pattern: 0, captures: [{name: 'element', text: 'd'}]},
{pattern: 0, captures: [{name: 'element', text: 'e'}]},
{pattern: 0, captures: [{name: 'element', text: 'f'}]},
{pattern: 0, captures: [{name: 'element', text: 'g'}]},
]
);
});
});
describe('.captures', () => {
it('returns all of the captures for the given query, in order', () => {
tree = parser.parse(`
a({
bc: function de() {
const fg = function hi() {}
},
jk: function lm() {
const no = function pq() {}
},
});
`);
query = JavaScript.query(`
(pair
key: * @method.def
(function
name: (identifier) @method.alias))
(variable_declarator
name: * @function.def
value: (function
name: (identifier) @function.alias))
":" @delimiter
"=" @operator
`);
const captures = query.captures(tree.rootNode);
assert.deepEqual(
formatCaptures(captures),
[
{name: "method.def", text: "bc"},
{name: "delimiter", text: ":"},
{name: "method.alias", text: "de"},
{name: "function.def", text: "fg"},
{name: "operator", text: "="},
{name: "function.alias", text: "hi"},
{name: "method.def", text: "jk"},
{name: "delimiter", text: ":"},
{name: "method.alias", text: "lm"},
{name: "function.def", text: "no"},
{name: "operator", text: "="},
{name: "function.alias", text: "pq"},
]
);
});
it('handles conditions that compare the text of capture to literal strings', () => {
tree = parser.parse(`
const ab = require('./ab');
new Cd(EF);
`);
query = JavaScript.query(`
(identifier) @variable
((identifier) @function.builtin
(eq? @function.builtin "require"))
((identifier) @constructor
(match? @constructor "^[A-Z]"))
((identifier) @constant
(match? @constant "^[A-Z]{2,}$"))
`);
const captures = query.captures(tree.rootNode);
assert.deepEqual(
formatCaptures(captures),
[
{name: "variable", text: "ab"},
{name: "variable", text: "require"},
{name: "function.builtin", text: "require"},
{name: "variable", text: "Cd"},
{name: "constructor", text: "Cd"},
{name: "variable", text: "EF"},
{name: "constructor", text: "EF"},
{name: "constant", text: "EF"},
]
);
});
it('handles conditions that compare the text of capture to each other', () => {
tree = parser.parse(`
ab = abc + 1;
def = de + 1;
ghi = ghi + 1;
`);
query = JavaScript.query(`
((assignment_expression
left: (identifier) @id1
right: (binary_expression
left: (identifier) @id2))
(eq? @id1 @id2))
`);
const captures = query.captures(tree.rootNode);
assert.deepEqual(
formatCaptures(captures),
[
{name: "id1", text: "ghi"},
{name: "id2", text: "ghi"},
]
);
});
});
});
function formatMatches(matches) {
return matches.map(({pattern, captures}) => ({
pattern,
captures: formatCaptures(captures)
}))
}
function formatCaptures(captures) {
return captures.map(({name, node}) => ({ name, text: node.text }))
}

View file

@ -26,6 +26,8 @@ typedef uint16_t TSFieldId;
typedef struct TSLanguage TSLanguage;
typedef struct TSParser TSParser;
typedef struct TSTree TSTree;
typedef struct TSQuery TSQuery;
typedef struct TSQueryCursor TSQueryCursor;
typedef enum {
TSInputEncodingUTF8,
@ -87,6 +89,37 @@ typedef struct {
uint32_t context[2];
} TSTreeCursor;
typedef struct {
TSNode node;
uint32_t index;
} TSQueryCapture;
typedef struct {
uint32_t id;
uint16_t pattern_index;
uint16_t capture_count;
const TSQueryCapture *captures;
} TSQueryMatch;
typedef enum {
TSQueryPredicateStepTypeDone,
TSQueryPredicateStepTypeCapture,
TSQueryPredicateStepTypeString,
} TSQueryPredicateStepType;
typedef struct {
TSQueryPredicateStepType type;
uint32_t value_id;
} TSQueryPredicateStep;
typedef enum {
TSQueryErrorNone = 0,
TSQueryErrorSyntax,
TSQueryErrorNodeType,
TSQueryErrorField,
TSQueryErrorCapture,
} TSQueryError;
/********************/
/* Section - Parser */
/********************/
@ -602,6 +635,148 @@ int64_t ts_tree_cursor_goto_first_child_for_byte(TSTreeCursor *, uint32_t);
TSTreeCursor ts_tree_cursor_copy(const TSTreeCursor *);
/*******************/
/* Section - Query */
/*******************/
/**
* Create a new query from a string containing one or more S-expression
* patterns. The query is associated with a particular language, and can
* only be run on syntax nodes parsed with that language.
*
* If all of the given patterns are valid, this returns a `TSQuery`.
* If a pattern is invalid, this returns `NULL`, and provides two pieces
* of information about the problem:
* 1. The byte offset of the error is written to the `error_offset` parameter.
* 2. The type of error is written to the `error_type` parameter.
*/
TSQuery *ts_query_new(
const TSLanguage *language,
const char *source,
uint32_t source_len,
uint32_t *error_offset,
TSQueryError *error_type
);
/**
* Delete a query, freeing all of the memory that it used.
*/
void ts_query_delete(TSQuery *);
/**
* Get the number of patterns, captures, or string literals in the query.
*/
uint32_t ts_query_pattern_count(const TSQuery *);
uint32_t ts_query_capture_count(const TSQuery *);
uint32_t ts_query_string_count(const TSQuery *);
/**
* Get the byte offset where the given pattern starts in the query's source.
*
* This can be useful when combining queries by concatenating their source
* code strings.
*/
uint32_t ts_query_start_byte_for_pattern(const TSQuery *, uint32_t);
/**
* Get all of the predicates for the given pattern in the query.
*
* The predicates are represented as a single array of steps. There are three
* types of steps in this array, which correspond to the three legal values for
* the `type` field:
* - `TSQueryPredicateStepTypeCapture` - Steps with this type represent names
* of captures. Their `value_id` can be used with the
* `ts_query_capture_name_for_id` function to obtain the name of the capture.
* - `TSQueryPredicateStepTypeString` - Steps with this type represent literal
* strings. Their `value_id` can be used with the
* `ts_query_string_value_for_id` function to obtain their string value.
* - `TSQueryPredicateStepTypeDone` - Steps with this type are *sentinels*
* that represent the end of an individual predicate. If a pattern has two
* predicates, then there will be two steps with this `type` in the array.
*/
const TSQueryPredicateStep *ts_query_predicates_for_pattern(
const TSQuery *self,
uint32_t pattern_index,
uint32_t *length
);
/**
* Get the name and length of one of the query's captures, or one of the
* query's string literals. Each capture and string is associated with a
* numeric id based on the order that it appeared in the query's source.
*/
const char *ts_query_capture_name_for_id(
const TSQuery *,
uint32_t id,
uint32_t *length
);
const char *ts_query_string_value_for_id(
const TSQuery *,
uint32_t id,
uint32_t *length
);
/**
* Create a new cursor for executing a given query.
*
* The cursor stores the state that is needed to iteratively search
* for matches. To use the query cursor, first call `ts_query_cursor_exec`
* to start running a given query on a given syntax node. Then, there are
* two options for consuming the results of the query:
* 1. Repeatedly call `ts_query_cursor_next_match` to iterate over all of the
* the *matches* in the order that they were found. Each match contains the
* index of the pattern that matched, and an array of captures. Because
* multiple patterns can match the same set of nodes, one match may contain
* captures that appear *before* some of the captures from a previous match.
* 2. Repeatedly call `ts_query_cursor_next_capture` to iterate over all of the
* individual *captures* in the order that they appear. This is useful if
* don't care about which pattern matched, and just want a single ordered
* sequence of captures.
*
* If you don't care about consuming all of the results, you can stop calling
* `ts_query_cursor_next_match` or `ts_query_cursor_next_capture` at any point.
* You can then start executing another query on another node by calling
* `ts_query_cursor_exec` again.
*/
TSQueryCursor *ts_query_cursor_new();
/**
* Delete a query cursor, freeing all of the memory that it used.
*/
void ts_query_cursor_delete(TSQueryCursor *);
/**
* Start running a given query on a given node.
*/
void ts_query_cursor_exec(TSQueryCursor *, const TSQuery *, TSNode);
/**
* Set the range of bytes or (row, column) positions in which the query
* will be executed.
*/
void ts_query_cursor_set_byte_range(TSQueryCursor *, uint32_t, uint32_t);
void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint);
/**
* Advance to the next match of the currently running query.
*
* If there is a match, write it to `*match` and return `true`.
* Otherwise, return `false`.
*/
bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match);
/**
* Advance to the next capture of the currently running query.
*
* If there is a capture, write its match to `*match` and its index within
* the matche's capture list to `*capture_index`. Otherwise, return `false`.
*/
bool ts_query_cursor_next_capture(
TSQueryCursor *,
TSQueryMatch *match,
uint32_t *capture_index
);
/**********************/
/* Section - Language */
/**********************/

29
lib/src/bits.h Normal file
View file

@ -0,0 +1,29 @@
#ifndef TREE_SITTER_BITS_H_
#define TREE_SITTER_BITS_H_
#include <stdint.h>
static inline uint32_t bitmask_for_index(uint16_t id) {
return (1u << (31 - id));
}
#ifdef _WIN32
#include <intrin.h>
static inline uint32_t count_leading_zeros(uint32_t x) {
if (x == 0) return 32;
uint32_t result;
_BitScanReverse(&result, x);
return 31 - result;
}
#else
static inline uint32_t count_leading_zeros(uint32_t x) {
if (x == 0) return 32;
return __builtin_clz(x);
}
#endif
#endif // TREE_SITTER_BITS_H_

View file

@ -96,7 +96,8 @@ TSFieldId ts_language_field_id_for_name(
for (TSSymbol i = 1; i < count + 1; i++) {
switch (strncmp(name, self->field_names[i], name_length)) {
case 0:
return i;
if (self->field_names[i][name_length] == 0) return i;
break;
case -1:
return 0;
default:

View file

@ -12,6 +12,7 @@
#include "./lexer.c"
#include "./node.c"
#include "./parser.c"
#include "./query.c"
#include "./stack.c"
#include "./subtree.c"
#include "./tree_cursor.c"

View file

@ -3,6 +3,7 @@
#include "tree_sitter/api.h"
#define POINT_ZERO ((TSPoint) {0, 0})
#define POINT_MAX ((TSPoint) {UINT32_MAX, UINT32_MAX})
static inline TSPoint point__new(unsigned row, unsigned column) {

1306
lib/src/query.c Normal file

File diff suppressed because it is too large Load diff

View file

@ -244,6 +244,72 @@ TSNode ts_tree_cursor_current_node(const TSTreeCursor *_self) {
);
}
TSFieldId ts_tree_cursor_current_status(
const TSTreeCursor *_self,
bool *can_have_later_siblings,
bool *can_have_later_siblings_with_this_field
) {
const TreeCursor *self = (const TreeCursor *)_self;
TSFieldId result = 0;
*can_have_later_siblings = false;
*can_have_later_siblings_with_this_field = false;
// Walk up the tree, visiting the current node and its invisible ancestors,
// because fields can refer to nodes through invisible *wrapper* nodes,
for (unsigned i = self->stack.size - 1; i > 0; i--) {
TreeCursorEntry *entry = &self->stack.contents[i];
TreeCursorEntry *parent_entry = &self->stack.contents[i - 1];
// Stop walking up when a visible ancestor is found.
if (i != self->stack.size - 1) {
if (ts_subtree_visible(*entry->subtree)) break;
const TSSymbol *alias_sequence = ts_language_alias_sequence(
self->tree->language,
parent_entry->subtree->ptr->production_id
);
if (alias_sequence && alias_sequence[entry->structural_child_index]) {
break;
}
}
if (ts_subtree_child_count(*parent_entry->subtree) > entry->child_index + 1) {
*can_have_later_siblings = true;
}
if (ts_subtree_extra(*entry->subtree)) break;
const TSFieldMapEntry *field_map, *field_map_end;
ts_language_field_map(
self->tree->language,
parent_entry->subtree->ptr->production_id,
&field_map, &field_map_end
);
// Look for a field name associated with the current node.
if (!result) {
for (const TSFieldMapEntry *i = field_map; i < field_map_end; i++) {
if (!i->inherited && i->child_index == entry->structural_child_index) {
result = i->field_id;
*can_have_later_siblings_with_this_field = false;
break;
}
}
}
// Determine if there other later siblings with the same field name.
if (result) {
for (const TSFieldMapEntry *i = field_map; i < field_map_end; i++) {
if (i->field_id == result && i->child_index > entry->structural_child_index) {
*can_have_later_siblings_with_this_field = true;
break;
}
}
}
}
return result;
}
TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) {
const TreeCursor *self = (const TreeCursor *)_self;
@ -264,20 +330,18 @@ TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) {
}
}
if (ts_subtree_extra(*entry->subtree)) break;
const TSFieldMapEntry *field_map, *field_map_end;
ts_language_field_map(
self->tree->language,
parent_entry->subtree->ptr->production_id,
&field_map, &field_map_end
);
while (field_map < field_map_end) {
if (
!ts_subtree_extra(*entry->subtree) &&
!field_map->inherited &&
field_map->child_index == entry->structural_child_index
) return field_map->field_id;
field_map++;
for (const TSFieldMapEntry *i = field_map; i < field_map_end; i++) {
if (!i->inherited && i->child_index == entry->structural_child_index) {
return i->field_id;
}
}
}
return 0;

View file

@ -16,5 +16,6 @@ typedef struct {
} TreeCursor;
void ts_tree_cursor_init(TreeCursor *, TSNode);
TSFieldId ts_tree_cursor_current_status(const TSTreeCursor *, bool *, bool *);
#endif // TREE_SITTER_TREE_CURSOR_H_

View file

@ -114,7 +114,6 @@ if [[ "$minify_js" == "1" ]]; then
${web_dir}/node_modules/.bin/terser \
--compress \
--mangle \
--keep-fnames \
--keep-classnames \
-- target/scratch/tree-sitter.js \
> $web_dir/tree-sitter.js

View file

@ -9,14 +9,14 @@ bundle exec ruby <<RUBY &
require "listen"
def copy_wasm_files
`cp $root/target/release/*.{js,wasm} $root/docs/assets/js/`
`cp $root/lib/binding_web/tree-sitter.{js,wasm} $root/docs/assets/js/`
end
puts "Copying WASM files to docs folder..."
copy_wasm_files
puts "Watching release directory"
listener = Listen.to("$root/target/release", only: /^tree-sitter\.(js|wasm)$/, wait_for_delay: 2) do
listener = Listen.to("$root/lib/binding_web", only: /^tree-sitter\.(js|wasm)$/, wait_for_delay: 2) do
puts "WASM files updated. Copying new files to docs folder..."
copy_wasm_files
end