diff --git a/cli/src/query.rs b/cli/src/query.rs index 2f50325a..9e58c263 100644 --- a/cli/src/query.rs +++ b/cli/src/query.rs @@ -32,7 +32,7 @@ pub fn query_files_at_paths( let tree = parser.parse(&source_code, None).unwrap(); - for mat in query_cursor.exec(&query, tree.root_node()) { + for mat in query_cursor.matches(&query, tree.root_node()) { writeln!(&mut stdout, " pattern: {}", mat.pattern_index())?; for (capture_id, node) in mat.captures() { writeln!( diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index ff4dd7a1..5c37a9ab 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1,6 +1,6 @@ use super::helpers::allocations; use super::helpers::fixtures::get_language; -use tree_sitter::{Parser, Query, QueryCursor, QueryError, QueryMatch}; +use tree_sitter::{Node, Parser, Query, QueryCursor, QueryError, QueryMatch}; #[test] fn test_query_errors_on_invalid_syntax() { @@ -67,7 +67,7 @@ fn test_query_errors_on_invalid_symbols() { } #[test] -fn test_query_exec_with_simple_pattern() { +fn test_query_matches_with_simple_pattern() { allocations::record(|| { let language = get_language("javascript"); let query = Query::new( @@ -82,7 +82,7 @@ fn test_query_exec_with_simple_pattern() { let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.exec(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node()); assert_eq!( collect_matches(matches, &query, source), @@ -95,7 +95,7 @@ fn test_query_exec_with_simple_pattern() { } #[test] -fn test_query_exec_with_multiple_matches_same_root() { +fn test_query_matches_with_multiple_on_same_root() { allocations::record(|| { let language = get_language("javascript"); let query = Query::new( @@ -122,7 +122,7 @@ fn test_query_exec_with_multiple_matches_same_root() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.exec(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node()); assert_eq!( collect_matches(matches, &query, source), @@ -147,7 +147,7 @@ fn test_query_exec_with_multiple_matches_same_root() { } #[test] -fn test_query_exec_multiple_patterns_different_roots() { +fn test_query_matches_with_multiple_patterns_different_roots() { allocations::record(|| { let language = get_language("javascript"); let query = Query::new( @@ -169,7 +169,7 @@ fn test_query_exec_multiple_patterns_different_roots() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.exec(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node()); assert_eq!( collect_matches(matches, &query, source), @@ -183,7 +183,7 @@ fn test_query_exec_multiple_patterns_different_roots() { } #[test] -fn test_query_exec_multiple_patterns_same_root() { +fn test_query_matches_with_multiple_patterns_same_root() { allocations::record(|| { let language = get_language("javascript"); let query = Query::new( @@ -211,7 +211,7 @@ fn test_query_exec_multiple_patterns_same_root() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.exec(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node()); assert_eq!( collect_matches(matches, &query, source), @@ -224,7 +224,7 @@ fn test_query_exec_multiple_patterns_same_root() { } #[test] -fn test_query_exec_nested_matches_without_fields() { +fn test_query_matches_with_nesting_and_no_fields() { allocations::record(|| { let language = get_language("javascript"); let query = Query::new( @@ -248,7 +248,7 @@ fn test_query_exec_nested_matches_without_fields() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.exec(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node()); assert_eq!( collect_matches(matches, &query, source), @@ -263,7 +263,7 @@ fn test_query_exec_nested_matches_without_fields() { } #[test] -fn test_query_exec_many_matches() { +fn test_query_matches_with_many() { allocations::record(|| { let language = get_language("javascript"); let query = Query::new(language, "(array (identifier) @element)").unwrap(); @@ -274,7 +274,7 @@ fn test_query_exec_many_matches() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.exec(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node()); assert_eq!( collect_matches(matches, &query, source.as_str()), @@ -284,7 +284,7 @@ fn test_query_exec_many_matches() { } #[test] -fn test_query_exec_too_many_match_permutations_to_track() { +fn test_query_matches_with_too_many_permutations_to_track() { allocations::record(|| { let language = get_language("javascript"); let query = Query::new( @@ -303,7 +303,7 @@ fn test_query_exec_too_many_match_permutations_to_track() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.exec(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node()); // For this pathological query, some match permutations will be dropped. // Just check that a subset of the results are returned, and crash or @@ -316,7 +316,7 @@ fn test_query_exec_too_many_match_permutations_to_track() { } #[test] -fn test_query_exec_with_anonymous_tokens() { +fn test_query_matches_with_anonymous_tokens() { allocations::record(|| { let language = get_language("javascript"); let query = Query::new( @@ -334,7 +334,7 @@ fn test_query_exec_with_anonymous_tokens() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.exec(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node()); assert_eq!( collect_matches(matches, &query, source), @@ -347,7 +347,7 @@ fn test_query_exec_with_anonymous_tokens() { } #[test] -fn test_query_exec_within_byte_range() { +fn test_query_matches_within_byte_range() { allocations::record(|| { let language = get_language("javascript"); let query = Query::new(language, "(identifier) @element").unwrap(); @@ -359,7 +359,9 @@ fn test_query_exec_within_byte_range() { let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.set_byte_range(5, 15).exec(&query, tree.root_node()); + let matches = cursor + .set_byte_range(5, 15) + .matches(&query, tree.root_node()); assert_eq!( collect_matches(matches, &query, source), @@ -373,7 +375,7 @@ fn test_query_exec_within_byte_range() { } #[test] -fn test_query_exec_different_queries() { +fn test_query_matches_different_queries_same_cursor() { allocations::record(|| { let language = get_language("javascript"); let query1 = Query::new( @@ -409,13 +411,13 @@ fn test_query_exec_different_queries() { parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); - let matches = cursor.exec(&query1, tree.root_node()); + let matches = cursor.matches(&query1, tree.root_node()); assert_eq!( collect_matches(matches, &query1, source), &[(0, vec![("id1", "a")]),] ); - let matches = cursor.exec(&query3, tree.root_node()); + let matches = cursor.matches(&query3, tree.root_node()); assert_eq!( collect_matches(matches, &query3, source), &[ @@ -425,7 +427,7 @@ fn test_query_exec_different_queries() { ] ); - let matches = cursor.exec(&query2, tree.root_node()); + let matches = cursor.matches(&query2, tree.root_node()); assert_eq!( collect_matches(matches, &query2, source), &[(0, vec![("id1", "a")]), (1, vec![("id2", "b")]),] @@ -433,6 +435,81 @@ fn test_query_exec_different_queries() { }); } +#[test] +fn test_query_captures() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new( + language, + r#" + (pair + key: * @method.def + (function + name: (identifier) @method.alias)) + + (variable_declarator + name: * @function.def + value: (function + name: (identifier) @function.alias)) + + ":" @delimiter + "=" @operator + "#, + ) + .unwrap(); + + let source = " + a({ + bc: function de() { + const fg = function hi() {} + }, + jk: function lm() { + const no = function pq() {} + }, + }); + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let matches = cursor.matches(&query, tree.root_node()); + + assert_eq!( + collect_matches(matches, &query, source), + &[ + (2, vec![("delimiter", ":")]), + (0, vec![("method.def", "bc"), ("method.alias", "de")]), + (3, vec![("operator", "=")]), + (1, vec![("function.def", "fg"), ("function.alias", "hi")]), + (2, vec![("delimiter", ":")]), + (0, vec![("method.def", "jk"), ("method.alias", "lm")]), + (3, vec![("operator", "=")]), + (1, vec![("function.def", "no"), ("function.alias", "pq")]), + ], + ); + + let captures = cursor.captures(&query, tree.root_node()); + assert_eq!( + collect_captures(captures, &query, source), + &[ + ("method.def", "bc"), + ("delimiter", ":"), + ("method.alias", "de"), + ("function.def", "fg"), + ("operator", "="), + ("function.alias", "hi"), + ("method.def", "jk"), + ("delimiter", ":"), + ("method.alias", "lm"), + ("function.def", "no"), + ("operator", "="), + ("function.alias", "pq"), + ] + ); + }); +} + #[test] fn test_query_capture_names() { allocations::record(|| { @@ -486,7 +563,7 @@ fn test_query_comments() { parser.set_language(language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); - let matches = cursor.exec(&query, tree.root_node()); + let matches = cursor.matches(&query, tree.root_node()); assert_eq!( collect_matches(matches, &query, source), &[(0, vec![("fn-name", "one")]),], @@ -503,14 +580,22 @@ fn collect_matches<'a>( .map(|m| { ( m.pattern_index(), - m.captures() - .map(|(capture_id, node)| { - ( - query.capture_names()[capture_id].as_str(), - node.utf8_text(source.as_bytes()).unwrap(), - ) - }) - .collect(), + collect_captures(m.captures(), query, source), + ) + }) + .collect() +} + +fn collect_captures<'a, 'b>( + captures: impl Iterator)>, + query: &'b Query, + source: &'b str, +) -> Vec<(&'b str, &'b str)> { + captures + .map(|(capture_id, node)| { + ( + query.capture_names()[capture_id].as_str(), + node.utf8_text(source.as_bytes()).unwrap(), ) }) .collect() diff --git a/docs/assets/js/playground.js b/docs/assets/js/playground.js index 2366ed2f..d7d4ff33 100644 --- a/docs/assets/js/playground.js +++ b/docs/assets/js/playground.js @@ -47,7 +47,7 @@ let tree; }); const renderTreeOnCodeChange = debounce(renderTree, 50); const saveStateOnChange = debounce(saveState, 2000); - const runTreeQueryOnChange = debounce(runTreeQuery, 150); + const runTreeQueryOnChange = debounce(runTreeQuery, 50); let languageName = languageSelect.value; let treeRows = null; @@ -208,24 +208,22 @@ let tree; marks.forEach(m => m.clear()); if (tree && query) { - const matches = query.exec( + const captures = query.captures( tree.rootNode, {row: startRow, column: 0}, {row: endRow, column: 0}, ); - for (const {captures} of matches) { - for (const {name, node} of captures) { - const {startPosition, endPosition} = node; - codeEditor.markText( - {line: startPosition.row, ch: startPosition.column}, - {line: endPosition.row, ch: endPosition.column}, - { - inclusiveLeft: true, - inclusiveRight: true, - css: `color: ${colorForCaptureName(name)}` - } - ); - } + for (const {name, node} of captures) { + const {startPosition, endPosition} = node; + codeEditor.markText( + {line: startPosition.row, ch: startPosition.column}, + {line: endPosition.row, ch: endPosition.column}, + { + inclusiveLeft: true, + inclusiveRight: true, + css: `color: ${colorForCaptureName(name)}` + } + ); } } }); diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index effd0c81..2c8ac77d 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -607,19 +607,23 @@ extern "C" { #[doc = " Create a new cursor for executing a given query."] #[doc = ""] #[doc = " The cursor stores the state that is needed to iteratively search"] - #[doc = " for matches. To use the query cursor:"] - #[doc = " 1. First call `ts_query_cursor_exec` to start running a given query on"] - #[doc = "a given syntax node."] - #[doc = " 2. Then repeatedly call `ts_query_cursor_next` to iterate over the matches."] - #[doc = " This will return `false` when there are no more matches left."] - #[doc = " 3. After each successful call to `ts_query_cursor_next`, you can call"] - #[doc = " `ts_query_cursor_matched_pattern_index` to determine which pattern"] - #[doc = " matched. You can also call `ts_query_cursor_matched_captures` to"] - #[doc = " determine which nodes were captured, and by which capture names."] + #[doc = " for matches. To use the query cursor, first call `ts_query_cursor_exec`"] + #[doc = " to start running a given query on a given syntax node. Then, there are"] + #[doc = " two options for consuming the results of the query:"] + #[doc = " 1. Repeatedly call `ts_query_cursor_next_match` to iterate over all of the"] + #[doc = " the *matches* in the order that they were found. Each match contains the"] + #[doc = " index of the pattern that matched, and an array of captures. Because"] + #[doc = " multiple patterns can match the same set of nodes, one match may contain"] + #[doc = " captures that appear *before* some of the captures from a previous match."] + #[doc = " 2. Repeatedly call `ts_query_cursor_next_capture` to iterate over all of the"] + #[doc = " individual *captures* in the order that they appear. This is useful if"] + #[doc = " don\'t care about which pattern matched, and just want a single ordered"] + #[doc = " sequence of captures."] #[doc = ""] - #[doc = " If you don\'t care about finding all of the matches, you can stop calling"] - #[doc = " `ts_query_cursor_next` at any point. And you can start executing another"] - #[doc = " query on another node by calling `ts_query_cursor_exec` again."] + #[doc = " If you don\'t care about consuming all of the results, you can stop calling"] + #[doc = " `ts_query_cursor_next_match` or `ts_query_cursor_next_capture` at any point."] + #[doc = " You can then start executing another query on another node by calling"] + #[doc = " `ts_query_cursor_exec` again."] pub fn ts_query_cursor_new() -> *mut TSQueryCursor; } extern "C" { @@ -640,18 +644,26 @@ extern "C" { } extern "C" { #[doc = " Advance to the next match of the currently running query."] - pub fn ts_query_cursor_next(arg1: *mut TSQueryCursor) -> bool; + #[doc = ""] + #[doc = " If there is another match, write its pattern index to `pattern_index`,"] + #[doc = " the number of captures to `capture_count`, and the captures themselves"] + #[doc = " to `*captures`, and return `true`. Otherwise, return `false`."] + pub fn ts_query_cursor_next_match( + self_: *mut TSQueryCursor, + pattern_index: *mut u32, + capture_count: *mut u32, + captures: *mut *const TSQueryCapture, + ) -> bool; } extern "C" { - #[doc = " Check which pattern matched."] - pub fn ts_query_cursor_matched_pattern_index(arg1: *const TSQueryCursor) -> u32; -} -extern "C" { - #[doc = " Check which pattern matched."] - pub fn ts_query_cursor_matched_captures( - arg1: *const TSQueryCursor, - arg2: *mut u32, - ) -> *const TSQueryCapture; + #[doc = " Advance to the next capture of the currently running query."] + #[doc = ""] + #[doc = " If there is another capture, write it to `capture` and return `true`."] + #[doc = " Otherwise, return `false`."] + pub fn ts_query_cursor_next_capture( + arg1: *mut TSQueryCursor, + capture: *mut TSQueryCapture, + ) -> bool; } extern "C" { #[doc = " Get the number of distinct node types in the language."] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 2c20fd12..3eea8c2f 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -18,6 +18,7 @@ use std::marker::PhantomData; use std::os::raw::{c_char, c_void}; use std::sync::atomic::AtomicUsize; use std::{char, fmt, ptr, slice, str, u16}; +use std::mem::MaybeUninit; pub const LANGUAGE_VERSION: usize = ffi::TREE_SITTER_LANGUAGE_VERSION; pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/parser.h"); @@ -144,7 +145,12 @@ pub struct Query { pub struct QueryCursor(*mut ffi::TSQueryCursor); -pub struct QueryMatch<'a>(*mut ffi::TSQueryCursor, PhantomData<&'a ()>); +pub struct QueryMatch<'a> { + pattern_index: usize, + capture_count: usize, + captures_ptr: *const ffi::TSQueryCapture, + cursor: PhantomData<&'a ()>, +} #[derive(Debug, PartialEq, Eq)] pub enum QueryError<'a> { @@ -996,14 +1002,52 @@ impl QueryCursor { QueryCursor(unsafe { ffi::ts_query_cursor_new() }) } - pub fn exec<'a>(&'a mut self, query: &'a Query, node: Node<'a>) -> impl Iterator> + 'a { + pub fn matches<'a>( + &'a mut self, + query: &'a Query, + node: Node<'a>, + ) -> impl Iterator> + 'a { unsafe { ffi::ts_query_cursor_exec(self.0, query.ptr, node.0); } std::iter::from_fn(move || -> Option> { unsafe { - if ffi::ts_query_cursor_next(self.0) { - Some(QueryMatch(self.0, PhantomData)) + let mut pattern_index = 0u32; + let mut capture_count = 0u32; + let mut captures = ptr::null(); + if ffi::ts_query_cursor_next_match( + self.0, + &mut pattern_index as *mut u32, + &mut capture_count as *mut u32, + &mut captures as *mut *const ffi::TSQueryCapture, + ) { + Some(QueryMatch { + pattern_index: pattern_index as usize, + capture_count: capture_count as usize, + captures_ptr: captures, + cursor: PhantomData + }) + } else { + None + } + } + }) + } + + pub fn captures<'a>( + &'a mut self, + query: &'a Query, + node: Node<'a>, + ) -> impl Iterator + 'a { + unsafe { + ffi::ts_query_cursor_exec(self.0, query.ptr, node.0); + } + std::iter::from_fn(move || -> Option<(usize, Node<'a>)> { + unsafe { + let mut capture = MaybeUninit::::uninit(); + if ffi::ts_query_cursor_next_capture(self.0, capture.as_mut_ptr()) { + let capture = capture.assume_init(); + Some((capture.index as usize, Node::new(capture.node).unwrap())) } else { None } @@ -1028,19 +1072,14 @@ impl QueryCursor { impl<'a> QueryMatch<'a> { pub fn pattern_index(&self) -> usize { - unsafe { ffi::ts_query_cursor_matched_pattern_index(self.0) as usize } + self.pattern_index } pub fn captures(&self) -> impl ExactSizeIterator { - unsafe { - let mut capture_count = 0u32; - let captures = - ffi::ts_query_cursor_matched_captures(self.0, &mut capture_count as *mut u32); - let captures = slice::from_raw_parts(captures, capture_count as usize); - captures - .iter() - .map(move |capture| (capture.index as usize, Node::new(capture.node).unwrap())) - } + let captures = unsafe { slice::from_raw_parts(self.captures_ptr, self.capture_count as usize) }; + captures + .iter() + .map(|capture| (capture.index as usize, Node::new(capture.node).unwrap())) } } diff --git a/lib/binding_web/binding.c b/lib/binding_web/binding.c index db4449a2..9a4dec0b 100644 --- a/lib/binding_web/binding.c +++ b/lib/binding_web/binding.c @@ -567,7 +567,7 @@ int ts_node_is_missing_wasm(const TSTree *tree) { /* Section - Query */ /******************/ -void ts_query_exec_wasm( +void ts_query_matches_wasm( const TSQuery *self, const TSTree *tree, uint32_t start_row, @@ -580,24 +580,23 @@ void ts_query_exec_wasm( TSNode node = unmarshal_node(tree); TSPoint start_point = {start_row, code_unit_to_byte(start_column)}; TSPoint end_point = {end_row, code_unit_to_byte(end_column)}; - - Array(const void *) result = array_new(); - - unsigned index = 0; - unsigned match_count = 0; ts_query_cursor_set_point_range(scratch_query_cursor, start_point, end_point); ts_query_cursor_exec(scratch_query_cursor, self, node); - while (ts_query_cursor_next(scratch_query_cursor)) { + + uint32_t index = 0; + uint32_t match_count = 0; + Array(const void *) result = array_new(); + + uint32_t pattern_index, capture_count; + const TSQueryCapture *captures; + while (ts_query_cursor_next_match( + scratch_query_cursor, + &pattern_index, + &capture_count, + &captures + )) { match_count++; - uint32_t pattern_index = ts_query_cursor_matched_pattern_index(scratch_query_cursor); - uint32_t capture_count; - const TSQueryCapture *captures = ts_query_cursor_matched_captures( - scratch_query_cursor, - &capture_count - ); - array_grow_by(&result, 2 + 6 * capture_count); - result.contents[index++] = (const void *)pattern_index; result.contents[index++] = (const void *)capture_count; for (unsigned i = 0; i < capture_count; i++) { @@ -611,3 +610,37 @@ void ts_query_exec_wasm( TRANSFER_BUFFER[0] = (const void *)(match_count); TRANSFER_BUFFER[1] = result.contents; } + +void ts_query_captures_wasm( + const TSQuery *self, + const TSTree *tree, + uint32_t start_row, + uint32_t start_column, + uint32_t end_row, + uint32_t end_column +) { + if (!scratch_query_cursor) scratch_query_cursor = ts_query_cursor_new(); + + TSNode node = unmarshal_node(tree); + TSPoint start_point = {start_row, code_unit_to_byte(start_column)}; + TSPoint end_point = {end_row, code_unit_to_byte(end_column)}; + ts_query_cursor_set_point_range(scratch_query_cursor, start_point, end_point); + ts_query_cursor_exec(scratch_query_cursor, self, node); + + unsigned index = 0; + unsigned capture_count = 0; + Array(const void *) result = array_new(); + + TSQueryCapture capture; + while (ts_query_cursor_next_capture(scratch_query_cursor, &capture)) { + capture_count++; + + array_grow_by(&result, 6); + result.contents[index++] = (const void *)capture.index; + marshal_node(result.contents + index, capture.node); + index += 5; + } + + TRANSFER_BUFFER[0] = (const void *)(capture_count); + TRANSFER_BUFFER[1] = result.contents; +} diff --git a/lib/binding_web/binding.js b/lib/binding_web/binding.js index 0fd1ea63..8ed7fca7 100644 --- a/lib/binding_web/binding.js +++ b/lib/binding_web/binding.js @@ -5,7 +5,7 @@ const SIZE_OF_NODE = 5 * SIZE_OF_INT; const SIZE_OF_POINT = 2 * SIZE_OF_INT; const SIZE_OF_RANGE = 2 * SIZE_OF_INT + 2 * SIZE_OF_POINT; const ZERO_POINT = {row: 0, column: 0}; -const QUERY_WORD_REGEX = /[\w-.]*/; +const QUERY_WORD_REGEX = /[\w-.]*/g; var VERSION; var MIN_COMPATIBLE_VERSION; @@ -694,7 +694,7 @@ class Language { const errorId = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); const errorByte = getValue(TRANSFER_BUFFER, 'i32'); const errorIndex = UTF8ToString(sourceAddress, errorByte).length; - const suffix = source.slice(errorIndex, 100); + const suffix = source.substr(errorIndex, 100); const word = suffix.match(QUERY_WORD_REGEX)[0]; let error; switch (errorId) { @@ -758,46 +758,75 @@ class Query { C._ts_query_delete(this[0]); } - exec(queryNode, startPosition, endPosition) { + matches(node, startPosition, endPosition) { if (!startPosition) startPosition = ZERO_POINT; if (!endPosition) endPosition = ZERO_POINT; - marshalNode(queryNode); + marshalNode(node); - C._ts_query_exec_wasm( + C._ts_query_matches_wasm( this[0], - queryNode.tree[0], + node.tree[0], startPosition.row, startPosition.column, endPosition.row, endPosition.column ); - const matchCount = getValue(TRANSFER_BUFFER, 'i32'); - const nodesAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); - const result = new Array(matchCount); + const count = getValue(TRANSFER_BUFFER, 'i32'); + const startAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); + const result = new Array(count); - let address = nodesAddress; - for (let i = 0; i < matchCount; i++) { + let address = startAddress; + for (let i = 0; i < count; i++) { const pattern = getValue(address, 'i32'); address += SIZE_OF_INT; - const captures = new Array(getValue(address, 'i32')); + const captureCount = getValue(address, 'i32'); address += SIZE_OF_INT; - for (let j = 0, n = captures.length; j < n; j++) { - const captureIndex = getValue(address, 'i32'); - address += SIZE_OF_INT; - const node = unmarshalNode(queryNode.tree, address); - address += SIZE_OF_NODE; - captures[j] = {name: this.captureNames[captureIndex], node}; - } + + const captures = new Array(captureCount); + address = unmarshalCaptures(this, node.tree, address, captures); result[i] = {pattern, captures}; } - // Free the intermediate buffers - C._free(nodesAddress); - + C._free(startAddress); return result; } + + captures(node, startPosition, endPosition) { + if (!startPosition) startPosition = ZERO_POINT; + if (!endPosition) endPosition = ZERO_POINT; + + marshalNode(node); + + C._ts_query_captures_wasm( + this[0], + node.tree[0], + startPosition.row, + startPosition.column, + endPosition.row, + endPosition.column + ); + + const count = getValue(TRANSFER_BUFFER, 'i32'); + const startAddress = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); + const result = new Array(count); + unmarshalCaptures(this, node.tree, startAddress, result); + + C._free(startAddress); + return result; + } +} + +function unmarshalCaptures(query, tree, address, result) { + for (let i = 0, n = result.length; i < n; i++) { + const captureIndex = getValue(address, 'i32'); + address += SIZE_OF_INT; + const node = unmarshalNode(tree, address); + address += SIZE_OF_NODE; + result[i] = {name: query.captureNames[captureIndex], node}; + } + return address; } function assertInternal(x) { diff --git a/lib/binding_web/exports.json b/lib/binding_web/exports.json index e2b187f7..6b0eab30 100644 --- a/lib/binding_web/exports.json +++ b/lib/binding_web/exports.json @@ -73,7 +73,8 @@ "_ts_query_context_delete", "_ts_query_context_new", "_ts_query_delete", - "_ts_query_exec_wasm", + "_ts_query_matches_wasm", + "_ts_query_captures_wasm", "_ts_query_new", "_ts_tree_cursor_current_field_id_wasm", "_ts_tree_cursor_current_node_id_wasm", diff --git a/lib/binding_web/test/query-test.js b/lib/binding_web/test/query-test.js index 4fd44165..5d7ce620 100644 --- a/lib/binding_web/test/query-test.js +++ b/lib/binding_web/test/query-test.js @@ -18,64 +18,117 @@ describe("Query", () => { if (query) query.delete(); }); - it('throws an error on invalid syntax', () => { - assert.throws(() => { - JavaScript.query("(function_declaration wat)") - }, "Bad syntax at offset 22: \'wat)\'..."); - assert.throws(() => { - JavaScript.query("(non_existent)") - }, "Bad node name 'non_existent'"); - assert.throws(() => { - JavaScript.query("(a)") - }, "Bad node name 'a'"); - assert.throws(() => { - JavaScript.query("(function_declaration non_existent:(identifier))") - }, "Bad field name 'non_existent'"); + describe('construction', () => { + it('throws an error on invalid syntax', () => { + assert.throws(() => { + JavaScript.query("(function_declaration wat)") + }, "Bad syntax at offset 22: \'wat)\'..."); + assert.throws(() => { + JavaScript.query("(non_existent)") + }, "Bad node name 'non_existent'"); + assert.throws(() => { + JavaScript.query("(a)") + }, "Bad node name 'a'"); + assert.throws(() => { + JavaScript.query("(function_declaration non_existent:(identifier))") + }, "Bad field name 'non_existent'"); + }); }); - it('matches simple queries', () => { - tree = parser.parse("function one() { two(); function three() {} }"); - query = JavaScript.query(` - (function_declaration name:(identifier) @fn-def) - (call_expression function:(identifier) @fn-ref) - `); - const matches = query.exec(tree.rootNode); - assert.deepEqual( - formatMatches(matches), - [ - {pattern: 0, captures: [{name: 'fn-def', text: 'one'}]}, - {pattern: 1, captures: [{name: 'fn-ref', text: 'two'}]}, - {pattern: 0, captures: [{name: 'fn-def', text: 'three'}]}, - ] - ); + describe('.matches', () => { + it('returns all of the matches for the given query', () => { + tree = parser.parse("function one() { two(); function three() {} }"); + query = JavaScript.query(` + (function_declaration name:(identifier) @fn-def) + (call_expression function:(identifier) @fn-ref) + `); + const matches = query.matches(tree.rootNode); + assert.deepEqual( + formatMatches(matches), + [ + {pattern: 0, captures: [{name: 'fn-def', text: 'one'}]}, + {pattern: 1, captures: [{name: 'fn-ref', text: 'two'}]}, + {pattern: 0, captures: [{name: 'fn-def', text: 'three'}]}, + ] + ); + }); + + it('can search in a specified ranges', () => { + tree = parser.parse("[a, b,\nc, d,\ne, f,\ng, h]"); + query = JavaScript.query('(identifier) @element'); + const matches = query.matches( + tree.rootNode, + {row: 1, column: 1}, + {row: 3, column: 1} + ); + assert.deepEqual( + formatMatches(matches), + [ + {pattern: 0, captures: [{name: 'element', text: 'd'}]}, + {pattern: 0, captures: [{name: 'element', text: 'e'}]}, + {pattern: 0, captures: [{name: 'element', text: 'f'}]}, + {pattern: 0, captures: [{name: 'element', text: 'g'}]}, + ] + ); + }); }); - it('matches queries in specified ranges', () => { - tree = parser.parse("[a, b,\nc, d,\ne, f,\ng, h]"); - query = JavaScript.query('(identifier) @element'); - const matches = query.exec( - tree.rootNode, - {row: 1, column: 1}, - {row: 3, column: 1} - ); - assert.deepEqual( - formatMatches(matches), - [ - {pattern: 0, captures: [{name: 'element', text: 'd'}]}, - {pattern: 0, captures: [{name: 'element', text: 'e'}]}, - {pattern: 0, captures: [{name: 'element', text: 'f'}]}, - {pattern: 0, captures: [{name: 'element', text: 'g'}]}, - ] - ); + describe('.captures', () => { + it('returns all of the captures for the given query, in order', () => { + tree = parser.parse(` + a({ + bc: function de() { + const fg = function hi() {} + }, + jk: function lm() { + const no = function pq() {} + }, + }); + `); + query = JavaScript.query(` + (pair + key: * @method.def + (function + name: (identifier) @method.alias)) + + (variable_declarator + name: * @function.def + value: (function + name: (identifier) @function.alias)) + + ":" @delimiter + "=" @operator + `); + + const captures = query.captures(tree.rootNode); + assert.deepEqual( + formatCaptures(captures), + [ + {name: "method.def", text: "bc"}, + {name: "delimiter", text: ":"}, + {name: "method.alias", text: "de"}, + {name: "function.def", text: "fg"}, + {name: "operator", text: "="}, + {name: "function.alias", text: "hi"}, + {name: "method.def", text: "jk"}, + {name: "delimiter", text: ":"}, + {name: "method.alias", text: "lm"}, + {name: "function.def", text: "no"}, + {name: "operator", text: "="}, + {name: "function.alias", text: "pq"}, + ] + ); + }); }); }); function formatMatches(matches) { return matches.map(({pattern, captures}) => ({ pattern, - captures: captures.map(({name, node}) => ({ - name, - text: node.text - })) + captures: formatCaptures(captures) })) } + +function formatCaptures(captures) { + return captures.map(({name, node}) => ({ name, text: node.text })) +} diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 624658b4..a5c22eb9 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -673,19 +673,23 @@ int ts_query_capture_id_for_name( * Create a new cursor for executing a given query. * * The cursor stores the state that is needed to iteratively search - * for matches. To use the query cursor: - * 1. First call `ts_query_cursor_exec` to start running a given query on - a given syntax node. - * 2. Then repeatedly call `ts_query_cursor_next` to iterate over the matches. - * This will return `false` when there are no more matches left. - * 3. After each successful call to `ts_query_cursor_next`, you can call - * `ts_query_cursor_matched_pattern_index` to determine which pattern - * matched. You can also call `ts_query_cursor_matched_captures` to - * determine which nodes were captured, and by which capture names. + * for matches. To use the query cursor, first call `ts_query_cursor_exec` + * to start running a given query on a given syntax node. Then, there are + * two options for consuming the results of the query: + * 1. Repeatedly call `ts_query_cursor_next_match` to iterate over all of the + * the *matches* in the order that they were found. Each match contains the + * index of the pattern that matched, and an array of captures. Because + * multiple patterns can match the same set of nodes, one match may contain + * captures that appear *before* some of the captures from a previous match. + * 2. Repeatedly call `ts_query_cursor_next_capture` to iterate over all of the + * individual *captures* in the order that they appear. This is useful if + * don't care about which pattern matched, and just want a single ordered + * sequence of captures. * - * If you don't care about finding all of the matches, you can stop calling - * `ts_query_cursor_next` at any point. And you can start executing another - * query on another node by calling `ts_query_cursor_exec` again. + * If you don't care about consuming all of the results, you can stop calling + * `ts_query_cursor_next_match` or `ts_query_cursor_next_capture` at any point. + * You can then start executing another query on another node by calling + * `ts_query_cursor_exec` again. */ TSQueryCursor *ts_query_cursor_new(); @@ -708,22 +712,26 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); /** * Advance to the next match of the currently running query. + * + * If there is another match, write its pattern index to `pattern_index`, + * the number of captures to `capture_count`, and the captures themselves + * to `*captures`, and return `true`. Otherwise, return `false`. */ -bool ts_query_cursor_next(TSQueryCursor *); - -/** - * Check which pattern matched. - */ -uint32_t ts_query_cursor_matched_pattern_index(const TSQueryCursor *); - -/** - * Check which pattern matched. - */ -const TSQueryCapture *ts_query_cursor_matched_captures( - const TSQueryCursor *, - uint32_t * +bool ts_query_cursor_next_match( + TSQueryCursor *self, + uint32_t *pattern_index, + uint32_t *capture_count, + const TSQueryCapture **captures ); +/** + * Advance to the next capture of the currently running query. + * + * If there is another capture, write it to `capture` and return `true`. + * Otherwise, return `false`. + */ +bool ts_query_cursor_next_capture(TSQueryCursor *, TSQueryCapture *capture); + /**********************/ /* Section - Language */ /**********************/ diff --git a/lib/src/query.c b/lib/src/query.c index 5c20f0f3..ea01cf24 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -54,11 +54,12 @@ typedef struct { * represented as one of these states. */ typedef struct { - uint16_t step_index; - uint16_t pattern_index; uint16_t start_depth; - uint16_t capture_list_id; - uint16_t capture_count; + uint16_t pattern_index; + uint8_t step_index; + uint8_t capture_count; + uint8_t capture_list_id; + uint8_t consumed_capture_count; } QueryState; /* @@ -96,12 +97,12 @@ struct TSQueryCursor { Array(QueryState) states; Array(QueryState) finished_states; CaptureListPool capture_list_pool; - bool ascending; uint32_t depth; uint32_t start_byte; uint32_t end_byte; TSPoint start_point; TSPoint end_point; + bool ascending; }; static const TSQueryError PARENT_DONE = -1; @@ -686,13 +687,8 @@ static QueryState *ts_query_cursor_copy_state( return new_state; } -bool ts_query_cursor_next(TSQueryCursor *self) { - if (self->finished_states.size > 0) { - QueryState state = array_pop(&self->finished_states); - capture_list_pool_release(&self->capture_list_pool, state.capture_list_id); - } - - while (self->finished_states.size == 0) { +static inline bool ts_query_cursor__advance(TSQueryCursor *self) { + do { if (self->ascending) { // When leaving a node, remove any unfinished states whose next step // needed to match something within that node. @@ -784,6 +780,7 @@ bool ts_query_cursor_next(TSQueryCursor *self) { .pattern_index = slice->pattern_index, .capture_list_id = capture_list_id, .capture_count = 0, + .consumed_capture_count = 0, })); } @@ -821,6 +818,7 @@ bool ts_query_cursor_next(TSQueryCursor *self) { .start_depth = self->depth, .capture_list_id = capture_list_id, .capture_count = 0, + .consumed_capture_count = 0, })); // Advance to the next pattern whose root node matches this node. @@ -905,32 +903,108 @@ bool ts_query_cursor_next(TSQueryCursor *self) { self->ascending = true; } } - } + } while (self->finished_states.size == 0); return true; } -uint32_t ts_query_cursor_matched_pattern_index(const TSQueryCursor *self) { - if (self->finished_states.size > 0) { - QueryState *state = array_back(&self->finished_states); - return state->pattern_index; - } - return 0; -} - -const TSQueryCapture *ts_query_cursor_matched_captures( - const TSQueryCursor *self, - uint32_t *count +bool ts_query_cursor_next_match( + TSQueryCursor *self, + uint32_t *pattern_index, + uint32_t *capture_count, + const TSQueryCapture **captures ) { if (self->finished_states.size > 0) { - QueryState *state = array_back(&self->finished_states); - *count = state->capture_count; - return capture_list_pool_get( - (CaptureListPool *)&self->capture_list_pool, - state->capture_list_id - ); + QueryState state = array_pop(&self->finished_states); + capture_list_pool_release(&self->capture_list_pool, state.capture_list_id); + } + + if (!ts_query_cursor__advance(self)) return false; + + const QueryState *state = array_back(&self->finished_states); + *pattern_index = state->pattern_index; + *capture_count = state->capture_count; + *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + + return true; +} + +bool ts_query_cursor_next_capture( + TSQueryCursor *self, + TSQueryCapture *capture +) { + for (;;) { + if (self->finished_states.size > 0) { + // Find the position of the earliest capture in an unfinished match. + uint32_t first_unfinished_capture_byte = UINT32_MAX; + for (unsigned i = 0; i < self->states.size; i++) { + const QueryState *state = &self->states.contents[i]; + if (state->capture_count > 0) { + const TSQueryCapture *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + uint32_t capture_byte = ts_node_start_byte(captures[0].node); + if (capture_byte < first_unfinished_capture_byte) { + first_unfinished_capture_byte = capture_byte; + } + } + } + + // Find the earliest capture in a finished match. It must not start + // after the first unfinished capture. + int first_finished_state_index = -1; + uint32_t first_finished_capture_byte = first_unfinished_capture_byte; + for (unsigned i = 0; i < self->finished_states.size; i++) { + const QueryState *state = &self->finished_states.contents[i]; + if (state->capture_count > state->consumed_capture_count) { + const TSQueryCapture *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + uint32_t capture_byte = ts_node_start_byte( + captures[state->consumed_capture_count].node + ); + if (capture_byte <= first_finished_capture_byte) { + first_finished_state_index = i; + first_finished_capture_byte = capture_byte; + } + } else { + capture_list_pool_release( + &self->capture_list_pool, + state->capture_list_id + ); + array_erase(&self->finished_states, i); + i--; + } + } + + if (first_finished_state_index != -1) { + QueryState *state = &self->finished_states.contents[ + first_finished_state_index + ]; + const TSQueryCapture *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + *capture = captures[state->consumed_capture_count]; + state->consumed_capture_count++; + if (state->consumed_capture_count == state->capture_count) { + capture_list_pool_release( + &self->capture_list_pool, + state->capture_list_id + ); + array_erase(&self->finished_states, first_finished_state_index); + } + return true; + } + } + + if (!ts_query_cursor__advance(self)) return false; } - return NULL; } #undef LOG