From 096126d039dea2e0bdd30f2280c0054a1ba0c0ec Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Sun, 15 Sep 2019 22:06:51 -0700 Subject: [PATCH] Allow predicates in queries, to match on nodes' text --- cli/src/query.rs | 2 +- cli/src/tests/query_test.rs | 111 ++++++++++-- lib/binding_rust/bindings.rs | 83 +++++++-- lib/binding_rust/lib.rs | 289 +++++++++++++++++++++++------ lib/binding_web/binding.c | 43 +++-- lib/binding_web/binding.js | 24 ++- lib/include/tree_sitter/api.h | 84 +++++++-- lib/src/query.c | 331 ++++++++++++++++++++++++++++------ 8 files changed, 781 insertions(+), 186 deletions(-) diff --git a/cli/src/query.rs b/cli/src/query.rs index 9e58c263..4a2f6abb 100644 --- a/cli/src/query.rs +++ b/cli/src/query.rs @@ -32,7 +32,7 @@ pub fn query_files_at_paths( let tree = parser.parse(&source_code, None).unwrap(); - for mat in query_cursor.matches(&query, tree.root_node()) { + for mat in query_cursor.matches(&query, tree.root_node(), |n| &source_code[n.byte_range()]) { writeln!(&mut stdout, " pattern: {}", mat.pattern_index())?; for (capture_id, node) in mat.captures() { writeln!( diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 2ac178fc..91f4cc35 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -67,6 +67,30 @@ fn test_query_errors_on_invalid_symbols() { }); } +#[test] +fn test_query_errors_on_invalid_conditions() { + allocations::record(|| { + let language = get_language("javascript"); + + assert_eq!( + Query::new(language, "((identifier) @id (@id))"), + Err(QueryError::Predicate( + "Expected predicate to start with a function name. Got @id.".to_string() + )) + ); + assert_eq!( + Query::new(language, "((identifier) @id (eq? @id))"), + Err(QueryError::Predicate( + "Wrong number of arguments to eq? predicate. Expected 2, got 1.".to_string() + )) + ); + assert_eq!( + Query::new(language, "((identifier) @id (eq? @id @ok))"), + Err(QueryError::Capture("ok")) + ); + }); +} + #[test] fn test_query_matches_with_simple_pattern() { allocations::record(|| { @@ -83,7 +107,7 @@ fn test_query_matches_with_simple_pattern() { let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query, source), @@ -123,7 +147,7 @@ fn test_query_matches_with_multiple_on_same_root() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query, source), @@ -170,7 +194,7 @@ fn test_query_matches_with_multiple_patterns_different_roots() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query, source), @@ -212,7 +236,7 @@ fn test_query_matches_with_multiple_patterns_same_root() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query, source), @@ -249,7 +273,7 @@ fn test_query_matches_with_nesting_and_no_fields() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query, source), @@ -275,7 +299,7 @@ fn test_query_matches_with_many() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node(), to_callback(&source)); assert_eq!( collect_matches(matches, &query, source.as_str()), @@ -304,7 +328,7 @@ fn test_query_matches_with_too_many_permutations_to_track() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node(), to_callback(&source)); // For this pathological query, some match permutations will be dropped. // Just check that a subset of the results are returned, and crash or @@ -335,7 +359,7 @@ fn test_query_matches_with_anonymous_tokens() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query, source), @@ -360,9 +384,10 @@ fn test_query_matches_within_byte_range() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor - .set_byte_range(5, 15) - .matches(&query, tree.root_node()); + let matches = + cursor + .set_byte_range(5, 15) + .matches(&query, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query, source), @@ -412,13 +437,13 @@ fn test_query_matches_different_queries_same_cursor() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); - let matches = cursor.matches(&query1, tree.root_node()); + let matches = cursor.matches(&query1, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query1, source), &[(0, vec![("id1", "a")]),] ); - let matches = cursor.matches(&query3, tree.root_node()); + let matches = cursor.matches(&query3, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query3, source), &[ @@ -428,7 +453,7 @@ fn test_query_matches_different_queries_same_cursor() { ] ); - let matches = cursor.matches(&query2, tree.root_node()); + let matches = cursor.matches(&query2, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query2, source), &[(0, vec![("id1", "a")]), (1, vec![("id2", "b")]),] @@ -474,7 +499,7 @@ fn test_query_captures() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query, source), @@ -490,7 +515,7 @@ fn test_query_captures() { ], ); - let captures = cursor.captures(&query, tree.root_node()); + let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); assert_eq!( collect_captures(captures, &query, source), &[ @@ -511,6 +536,54 @@ fn test_query_captures() { }); } +#[test] +fn test_query_captures_with_text_conditions() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + (identifier) @variable + + ((identifier) @function.builtin + (eq? @function.builtin "require")) + + ((identifier) @constructor + (match? @constructor "^[A-Z]")) + + ((identifier) @constant + (match? @constant "^[A-Z]{2,}$")) + "#, + ) + .unwrap(); + + let source = " + const ab = require('./ab'); + new Cd(EF); + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + + let captures = cursor.captures(&query, tree.root_node(), to_callback(source)); + assert_eq!( + collect_captures(captures, &query, source), + &[ + ("variable", "ab"), + ("variable", "require"), + ("function.builtin", "require"), + ("variable", "Cd"), + ("constructor", "Cd"), + ("variable", "EF"), + ("constructor", "EF"), + ("constant", "EF"), + ], + ); + }); +} + #[test] fn test_query_capture_names() { allocations::record(|| { @@ -564,7 +637,7 @@ fn test_query_comments() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.matches(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); assert_eq!( collect_matches(matches, &query, source), &[(0, vec![("fn-name", "one")]),], @@ -601,3 +674,7 @@ fn collect_captures<'a, 'b>( }) .collect() } + +fn to_callback<'a>(source: &'a str) -> impl Fn(Node) -> &'a [u8] { + move |n| &source.as_bytes()[n.byte_range()] +} diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 2c8ac77d..1be6472d 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -109,10 +109,29 @@ pub struct TSQueryCapture { pub node: TSNode, pub index: u32, } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TSQueryMatch { + pub id: u32, + pub pattern_index: u16, + pub capture_count: u16, + pub captures: *const TSQueryCapture, +} +pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeDone: TSQueryPredicateStepType = 0; +pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture: TSQueryPredicateStepType = 1; +pub const TSQueryPredicateStepType_TSQueryPredicateStepTypeString: TSQueryPredicateStepType = 2; +pub type TSQueryPredicateStepType = u32; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TSQueryPredicateStep { + pub type_: TSQueryPredicateStepType, + pub value_id: u32, +} pub const TSQueryError_TSQueryErrorNone: TSQueryError = 0; pub const TSQueryError_TSQueryErrorSyntax: TSQueryError = 1; pub const TSQueryError_TSQueryErrorNodeType: TSQueryError = 2; pub const TSQueryError_TSQueryErrorField: TSQueryError = 3; +pub const TSQueryError_TSQueryErrorCapture: TSQueryError = 4; pub type TSQueryError = u32; extern "C" { #[doc = " Create a new parser."] @@ -582,27 +601,58 @@ extern "C" { pub fn ts_query_delete(arg1: *mut TSQuery); } extern "C" { - #[doc = " Get the number of distinct capture names in the query."] + #[doc = " Get the number of patterns in the query."] + pub fn ts_query_pattern_count(arg1: *const TSQuery) -> u32; +} +extern "C" { + #[doc = " Get the predicates for the given pattern in the query."] + pub fn ts_query_predicates_for_pattern( + self_: *const TSQuery, + pattern_index: u32, + length: *mut u32, + ) -> *const TSQueryPredicateStep; +} +extern "C" { + #[doc = " Get the number of distinct capture names in the query, or the number of"] + #[doc = " distinct string literals in the query."] pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32; } extern "C" { - #[doc = " Get the name and length of one of the query\'s capture. Each capture"] - #[doc = " is associated with a numeric id based on the order that it appeared"] - #[doc = " in the query\'s source."] + pub fn ts_query_string_count(arg1: *const TSQuery) -> u32; +} +extern "C" { + #[doc = " Get the name and length of one of the query\'s captures, or one of the"] + #[doc = " query\'s string literals. Each capture and string is associated with a"] + #[doc = " numeric id based on the order that it appeared in the query\'s source."] pub fn ts_query_capture_name_for_id( - self_: *const TSQuery, - index: u32, + arg1: *const TSQuery, + id: u32, length: *mut u32, ) -> *const ::std::os::raw::c_char; } extern "C" { - #[doc = " Get the numeric id of the capture with the given name."] + pub fn ts_query_string_value_for_id( + arg1: *const TSQuery, + id: u32, + length: *mut u32, + ) -> *const ::std::os::raw::c_char; +} +extern "C" { + #[doc = " Get the numeric id of the capture with the given name, or string with the"] + #[doc = " given value."] pub fn ts_query_capture_id_for_name( self_: *const TSQuery, name: *const ::std::os::raw::c_char, length: u32, ) -> ::std::os::raw::c_int; } +extern "C" { + pub fn ts_query_string_id_for_value( + self_: *const TSQuery, + value: *const ::std::os::raw::c_char, + length: u32, + ) -> ::std::os::raw::c_int; +} extern "C" { #[doc = " Create a new cursor for executing a given query."] #[doc = ""] @@ -645,24 +695,19 @@ extern "C" { extern "C" { #[doc = " Advance to the next match of the currently running query."] #[doc = ""] - #[doc = " If there is another match, write its pattern index to `pattern_index`,"] - #[doc = " the number of captures to `capture_count`, and the captures themselves"] - #[doc = " to `*captures`, and return `true`. Otherwise, return `false`."] - pub fn ts_query_cursor_next_match( - self_: *mut TSQueryCursor, - pattern_index: *mut u32, - capture_count: *mut u32, - captures: *mut *const TSQueryCapture, - ) -> bool; + #[doc = " If there is a match, write it to `*match` and return `true`."] + #[doc = " Otherwise, return `false`."] + pub fn ts_query_cursor_next_match(arg1: *mut TSQueryCursor, match_: *mut TSQueryMatch) -> bool; } extern "C" { #[doc = " Advance to the next capture of the currently running query."] #[doc = ""] - #[doc = " If there is another capture, write it to `capture` and return `true`."] - #[doc = " Otherwise, return `false`."] + #[doc = " If there is a capture, write its match to `*match` and its index within"] + #[doc = " the matche\'s capture list to `*capture_index`. Otherwise, return `false`."] pub fn ts_query_cursor_next_capture( arg1: *mut TSQueryCursor, - capture: *mut TSQueryCapture, + match_: *mut TSQueryMatch, + capture_index: *mut u32, ) -> bool; } extern "C" { diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 3eea8c2f..17bb0ffe 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -15,10 +15,10 @@ use serde::de::DeserializeOwned; use std::collections::HashMap; use std::ffi::CStr; use std::marker::PhantomData; +use std::mem::MaybeUninit; use std::os::raw::{c_char, c_void}; use std::sync::atomic::AtomicUsize; use std::{char, fmt, ptr, slice, str, u16}; -use std::mem::MaybeUninit; pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION; pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h"); @@ -137,10 +137,18 @@ pub struct TreePropertyCursor<'a, P> { source: &'a [u8], } +#[derive(Debug)] +enum QueryPredicate { + CaptureEqString(u32, String), + CaptureEqCapture(u32, u32), + CaptureMatchString(u32, regex::bytes::Regex), +} + #[derive(Debug)] pub struct Query { ptr: *mut ffi::TSQuery, capture_names: Vec, + predicates: Vec>, } pub struct QueryCursor(*mut ffi::TSQueryCursor); @@ -157,6 +165,8 @@ pub enum QueryError<'a> { Syntax(usize), NodeType(&'a str), Field(&'a str), + Capture(&'a str), + Predicate(String), } impl Language { @@ -331,7 +341,7 @@ impl Parser { ) } - /// Parse a slice UTF16 text. + /// Parse a slice of UTF16 text. /// /// # Arguments: /// * `text` The UTF16-encoded text to parse. @@ -615,6 +625,10 @@ impl<'tree> Node<'tree> { unsafe { ffi::ts_node_end_byte(self.0) as usize } } + pub fn byte_range(&self) -> std::ops::Range { + self.start_byte()..self.end_byte() + } + pub fn range(&self) -> Range { Range { start_byte: self.start_byte(), @@ -945,10 +959,12 @@ impl<'a, P> TreePropertyCursor<'a, P> { } impl Query { - pub fn new(language: Language, source: &str) -> Result { + pub fn new<'a>(language: Language, source: &'a str) -> Result> { let mut error_offset = 0u32; let mut error_type: ffi::TSQueryError = 0; let bytes = source.as_bytes(); + + // Compile the query. let ptr = unsafe { ffi::ts_query_new( language.0, @@ -958,38 +974,156 @@ impl Query { &mut error_type as *mut ffi::TSQueryError, ) }; + + // On failure, build an error based on the error code and offset. if ptr.is_null() { let offset = error_offset as usize; - Err(match error_type { - ffi::TSQueryError_TSQueryErrorNodeType | ffi::TSQueryError_TSQueryErrorField => { - let suffix = source.split_at(offset).1; - let end_offset = suffix - .find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-') - .unwrap_or(source.len()); - let name = suffix.split_at(end_offset).0; - if error_type == ffi::TSQueryError_TSQueryErrorNodeType { - QueryError::NodeType(name) - } else { - QueryError::Field(name) - } + return if error_type != ffi::TSQueryError_TSQueryErrorSyntax { + let suffix = source.split_at(offset).1; + let end_offset = suffix + .find(|c| !char::is_alphanumeric(c) && c != '_' && c != '-') + .unwrap_or(source.len()); + let name = suffix.split_at(end_offset).0; + match error_type { + ffi::TSQueryError_TSQueryErrorNodeType => Err(QueryError::NodeType(name)), + ffi::TSQueryError_TSQueryErrorField => Err(QueryError::Field(name)), + ffi::TSQueryError_TSQueryErrorCapture => Err(QueryError::Capture(name)), + _ => Err(QueryError::Syntax(offset)), } - _ => QueryError::Syntax(offset), - }) - } else { - let capture_count = unsafe { ffi::ts_query_capture_count(ptr) }; - let capture_names = (0..capture_count) - .map(|i| unsafe { - let mut length = 0u32; - let name = - ffi::ts_query_capture_name_for_id(ptr, i as u32, &mut length as *mut u32) - as *const u8; - let name = slice::from_raw_parts(name, length as usize); - let name = str::from_utf8_unchecked(name); - name.to_string() - }) - .collect(); - Ok(Query { ptr, capture_names }) + } else { + Err(QueryError::Syntax(offset)) + }; } + + let string_count = unsafe { ffi::ts_query_string_count(ptr) }; + let capture_count = unsafe { ffi::ts_query_capture_count(ptr) }; + let pattern_count = unsafe { ffi::ts_query_pattern_count(ptr) as usize }; + let mut result = Query { + ptr, + capture_names: Vec::with_capacity(capture_count as usize), + predicates: Vec::with_capacity(pattern_count), + }; + + // Build a vector of strings to store the capture names. + for i in 0..capture_count { + unsafe { + let mut length = 0u32; + let name = + ffi::ts_query_capture_name_for_id(ptr, i, &mut length as *mut u32) as *const u8; + let name = slice::from_raw_parts(name, length as usize); + let name = str::from_utf8_unchecked(name); + result.capture_names.push(name.to_string()); + } + } + + // Build a vector of strings to represent literal values used in predicates. + let string_values = (0..string_count) + .map(|i| unsafe { + let mut length = 0u32; + let value = + ffi::ts_query_string_value_for_id(ptr, i as u32, &mut length as *mut u32) + as *const u8; + let value = slice::from_raw_parts(value, length as usize); + let value = str::from_utf8_unchecked(value); + value.to_string() + }) + .collect::>(); + + // Build a vector of predicates for each pattern. + for i in 0..pattern_count { + let predicate_steps = unsafe { + let mut length = 0u32; + let raw_predicates = + ffi::ts_query_predicates_for_pattern(ptr, i as u32, &mut length as *mut u32); + slice::from_raw_parts(raw_predicates, length as usize) + }; + + let type_done = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeDone; + let type_capture = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeCapture; + let type_string = ffi::TSQueryPredicateStepType_TSQueryPredicateStepTypeString; + + let mut pattern_predicates = Vec::new(); + for p in predicate_steps.split(|s| s.type_ == type_done) { + if p.is_empty() { + continue; + } + + if p[0].type_ != type_string { + return Err(QueryError::Predicate(format!( + "Expected predicate to start with a function name. Got @{}.", + result.capture_names[p[0].value_id as usize], + ))); + } + + // Build a predicate for each of the known predicate function names. + let operator_name = &string_values[p[0].value_id as usize]; + pattern_predicates.push(match operator_name.as_str() { + "eq?" => { + if p.len() != 3 { + return Err(QueryError::Predicate(format!( + "Wrong number of arguments to eq? predicate. Expected 2, got {}.", + p.len() - 1 + ))); + } + if p[1].type_ != type_capture { + return Err(QueryError::Predicate(format!( + "First argument to eq? predicate must be a capture name. Got literal \"{}\".", + string_values[p[1].value_id as usize], + ))); + } + + if p[2].type_ == type_capture { + Ok(QueryPredicate::CaptureEqCapture( + p[1].value_id, + p[2].value_id, + )) + } else { + Ok(QueryPredicate::CaptureEqString( + p[1].value_id, + string_values[p[2].value_id as usize].clone(), + )) + } + } + + "match?" => { + if p.len() != 3 { + return Err(QueryError::Predicate(format!( + "Wrong number of arguments to match? predicate. Expected 2, got {}.", + p.len() - 1 + ))); + } + if p[1].type_ != type_capture { + return Err(QueryError::Predicate(format!( + "First argument to match? predicate must be a capture name. Got literal \"{}\".", + string_values[p[1].value_id as usize], + ))); + } + if p[2].type_ == type_capture { + return Err(QueryError::Predicate(format!( + "Second argument to match? predicate must be a literal. Got capture @{}.", + result.capture_names[p[2].value_id as usize], + ))); + } + + let regex = &string_values[p[2].value_id as usize]; + Ok(QueryPredicate::CaptureMatchString( + p[1].value_id, + regex::bytes::Regex::new(regex) + .map_err(|_| QueryError::Predicate(format!("Invalid regex '{}'", regex)))?, + )) + } + + _ => Err(QueryError::Predicate(format!( + "Unknown query predicate function {}", + operator_name, + ))), + }?); + } + + result.predicates.push(pattern_predicates); + } + + Ok(result) } pub fn capture_names(&self) -> &[String] { @@ -1006,26 +1140,21 @@ impl QueryCursor { &'a mut self, query: &'a Query, node: Node<'a>, + text_callback: impl FnMut(Node<'a>) -> &'a [u8], ) -> impl Iterator> + 'a { unsafe { ffi::ts_query_cursor_exec(self.0, query.ptr, node.0); } std::iter::from_fn(move || -> Option> { unsafe { - let mut pattern_index = 0u32; - let mut capture_count = 0u32; - let mut captures = ptr::null(); - if ffi::ts_query_cursor_next_match( - self.0, - &mut pattern_index as *mut u32, - &mut capture_count as *mut u32, - &mut captures as *mut *const ffi::TSQueryCapture, - ) { + let mut m = MaybeUninit::::uninit(); + if ffi::ts_query_cursor_next_match(self.0, m.as_mut_ptr()) { + let m = m.assume_init(); Some(QueryMatch { - pattern_index: pattern_index as usize, - capture_count: capture_count as usize, - captures_ptr: captures, - cursor: PhantomData + pattern_index: m.pattern_index as usize, + capture_count: m.capture_count as usize, + captures_ptr: m.captures, + cursor: PhantomData, }) } else { None @@ -1038,23 +1167,78 @@ impl QueryCursor { &'a mut self, query: &'a Query, node: Node<'a>, + mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a, ) -> impl Iterator + 'a { unsafe { ffi::ts_query_cursor_exec(self.0, query.ptr, node.0); } std::iter::from_fn(move || -> Option<(usize, Node<'a>)> { - unsafe { - let mut capture = MaybeUninit::::uninit(); - if ffi::ts_query_cursor_next_capture(self.0, capture.as_mut_ptr()) { - let capture = capture.assume_init(); - Some((capture.index as usize, Node::new(capture.node).unwrap())) - } else { - None + loop { + unsafe { + let mut m = MaybeUninit::::uninit(); + let mut capture_index = 0u32; + if ffi::ts_query_cursor_next_capture( + self.0, + m.as_mut_ptr(), + &mut capture_index as *mut u32, + ) { + let m = m.assume_init(); + let captures = slice::from_raw_parts(m.captures, m.capture_count as usize); + if self.captures_match_condition( + query, + captures, + m.pattern_index as usize, + &mut text_callback, + ) { + let capture = captures[capture_index as usize]; + return Some(( + capture.index as usize, + Node::new(capture.node).unwrap(), + )); + } + } else { + return None; + } } } }) } + fn captures_match_condition<'a>( + &self, + query: &'a Query, + captures: &'a [ffi::TSQueryCapture], + pattern_index: usize, + text_callback: &mut impl FnMut(Node<'a>) -> &'a [u8], + ) -> bool { + query.predicates[pattern_index] + .iter() + .all(|predicate| match predicate { + QueryPredicate::CaptureEqCapture(i, j) => { + let node1 = Self::capture_for_id(captures, *i).unwrap(); + let node2 = Self::capture_for_id(captures, *j).unwrap(); + text_callback(node1) == text_callback(node2) + } + QueryPredicate::CaptureEqString(i, s) => { + let node = Self::capture_for_id(captures, *i).unwrap(); + text_callback(node) == s.as_bytes() + } + QueryPredicate::CaptureMatchString(i, r) => { + let node = Self::capture_for_id(captures, *i).unwrap(); + r.is_match(text_callback(node)) + } + }) + } + + fn capture_for_id(captures: &[ffi::TSQueryCapture], capture_id: u32) -> Option { + for c in captures { + if c.index == capture_id { + return Node::new(c.node); + } + } + None + } + pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self { unsafe { ffi::ts_query_cursor_set_byte_range(self.0, start as u32, end as u32); @@ -1076,7 +1260,8 @@ impl<'a> QueryMatch<'a> { } pub fn captures(&self) -> impl ExactSizeIterator { - let captures = unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) }; + let captures = + unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) }; captures .iter() .map(|capture| (capture.index as usize, Node::new(capture.node).unwrap())) diff --git a/lib/binding_web/binding.c b/lib/binding_web/binding.c index 1e95bf0a..eb463b26 100644 --- a/lib/binding_web/binding.c +++ b/lib/binding_web/binding.c @@ -587,20 +587,14 @@ void ts_query_matches_wasm( uint32_t match_count = 0; Array(const void *) result = array_new(); - uint32_t pattern_index, capture_count; - const TSQueryCapture *captures; - while (ts_query_cursor_next_match( - scratch_query_cursor, - &pattern_index, - &capture_count, - &captures - )) { + TSQueryMatch match; + while (ts_query_cursor_next_match(scratch_query_cursor, &match)) { match_count++; - array_grow_by(&result, 2 + 6 * capture_count); - result.contents[index++] = (const void *)pattern_index; - result.contents[index++] = (const void *)capture_count; - for (unsigned i = 0; i < capture_count; i++) { - const TSQueryCapture *capture = &captures[i]; + array_grow_by(&result, 2 + 6 * match.capture_count); + result.contents[index++] = (const void *)(uint32_t)match.pattern_index; + result.contents[index++] = (const void *)(uint32_t)match.capture_count; + for (unsigned i = 0; i < match.capture_count; i++) { + const TSQueryCapture *capture = &match.captures[i]; result.contents[index++] = (const void *)capture->index; marshal_node(result.contents + index, capture->node); index += 5; @@ -631,14 +625,25 @@ void ts_query_captures_wasm( unsigned capture_count = 0; Array(const void *) result = array_new(); - TSQueryCapture capture; - while (ts_query_cursor_next_capture(scratch_query_cursor, &capture)) { + TSQueryMatch match; + uint32_t capture_index; + while (ts_query_cursor_next_capture( + scratch_query_cursor, + &match, + &capture_index + )) { capture_count++; - array_grow_by(&result, 6); - result.contents[index++] = (const void *)capture.index; - marshal_node(result.contents + index, capture.node); - index += 5; + array_grow_by(&result, 3 + 6 * match.capture_count); + result.contents[index++] = (const void *)(uint32_t)match.pattern_index; + result.contents[index++] = (const void *)(uint32_t)match.capture_count; + result.contents[index++] = (const void *)(uint32_t)capture_index; + for (unsigned i = 0; i < match.capture_count; i++) { + const TSQueryCapture *capture = &match.captures[i]; + result.contents[index++] = (const void *)capture->index; + marshal_node(result.contents + index, capture->node); + index += 5; + } } TRANSFER_BUFFER[0] = (const void *)(capture_count); diff --git a/lib/binding_web/binding.js b/lib/binding_web/binding.js index e1ac4910..599f4fb5 100644 --- a/lib/binding_web/binding.js +++ b/lib/binding_web/binding.js @@ -795,14 +795,34 @@ class Query { const count = getValue(TRANSFER_BUFFER, 'i32'); const startAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); - const result = new Array(count); - unmarshalCaptures(this, node.tree, startAddress, result); + const result = []; + + let address = startAddress; + for (let i = 0; i < count; i++) { + const pattern = getValue(address, 'i32'); + address += SIZE_OF_INT; + const captureCount = getValue(address, 'i32'); + address += SIZE_OF_INT; + const captureIndex = getValue(address, 'i32'); + address += SIZE_OF_INT; + + const captures = new Array(captureCount); + address = unmarshalCaptures(this, node.tree, address, captures); + + if (capturesMatchConditions(this, node.tree, pattern, captures)) { + result.push(captures[captureIndex]); + } + } C._free(startAddress); return result; } } +function capturesMatchConditions(query, tree, pattern, captures) { + return true; +} + function unmarshalCaptures(query, tree, address, result) { for (let i = 0, n = result.length; i < n; i++) { const captureIndex = getValue(address, 'i32'); diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index a5c22eb9..0f96dc65 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -94,11 +94,30 @@ typedef struct { uint32_t index; } TSQueryCapture; +typedef struct { + uint32_t id; + uint16_t pattern_index; + uint16_t capture_count; + const TSQueryCapture *captures; +} TSQueryMatch; + +typedef enum { + TSQueryPredicateStepTypeDone, + TSQueryPredicateStepTypeCapture, + TSQueryPredicateStepTypeString, +} TSQueryPredicateStepType; + +typedef struct { + TSQueryPredicateStepType type; + uint32_t value_id; +} TSQueryPredicateStep; + typedef enum { TSQueryErrorNone = 0, TSQueryErrorSyntax, TSQueryErrorNodeType, TSQueryErrorField, + TSQueryErrorCapture, } TSQueryError; /********************/ @@ -645,29 +664,56 @@ TSQuery *ts_query_new( void ts_query_delete(TSQuery *); /** - * Get the number of distinct capture names in the query. + * Get the number of patterns in the query. */ -uint32_t ts_query_capture_count(const TSQuery *); +uint32_t ts_query_pattern_count(const TSQuery *); /** - * Get the name and length of one of the query's capture. Each capture - * is associated with a numeric id based on the order that it appeared - * in the query's source. + * Get the predicates for the given pattern in the query. */ -const char *ts_query_capture_name_for_id( +const TSQueryPredicateStep *ts_query_predicates_for_pattern( const TSQuery *self, - uint32_t index, + uint32_t pattern_index, uint32_t *length ); /** - * Get the numeric id of the capture with the given name. + * Get the number of distinct capture names in the query, or the number of + * distinct string literals in the query. + */ +uint32_t ts_query_capture_count(const TSQuery *); +uint32_t ts_query_string_count(const TSQuery *); + +/** + * Get the name and length of one of the query's captures, or one of the + * query's string literals. Each capture and string is associated with a + * numeric id based on the order that it appeared in the query's source. + */ +const char *ts_query_capture_name_for_id( + const TSQuery *, + uint32_t id, + uint32_t *length +); +const char *ts_query_string_value_for_id( + const TSQuery *, + uint32_t id, + uint32_t *length +); + +/** + * Get the numeric id of the capture with the given name, or string with the + * given value. */ int ts_query_capture_id_for_name( const TSQuery *self, const char *name, uint32_t length ); +int ts_query_string_id_for_value( + const TSQuery *self, + const char *value, + uint32_t length +); /** * Create a new cursor for executing a given query. @@ -713,24 +759,22 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); /** * Advance to the next match of the currently running query. * - * If there is another match, write its pattern index to `pattern_index`, - * the number of captures to `capture_count`, and the captures themselves - * to `*captures`, and return `true`. Otherwise, return `false`. + * If there is a match, write it to `*match` and return `true`. + * Otherwise, return `false`. */ -bool ts_query_cursor_next_match( - TSQueryCursor *self, - uint32_t *pattern_index, - uint32_t *capture_count, - const TSQueryCapture **captures -); +bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match); /** * Advance to the next capture of the currently running query. * - * If there is another capture, write it to `capture` and return `true`. - * Otherwise, return `false`. + * If there is a capture, write its match to `*match` and its index within + * the matche's capture list to `*capture_index`. Otherwise, return `false`. */ -bool ts_query_cursor_next_capture(TSQueryCursor *, TSQueryCapture *capture); +bool ts_query_cursor_next_capture( + TSQueryCursor *, + TSQueryMatch *match, + uint32_t *capture_index +); /**********************/ /* Section - Language */ diff --git a/lib/src/query.c b/lib/src/query.c index 10d409ed..7a90b5eb 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -30,13 +30,17 @@ typedef struct { } QueryStep; /* - * CaptureSlice - The name of a capture, represented as a slice of a - * shared string. + * Slice - A string represented as a slice of a shared string. */ typedef struct { uint32_t offset; uint32_t length; -} CaptureSlice; +} Slice; + +typedef struct { + Array(char) characters; + Array(Slice) slices; +} SymbolTable; /* * PatternSlice - The set of steps needed to match a particular pattern, @@ -60,6 +64,7 @@ typedef struct { uint8_t capture_count; uint8_t capture_list_id; uint8_t consumed_capture_count; + uint32_t id; } QueryState; /* @@ -73,6 +78,17 @@ typedef struct { uint32_t usage_map; } CaptureListPool; +typedef enum { + PredicateStepTypeSymbol, + PredicateStepTypeCapture, + PredicateStepTypeDone, +} PredicateStepType; + +typedef struct { + bool is_capture; + uint16_t value_id; +} PredicateStep; + /* * TSQuery - A tree query, compiled from a string of S-expressions. The query * itself is immutable. The mutable state used in the process of executing the @@ -80,9 +96,11 @@ typedef struct { */ struct TSQuery { Array(QueryStep) steps; - Array(char) capture_data; - Array(CaptureSlice) capture_names; + SymbolTable captures; + SymbolTable predicate_values; Array(PatternSlice) pattern_map; + Array(TSQueryPredicateStep) predicate_steps; + Array(Slice) predicates_by_pattern; const TSLanguage *language; uint16_t max_capture_count; uint16_t wildcard_root_pattern_count; @@ -100,6 +118,7 @@ struct TSQueryCursor { uint32_t depth; uint32_t start_byte; uint32_t end_byte; + uint32_t next_state_id; TSPoint start_point; TSPoint end_point; bool ascending; @@ -177,7 +196,9 @@ static void stream_scan_identifier(Stream *stream) { iswalnum(stream->next) || stream->next == '_' || stream->next == '-' || - stream->next == '.' + stream->next == '.' || + stream->next == '?' || + stream->next == '!' ); } @@ -222,6 +243,65 @@ static void capture_list_pool_release(CaptureListPool *self, uint16_t id) { self->usage_map |= bitmask_for_index(id); } +/************** + * SymbolTable + **************/ + +static SymbolTable symbol_table_new() { + return (SymbolTable) { + .characters = array_new(), + .slices = array_new(), + }; +} + +static void symbol_table_delete(SymbolTable *self) { + array_delete(&self->characters); + array_delete(&self->slices); +} + +static int symbol_table_id_for_name( + const SymbolTable *self, + const char *name, + uint32_t length +) { + for (unsigned i = 0; i < self->slices.size; i++) { + Slice slice = self->slices.contents[i]; + if ( + slice.length == length && + !strncmp(&self->characters.contents[slice.offset], name, length) + ) return i; + } + return -1; +} + +static const char *symbol_table_name_for_id( + const SymbolTable *self, + uint16_t id, + uint32_t *length +) { + Slice slice = self->slices.contents[id]; + *length = slice.length; + return &self->characters.contents[slice.offset]; +} + +static uint16_t symbol_table_insert_name( + SymbolTable *self, + const char *name, + uint32_t length +) { + int id = symbol_table_id_for_name(self, name, length); + if (id >= 0) return (uint16_t)id; + Slice slice = { + .offset = self->characters.size, + .length = length, + }; + array_grow_by(&self->characters, length + 1); + memcpy(&self->characters.contents[slice.offset], name, length); + self->characters.contents[self->characters.size - 1] = 0; + array_push(&self->slices, slice); + return self->slices.size - 1; +} + /********* * Query *********/ @@ -241,24 +321,6 @@ static TSSymbol ts_query_intern_node_name( return 0; } -static uint16_t ts_query_intern_capture_name( - TSQuery *self, - const char *name, - uint32_t length -) { - int id = ts_query_capture_id_for_name(self, name, length); - if (id >= 0) return (uint16_t)id; - CaptureSlice capture = { - .offset = self->capture_data.size, - .length = length, - }; - array_grow_by(&self->capture_data, length + 1); - memcpy(&self->capture_data.contents[capture.offset], name, length); - self->capture_data.contents[self->capture_data.size - 1] = 0; - array_push(&self->capture_names, capture); - return self->capture_names.size - 1; -} - // The `pattern_map` contains a mapping from TSSymbol values to indices in the // `steps` array. For a given syntax node, the `pattern_map` makes it possible // to quickly find the starting steps of all of the patterns whose root matches @@ -322,6 +384,110 @@ static inline void ts_query__pattern_map_insert( })); } +static TSQueryError ts_query_parse_predicate( + TSQuery *self, + Stream *stream +) { + if (stream->next == ')') return PARENT_DONE; + if (stream->next != '(') return TSQueryErrorSyntax; + stream_advance(stream); + stream_skip_whitespace(stream); + + unsigned step_count = 0; + for (;;) { + if (stream->next == ')') { + stream_advance(stream); + array_back(&self->predicates_by_pattern)->length++; + array_push(&self->predicate_steps, ((TSQueryPredicateStep) { + .type = TSQueryPredicateStepTypeDone, + .value_id = 0, + })); + break; + } + + // Parse an `@`-prefixed capture + else if (stream->next == '@') { + stream_advance(stream); + stream_skip_whitespace(stream); + + // Parse the capture name + if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax; + const char *capture_name = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - capture_name; + + // Add the capture id to the first step of the pattern + int capture_id = symbol_table_id_for_name( + &self->captures, + capture_name, + length + ); + if (capture_id == -1) { + stream_reset(stream, capture_name); + return TSQueryErrorCapture; + } + + array_back(&self->predicates_by_pattern)->length++; + array_push(&self->predicate_steps, ((TSQueryPredicateStep) { + .type = TSQueryPredicateStepTypeCapture, + .value_id = capture_id, + })); + } + + // Parse a string literal + else if (stream->next == '"') { + stream_advance(stream); + + // Parse the string content + const char *string_content = stream->input; + while (stream->next != '"') { + if (!stream_advance(stream)) { + stream_reset(stream, string_content - 1); + return TSQueryErrorSyntax; + } + } + uint32_t length = stream->input - string_content; + + // Add a step for the node + uint16_t id = symbol_table_insert_name( + &self->predicate_values, + string_content, + length + ); + array_back(&self->predicates_by_pattern)->length++; + array_push(&self->predicate_steps, ((TSQueryPredicateStep) { + .type = TSQueryPredicateStepTypeString, + .value_id = id, + })); + + if (stream->next != '"') return TSQueryErrorSyntax; + stream_advance(stream); + } + + // Parse a bare symbol + else if (stream_is_ident_start(stream)) { + const char *symbol_start = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - symbol_start; + uint16_t id = symbol_table_insert_name( + &self->predicate_values, + symbol_start, + length + ); + array_back(&self->predicates_by_pattern)->length++; + array_push(&self->predicate_steps, ((TSQueryPredicateStep) { + .type = TSQueryPredicateStepTypeString, + .value_id = id, + })); + } + + step_count++; + stream_skip_whitespace(stream); + } + + return 0; +} + // Read one S-expression pattern from the stream, and incorporate it into // the query's internal state machine representation. For nested patterns, // this function calls itself recursively. @@ -344,6 +510,26 @@ static TSQueryError ts_query_parse_pattern( else if (stream->next == '(') { stream_advance(stream); stream_skip_whitespace(stream); + + // Parse a pattern inside of a conditional form + if (stream->next == '(' && depth == 0) { + TSQueryError e = ts_query_parse_pattern(self, stream, 0, capture_count); + if (e) return e; + + // Parse the child patterns + stream_skip_whitespace(stream); + for (;;) { + TSQueryError e = ts_query_parse_predicate(self, stream); + if (e == PARENT_DONE) { + stream_advance(stream); + stream_skip_whitespace(stream); + return 0; + } else if (e) { + return e; + } + } + } + TSSymbol symbol; // Parse the wildcard symbol @@ -494,8 +680,8 @@ static TSQueryError ts_query_parse_pattern( uint32_t length = stream->input - capture_name; // Add the capture id to the first step of the pattern - uint16_t capture_id = ts_query_intern_capture_name( - self, + uint16_t capture_id = symbol_table_insert_name( + &self->captures, capture_name, length ); @@ -519,6 +705,10 @@ TSQuery *ts_query_new( *self = (TSQuery) { .steps = array_new(), .pattern_map = array_new(), + .captures = symbol_table_new(), + .predicate_values = symbol_table_new(), + .predicate_steps = array_new(), + .predicates_by_pattern = array_new(), .wildcard_root_pattern_count = 0, .max_capture_count = 0, .language = language, @@ -531,6 +721,10 @@ TSQuery *ts_query_new( for (;;) { start_step_index = self->steps.size; uint32_t capture_count = 0; + array_push(&self->predicates_by_pattern, ((Slice) { + .offset = self->predicate_steps.size, + .length = 0, + })); *error_type = ts_query_parse_pattern(self, &stream, 0, &capture_count); array_push(&self->steps, ((QueryStep) { .depth = PATTERN_DONE_MARKER })); @@ -569,14 +763,24 @@ void ts_query_delete(TSQuery *self) { if (self) { array_delete(&self->steps); array_delete(&self->pattern_map); - array_delete(&self->capture_data); - array_delete(&self->capture_names); + array_delete(&self->predicate_steps); + array_delete(&self->predicates_by_pattern); + symbol_table_delete(&self->captures); + symbol_table_delete(&self->predicate_values); ts_free(self); } } +uint32_t ts_query_pattern_count(const TSQuery *self) { + return self->predicates_by_pattern.size; +} + uint32_t ts_query_capture_count(const TSQuery *self) { - return self->capture_names.size; + return self->captures.slices.size; +} + +uint32_t ts_query_string_count(const TSQuery *self) { + return self->predicate_values.slices.size; } const char *ts_query_capture_name_for_id( @@ -584,9 +788,15 @@ const char *ts_query_capture_name_for_id( uint32_t index, uint32_t *length ) { - CaptureSlice name = self->capture_names.contents[index]; - *length = name.length; - return &self->capture_data.contents[name.offset]; + return symbol_table_name_for_id(&self->captures, index, length); +} + +const char *ts_query_string_value_for_id( + const TSQuery *self, + uint32_t index, + uint32_t *length +) { + return symbol_table_name_for_id(&self->predicate_values, index, length); } int ts_query_capture_id_for_name( @@ -594,14 +804,25 @@ int ts_query_capture_id_for_name( const char *name, uint32_t length ) { - for (unsigned i = 0; i < self->capture_names.size; i++) { - CaptureSlice existing = self->capture_names.contents[i]; - if ( - existing.length == length && - !strncmp(&self->capture_data.contents[existing.offset], name, length) - ) return i; - } - return -1; + return symbol_table_id_for_name(&self->captures, name, length); +} + +int ts_query_string_id_for_value( + const TSQuery *self, + const char *value, + uint32_t length +) { + return symbol_table_id_for_name(&self->predicate_values, value, length); +} + +const TSQueryPredicateStep *ts_query_predicates_for_pattern( + const TSQuery *self, + uint32_t pattern_index, + uint32_t *step_count +) { + Slice slice = self->predicates_by_pattern.contents[pattern_index]; + *step_count = slice.length; + return &self->predicate_steps.contents[slice.offset]; } /*************** @@ -640,6 +861,7 @@ void ts_query_cursor_exec( array_clear(&self->finished_states); ts_tree_cursor_reset(&self->cursor, node); capture_list_pool_reset(&self->capture_list_pool, query->max_capture_count); + self->next_state_id = 0; self->depth = 0; self->ascending = false; self->query = query; @@ -891,6 +1113,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { if (next_step->depth == PATTERN_DONE_MARKER) { LOG("finish pattern %u\n", next_state->pattern_index); + next_state->id = self->next_state_id++; array_push(&self->finished_states, *next_state); if (next_state == state) { array_erase(&self->states, i); @@ -915,9 +1138,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { bool ts_query_cursor_next_match( TSQueryCursor *self, - uint32_t *pattern_index, - uint32_t *capture_count, - const TSQueryCapture **captures + TSQueryMatch *match ) { if (self->finished_states.size > 0) { QueryState state = array_pop(&self->finished_states); @@ -927,9 +1148,10 @@ bool ts_query_cursor_next_match( if (!ts_query_cursor__advance(self)) return false; const QueryState *state = array_back(&self->finished_states); - *pattern_index = state->pattern_index; - *capture_count = state->capture_count; - *captures = capture_list_pool_get( + match->id = state->id; + match->pattern_index = state->pattern_index; + match->capture_count = state->capture_count; + match->captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); @@ -939,7 +1161,8 @@ bool ts_query_cursor_next_match( bool ts_query_cursor_next_capture( TSQueryCursor *self, - TSQueryCapture *capture + TSQueryMatch *match, + uint32_t *capture_index ) { for (;;) { if (self->finished_states.size > 0) { @@ -991,19 +1214,15 @@ bool ts_query_cursor_next_capture( QueryState *state = &self->finished_states.contents[ first_finished_state_index ]; - const TSQueryCapture *captures = capture_list_pool_get( + match->id = state->id; + match->pattern_index = state->pattern_index; + match->capture_count = state->capture_count; + match->captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); - *capture = captures[state->consumed_capture_count]; + *capture_index = state->consumed_capture_count; state->consumed_capture_count++; - if (state->consumed_capture_count == state->capture_count) { - capture_list_pool_release( - &self->capture_list_pool, - state->capture_list_id - ); - array_erase(&self->finished_states, first_finished_state_index); - } return true; } }