Allow predicates in queries, to match on nodes' text

This commit is contained in:
Max Brunsfeld 2019-09-15 22:06:51 -07:00
parent 307a1a6c11
commit 096126d039
8 changed files with 781 additions and 186 deletions

View file

@ -109,10 +109,29 @@ pub struct TSQueryCapture {
pub node: TSNode,
pub index: u32,
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryMatch {
pub id: u32,
pub pattern_index: u16,
pub capture_count: u16,
pub captures: *const TSQueryCapture,
}
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeDone: TSQueryPredicateStepType = 0;
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture: TSQueryPredicateStepType = 1;
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeString: TSQueryPredicateStepType = 2;
pub type TSQueryPredicateStepType = u32;
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct TSQueryPredicateStep {
pub type_: TSQueryPredicateStepType,
pub value_id: u32,
}
pub const TSQueryError_TSQueryErrorNone: TSQueryError = 0;
pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1;
pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2;
pub const TSQueryError_TSQueryErrorField: TSQueryError = 3;
pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4;
pub type TSQueryError = u32;
extern "C" {
#[doc = " Create a new parser."]
@ -582,27 +601,58 @@ extern "C" {
pub fn ts_query_delete(arg1: *mut TSQuery);
}
extern "C" {
#[doc = " Get the number of distinct capture names in the query."]
#[doc = " Get the number of patterns in the query."]
pub fn ts_query_pattern_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
#[doc = " Get the predicates for the given pattern in the query."]
pub fn ts_query_predicates_for_pattern(
self_: *const TSQuery,
pattern_index: u32,
length: *mut u32,
) -> *const TSQueryPredicateStep;
}
extern "C" {
#[doc = " Get the number of distinct capture names in the query, or the number of"]
#[doc = " distinct string literals in the query."]
pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
#[doc = " Get the name and length of one of the query\'s capture. Each capture"]
#[doc = " is associated with a numeric id based on the order that it appeared"]
#[doc = " in the query\'s source."]
pub fn ts_query_string_count(arg1: *const TSQuery) -> u32;
}
extern "C" {
#[doc = " Get the name and length of one of the query\'s captures, or one of the"]
#[doc = " query\'s string literals. Each capture and string is associated with a"]
#[doc = " numeric id based on the order that it appeared in the query\'s source."]
pub fn ts_query_capture_name_for_id(
self_: *const TSQuery,
index: u32,
arg1: *const TSQuery,
id: u32,
length: *mut u32,
) -> *const ::std::os::raw::c_char;
}
extern "C" {
#[doc = " Get the numeric id of the capture with the given name."]
pub fn ts_query_string_value_for_id(
arg1: *const TSQuery,
id: u32,
length: *mut u32,
) -> *const ::std::os::raw::c_char;
}
extern "C" {
#[doc = " Get the numeric id of the capture with the given name, or string with the"]
#[doc = " given value."]
pub fn ts_query_capture_id_for_name(
self_: *const TSQuery,
name: *const ::std::os::raw::c_char,
length: u32,
) -> ::std::os::raw::c_int;
}
extern "C" {
pub fn ts_query_string_id_for_value(
self_: *const TSQuery,
value: *const ::std::os::raw::c_char,
length: u32,
) -> ::std::os::raw::c_int;
}
extern "C" {
#[doc = " Create a new cursor for executing a given query."]
#[doc = ""]
@ -645,24 +695,19 @@ extern "C" {
extern "C" {
#[doc = " Advance to the next match of the currently running query."]
#[doc = ""]
#[doc = " If there is another match, write its pattern index to `pattern_index`,"]
#[doc = " the number of captures to `capture_count`, and the captures themselves"]
#[doc = " to `*captures`, and return `true`. Otherwise, return `false`."]
pub fn ts_query_cursor_next_match(
self_: *mut TSQueryCursor,
pattern_index: *mut u32,
capture_count: *mut u32,
captures: *mut *const TSQueryCapture,
) -> bool;
#[doc = " If there is a match, write it to `*match` and return `true`."]
#[doc = " Otherwise, return `false`."]
pub fn ts_query_cursor_next_match(arg1: *mut TSQueryCursor, match_: *mut TSQueryMatch) -> bool;
}
extern "C" {
#[doc = " Advance to the next capture of the currently running query."]
#[doc = ""]
#[doc = " If there is another capture, write it to `capture` and return `true`."]
#[doc = " Otherwise, return `false`."]
#[doc = " If there is a capture, write its match to `*match` and its index within"]
#[doc = " the matche\'s capture list to `*capture_index`. Otherwise, return `false`."]
pub fn ts_query_cursor_next_capture(
arg1: *mut TSQueryCursor,
capture: *mut TSQueryCapture,
match_: *mut TSQueryMatch,
capture_index: *mut u32,
) -> bool;
}
extern "C" {

View file

@ -15,10 +15,10 @@ use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::ffi::CStr;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::os::raw::{c_char, c_void};
use std::sync::atomic::AtomicUsize;
use std::{char, fmt, ptr, slice, str, u16};
use std::mem::MaybeUninit;
pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION;
pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h");
@ -137,10 +137,18 @@ pub struct TreePropertyCursor<'a, P> {
source: &'a [u8],
}
#[derive(Debug)]
enum QueryPredicate {
CaptureEqString(u32, String),
CaptureEqCapture(u32, u32),
CaptureMatchString(u32, regex::bytes::Regex),
}
#[derive(Debug)]
pub struct Query {
ptr: *mut ffi::TSQuery,
capture_names: Vec<String>,
predicates: Vec<Vec<QueryPredicate>>,
}
pub struct QueryCursor(*mut ffi::TSQueryCursor);
@ -157,6 +165,8 @@ pub enum QueryError<'a> {
Syntax(usize),
NodeType(&'a str),
Field(&'a str),
Capture(&'a str),
Predicate(String),
}
impl Language {
@ -331,7 +341,7 @@ impl Parser {
)
}
/// Parse a slice UTF16 text.
/// Parse a slice of UTF16 text.
///
/// # Arguments:
/// * `text` The UTF16-encoded text to parse.
@ -615,6 +625,10 @@ impl<'tree> Node<'tree> {
unsafe { ffi::ts_node_end_byte(self.0) as usize }
}
pub fn byte_range(&self) -> std::ops::Range<usize> {
self.start_byte()..self.end_byte()
}
pub fn range(&self) -> Range {
Range {
start_byte: self.start_byte(),
@ -945,10 +959,12 @@ impl<'a, P> TreePropertyCursor<'a, P> {
}
impl Query {
pub fn new(language: Language, source: &str) -> Result<Self, QueryError> {
pub fn new<'a>(language: Language, source: &'a str) -> Result<Self, QueryError<'a>> {
let mut error_offset = 0u32;
let mut error_type: ffi::TSQueryError = 0;
let bytes = source.as_bytes();
// Compile the query.
let ptr = unsafe {
ffi::ts_query_new(
language.0,
@ -958,38 +974,156 @@ impl Query {
&mut error_type as *mut ffi::TSQueryError,
)
};
// On failure, build an error based on the error code and offset.
if ptr.is_null() {
let offset = error_offset as usize;
Err(match error_type {
ffi::TSQueryError_TSQueryErrorNodeType | ffi::TSQueryError_TSQueryErrorField => {
let suffix = source.split_at(offset).1;
let end_offset = suffix
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
.unwrap_or(source.len());
let name = suffix.split_at(end_offset).0;
if error_type == ffi::TSQueryError_TSQueryErrorNodeType {
QueryError::NodeType(name)
} else {
QueryError::Field(name)
}
return if error_type != ffi::TSQueryError_TSQueryErrorSyntax {
let suffix = source.split_at(offset).1;
let end_offset = suffix
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
.unwrap_or(source.len());
let name = suffix.split_at(end_offset).0;
match error_type {
ffi::TSQueryError_TSQueryErrorNodeType => Err(QueryError::NodeType(name)),
ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(name)),
ffi::TSQueryError_TSQueryErrorCapture => Err(QueryError::Capture(name)),
_ => Err(QueryError::Syntax(offset)),
}
_ => QueryError::Syntax(offset),
})
} else {
let capture_count = unsafe { ffi::ts_query_capture_count(ptr) };
let capture_names = (0..capture_count)
.map(|i| unsafe {
let mut length = 0u32;
let name =
ffi::ts_query_capture_name_for_id(ptr, i as u32, &mut length as *mut u32)
as *const u8;
let name = slice::from_raw_parts(name, length as usize);
let name = str::from_utf8_unchecked(name);
name.to_string()
})
.collect();
Ok(Query { ptr, capture_names })
} else {
Err(QueryError::Syntax(offset))
};
}
let string_count = unsafe { ffi::ts_query_string_count(ptr) };
let capture_count = unsafe { ffi::ts_query_capture_count(ptr) };
let pattern_count = unsafe { ffi::ts_query_pattern_count(ptr) as usize };
let mut result = Query {
ptr,
capture_names: Vec::with_capacity(capture_count as usize),
predicates: Vec::with_capacity(pattern_count),
};
// Build a vector of strings to store the capture names.
for i in 0..capture_count {
unsafe {
let mut length = 0u32;
let name =
ffi::ts_query_capture_name_for_id(ptr, i, &mut length as *mut u32) as *const u8;
let name = slice::from_raw_parts(name, length as usize);
let name = str::from_utf8_unchecked(name);
result.capture_names.push(name.to_string());
}
}
// Build a vector of strings to represent literal values used in predicates.
let string_values = (0..string_count)
.map(|i| unsafe {
let mut length = 0u32;
let value =
ffi::ts_query_string_value_for_id(ptr, i as u32, &mut length as *mut u32)
as *const u8;
let value = slice::from_raw_parts(value, length as usize);
let value = str::from_utf8_unchecked(value);
value.to_string()
})
.collect::<Vec<_>>();
// Build a vector of predicates for each pattern.
for i in 0..pattern_count {
let predicate_steps = unsafe {
let mut length = 0u32;
let raw_predicates =
ffi::ts_query_predicates_for_pattern(ptr, i as u32, &mut length as *mut u32);
slice::from_raw_parts(raw_predicates, length as usize)
};
let type_done = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeDone;
let type_capture = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture;
let type_string = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeString;
let mut pattern_predicates = Vec::new();
for p in predicate_steps.split(|s| s.type_ == type_done) {
if p.is_empty() {
continue;
}
if p[0].type_ != type_string {
return Err(QueryError::Predicate(format!(
"Expected predicate to start with a function name. Got @{}.",
result.capture_names[p[0].value_id as usize],
)));
}
// Build a predicate for each of the known predicate function names.
let operator_name = &string_values[p[0].value_id as usize];
pattern_predicates.push(match operator_name.as_str() {
"eq?" => {
if p.len() != 3 {
return Err(QueryError::Predicate(format!(
"Wrong number of arguments to eq? predicate. Expected 2, got {}.",
p.len() - 1
)));
}
if p[1].type_ != type_capture {
return Err(QueryError::Predicate(format!(
"First argument to eq? predicate must be a capture name. Got literal \"{}\".",
string_values[p[1].value_id as usize],
)));
}
if p[2].type_ == type_capture {
Ok(QueryPredicate::CaptureEqCapture(
p[1].value_id,
p[2].value_id,
))
} else {
Ok(QueryPredicate::CaptureEqString(
p[1].value_id,
string_values[p[2].value_id as usize].clone(),
))
}
}
"match?" => {
if p.len() != 3 {
return Err(QueryError::Predicate(format!(
"Wrong number of arguments to match? predicate. Expected 2, got {}.",
p.len() - 1
)));
}
if p[1].type_ != type_capture {
return Err(QueryError::Predicate(format!(
"First argument to match? predicate must be a capture name. Got literal \"{}\".",
string_values[p[1].value_id as usize],
)));
}
if p[2].type_ == type_capture {
return Err(QueryError::Predicate(format!(
"Second argument to match? predicate must be a literal. Got capture @{}.",
result.capture_names[p[2].value_id as usize],
)));
}
let regex = &string_values[p[2].value_id as usize];
Ok(QueryPredicate::CaptureMatchString(
p[1].value_id,
regex::bytes::Regex::new(regex)
.map_err(|_| QueryError::Predicate(format!("Invalid regex '{}'", regex)))?,
))
}
_ => Err(QueryError::Predicate(format!(
"Unknown query predicate function {}",
operator_name,
))),
}?);
}
result.predicates.push(pattern_predicates);
}
Ok(result)
}
pub fn capture_names(&self) -> &[String] {
@ -1006,26 +1140,21 @@ impl QueryCursor {
&'a mut self,
query: &'a Query,
node: Node<'a>,
text_callback: impl FnMut(Node<'a>) -> &'a [u8],
) -> impl Iterator<Item = QueryMatch<'a>> + 'a {
unsafe {
ffi::ts_query_cursor_exec(self.0, query.ptr, node.0);
}
std::iter::from_fn(move || -> Option<QueryMatch<'a>> {
unsafe {
let mut pattern_index = 0u32;
let mut capture_count = 0u32;
let mut captures = ptr::null();
if ffi::ts_query_cursor_next_match(
self.0,
&mut pattern_index as *mut u32,
&mut capture_count as *mut u32,
&mut captures as *mut *const ffi::TSQueryCapture,
) {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
if ffi::ts_query_cursor_next_match(self.0, m.as_mut_ptr()) {
let m = m.assume_init();
Some(QueryMatch {
pattern_index: pattern_index as usize,
capture_count: capture_count as usize,
captures_ptr: captures,
cursor: PhantomData
pattern_index: m.pattern_index as usize,
capture_count: m.capture_count as usize,
captures_ptr: m.captures,
cursor: PhantomData,
})
} else {
None
@ -1038,23 +1167,78 @@ impl QueryCursor {
&'a mut self,
query: &'a Query,
node: Node<'a>,
mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a,
) -> impl Iterator<Item = (usize, Node)> + 'a {
unsafe {
ffi::ts_query_cursor_exec(self.0, query.ptr, node.0);
}
std::iter::from_fn(move || -> Option<(usize, Node<'a>)> {
unsafe {
let mut capture = MaybeUninit::<ffi::TSQueryCapture>::uninit();
if ffi::ts_query_cursor_next_capture(self.0, capture.as_mut_ptr()) {
let capture = capture.assume_init();
Some((capture.index as usize, Node::new(capture.node).unwrap()))
} else {
None
loop {
unsafe {
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
let mut capture_index = 0u32;
if ffi::ts_query_cursor_next_capture(
self.0,
m.as_mut_ptr(),
&mut capture_index as *mut u32,
) {
let m = m.assume_init();
let captures = slice::from_raw_parts(m.captures, m.capture_count as usize);
if self.captures_match_condition(
query,
captures,
m.pattern_index as usize,
&mut text_callback,
) {
let capture = captures[capture_index as usize];
return Some((
capture.index as usize,
Node::new(capture.node).unwrap(),
));
}
} else {
return None;
}
}
}
})
}
fn captures_match_condition<'a>(
&self,
query: &'a Query,
captures: &'a [ffi::TSQueryCapture],
pattern_index: usize,
text_callback: &mut impl FnMut(Node<'a>) -> &'a [u8],
) -> bool {
query.predicates[pattern_index]
.iter()
.all(|predicate| match predicate {
QueryPredicate::CaptureEqCapture(i, j) => {
let node1 = Self::capture_for_id(captures, *i).unwrap();
let node2 = Self::capture_for_id(captures, *j).unwrap();
text_callback(node1) == text_callback(node2)
}
QueryPredicate::CaptureEqString(i, s) => {
let node = Self::capture_for_id(captures, *i).unwrap();
text_callback(node) == s.as_bytes()
}
QueryPredicate::CaptureMatchString(i, r) => {
let node = Self::capture_for_id(captures, *i).unwrap();
r.is_match(text_callback(node))
}
})
}
fn capture_for_id(captures: &[ffi::TSQueryCapture], capture_id: u32) -> Option<Node> {
for c in captures {
if c.index == capture_id {
return Node::new(c.node);
}
}
None
}
pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self {
unsafe {
ffi::ts_query_cursor_set_byte_range(self.0, start as u32, end as u32);
@ -1076,7 +1260,8 @@ impl<'a> QueryMatch<'a> {
}
pub fn captures(&self) -> impl ExactSizeIterator<Item = (usize, Node)> {
let captures = unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) };
let captures =
unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) };
captures
.iter()
.map(|capture| (capture.index as usize, Node::new(capture.node).unwrap()))