Allow predicates in queries, to match on nodes' text
This commit is contained in:
parent
307a1a6c11
commit
096126d039
8 changed files with 781 additions and 186 deletions
|
|
@ -109,10 +109,29 @@ pub struct TSQueryCapture {
|
|||
pub node: TSNode,
|
||||
pub index: u32,
|
||||
}
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct TSQueryMatch {
|
||||
pub id: u32,
|
||||
pub pattern_index: u16,
|
||||
pub capture_count: u16,
|
||||
pub captures: *const TSQueryCapture,
|
||||
}
|
||||
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeDone: TSQueryPredicateStepType = 0;
|
||||
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture: TSQueryPredicateStepType = 1;
|
||||
pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeString: TSQueryPredicateStepType = 2;
|
||||
pub type TSQueryPredicateStepType = u32;
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct TSQueryPredicateStep {
|
||||
pub type_: TSQueryPredicateStepType,
|
||||
pub value_id: u32,
|
||||
}
|
||||
pub const TSQueryError_TSQueryErrorNone: TSQueryError = 0;
|
||||
pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1;
|
||||
pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2;
|
||||
pub const TSQueryError_TSQueryErrorField: TSQueryError = 3;
|
||||
pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4;
|
||||
pub type TSQueryError = u32;
|
||||
extern "C" {
|
||||
#[doc = " Create a new parser."]
|
||||
|
|
@ -582,27 +601,58 @@ extern "C" {
|
|||
pub fn ts_query_delete(arg1: *mut TSQuery);
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Get the number of distinct capture names in the query."]
|
||||
#[doc = " Get the number of patterns in the query."]
|
||||
pub fn ts_query_pattern_count(arg1: *const TSQuery) -> u32;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Get the predicates for the given pattern in the query."]
|
||||
pub fn ts_query_predicates_for_pattern(
|
||||
self_: *const TSQuery,
|
||||
pattern_index: u32,
|
||||
length: *mut u32,
|
||||
) -> *const TSQueryPredicateStep;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Get the number of distinct capture names in the query, or the number of"]
|
||||
#[doc = " distinct string literals in the query."]
|
||||
pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Get the name and length of one of the query\'s capture. Each capture"]
|
||||
#[doc = " is associated with a numeric id based on the order that it appeared"]
|
||||
#[doc = " in the query\'s source."]
|
||||
pub fn ts_query_string_count(arg1: *const TSQuery) -> u32;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Get the name and length of one of the query\'s captures, or one of the"]
|
||||
#[doc = " query\'s string literals. Each capture and string is associated with a"]
|
||||
#[doc = " numeric id based on the order that it appeared in the query\'s source."]
|
||||
pub fn ts_query_capture_name_for_id(
|
||||
self_: *const TSQuery,
|
||||
index: u32,
|
||||
arg1: *const TSQuery,
|
||||
id: u32,
|
||||
length: *mut u32,
|
||||
) -> *const ::std::os::raw::c_char;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Get the numeric id of the capture with the given name."]
|
||||
pub fn ts_query_string_value_for_id(
|
||||
arg1: *const TSQuery,
|
||||
id: u32,
|
||||
length: *mut u32,
|
||||
) -> *const ::std::os::raw::c_char;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Get the numeric id of the capture with the given name, or string with the"]
|
||||
#[doc = " given value."]
|
||||
pub fn ts_query_capture_id_for_name(
|
||||
self_: *const TSQuery,
|
||||
name: *const ::std::os::raw::c_char,
|
||||
length: u32,
|
||||
) -> ::std::os::raw::c_int;
|
||||
}
|
||||
extern "C" {
|
||||
pub fn ts_query_string_id_for_value(
|
||||
self_: *const TSQuery,
|
||||
value: *const ::std::os::raw::c_char,
|
||||
length: u32,
|
||||
) -> ::std::os::raw::c_int;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Create a new cursor for executing a given query."]
|
||||
#[doc = ""]
|
||||
|
|
@ -645,24 +695,19 @@ extern "C" {
|
|||
extern "C" {
|
||||
#[doc = " Advance to the next match of the currently running query."]
|
||||
#[doc = ""]
|
||||
#[doc = " If there is another match, write its pattern index to `pattern_index`,"]
|
||||
#[doc = " the number of captures to `capture_count`, and the captures themselves"]
|
||||
#[doc = " to `*captures`, and return `true`. Otherwise, return `false`."]
|
||||
pub fn ts_query_cursor_next_match(
|
||||
self_: *mut TSQueryCursor,
|
||||
pattern_index: *mut u32,
|
||||
capture_count: *mut u32,
|
||||
captures: *mut *const TSQueryCapture,
|
||||
) -> bool;
|
||||
#[doc = " If there is a match, write it to `*match` and return `true`."]
|
||||
#[doc = " Otherwise, return `false`."]
|
||||
pub fn ts_query_cursor_next_match(arg1: *mut TSQueryCursor, match_: *mut TSQueryMatch) -> bool;
|
||||
}
|
||||
extern "C" {
|
||||
#[doc = " Advance to the next capture of the currently running query."]
|
||||
#[doc = ""]
|
||||
#[doc = " If there is another capture, write it to `capture` and return `true`."]
|
||||
#[doc = " Otherwise, return `false`."]
|
||||
#[doc = " If there is a capture, write its match to `*match` and its index within"]
|
||||
#[doc = " the matche\'s capture list to `*capture_index`. Otherwise, return `false`."]
|
||||
pub fn ts_query_cursor_next_capture(
|
||||
arg1: *mut TSQueryCursor,
|
||||
capture: *mut TSQueryCapture,
|
||||
match_: *mut TSQueryMatch,
|
||||
capture_index: *mut u32,
|
||||
) -> bool;
|
||||
}
|
||||
extern "C" {
|
||||
|
|
|
|||
|
|
@ -15,10 +15,10 @@ use serde::de::DeserializeOwned;
|
|||
use std::collections::HashMap;
|
||||
use std::ffi::CStr;
|
||||
use std::marker::PhantomData;
|
||||
use std::mem::MaybeUninit;
|
||||
use std::os::raw::{c_char, c_void};
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::{char, fmt, ptr, slice, str, u16};
|
||||
use std::mem::MaybeUninit;
|
||||
|
||||
pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION;
|
||||
pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h");
|
||||
|
|
@ -137,10 +137,18 @@ pub struct TreePropertyCursor<'a, P> {
|
|||
source: &'a [u8],
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum QueryPredicate {
|
||||
CaptureEqString(u32, String),
|
||||
CaptureEqCapture(u32, u32),
|
||||
CaptureMatchString(u32, regex::bytes::Regex),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Query {
|
||||
ptr: *mut ffi::TSQuery,
|
||||
capture_names: Vec<String>,
|
||||
predicates: Vec<Vec<QueryPredicate>>,
|
||||
}
|
||||
|
||||
pub struct QueryCursor(*mut ffi::TSQueryCursor);
|
||||
|
|
@ -157,6 +165,8 @@ pub enum QueryError<'a> {
|
|||
Syntax(usize),
|
||||
NodeType(&'a str),
|
||||
Field(&'a str),
|
||||
Capture(&'a str),
|
||||
Predicate(String),
|
||||
}
|
||||
|
||||
impl Language {
|
||||
|
|
@ -331,7 +341,7 @@ impl Parser {
|
|||
)
|
||||
}
|
||||
|
||||
/// Parse a slice UTF16 text.
|
||||
/// Parse a slice of UTF16 text.
|
||||
///
|
||||
/// # Arguments:
|
||||
/// * `text` The UTF16-encoded text to parse.
|
||||
|
|
@ -615,6 +625,10 @@ impl<'tree> Node<'tree> {
|
|||
unsafe { ffi::ts_node_end_byte(self.0) as usize }
|
||||
}
|
||||
|
||||
pub fn byte_range(&self) -> std::ops::Range<usize> {
|
||||
self.start_byte()..self.end_byte()
|
||||
}
|
||||
|
||||
pub fn range(&self) -> Range {
|
||||
Range {
|
||||
start_byte: self.start_byte(),
|
||||
|
|
@ -945,10 +959,12 @@ impl<'a, P> TreePropertyCursor<'a, P> {
|
|||
}
|
||||
|
||||
impl Query {
|
||||
pub fn new(language: Language, source: &str) -> Result<Self, QueryError> {
|
||||
pub fn new<'a>(language: Language, source: &'a str) -> Result<Self, QueryError<'a>> {
|
||||
let mut error_offset = 0u32;
|
||||
let mut error_type: ffi::TSQueryError = 0;
|
||||
let bytes = source.as_bytes();
|
||||
|
||||
// Compile the query.
|
||||
let ptr = unsafe {
|
||||
ffi::ts_query_new(
|
||||
language.0,
|
||||
|
|
@ -958,38 +974,156 @@ impl Query {
|
|||
&mut error_type as *mut ffi::TSQueryError,
|
||||
)
|
||||
};
|
||||
|
||||
// On failure, build an error based on the error code and offset.
|
||||
if ptr.is_null() {
|
||||
let offset = error_offset as usize;
|
||||
Err(match error_type {
|
||||
ffi::TSQueryError_TSQueryErrorNodeType | ffi::TSQueryError_TSQueryErrorField => {
|
||||
let suffix = source.split_at(offset).1;
|
||||
let end_offset = suffix
|
||||
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
|
||||
.unwrap_or(source.len());
|
||||
let name = suffix.split_at(end_offset).0;
|
||||
if error_type == ffi::TSQueryError_TSQueryErrorNodeType {
|
||||
QueryError::NodeType(name)
|
||||
} else {
|
||||
QueryError::Field(name)
|
||||
}
|
||||
return if error_type != ffi::TSQueryError_TSQueryErrorSyntax {
|
||||
let suffix = source.split_at(offset).1;
|
||||
let end_offset = suffix
|
||||
.find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-')
|
||||
.unwrap_or(source.len());
|
||||
let name = suffix.split_at(end_offset).0;
|
||||
match error_type {
|
||||
ffi::TSQueryError_TSQueryErrorNodeType => Err(QueryError::NodeType(name)),
|
||||
ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(name)),
|
||||
ffi::TSQueryError_TSQueryErrorCapture => Err(QueryError::Capture(name)),
|
||||
_ => Err(QueryError::Syntax(offset)),
|
||||
}
|
||||
_ => QueryError::Syntax(offset),
|
||||
})
|
||||
} else {
|
||||
let capture_count = unsafe { ffi::ts_query_capture_count(ptr) };
|
||||
let capture_names = (0..capture_count)
|
||||
.map(|i| unsafe {
|
||||
let mut length = 0u32;
|
||||
let name =
|
||||
ffi::ts_query_capture_name_for_id(ptr, i as u32, &mut length as *mut u32)
|
||||
as *const u8;
|
||||
let name = slice::from_raw_parts(name, length as usize);
|
||||
let name = str::from_utf8_unchecked(name);
|
||||
name.to_string()
|
||||
})
|
||||
.collect();
|
||||
Ok(Query { ptr, capture_names })
|
||||
} else {
|
||||
Err(QueryError::Syntax(offset))
|
||||
};
|
||||
}
|
||||
|
||||
let string_count = unsafe { ffi::ts_query_string_count(ptr) };
|
||||
let capture_count = unsafe { ffi::ts_query_capture_count(ptr) };
|
||||
let pattern_count = unsafe { ffi::ts_query_pattern_count(ptr) as usize };
|
||||
let mut result = Query {
|
||||
ptr,
|
||||
capture_names: Vec::with_capacity(capture_count as usize),
|
||||
predicates: Vec::with_capacity(pattern_count),
|
||||
};
|
||||
|
||||
// Build a vector of strings to store the capture names.
|
||||
for i in 0..capture_count {
|
||||
unsafe {
|
||||
let mut length = 0u32;
|
||||
let name =
|
||||
ffi::ts_query_capture_name_for_id(ptr, i, &mut length as *mut u32) as *const u8;
|
||||
let name = slice::from_raw_parts(name, length as usize);
|
||||
let name = str::from_utf8_unchecked(name);
|
||||
result.capture_names.push(name.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Build a vector of strings to represent literal values used in predicates.
|
||||
let string_values = (0..string_count)
|
||||
.map(|i| unsafe {
|
||||
let mut length = 0u32;
|
||||
let value =
|
||||
ffi::ts_query_string_value_for_id(ptr, i as u32, &mut length as *mut u32)
|
||||
as *const u8;
|
||||
let value = slice::from_raw_parts(value, length as usize);
|
||||
let value = str::from_utf8_unchecked(value);
|
||||
value.to_string()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Build a vector of predicates for each pattern.
|
||||
for i in 0..pattern_count {
|
||||
let predicate_steps = unsafe {
|
||||
let mut length = 0u32;
|
||||
let raw_predicates =
|
||||
ffi::ts_query_predicates_for_pattern(ptr, i as u32, &mut length as *mut u32);
|
||||
slice::from_raw_parts(raw_predicates, length as usize)
|
||||
};
|
||||
|
||||
let type_done = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeDone;
|
||||
let type_capture = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture;
|
||||
let type_string = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeString;
|
||||
|
||||
let mut pattern_predicates = Vec::new();
|
||||
for p in predicate_steps.split(|s| s.type_ == type_done) {
|
||||
if p.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if p[0].type_ != type_string {
|
||||
return Err(QueryError::Predicate(format!(
|
||||
"Expected predicate to start with a function name. Got @{}.",
|
||||
result.capture_names[p[0].value_id as usize],
|
||||
)));
|
||||
}
|
||||
|
||||
// Build a predicate for each of the known predicate function names.
|
||||
let operator_name = &string_values[p[0].value_id as usize];
|
||||
pattern_predicates.push(match operator_name.as_str() {
|
||||
"eq?" => {
|
||||
if p.len() != 3 {
|
||||
return Err(QueryError::Predicate(format!(
|
||||
"Wrong number of arguments to eq? predicate. Expected 2, got {}.",
|
||||
p.len() - 1
|
||||
)));
|
||||
}
|
||||
if p[1].type_ != type_capture {
|
||||
return Err(QueryError::Predicate(format!(
|
||||
"First argument to eq? predicate must be a capture name. Got literal \"{}\".",
|
||||
string_values[p[1].value_id as usize],
|
||||
)));
|
||||
}
|
||||
|
||||
if p[2].type_ == type_capture {
|
||||
Ok(QueryPredicate::CaptureEqCapture(
|
||||
p[1].value_id,
|
||||
p[2].value_id,
|
||||
))
|
||||
} else {
|
||||
Ok(QueryPredicate::CaptureEqString(
|
||||
p[1].value_id,
|
||||
string_values[p[2].value_id as usize].clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
"match?" => {
|
||||
if p.len() != 3 {
|
||||
return Err(QueryError::Predicate(format!(
|
||||
"Wrong number of arguments to match? predicate. Expected 2, got {}.",
|
||||
p.len() - 1
|
||||
)));
|
||||
}
|
||||
if p[1].type_ != type_capture {
|
||||
return Err(QueryError::Predicate(format!(
|
||||
"First argument to match? predicate must be a capture name. Got literal \"{}\".",
|
||||
string_values[p[1].value_id as usize],
|
||||
)));
|
||||
}
|
||||
if p[2].type_ == type_capture {
|
||||
return Err(QueryError::Predicate(format!(
|
||||
"Second argument to match? predicate must be a literal. Got capture @{}.",
|
||||
result.capture_names[p[2].value_id as usize],
|
||||
)));
|
||||
}
|
||||
|
||||
let regex = &string_values[p[2].value_id as usize];
|
||||
Ok(QueryPredicate::CaptureMatchString(
|
||||
p[1].value_id,
|
||||
regex::bytes::Regex::new(regex)
|
||||
.map_err(|_| QueryError::Predicate(format!("Invalid regex '{}'", regex)))?,
|
||||
))
|
||||
}
|
||||
|
||||
_ => Err(QueryError::Predicate(format!(
|
||||
"Unknown query predicate function {}",
|
||||
operator_name,
|
||||
))),
|
||||
}?);
|
||||
}
|
||||
|
||||
result.predicates.push(pattern_predicates);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn capture_names(&self) -> &[String] {
|
||||
|
|
@ -1006,26 +1140,21 @@ impl QueryCursor {
|
|||
&'a mut self,
|
||||
query: &'a Query,
|
||||
node: Node<'a>,
|
||||
text_callback: impl FnMut(Node<'a>) -> &'a [u8],
|
||||
) -> impl Iterator<Item = QueryMatch<'a>> + 'a {
|
||||
unsafe {
|
||||
ffi::ts_query_cursor_exec(self.0, query.ptr, node.0);
|
||||
}
|
||||
std::iter::from_fn(move || -> Option<QueryMatch<'a>> {
|
||||
unsafe {
|
||||
let mut pattern_index = 0u32;
|
||||
let mut capture_count = 0u32;
|
||||
let mut captures = ptr::null();
|
||||
if ffi::ts_query_cursor_next_match(
|
||||
self.0,
|
||||
&mut pattern_index as *mut u32,
|
||||
&mut capture_count as *mut u32,
|
||||
&mut captures as *mut *const ffi::TSQueryCapture,
|
||||
) {
|
||||
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
|
||||
if ffi::ts_query_cursor_next_match(self.0, m.as_mut_ptr()) {
|
||||
let m = m.assume_init();
|
||||
Some(QueryMatch {
|
||||
pattern_index: pattern_index as usize,
|
||||
capture_count: capture_count as usize,
|
||||
captures_ptr: captures,
|
||||
cursor: PhantomData
|
||||
pattern_index: m.pattern_index as usize,
|
||||
capture_count: m.capture_count as usize,
|
||||
captures_ptr: m.captures,
|
||||
cursor: PhantomData,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
|
|
@ -1038,23 +1167,78 @@ impl QueryCursor {
|
|||
&'a mut self,
|
||||
query: &'a Query,
|
||||
node: Node<'a>,
|
||||
mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a,
|
||||
) -> impl Iterator<Item = (usize, Node)> + 'a {
|
||||
unsafe {
|
||||
ffi::ts_query_cursor_exec(self.0, query.ptr, node.0);
|
||||
}
|
||||
std::iter::from_fn(move || -> Option<(usize, Node<'a>)> {
|
||||
unsafe {
|
||||
let mut capture = MaybeUninit::<ffi::TSQueryCapture>::uninit();
|
||||
if ffi::ts_query_cursor_next_capture(self.0, capture.as_mut_ptr()) {
|
||||
let capture = capture.assume_init();
|
||||
Some((capture.index as usize, Node::new(capture.node).unwrap()))
|
||||
} else {
|
||||
None
|
||||
loop {
|
||||
unsafe {
|
||||
let mut m = MaybeUninit::<ffi::TSQueryMatch>::uninit();
|
||||
let mut capture_index = 0u32;
|
||||
if ffi::ts_query_cursor_next_capture(
|
||||
self.0,
|
||||
m.as_mut_ptr(),
|
||||
&mut capture_index as *mut u32,
|
||||
) {
|
||||
let m = m.assume_init();
|
||||
let captures = slice::from_raw_parts(m.captures, m.capture_count as usize);
|
||||
if self.captures_match_condition(
|
||||
query,
|
||||
captures,
|
||||
m.pattern_index as usize,
|
||||
&mut text_callback,
|
||||
) {
|
||||
let capture = captures[capture_index as usize];
|
||||
return Some((
|
||||
capture.index as usize,
|
||||
Node::new(capture.node).unwrap(),
|
||||
));
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn captures_match_condition<'a>(
|
||||
&self,
|
||||
query: &'a Query,
|
||||
captures: &'a [ffi::TSQueryCapture],
|
||||
pattern_index: usize,
|
||||
text_callback: &mut impl FnMut(Node<'a>) -> &'a [u8],
|
||||
) -> bool {
|
||||
query.predicates[pattern_index]
|
||||
.iter()
|
||||
.all(|predicate| match predicate {
|
||||
QueryPredicate::CaptureEqCapture(i, j) => {
|
||||
let node1 = Self::capture_for_id(captures, *i).unwrap();
|
||||
let node2 = Self::capture_for_id(captures, *j).unwrap();
|
||||
text_callback(node1) == text_callback(node2)
|
||||
}
|
||||
QueryPredicate::CaptureEqString(i, s) => {
|
||||
let node = Self::capture_for_id(captures, *i).unwrap();
|
||||
text_callback(node) == s.as_bytes()
|
||||
}
|
||||
QueryPredicate::CaptureMatchString(i, r) => {
|
||||
let node = Self::capture_for_id(captures, *i).unwrap();
|
||||
r.is_match(text_callback(node))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn capture_for_id(captures: &[ffi::TSQueryCapture], capture_id: u32) -> Option<Node> {
|
||||
for c in captures {
|
||||
if c.index == capture_id {
|
||||
return Node::new(c.node);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self {
|
||||
unsafe {
|
||||
ffi::ts_query_cursor_set_byte_range(self.0, start as u32, end as u32);
|
||||
|
|
@ -1076,7 +1260,8 @@ impl<'a> QueryMatch<'a> {
|
|||
}
|
||||
|
||||
pub fn captures(&self) -> impl ExactSizeIterator<Item = (usize, Node)> {
|
||||
let captures = unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) };
|
||||
let captures =
|
||||
unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) };
|
||||
captures
|
||||
.iter()
|
||||
.map(|capture| (capture.index as usize, Node::new(capture.node).unwrap()))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue