diff --git a/cli/src/main.rs b/cli/src/main.rs index 4b387050..8de7ed67 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -84,7 +84,8 @@ fn run() -> error::Result<()> { .multiple(true) .required(true), ) - .arg(Arg::with_name("scope").long("scope").takes_value(true)), + .arg(Arg::with_name("scope").long("scope").takes_value(true)) + .arg(Arg::with_name("captures").long("captures").short("c")), ) .subcommand( SubCommand::with_name("test") @@ -210,6 +211,7 @@ fn run() -> error::Result<()> { return Error::err(String::new()); } } else if let Some(matches) = matches.subcommand_matches("query") { + let ordered_captures = matches.values_of("captures").is_some(); let paths = matches .values_of("path") .unwrap() @@ -224,7 +226,7 @@ fn run() -> error::Result<()> { matches.value_of("scope"), )?; let query_path = Path::new(matches.value_of("query-path").unwrap()); - query::query_files_at_paths(language, paths, query_path)?; + query::query_files_at_paths(language, paths, query_path, ordered_captures)?; } else if let Some(matches) = matches.subcommand_matches("highlight") { let paths = matches.values_of("path").unwrap().into_iter(); let html_mode = matches.is_present("html"); diff --git a/cli/src/query.rs b/cli/src/query.rs index 4a2f6abb..f373a314 100644 --- a/cli/src/query.rs +++ b/cli/src/query.rs @@ -2,12 +2,13 @@ use super::error::{Error, Result}; use std::fs; use std::io::{self, Write}; use std::path::Path; -use tree_sitter::{Language, Parser, Query, QueryCursor}; +use tree_sitter::{Language, Node, Parser, Query, QueryCursor}; pub fn query_files_at_paths( language: Language, paths: Vec<&Path>, query_path: &Path, + ordered_captures: bool, ) -> Result<()> { let stdout = io::stdout(); let mut stdout = stdout.lock(); @@ -29,19 +30,33 @@ pub fn query_files_at_paths( let source_code = fs::read(path).map_err(Error::wrap(|| { format!("Error reading source file {:?}", path) }))?; - + let text_callback = |n: Node| &source_code[n.byte_range()]; let tree = parser.parse(&source_code, None).unwrap(); - for mat in query_cursor.matches(&query, tree.root_node(), |n| &source_code[n.byte_range()]) { - writeln!(&mut stdout, " pattern: {}", mat.pattern_index())?; - for (capture_id, node) in mat.captures() { + if ordered_captures { + for (pattern_index, capture) in query_cursor.captures(&query, tree.root_node(), text_callback) { writeln!( &mut stdout, - " {}: {:?}", - &query.capture_names()[capture_id], - node.utf8_text(&source_code).unwrap_or("") + " pattern: {}, capture: {}, row: {}, text: {:?}", + pattern_index, + &query.capture_names()[capture.index], + capture.node.start_position().row, + capture.node.utf8_text(&source_code).unwrap_or("") )?; } + } else { + for m in query_cursor.matches(&query, tree.root_node(), text_callback) { + writeln!(&mut stdout, " pattern: {}", m.pattern_index)?; + for capture in m.captures() { + writeln!( + &mut stdout, + " capture: {}, row: {}, text: {:?}", + &query.capture_names()[capture.index], + capture.node.start_position().row, + capture.node.utf8_text(&source_code).unwrap_or("") + )?; + } + } } } diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index c64405e8..b2e78af9 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1,6 +1,6 @@ use super::helpers::allocations; use super::helpers::fixtures::get_language; -use tree_sitter::{Node, Parser, Query, QueryCursor, QueryError, QueryMatch}; +use tree_sitter::{Node, Parser, Query, QueryCapture, QueryCursor, QueryError, QueryMatch}; use std::fmt::Write; #[test] @@ -797,22 +797,22 @@ fn collect_matches<'a>( matches .map(|m| { ( - m.pattern_index(), - collect_captures(m.captures(), query, source), + m.pattern_index, + collect_captures(m.captures().map(|c| (m.pattern_index, c)), query, source), ) }) .collect() } fn collect_captures<'a, 'b>( - captures: impl Iterator)>, + captures: impl Iterator)>, query: &'b Query, source: &'b str, ) -> Vec<(&'b str, &'b str)> { captures - .map(|(capture_id, node)| { + .map(|(_, QueryCapture { index, node })| { ( - query.capture_names()[capture_id].as_str(), + query.capture_names()[index].as_str(), node.utf8_text(source.as_bytes()).unwrap(), ) }) diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index ee457b5f..2069e373 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -154,10 +154,15 @@ pub struct Query { pub struct QueryCursor(*mut ffi::TSQueryCursor); pub struct QueryMatch<'a> { - pattern_index: usize, + pub pattern_index: usize, captures: &'a [ffi::TSQueryCapture], } +pub struct QueryCapture<'a> { + pub index: usize, + pub node: Node<'a>, +} + #[derive(Debug, PartialEq, Eq)] pub enum QueryError<'a> { Syntax(usize), @@ -1135,6 +1140,10 @@ impl Query { unsafe { ffi::ts_query_start_byte_for_pattern(self.ptr, pattern_index as u32) as usize } } + pub fn pattern_count(&self) -> usize { + unsafe { ffi::ts_query_pattern_count(self.ptr) as usize } + } + pub fn capture_names(&self) -> &[String] { &self.capture_names } @@ -1185,37 +1194,38 @@ impl QueryCursor { query: &'a Query, node: Node<'a>, mut text_callback: impl FnMut(Node<'a>) -> &'a [u8] + 'a, - ) -> impl Iterator + 'a { + ) -> impl Iterator + 'a { unsafe { ffi::ts_query_cursor_exec(self.0, query.ptr, node.0); } - std::iter::from_fn(move || -> Option<(usize, Node<'a>)> { - loop { - unsafe { - let mut m = MaybeUninit::::uninit(); - let mut capture_index = 0u32; - if ffi::ts_query_cursor_next_capture( - self.0, - m.as_mut_ptr(), - &mut capture_index as *mut u32, + std::iter::from_fn(move || loop { + unsafe { + let mut m = MaybeUninit::::uninit(); + let mut capture_index = 0u32; + if ffi::ts_query_cursor_next_capture( + self.0, + m.as_mut_ptr(), + &mut capture_index as *mut u32, + ) { + let m = m.assume_init(); + let captures = slice::from_raw_parts(m.captures, m.capture_count as usize); + if self.captures_match_condition( + query, + captures, + m.pattern_index as usize, + &mut text_callback, ) { - let m = m.assume_init(); - let captures = slice::from_raw_parts(m.captures, m.capture_count as usize); - if self.captures_match_condition( - query, - captures, + let capture = captures[capture_index as usize]; + return Some(( m.pattern_index as usize, - &mut text_callback, - ) { - let capture = captures[capture_index as usize]; - return Some(( - capture.index as usize, - Node::new(capture.node).unwrap(), - )); - } - } else { - return None; + QueryCapture { + index: capture.index as usize, + node: Node::new(capture.node).unwrap(), + }, + )); } + } else { + return None; } } }) @@ -1272,14 +1282,11 @@ impl QueryCursor { } impl<'a> QueryMatch<'a> { - pub fn pattern_index(&self) -> usize { - self.pattern_index - } - - pub fn captures(&self) -> impl ExactSizeIterator { - self.captures - .iter() - .map(|capture| (capture.index as usize, Node::new(capture.node).unwrap())) + pub fn captures(&self) -> impl ExactSizeIterator { + self.captures.iter().map(|capture| QueryCapture { + index: capture.index as usize, + node: Node::new(capture.node).unwrap(), + }) } }