diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 927df294..6131d1ea 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -342,6 +342,32 @@ fn test_query_exec_with_anonymous_tokens() { }); } +#[test] +fn test_query_exec_within_byte_range() { + allocations::record(|| { + let language = get_language("javascript"); + let query = Query::new(language, "(identifier) @element").unwrap(); + + let source = "[a, b, c, d, e, f, g]"; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + + let mut context = query.context(); + let matches = context.set_byte_range(5, 15).exec(tree.root_node()); + + assert_eq!( + collect_matches(matches, &query, source), + &[ + (0, vec![("element", "c")]), + (0, vec![("element", "d")]), + (0, vec![("element", "e")]), + ] + ); + }); +} + #[test] fn test_query_capture_names() { allocations::record(|| { diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 53b77405..b44b5622 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -560,8 +560,9 @@ extern "C" { pub fn ts_tree_cursor_copy(arg1: *const TSTreeCursor) -> TSTreeCursor; } extern "C" { - #[doc = " Create a new query based on a given language and string containing"] - #[doc = " one or more S-expression patterns."] + #[doc = " Create a new query from a string containing one or more S-expression"] + #[doc = " patterns. The query is associated with a particular language, and can"] + #[doc = " only be run on syntax nodes parsed with that language."] #[doc = ""] #[doc = " If all of the given patterns are valid, this returns a `TSQuery`."] #[doc = " If a pattern is invalid, this returns `NULL`, and provides two pieces"] @@ -569,7 +570,7 @@ extern "C" { #[doc = " 1. The byte offset of the error is written to the `error_offset` parameter."] #[doc = " 2. The type of error is written to the `error_type` parameter."] pub fn ts_query_new( - arg1: *const TSLanguage, + language: *const TSLanguage, source: *const ::std::os::raw::c_char, source_len: u32, error_offset: *mut u32, @@ -581,9 +582,13 @@ extern "C" { pub fn ts_query_delete(arg1: *mut TSQuery); } extern "C" { + #[doc = " Get the number of distinct capture names in the query."] pub fn ts_query_capture_count(arg1: *const TSQuery) -> u32; } extern "C" { + #[doc = " Get the name and length of one of the query\'s capture. Each capture"] + #[doc = " is associated with a numeric id based on the order that it appeared"] + #[doc = " in the query\'s source."] pub fn ts_query_capture_name_for_id( self_: *const TSQuery, index: u32, @@ -591,6 +596,7 @@ extern "C" { ) -> *const ::std::os::raw::c_char; } extern "C" { + #[doc = " Get the numeric id of the capture with the given name."] pub fn ts_query_capture_id_for_name( self_: *const TSQuery, name: *const ::std::os::raw::c_char, @@ -598,21 +604,54 @@ extern "C" { ) -> ::std::os::raw::c_int; } extern "C" { + #[doc = " Create a new context for executing a given query."] + #[doc = ""] + #[doc = " The context stores the state that is needed to iteratively search"] + #[doc = " for matches. To use the query context:"] + #[doc = " 1. First call `ts_query_context_exec` to start running the query"] + #[doc = " on a particular syntax node."] + #[doc = " 2. Then repeatedly call `ts_query_context_next` to iterate over"] + #[doc = " the matches."] + #[doc = " 3. After each successful call to `ts_query_context_next`, you can call"] + #[doc = " `ts_query_context_matched_pattern_index` to determine which pattern"] + #[doc = " matched. You can also call `ts_query_context_matched_captures` to"] + #[doc = " determine which nodes were captured by which capture names."] + #[doc = ""] + #[doc = " If you don\'t care about finding all of the matches, you can stop calling"] + #[doc = " `ts_query_context_next` at any point. And you can start executing the"] + #[doc = " query against a different node by calling `ts_query_context_exec` again."] pub fn ts_query_context_new(arg1: *const TSQuery) -> *mut TSQueryContext; } extern "C" { + #[doc = " Delete a query context, freeing all of the memory that it used."] pub fn ts_query_context_delete(arg1: *mut TSQueryContext); } extern "C" { + #[doc = " Start running a query on a given node."] pub fn ts_query_context_exec(arg1: *mut TSQueryContext, arg2: TSNode); } extern "C" { + #[doc = " Set the range of bytes or (row, column) positions in which the query"] + #[doc = " will be executed."] + pub fn ts_query_context_set_byte_range(arg1: *mut TSQueryContext, arg2: u32, arg3: u32); +} +extern "C" { + pub fn ts_query_context_set_point_range( + arg1: *mut TSQueryContext, + arg2: TSPoint, + arg3: TSPoint, + ); +} +extern "C" { + #[doc = " Advance to the next match of the currently running query."] pub fn ts_query_context_next(arg1: *mut TSQueryContext) -> bool; } extern "C" { + #[doc = " Check which pattern matched."] pub fn ts_query_context_matched_pattern_index(arg1: *const TSQueryContext) -> u32; } extern "C" { + #[doc = " Check which pattern matched."] pub fn ts_query_context_matched_captures( arg1: *const TSQueryContext, arg2: *mut u32, diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 80e56ba9..8d29a3c3 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1011,6 +1011,20 @@ impl<'a> QueryContext<'a> { } }) } + + pub fn set_byte_range(&mut self, start: usize, end: usize) -> &mut Self { + unsafe { + ffi::ts_query_context_set_byte_range(self.0, start as u32, end as u32); + } + self + } + + pub fn set_point_range(&mut self, start: Point, end: Point) -> &mut Self { + unsafe { + ffi::ts_query_context_set_point_range(self.0, start.into(), end.into()); + } + self + } } impl<'a> QueryMatch<'a> { diff --git a/lib/binding_web/binding.c b/lib/binding_web/binding.c index e94c5aa0..f46d1def 100644 --- a/lib/binding_web/binding.c +++ b/lib/binding_web/binding.c @@ -2,6 +2,7 @@ #include #include #include "array.h" +#include "point.h" /*****************************/ /* Section - Data marshaling */ @@ -464,12 +465,6 @@ void ts_node_named_children_wasm(const TSTree *tree) { TRANSFER_BUFFER[1] = result; } -bool point_lte(TSPoint a, TSPoint b) { - if (a.row < b.row) return true; - if (a.row > b.row) return false; - return a.column <= b.column; -} - bool symbols_contain(const uint32_t *set, uint32_t length, uint32_t value) { for (unsigned i = 0; i < length; i++) { if (set[i] == value) return true; diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index ad991818..d951a35a 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -644,12 +644,12 @@ TSQuery *ts_query_new( */ void ts_query_delete(TSQuery *); -/* +/** * Get the number of distinct capture names in the query. */ uint32_t ts_query_capture_count(const TSQuery *); -/* +/** * Get the name and length of one of the query's capture. Each capture * is associated with a numeric id based on the order that it appeared * in the query's source. @@ -660,7 +660,7 @@ const char *ts_query_capture_name_for_id( uint32_t *length ); -/* +/** * Get the numeric id of the capture with the given name. */ int ts_query_capture_id_for_name( @@ -669,7 +669,7 @@ int ts_query_capture_id_for_name( uint32_t length ); -/* +/** * Create a new context for executing a given query. * * The context stores the state that is needed to iteratively search @@ -678,10 +678,10 @@ int ts_query_capture_id_for_name( * on a particular syntax node. * 2. Then repeatedly call `ts_query_context_next` to iterate over * the matches. - * 3. For each match, you can call `ts_query_context_matched_pattern_index` - * to determine which pattern matched. You can also call - * `ts_query_context_matched_captures` to determine which nodes - * were captured by which capture names. + * 3. After each successful call to `ts_query_context_next`, you can call + * `ts_query_context_matched_pattern_index` to determine which pattern + * matched. You can also call `ts_query_context_matched_captures` to + * determine which nodes were captured by which capture names. * * If you don't care about finding all of the matches, you can stop calling * `ts_query_context_next` at any point. And you can start executing the @@ -689,27 +689,34 @@ int ts_query_capture_id_for_name( */ TSQueryContext *ts_query_context_new(const TSQuery *); -/* +/** * Delete a query context, freeing all of the memory that it used. */ void ts_query_context_delete(TSQueryContext *); -/* +/** * Start running a query on a given node. */ void ts_query_context_exec(TSQueryContext *, TSNode); -/* +/** + * Set the range of bytes or (row, column) positions in which the query + * will be executed. + */ +void ts_query_context_set_byte_range(TSQueryContext *, uint32_t, uint32_t); +void ts_query_context_set_point_range(TSQueryContext *, TSPoint, TSPoint); + +/** * Advance to the next match of the currently running query. */ bool ts_query_context_next(TSQueryContext *); -/* +/** * Check which pattern matched. */ uint32_t ts_query_context_matched_pattern_index(const TSQueryContext *); -/* +/** * Check which pattern matched. */ const TSQueryCapture *ts_query_context_matched_captures( diff --git a/lib/src/point.h b/lib/src/point.h index 4d0aed18..a50d2021 100644 --- a/lib/src/point.h +++ b/lib/src/point.h @@ -3,6 +3,7 @@ #include "tree_sitter/api.h" +#define POINT_ZERO ((TSPoint) {0, 0}) #define POINT_MAX ((TSPoint) {UINT32_MAX, UINT32_MAX}) static inline TSPoint point__new(unsigned row, unsigned column) { diff --git a/lib/src/query.c b/lib/src/query.c index 042b5d9e..8b4deb81 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -2,6 +2,7 @@ #include "./alloc.h" #include "./array.h" #include "./bits.h" +#include "./point.h" #include "utf8proc.h" #include @@ -99,6 +100,10 @@ struct TSQueryContext { CaptureListPool capture_list_pool; bool ascending; uint32_t depth; + uint32_t start_byte; + uint32_t end_byte; + TSPoint start_point; + TSPoint end_point; }; static const TSQueryError PARENT_DONE = -1; @@ -605,6 +610,10 @@ TSQueryContext *ts_query_context_new(const TSQuery *query) { .states = array_new(), .finished_states = array_new(), .capture_list_pool = capture_list_pool_new(query->max_capture_count), + .start_byte = 0, + .end_byte = UINT32_MAX, + .start_point = {0, 0}, + .end_point = POINT_MAX, }; return self; } @@ -626,6 +635,32 @@ void ts_query_context_exec(TSQueryContext *self, TSNode node) { self->ascending = false; } +void ts_query_context_set_byte_range( + TSQueryContext *self, + uint32_t start_byte, + uint32_t end_byte +) { + if (end_byte == 0) { + start_byte = 0; + end_byte = UINT32_MAX; + } + self->start_byte = start_byte; + self->end_byte = end_byte; +} + +void ts_query_context_set_point_range( + TSQueryContext *self, + TSPoint start_point, + TSPoint end_point +) { + if (end_point.row == 0 && end_point.column == 0) { + start_point = POINT_ZERO; + end_point = POINT_MAX; + } + self->start_point = start_point; + self->end_point = end_point; +} + static QueryState *ts_query_context_copy_state( TSQueryContext *self, QueryState *state @@ -698,6 +733,24 @@ bool ts_query_context_next(TSQueryContext *self) { TSNode node = ts_tree_cursor_current_node(&self->cursor); TSSymbol symbol = ts_node_symbol(node); + // If this node is before the selected range, then avoid + // descending into it. + if ( + ts_node_end_byte(node) <= self->start_byte || + point_lte(ts_node_end_point(node), self->start_point) + ) { + if (!ts_tree_cursor_goto_next_sibling(&self->cursor)) { + self->ascending = true; + } + continue; + } + + // If this node is after the selected range, then stop walking. + if ( + self->end_byte <= ts_node_start_byte(node) || + point_lte(self->end_point, ts_node_start_point(node)) + ) return false; + LOG("enter node %s\n", ts_node_type(node)); // Add new states for any patterns whose root node is a wildcard.