Allow predicates in queries, to match on nodes' text

This commit is contained in:
Max Brunsfeld 2019-09-15 22:06:51 -07:00
parent 307a1a6c11
commit 096126d039
8 changed files with 781 additions and 186 deletions

View file

@ -32,7 +32,7 @@ pub fn query_files_at_paths(
let tree = parser.parse(&source_code, None).unwrap();
for mat in query_cursor.matches(&query, tree.root_node()) {
for mat in query_cursor.matches(&query, tree.root_node(), |n| &source_code[n.byte_range()]) {
writeln!(&mut stdout, " pattern: {}", mat.pattern_index())?;
for (capture_id, node) in mat.captures() {
writeln!(

View file

@ -67,6 +67,30 @@ fn test_query_errors_on_invalid_symbols() {
});
}
#[test]
fn test_query_errors_on_invalid_conditions() {
allocations::record(|| {
let language = get_language("javascript");
assert_eq!(
Query::new(language, "((identifier) @id (@id))"),
Err(QueryError::Predicate(
"Expected predicate to start with a function name. Got @id.".to_string()
))
);
assert_eq!(
Query::new(language, "((identifier) @id (eq? @id))"),
Err(QueryError::Predicate(
"Wrong number of arguments to eq? predicate. Expected 2, got 1.".to_string()
))
);
assert_eq!(
Query::new(language, "((identifier) @id (eq? @id @ok))"),
Err(QueryError::Capture("ok"))
);
});
}
#[test]
fn test_query_matches_with_simple_pattern() {
allocations::record(|| {
@ -83,7 +107,7 @@ fn test_query_matches_with_simple_pattern() {
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node());
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
@ -123,7 +147,7 @@ fn test_query_matches_with_multiple_on_same_root() {
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node());
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
@ -170,7 +194,7 @@ fn test_query_matches_with_multiple_patterns_different_roots() {
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node());
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
@ -212,7 +236,7 @@ fn test_query_matches_with_multiple_patterns_same_root() {
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node());
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
@ -249,7 +273,7 @@ fn test_query_matches_with_nesting_and_no_fields() {
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node());
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
@ -275,7 +299,7 @@ fn test_query_matches_with_many() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node());
let matches = cursor.matches(&query, tree.root_node(), to_callback(&source));
assert_eq!(
collect_matches(matches, &query, source.as_str()),
@ -304,7 +328,7 @@ fn test_query_matches_with_too_many_permutations_to_track() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node());
let matches = cursor.matches(&query, tree.root_node(), to_callback(&source));
// For this pathological query, some match permutations will be dropped.
// Just check that a subset of the results are returned, and crash or
@ -335,7 +359,7 @@ fn test_query_matches_with_anonymous_tokens() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node());
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
@ -360,9 +384,10 @@ fn test_query_matches_within_byte_range() {
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor
.set_byte_range(5, 15)
.matches(&query, tree.root_node());
let matches =
cursor
.set_byte_range(5, 15)
.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
@ -412,13 +437,13 @@ fn test_query_matches_different_queries_same_cursor() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let matches = cursor.matches(&query1, tree.root_node());
let matches = cursor.matches(&query1, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query1, source),
&[(0, vec![("id1", "a")]),]
);
let matches = cursor.matches(&query3, tree.root_node());
let matches = cursor.matches(&query3, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query3, source),
&[
@ -428,7 +453,7 @@ fn test_query_matches_different_queries_same_cursor() {
]
);
let matches = cursor.matches(&query2, tree.root_node());
let matches = cursor.matches(&query2, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query2, source),
&[(0, vec![("id1", "a")]), (1, vec![("id2", "b")]),]
@ -474,7 +499,7 @@ fn test_query_captures() {
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node());
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
@ -490,7 +515,7 @@ fn test_query_captures() {
],
);
let captures = cursor.captures(&query, tree.root_node());
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_captures(captures, &query, source),
&[
@ -511,6 +536,54 @@ fn test_query_captures() {
});
}
#[test]
fn test_query_captures_with_text_conditions() {
allocations::record(|| {
let language = get_language("javascript");
let query = Query::new(
language,
r#"
(identifier) @variable
((identifier) @function.builtin
(eq? @function.builtin "require"))
((identifier) @constructor
(match? @constructor "^[A-Z]"))
((identifier) @constant
(match? @constant "^[A-Z]{2,}$"))
"#,
)
.unwrap();
let source = "
const ab = require('./ab');
new Cd(EF);
";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_captures(captures, &query, source),
&[
("variable", "ab"),
("variable", "require"),
("function.builtin", "require"),
("variable", "Cd"),
("constructor", "Cd"),
("variable", "EF"),
("constructor", "EF"),
("constant", "EF"),
],
);
});
}
#[test]
fn test_query_capture_names() {
allocations::record(|| {
@ -564,7 +637,7 @@ fn test_query_comments() {
parser.set_language(language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node());
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
assert_eq!(
collect_matches(matches, &query, source),
&[(0, vec![("fn-name", "one")]),],
@ -601,3 +674,7 @@ fn collect_captures<'a, 'b>(
})
.collect()
}
fn to_callback<'a>(source: &'a str) -> impl Fn(Node) -> &'a [u8] {
move |n| &source.as_bytes()[n.byte_range()]
}

View file

@ -109,10 +109,29 @@ pub struct TSQueryCapture {
pub node: TSNode,
pub index: u32,
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryMatch {
pub id: u32,
pub pattern_index: u16,
pub capture_count: u16,
pub captures: *const TSQueryCapture,
}
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeDone: TSQueryPredicateStepType = 0;
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture: TSQueryPredicateStepType = 1;
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeString: TSQueryPredicateStepType = 2;
pub type TSQueryPredicateStepType = u32;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryPredicateStep {
pub type_: TSQueryPredicateStepType,
pub value_id: u32,
}
pub const TSQueryError_TSQueryErrorNone: TSQueryError = 0;
pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1;
pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2;
pub const TSQueryError_TSQueryErrorField: TSQueryError = 3;
pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4;
pub type TSQueryError = u32;
extern "C" {
#[doc = " Create a new parser."]
@ -582,27 +601,58 @@ extern "C" {
pub fn ts_query_delete(arg1: *mut TSQuery);
}
extern "C" {
#[doc = " Get the number of distinct capture names in the query."]
#[doc = " Get the number of patterns in the query."]
pub fn ts_query_pattern_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
#[doc = " Get the predicates for the given pattern in the query."]
pub fn ts_query_predicates_for_pattern(
self_: *const TSQuery,
pattern_index: u32,
length: *mut u32,
) -> *const TSQueryPredicateStep;
}
extern "C" {
#[doc = " Get the number of distinct capture names in the query, or the number of"]
#[doc = " distinct string literals in the query."]
pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
#[doc = " Get the name and length of one of the query\'s capture. Each capture"]
#[doc = " is associated with a numeric id based on the order that it appeared"]
#[doc = " in the query\'s source."]
pub fn ts_query_string_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
#[doc = " Get the name and length of one of the query\'s captures, or one of the"]
#[doc = " query\'s string literals. Each capture and string is associated with a"]
#[doc = " numeric id based on the order that it appeared in the query\'s source."]
pub fn ts_query_capture_name_for_id(
self_: *const TSQuery,
index: u32,
arg1: *const TSQuery,
id: u32,
length: *mut u32,
) -> *const ::std::os::raw::c_char;
}
extern "C" {
#[doc = " Get the numeric id of the capture with the given name."]
pub fn ts_query_string_value_for_id(
arg1: *const TSQuery,
id: u32,
length: *mut u32,
) -> *const ::std::os::raw::c_char;
}
extern "C" {
#[doc = " Get the numeric id of the capture with the given name, or string with the"]
#[doc = " given value."]
pub fn ts_query_capture_id_for_name(
self_: *const TSQuery,
name: *const ::std::os::raw::c_char,
length: u32,
) -> ::std::os::raw::c_int;
}
extern "C" {
pub fn ts_query_string_id_for_value(
self_: *const TSQuery,
value: *const ::std::os::raw::c_char,
length: u32,
) -> ::std::os::raw::c_int;
}
extern "C" {
#[doc = " Create a new cursor for executing a given query."]
#[doc = ""]
@ -645,24 +695,19 @@ extern "C" {
extern "C" {
#[doc = " Advance to the next match of the currently running query."]
#[doc = ""]
#[doc = " If there is another match, write its pattern index to `pattern_index`,"]
#[doc = " the number of captures to `capture_count`, and the captures themselves"]
#[doc = " to `*captures`, and return `true`. Otherwise, return `false`."]
pub fn ts_query_cursor_next_match(
self_: *mut TSQueryCursor,
pattern_index: *mut u32,
capture_count: *mut u32,
captures: *mut *const TSQueryCapture,
) -> bool;
#[doc = " If there is a match, write it to `*match` and return `true`."]
#[doc = " Otherwise, return `false`."]
pub fn ts_query_cursor_next_match(arg1: *mut TSQueryCursor, match_: *mut TSQueryMatch) -> bool;
}
extern "C" {
#[doc = " Advance to the next capture of the currently running query."]
#[doc = ""]
#[doc = " If there is another capture, write it to `capture` and return `true`."]
#[doc = " Otherwise, return `false`."]
#[doc = " If there is a capture, write its match to `*match` and its index within"]
#[doc = " the matche\'s capture list to `*capture_index`. Otherwise, return `false`."]
pub fn ts_query_cursor_next_capture(
arg1: *mut TSQueryCursor,
capture: *mut TSQueryCapture,
match_: *mut TSQueryMatch,
capture_index: *mut u32,
) -> bool;
}
extern "C" {

View file

@ -15,10 +15,10 @@ use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::ffi::CStr;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::os::raw::{c_char, c_void};
use std::sync::atomic::AtomicUsize;
use std::{char, fmt, ptr, slice, str, u16};
use std::mem::MaybeUninit;
pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION;
pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h");
@ -137,10 +137,18 @@ pub struct TreePropertyCursor<'a, P> {
source: &'a [u8],
}
#[derive(Debug)]
enum QueryPredicate {
CaptureEqString(u32, String),
CaptureEqCapture(u32, u32),
CaptureMatchString(u32, regex::bytes::Regex),
}
#[derive(Debug)]
pub struct Query {
ptr: *mut ffi::TSQuery,
capture_names: Vec<String>,
predicates: Vec<Vec<QueryPredicate>>,
}
pub struct QueryCursor(*mut ffi::TSQueryCursor);
@ -157,6 +165,8 @@ pub enum QueryError<'a> {
Syntax(usize),
NodeType(&'a str),
Field(&'a str),
Capture(&'a str),
Predicate(String),
}
impl Language {
@ -331,7 +341,7 @@ impl Parser {
)
}
/// Parse a slice UTF16 text.
/// Parse a slice of UTF16 text.
///
/// # Arguments:
/// * `text` The UTF16-encoded text to parse.
@ -615,6 +625,10 @@ impl<'tree> Node<'tree> {
unsafe { ffi::ts_node_end_byte(self.0) as usize }
}
pub fn byte_range(&self) -> std::ops::Range<usize> {
self.start_byte()..self.end_byte()
}
pub fn range(&self) -> Range {
Range {
start_byte: self.start_byte(),
@ -945,10 +959,12 @@ impl<'a, P> TreePropertyCursor<'a, P> {
}
impl Query {
pub fn new(language: Language, source: &str) -> Result<Self, QueryError> {
pub fn new<'a>(language: Language, source: &'a str) -> Result<Self, QueryError<'a>> {
let mut error_offset = 0u32;
let mut error_type: ffi::TSQueryError = 0;
let bytes = source.as_bytes();
// Compile the query.
let ptr = unsafe {
ffi::ts_query_new(
language.0,
@ -958,38 +974,156 @@ impl Query {
&mut error_type as *mut ffi::TSQueryError,
)
};
// On failure, build an error based on the error code and offset.
if ptr.is_null() {
let offset = error_offset as usize;
Err(match error_type {
ffi::TSQueryError_TSQueryErrorNodeType | ffi::TSQueryError_TSQueryErrorField => {
let suffix = source.split_at(offset).1;
let end_offset = suffix
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
.unwrap_or(source.len());
let name = suffix.split_at(end_offset).0;
if error_type == ffi::TSQueryError_TSQueryErrorNodeType {
QueryError::NodeType(name)
} else {
QueryError::Field(name)
}
return if error_type != ffi::TSQueryError_TSQueryErrorSyntax {
let suffix = source.split_at(offset).1;
let end_offset = suffix
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
.unwrap_or(source.len());
let name = suffix.split_at(end_offset).0;
match error_type {
ffi::TSQueryError_TSQueryErrorNodeType => Err(QueryError::NodeType(name)),
ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(name)),
ffi::TSQueryError_TSQueryErrorCapture => Err(QueryError::Capture(name)),
_ => Err(QueryError::Syntax(offset)),
}
_ => QueryError::Syntax(offset),
})
} else {
let capture_count = unsafe { ffi::ts_query_capture_count(ptr) };
let capture_names = (0..capture_count)
.map(|i| unsafe {
let mut length = 0u32;
let name =
ffi::ts_query_capture_name_for_id(ptr, i as u32, &mut length as *mut u32)
as *const u8;
let name = slice::from_raw_parts(name, length as usize);
let name = str::from_utf8_unchecked(name);
name.to_string()
})
.collect();
Ok(Query { ptr, capture_names })
} else {
Err(QueryError::Syntax(offset))
};
}
let string_count = unsafe { ffi::ts_query_string_count(ptr) };
let capture_count = unsafe { ffi::ts_query_capture_count(ptr) };
let pattern_count = unsafe { ffi::ts_query_pattern_count(ptr) as usize };
let mut result = Query {
ptr,
capture_names: Vec::with_capacity(capture_count as usize),
predicates: Vec::with_capacity(pattern_count),
};
// Build a vector of strings to store the capture names.
for i in 0..capture_count {
unsafe {
let mut length = 0u32;
let name =
ffi::ts_query_capture_name_for_id(ptr, i, &mut length as *mut u32) as *const u8;
let name = slice::from_raw_parts(name, length as usize);
let name = str::from_utf8_unchecked(name);
result.capture_names.push(name.to_string());
}
}
// Build a vector of strings to represent literal values used in predicates.
let string_values = (0..string_count)
.map(|i| unsafe {
let mut length = 0u32;
let value =
ffi::ts_query_string_value_for_id(ptr, i as u32, &mut length as *mut u32)
as *const u8;
let value = slice::from_raw_parts(value, length as usize);
let value = str::from_utf8_unchecked(value);
value.to_string()
})
.collect::<Vec<_>>();
// Build a vector of predicates for each pattern.
for i in 0..pattern_count {
let predicate_steps = unsafe {
let mut length = 0u32;
let raw_predicates =
ffi::ts_query_predicates_for_pattern(ptr, i as u32, &mut length as *mut u32);
slice::from_raw_parts(raw_predicates, length as usize)
};
let type_done = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeDone;
let type_capture = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture;
let type_string = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeString;
let mut pattern_predicates = Vec::new();
for p in predicate_steps.split(|s| s.type_ == type_done) {
if p.is_empty() {
continue;
}
if p[0].type_ != type_string {
return Err(QueryError::Predicate(format!(
"Expected predicate to start with a function name. Got @{}.",
result.capture_names[p[0].value_id as usize],
)));
}
// Build a predicate for each of the known predicate function names.
let operator_name = &string_values[p[0].value_id as usize];
pattern_predicates.push(match operator_name.as_str() {
"eq?" => {
if p.len() != 3 {
return Err(QueryError::Predicate(format!(
"Wrong number of arguments to eq? predicate. Expected 2, got {}.",
p.len() - 1
)));
}
if p[1].type_ != type_capture {
return Err(QueryError::Predicate(format!(
"First argument to eq? predicate must be a capture name. Got literal \"{}\".",
string_values[p[1].value_id as usize],
)));
}
if p[2].type_ == type_capture {
Ok(QueryPredicate::CaptureEqCapture(
p[1].value_id,
p[2].value_id,
))
} else {
Ok(QueryPredicate::CaptureEqString(
p[1].value_id,
string_values[p[2].value_id as usize].clone(),
))
}
}
"match?" => {
if p.len() != 3 {
return Err(QueryError::Predicate(format!(
"Wrong number of arguments to match? predicate. Expected 2, got {}.",
p.len() - 1
)));
}
if p[1].type_ != type_capture {
return Err(QueryError::Predicate(format!(
"First argument to match? predicate must be a capture name. Got literal \"{}\".",
string_values[p[1].value_id as usize],
)));
}
if p[2].type_ == type_capture {
return Err(QueryError::Predicate(format!(
"Second argument to match? predicate must be a literal. Got capture @{}.",
result.capture_names[p[2].value_id as usize],
)));
}
let regex = &string_values[p[2].value_id as usize];
Ok(QueryPredicate::CaptureMatchString(
p[1].value_id,
regex::bytes::Regex::new(regex)
.map_err(|_| QueryError::Predicate(format!("Invalid regex '{}'", regex)))?,
))
}
_ => Err(QueryError::Predicate(format!(
"Unknown query predicate function {}",
operator_name,
))),
}?);
}
result.predicates.push(pattern_predicates);
}
Ok(result)
}
pub fn capture_names(&self) -> &[String] {
@ -1006,26 +1140,21 @@ impl QueryCursor {
&'a mut self,
query: &'a Query,
node: Node<'a>,
text_callback: impl FnMut(Node<'a>) -> &'a [u8],
) -> impl Iterator<Item = QueryMatch<'a>> + 'a {
unsafe {
ffi::ts_query_cursor_exec(self.0, query.ptr, node.0);
}
std::iter::from_fn(move || -> Option<QueryMatch<'a>> {
unsafe {
let mut pattern_index = 0u32;
let mut capture_count = 0u32;
let mut captures = ptr::null();
if ffi::ts_query_cursor_next_match(
self.0,
&mut pattern_index as *mut u32,
&mut capture_count as *mut u32,
&mut captures as *mut *const ffi::TSQueryCapture,
) {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_match(self.0, m.as_mut_ptr()) {
let m = m.assume_init();
Some(QueryMatch {
pattern_index: pattern_index as usize,
capture_count: capture_count as usize,
captures_ptr: captures,
cursor: PhantomData
pattern_index: m.pattern_index as usize,
capture_count: m.capture_count as usize,
captures_ptr: m.captures,
cursor: PhantomData,
})
} else {
None
@ -1038,23 +1167,78 @@ impl QueryCursor {
&'a mut self,
query: &'a Query,
node: Node<'a>,
mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a,
) -> impl Iterator<Item = (usize, Node)> + 'a {
unsafe {
ffi::ts_query_cursor_exec(self.0, query.ptr, node.0);
}
std::iter::from_fn(move || -> Option<(usize, Node<'a>)> {
unsafe {
let mut capture = MaybeUninit::<ffi::TSQueryCapture>::uninit();
if ffi::ts_query_cursor_next_capture(self.0, capture.as_mut_ptr()) {
let capture = capture.assume_init();
Some((capture.index as usize, Node::new(capture.node).unwrap()))
} else {
None
loop {
unsafe {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
let mut capture_index = 0u32;
if ffi::ts_query_cursor_next_capture(
self.0,
m.as_mut_ptr(),
&mut capture_index as *mut u32,
) {
let m = m.assume_init();
let captures = slice::from_raw_parts(m.captures, m.capture_count as usize);
if self.captures_match_condition(
query,
captures,
m.pattern_index as usize,
&mut text_callback,
) {
let capture = captures[capture_index as usize];
return Some((
capture.index as usize,
Node::new(capture.node).unwrap(),
));
}
} else {
return None;
}
}
}
})
}
fn captures_match_condition<'a>(
&self,
query: &'a Query,
captures: &'a [ffi::TSQueryCapture],
pattern_index: usize,
text_callback: &mut impl FnMut(Node<'a>) -> &'a [u8],
) -> bool {
query.predicates[pattern_index]
.iter()
.all(|predicate| match predicate {
QueryPredicate::CaptureEqCapture(i, j) => {
let node1 = Self::capture_for_id(captures, *i).unwrap();
let node2 = Self::capture_for_id(captures, *j).unwrap();
text_callback(node1) == text_callback(node2)
}
QueryPredicate::CaptureEqString(i, s) => {
let node = Self::capture_for_id(captures, *i).unwrap();
text_callback(node) == s.as_bytes()
}
QueryPredicate::CaptureMatchString(i, r) => {
let node = Self::capture_for_id(captures, *i).unwrap();
r.is_match(text_callback(node))
}
})
}
fn capture_for_id(captures: &[ffi::TSQueryCapture], capture_id: u32) -> Option<Node> {
for c in captures {
if c.index == capture_id {
return Node::new(c.node);
}
}
None
}
pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self {
unsafe {
ffi::ts_query_cursor_set_byte_range(self.0, start as u32, end as u32);
@ -1076,7 +1260,8 @@ impl<'a> QueryMatch<'a> {
}
pub fn captures(&self) -> impl ExactSizeIterator<Item = (usize, Node)> {
let captures = unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) };
let captures =
unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) };
captures
.iter()
.map(|capture| (capture.index as usize, Node::new(capture.node).unwrap()))

View file

@ -587,20 +587,14 @@ void ts_query_matches_wasm(
uint32_t match_count = 0;
Array(const void *) result = array_new();
uint32_t pattern_index, capture_count;
const TSQueryCapture *captures;
while (ts_query_cursor_next_match(
scratch_query_cursor,
&pattern_index,
&capture_count,
&captures
)) {
TSQueryMatch match;
while (ts_query_cursor_next_match(scratch_query_cursor, &match)) {
match_count++;
array_grow_by(&result, 2 + 6 * capture_count);
result.contents[index++] = (const void *)pattern_index;
result.contents[index++] = (const void *)capture_count;
for (unsigned i = 0; i < capture_count; i++) {
const TSQueryCapture *capture = &captures[i];
array_grow_by(&result, 2 + 6 * match.capture_count);
result.contents[index++] = (const void *)(uint32_t)match.pattern_index;
result.contents[index++] = (const void *)(uint32_t)match.capture_count;
for (unsigned i = 0; i < match.capture_count; i++) {
const TSQueryCapture *capture = &match.captures[i];
result.contents[index++] = (const void *)capture->index;
marshal_node(result.contents + index, capture->node);
index += 5;
@ -631,14 +625,25 @@ void ts_query_captures_wasm(
unsigned capture_count = 0;
Array(const void *) result = array_new();
TSQueryCapture capture;
while (ts_query_cursor_next_capture(scratch_query_cursor, &capture)) {
TSQueryMatch match;
uint32_t capture_index;
while (ts_query_cursor_next_capture(
scratch_query_cursor,
&match,
&capture_index
)) {
capture_count++;
array_grow_by(&result, 6);
result.contents[index++] = (const void *)capture.index;
marshal_node(result.contents + index, capture.node);
index += 5;
array_grow_by(&result, 3 + 6 * match.capture_count);
result.contents[index++] = (const void *)(uint32_t)match.pattern_index;
result.contents[index++] = (const void *)(uint32_t)match.capture_count;
result.contents[index++] = (const void *)(uint32_t)capture_index;
for (unsigned i = 0; i < match.capture_count; i++) {
const TSQueryCapture *capture = &match.captures[i];
result.contents[index++] = (const void *)capture->index;
marshal_node(result.contents + index, capture->node);
index += 5;
}
}
TRANSFER_BUFFER[0] = (const void *)(capture_count);

View file

@ -795,14 +795,34 @@ class Query {
const count = getValue(TRANSFER_BUFFER, 'i32');
const startAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');
const result = new Array(count);
unmarshalCaptures(this, node.tree, startAddress, result);
const result = [];
let address = startAddress;
for (let i = 0; i < count; i++) {
const pattern = getValue(address, 'i32');
address += SIZE_OF_INT;
const captureCount = getValue(address, 'i32');
address += SIZE_OF_INT;
const captureIndex = getValue(address, 'i32');
address += SIZE_OF_INT;
const captures = new Array(captureCount);
address = unmarshalCaptures(this, node.tree, address, captures);
if (capturesMatchConditions(this, node.tree, pattern, captures)) {
result.push(captures[captureIndex]);
}
}
C._free(startAddress);
return result;
}
}
function capturesMatchConditions(query, tree, pattern, captures) {
return true;
}
function unmarshalCaptures(query, tree, address, result) {
for (let i = 0, n = result.length; i < n; i++) {
const captureIndex = getValue(address, 'i32');

View file

@ -94,11 +94,30 @@ typedef struct {
uint32_t index;
} TSQueryCapture;
typedef struct {
uint32_t id;
uint16_t pattern_index;
uint16_t capture_count;
const TSQueryCapture *captures;
} TSQueryMatch;
typedef enum {
TSQueryPredicateStepTypeDone,
TSQueryPredicateStepTypeCapture,
TSQueryPredicateStepTypeString,
} TSQueryPredicateStepType;
typedef struct {
TSQueryPredicateStepType type;
uint32_t value_id;
} TSQueryPredicateStep;
typedef enum {
TSQueryErrorNone = 0,
TSQueryErrorSyntax,
TSQueryErrorNodeType,
TSQueryErrorField,
TSQueryErrorCapture,
} TSQueryError;
/********************/
@ -645,29 +664,56 @@ TSQuery *ts_query_new(
void ts_query_delete(TSQuery *);
/**
* Get the number of distinct capture names in the query.
* Get the number of patterns in the query.
*/
uint32_t ts_query_capture_count(const TSQuery *);
uint32_t ts_query_pattern_count(const TSQuery *);
/**
* Get the name and length of one of the query's capture. Each capture
* is associated with a numeric id based on the order that it appeared
* in the query's source.
* Get the predicates for the given pattern in the query.
*/
const char *ts_query_capture_name_for_id(
const TSQueryPredicateStep *ts_query_predicates_for_pattern(
const TSQuery *self,
uint32_t index,
uint32_t pattern_index,
uint32_t *length
);
/**
* Get the numeric id of the capture with the given name.
* Get the number of distinct capture names in the query, or the number of
* distinct string literals in the query.
*/
uint32_t ts_query_capture_count(const TSQuery *);
uint32_t ts_query_string_count(const TSQuery *);
/**
* Get the name and length of one of the query's captures, or one of the
* query's string literals. Each capture and string is associated with a
* numeric id based on the order that it appeared in the query's source.
*/
const char *ts_query_capture_name_for_id(
const TSQuery *,
uint32_t id,
uint32_t *length
);
const char *ts_query_string_value_for_id(
const TSQuery *,
uint32_t id,
uint32_t *length
);
/**
* Get the numeric id of the capture with the given name, or string with the
* given value.
*/
int ts_query_capture_id_for_name(
const TSQuery *self,
const char *name,
uint32_t length
);
int ts_query_string_id_for_value(
const TSQuery *self,
const char *value,
uint32_t length
);
/**
* Create a new cursor for executing a given query.
@ -713,24 +759,22 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint);
/**
* Advance to the next match of the currently running query.
*
* If there is another match, write its pattern index to `pattern_index`,
* the number of captures to `capture_count`, and the captures themselves
* to `*captures`, and return `true`. Otherwise, return `false`.
* If there is a match, write it to `*match` and return `true`.
* Otherwise, return `false`.
*/
bool ts_query_cursor_next_match(
TSQueryCursor *self,
uint32_t *pattern_index,
uint32_t *capture_count,
const TSQueryCapture **captures
);
bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match);
/**
* Advance to the next capture of the currently running query.
*
* If there is another capture, write it to `capture` and return `true`.
* Otherwise, return `false`.
* If there is a capture, write its match to `*match` and its index within
* the matche's capture list to `*capture_index`. Otherwise, return `false`.
*/
bool ts_query_cursor_next_capture(TSQueryCursor *, TSQueryCapture *capture);
bool ts_query_cursor_next_capture(
TSQueryCursor *,
TSQueryMatch *match,
uint32_t *capture_index
);
/**********************/
/* Section - Language */

View file

@ -30,13 +30,17 @@ typedef struct {
} QueryStep;
/*
* CaptureSlice - The name of a capture, represented as a slice of a
* shared string.
* Slice - A string represented as a slice of a shared string.
*/
typedef struct {
uint32_t offset;
uint32_t length;
} CaptureSlice;
} Slice;
typedef struct {
Array(char) characters;
Array(Slice) slices;
} SymbolTable;
/*
* PatternSlice - The set of steps needed to match a particular pattern,
@ -60,6 +64,7 @@ typedef struct {
uint8_t capture_count;
uint8_t capture_list_id;
uint8_t consumed_capture_count;
uint32_t id;
} QueryState;
/*
@ -73,6 +78,17 @@ typedef struct {
uint32_t usage_map;
} CaptureListPool;
typedef enum {
PredicateStepTypeSymbol,
PredicateStepTypeCapture,
PredicateStepTypeDone,
} PredicateStepType;
typedef struct {
bool is_capture;
uint16_t value_id;
} PredicateStep;
/*
* TSQuery - A tree query, compiled from a string of S-expressions. The query
* itself is immutable. The mutable state used in the process of executing the
@ -80,9 +96,11 @@ typedef struct {
*/
struct TSQuery {
Array(QueryStep) steps;
Array(char) capture_data;
Array(CaptureSlice) capture_names;
SymbolTable captures;
SymbolTable predicate_values;
Array(PatternSlice) pattern_map;
Array(TSQueryPredicateStep) predicate_steps;
Array(Slice) predicates_by_pattern;
const TSLanguage *language;
uint16_t max_capture_count;
uint16_t wildcard_root_pattern_count;
@ -100,6 +118,7 @@ struct TSQueryCursor {
uint32_t depth;
uint32_t start_byte;
uint32_t end_byte;
uint32_t next_state_id;
TSPoint start_point;
TSPoint end_point;
bool ascending;
@ -177,7 +196,9 @@ static void stream_scan_identifier(Stream *stream) {
iswalnum(stream->next) ||
stream->next == '_' ||
stream->next == '-' ||
stream->next == '.'
stream->next == '.' ||
stream->next == '?' ||
stream->next == '!'
);
}
@ -222,6 +243,65 @@ static void capture_list_pool_release(CaptureListPool *self, uint16_t id) {
self->usage_map |= bitmask_for_index(id);
}
/**************
* SymbolTable
**************/
static SymbolTable symbol_table_new() {
return (SymbolTable) {
.characters = array_new(),
.slices = array_new(),
};
}
static void symbol_table_delete(SymbolTable *self) {
array_delete(&self->characters);
array_delete(&self->slices);
}
static int symbol_table_id_for_name(
const SymbolTable *self,
const char *name,
uint32_t length
) {
for (unsigned i = 0; i < self->slices.size; i++) {
Slice slice = self->slices.contents[i];
if (
slice.length == length &&
!strncmp(&self->characters.contents[slice.offset], name, length)
) return i;
}
return -1;
}
static const char *symbol_table_name_for_id(
const SymbolTable *self,
uint16_t id,
uint32_t *length
) {
Slice slice = self->slices.contents[id];
*length = slice.length;
return &self->characters.contents[slice.offset];
}
static uint16_t symbol_table_insert_name(
SymbolTable *self,
const char *name,
uint32_t length
) {
int id = symbol_table_id_for_name(self, name, length);
if (id >= 0) return (uint16_t)id;
Slice slice = {
.offset = self->characters.size,
.length = length,
};
array_grow_by(&self->characters, length + 1);
memcpy(&self->characters.contents[slice.offset], name, length);
self->characters.contents[self->characters.size - 1] = 0;
array_push(&self->slices, slice);
return self->slices.size - 1;
}
/*********
* Query
*********/
@ -241,24 +321,6 @@ static TSSymbol ts_query_intern_node_name(
return 0;
}
static uint16_t ts_query_intern_capture_name(
TSQuery *self,
const char *name,
uint32_t length
) {
int id = ts_query_capture_id_for_name(self, name, length);
if (id >= 0) return (uint16_t)id;
CaptureSlice capture = {
.offset = self->capture_data.size,
.length = length,
};
array_grow_by(&self->capture_data, length + 1);
memcpy(&self->capture_data.contents[capture.offset], name, length);
self->capture_data.contents[self->capture_data.size - 1] = 0;
array_push(&self->capture_names, capture);
return self->capture_names.size - 1;
}
// The `pattern_map` contains a mapping from TSSymbol values to indices in the
// `steps` array. For a given syntax node, the `pattern_map` makes it possible
// to quickly find the starting steps of all of the patterns whose root matches
@ -322,6 +384,110 @@ static inline void ts_query__pattern_map_insert(
}));
}
static TSQueryError ts_query_parse_predicate(
TSQuery *self,
Stream *stream
) {
if (stream->next == ')') return PARENT_DONE;
if (stream->next != '(') return TSQueryErrorSyntax;
stream_advance(stream);
stream_skip_whitespace(stream);
unsigned step_count = 0;
for (;;) {
if (stream->next == ')') {
stream_advance(stream);
array_back(&self->predicates_by_pattern)->length++;
array_push(&self->predicate_steps, ((TSQueryPredicateStep) {
.type = TSQueryPredicateStepTypeDone,
.value_id = 0,
}));
break;
}
// Parse an `@`-prefixed capture
else if (stream->next == '@') {
stream_advance(stream);
stream_skip_whitespace(stream);
// Parse the capture name
if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax;
const char *capture_name = stream->input;
stream_scan_identifier(stream);
uint32_t length = stream->input - capture_name;
// Add the capture id to the first step of the pattern
int capture_id = symbol_table_id_for_name(
&self->captures,
capture_name,
length
);
if (capture_id == -1) {
stream_reset(stream, capture_name);
return TSQueryErrorCapture;
}
array_back(&self->predicates_by_pattern)->length++;
array_push(&self->predicate_steps, ((TSQueryPredicateStep) {
.type = TSQueryPredicateStepTypeCapture,
.value_id = capture_id,
}));
}
// Parse a string literal
else if (stream->next == '"') {
stream_advance(stream);
// Parse the string content
const char *string_content = stream->input;
while (stream->next != '"') {
if (!stream_advance(stream)) {
stream_reset(stream, string_content - 1);
return TSQueryErrorSyntax;
}
}
uint32_t length = stream->input - string_content;
// Add a step for the node
uint16_t id = symbol_table_insert_name(
&self->predicate_values,
string_content,
length
);
array_back(&self->predicates_by_pattern)->length++;
array_push(&self->predicate_steps, ((TSQueryPredicateStep) {
.type = TSQueryPredicateStepTypeString,
.value_id = id,
}));
if (stream->next != '"') return TSQueryErrorSyntax;
stream_advance(stream);
}
// Parse a bare symbol
else if (stream_is_ident_start(stream)) {
const char *symbol_start = stream->input;
stream_scan_identifier(stream);
uint32_t length = stream->input - symbol_start;
uint16_t id = symbol_table_insert_name(
&self->predicate_values,
symbol_start,
length
);
array_back(&self->predicates_by_pattern)->length++;
array_push(&self->predicate_steps, ((TSQueryPredicateStep) {
.type = TSQueryPredicateStepTypeString,
.value_id = id,
}));
}
step_count++;
stream_skip_whitespace(stream);
}
return 0;
}
// Read one S-expression pattern from the stream, and incorporate it into
// the query's internal state machine representation. For nested patterns,
// this function calls itself recursively.
@ -344,6 +510,26 @@ static TSQueryError ts_query_parse_pattern(
else if (stream->next == '(') {
stream_advance(stream);
stream_skip_whitespace(stream);
// Parse a pattern inside of a conditional form
if (stream->next == '(' && depth == 0) {
TSQueryError e = ts_query_parse_pattern(self, stream, 0, capture_count);
if (e) return e;
// Parse the child patterns
stream_skip_whitespace(stream);
for (;;) {
TSQueryError e = ts_query_parse_predicate(self, stream);
if (e == PARENT_DONE) {
stream_advance(stream);
stream_skip_whitespace(stream);
return 0;
} else if (e) {
return e;
}
}
}
TSSymbol symbol;
// Parse the wildcard symbol
@ -494,8 +680,8 @@ static TSQueryError ts_query_parse_pattern(
uint32_t length = stream->input - capture_name;
// Add the capture id to the first step of the pattern
uint16_t capture_id = ts_query_intern_capture_name(
self,
uint16_t capture_id = symbol_table_insert_name(
&self->captures,
capture_name,
length
);
@ -519,6 +705,10 @@ TSQuery *ts_query_new(
*self = (TSQuery) {
.steps = array_new(),
.pattern_map = array_new(),
.captures = symbol_table_new(),
.predicate_values = symbol_table_new(),
.predicate_steps = array_new(),
.predicates_by_pattern = array_new(),
.wildcard_root_pattern_count = 0,
.max_capture_count = 0,
.language = language,
@ -531,6 +721,10 @@ TSQuery *ts_query_new(
for (;;) {
start_step_index = self->steps.size;
uint32_t capture_count = 0;
array_push(&self->predicates_by_pattern, ((Slice) {
.offset = self->predicate_steps.size,
.length = 0,
}));
*error_type = ts_query_parse_pattern(self, &stream, 0, &capture_count);
array_push(&self->steps, ((QueryStep) { .depth = PATTERN_DONE_MARKER }));
@ -569,14 +763,24 @@ void ts_query_delete(TSQuery *self) {
if (self) {
array_delete(&self->steps);
array_delete(&self->pattern_map);
array_delete(&self->capture_data);
array_delete(&self->capture_names);
array_delete(&self->predicate_steps);
array_delete(&self->predicates_by_pattern);
symbol_table_delete(&self->captures);
symbol_table_delete(&self->predicate_values);
ts_free(self);
}
}
uint32_t ts_query_pattern_count(const TSQuery *self) {
return self->predicates_by_pattern.size;
}
uint32_t ts_query_capture_count(const TSQuery *self) {
return self->capture_names.size;
return self->captures.slices.size;
}
uint32_t ts_query_string_count(const TSQuery *self) {
return self->predicate_values.slices.size;
}
const char *ts_query_capture_name_for_id(
@ -584,9 +788,15 @@ const char *ts_query_capture_name_for_id(
uint32_t index,
uint32_t *length
) {
CaptureSlice name = self->capture_names.contents[index];
*length = name.length;
return &self->capture_data.contents[name.offset];
return symbol_table_name_for_id(&self->captures, index, length);
}
const char *ts_query_string_value_for_id(
const TSQuery *self,
uint32_t index,
uint32_t *length
) {
return symbol_table_name_for_id(&self->predicate_values, index, length);
}
int ts_query_capture_id_for_name(
@ -594,14 +804,25 @@ int ts_query_capture_id_for_name(
const char *name,
uint32_t length
) {
for (unsigned i = 0; i < self->capture_names.size; i++) {
CaptureSlice existing = self->capture_names.contents[i];
if (
existing.length == length &&
!strncmp(&self->capture_data.contents[existing.offset], name, length)
) return i;
}
return -1;
return symbol_table_id_for_name(&self->captures, name, length);
}
int ts_query_string_id_for_value(
const TSQuery *self,
const char *value,
uint32_t length
) {
return symbol_table_id_for_name(&self->predicate_values, value, length);
}
const TSQueryPredicateStep *ts_query_predicates_for_pattern(
const TSQuery *self,
uint32_t pattern_index,
uint32_t *step_count
) {
Slice slice = self->predicates_by_pattern.contents[pattern_index];
*step_count = slice.length;
return &self->predicate_steps.contents[slice.offset];
}
/***************
@ -640,6 +861,7 @@ void ts_query_cursor_exec(
array_clear(&self->finished_states);
ts_tree_cursor_reset(&self->cursor, node);
capture_list_pool_reset(&self->capture_list_pool, query->max_capture_count);
self->next_state_id = 0;
self->depth = 0;
self->ascending = false;
self->query = query;
@ -891,6 +1113,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
if (next_step->depth == PATTERN_DONE_MARKER) {
LOG("finish pattern %u\n", next_state->pattern_index);
next_state->id = self->next_state_id++;
array_push(&self->finished_states, *next_state);
if (next_state == state) {
array_erase(&self->states, i);
@ -915,9 +1138,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
bool ts_query_cursor_next_match(
TSQueryCursor *self,
uint32_t *pattern_index,
uint32_t *capture_count,
const TSQueryCapture **captures
TSQueryMatch *match
) {
if (self->finished_states.size > 0) {
QueryState state = array_pop(&self->finished_states);
@ -927,9 +1148,10 @@ bool ts_query_cursor_next_match(
if (!ts_query_cursor__advance(self)) return false;
const QueryState *state = array_back(&self->finished_states);
*pattern_index = state->pattern_index;
*capture_count = state->capture_count;
*captures = capture_list_pool_get(
match->id = state->id;
match->pattern_index = state->pattern_index;
match->capture_count = state->capture_count;
match->captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
@ -939,7 +1161,8 @@ bool ts_query_cursor_next_match(
bool ts_query_cursor_next_capture(
TSQueryCursor *self,
TSQueryCapture *capture
TSQueryMatch *match,
uint32_t *capture_index
) {
for (;;) {
if (self->finished_states.size > 0) {
@ -991,19 +1214,15 @@ bool ts_query_cursor_next_capture(
QueryState *state = &self->finished_states.contents[
first_finished_state_index
];
const TSQueryCapture *captures = capture_list_pool_get(
match->id = state->id;
match->pattern_index = state->pattern_index;
match->capture_count = state->capture_count;
match->captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
*capture = captures[state->consumed_capture_count];
*capture_index = state->consumed_capture_count;
state->consumed_capture_count++;
if (state->consumed_capture_count == state->capture_count) {
capture_list_pool_release(
&self->capture_list_pool,
state->capture_list_id
);
array_erase(&self->finished_states, first_finished_state_index);
}
return true;
}
}