diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index e97c9cce..3ba22e45 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -448,6 +448,14 @@ struct Query { /// The range of rows in which the query will be executed #[arg(long)] pub row_range: Option, + /// The range of byte offsets in which the query will be executed. Only the matches that are fully contained within the provided + /// byte range will be returned. + #[arg(long)] + pub containing_byte_range: Option, + /// The range of rows in which the query will be executed. Only the matches that are fully contained within the provided row range + /// will be returned. + #[arg(long)] + pub containing_row_range: Option, /// Select a language by the scope instead of a file extension #[arg(long)] pub scope: Option, @@ -1486,6 +1494,18 @@ impl Query { let end = parts.next().unwrap().parse().ok()?; Some(Point::new(start, 0)..Point::new(end, 0)) }); + let containing_byte_range = self.containing_byte_range.as_ref().and_then(|range| { + let mut parts = range.split(':'); + let start = parts.next()?.parse().ok()?; + let end = parts.next().unwrap().parse().ok()?; + Some(start..end) + }); + let containing_point_range = self.containing_row_range.as_ref().and_then(|range| { + let mut parts = range.split(':'); + let start = parts.next()?.parse().ok()?; + let end = parts.next().unwrap().parse().ok()?; + Some(Point::new(start, 0)..Point::new(end, 0)) + }); let cancellation_flag = util::cancel_on_signal(); @@ -1514,6 +1534,8 @@ impl Query { ordered_captures: self.captures, byte_range, point_range, + containing_byte_range, + containing_point_range, quiet: self.quiet, print_time: self.time, stdin: false, @@ -1557,6 +1579,8 @@ impl Query { ordered_captures: self.captures, byte_range, point_range, + containing_byte_range, + containing_point_range, quiet: self.quiet, print_time: self.time, stdin: true, @@ -1575,6 +1599,8 @@ impl Query { ordered_captures: self.captures, byte_range, point_range, + containing_byte_range, + containing_point_range, quiet: self.quiet, print_time: self.time, stdin: true, diff --git a/crates/cli/src/query.rs b/crates/cli/src/query.rs index 7343049f..54674115 100644 --- a/crates/cli/src/query.rs +++ b/crates/cli/src/query.rs @@ -21,6 +21,8 @@ pub struct QueryFileOptions { pub ordered_captures: bool, pub byte_range: Option>, pub point_range: Option>, + pub containing_byte_range: Option>, + pub containing_point_range: Option>, pub quiet: bool, pub print_time: bool, pub stdin: bool, @@ -48,6 +50,12 @@ pub fn query_file_at_path( if let Some(ref range) = opts.point_range { query_cursor.set_point_range(range.clone()); } + if let Some(ref range) = opts.containing_byte_range { + query_cursor.set_containing_byte_range(range.clone()); + } + if let Some(ref range) = opts.containing_point_range { + query_cursor.set_containing_point_range(range.clone()); + } let mut parser = Parser::new(); parser.set_language(language)?; diff --git a/crates/cli/src/tests/query_test.rs b/crates/cli/src/tests/query_test.rs index 4eb74f6c..3f1467e5 100644 --- a/crates/cli/src/tests/query_test.rs +++ b/crates/cli/src/tests/query_test.rs @@ -2669,6 +2669,64 @@ fn test_query_matches_within_range_of_long_repetition() { }); } +#[test] +fn test_query_matches_contained_within_range() { + allocations::record(|| { + let language = get_language("json"); + let query = Query::new( + &language, + r#" + ("[" @l_bracket "]" @r_bracket) + ("{" @l_brace "}" @r_brace) + "#, + ) + .unwrap(); + + let source = r#" + [ + {"key1": "value1"}, + {"key2": "value2"}, + {"key3": "value3"}, + {"key4": "value4"}, + {"key5": "value5"}, + {"key6": "value6"}, + {"key7": "value7"}, + {"key8": "value8"}, + {"key9": "value9"}, + {"key10": "value10"}, + {"key11": "value11"}, + {"key12": "value12"}, + ] + "# + .unindent(); + + let mut parser = Parser::new(); + parser.set_language(&language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + + let expected_matches = [ + (1, vec![("l_brace", "{"), ("r_brace", "}")]), + (1, vec![("l_brace", "{"), ("r_brace", "}")]), + ]; + { + let mut cursor = QueryCursor::new(); + let matches = cursor + .set_containing_point_range(Point::new(5, 0)..Point::new(7, 0)) + .matches(&query, tree.root_node(), source.as_bytes()); + assert_eq!(collect_matches(matches, &query, &source), &expected_matches); + } + { + let mut cursor = QueryCursor::new(); + let matches = cursor.set_containing_byte_range(78..120).matches( + &query, + tree.root_node(), + source.as_bytes(), + ); + assert_eq!(collect_matches(matches, &query, &source), &expected_matches); + } + }); +} + #[test] fn test_query_matches_different_queries_same_cursor() { allocations::record(|| { diff --git a/crates/xtask/src/check_wasm_exports.rs b/crates/xtask/src/check_wasm_exports.rs index 124725b7..c93f7cd3 100644 --- a/crates/xtask/src/check_wasm_exports.rs +++ b/crates/xtask/src/check_wasm_exports.rs @@ -16,7 +16,7 @@ use notify_debouncer_full::new_debouncer; use crate::{bail_on_err, watch_wasm, CheckWasmExports}; -const EXCLUDES: [&str; 23] = [ +const EXCLUDES: [&str; 25] = [ // Unneeded because the JS side has its own way of implementing it "ts_node_child_by_field_name", "ts_node_edit", @@ -44,6 +44,8 @@ const EXCLUDES: [&str; 23] = [ "ts_query_cursor_delete", "ts_query_cursor_match_limit", "ts_query_cursor_remove_match", + "ts_query_cursor_set_point_range", + "ts_query_cursor_set_containing_byte_range", ]; pub fn run(args: &CheckWasmExports) -> Result<()> { diff --git a/docs/src/cli/query.md b/docs/src/cli/query.md index 395ca486..08ff2654 100644 --- a/docs/src/cli/query.md +++ b/docs/src/cli/query.md @@ -36,10 +36,20 @@ The path to a file that contains paths to source files in which the query will b The range of byte offsets in which the query will be executed. The format is `start_byte:end_byte`. +### `--containing-byte-range ` + +The range of byte offsets in which the query will be executed. Only the matches that are fully contained within the provided +byte range will be returned. + ### `--row-range ` The range of rows in which the query will be executed. The format is `start_row:end_row`. +### `--containing-row-range ` + +The range of rows in which the query will be executed. Only the matches that are fully contained within the provided row range +will be returned. + ### `--scope ` The language scope to use for parsing and querying. This is useful when the language is ambiguous. diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 3feb8409..ec00a7c6 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -712,6 +712,22 @@ extern "C" { end_point: TSPoint, ) -> bool; } +extern "C" { + #[doc = " Set the byte range within which all matches must be fully contained.\n\n Set the range of bytes in which matches will be searched for. In contrast to\n `ts_query_cursor_set_byte_range`, this will restrict the query cursor to only return\n matches where _all_ nodes are _fully_ contained within the given range. Both functions\n can be used together, e.g. to search for any matches that intersect line 5000, as\n long as they are fully contained within lines 4500-5500"] + pub fn ts_query_cursor_set_containing_byte_range( + self_: *mut TSQueryCursor, + start_byte: u32, + end_byte: u32, + ) -> bool; +} +extern "C" { + #[doc = " Set the point range within which all matches must be fully contained.\n\n Set the range of bytes in which matches will be searched for. In contrast to\n `ts_query_cursor_set_point_range`, this will restrict the query cursor to only return\n matches where _all_ nodes are _fully_ contained within the given range. Both functions\n can be used together, e.g. to search for any matches that intersect line 5000, as\n long as they are fully contained within lines 4500-5500"] + pub fn ts_query_cursor_set_containing_point_range( + self_: *mut TSQueryCursor, + start_point: TSPoint, + end_point: TSPoint, + ) -> bool; +} extern "C" { #[doc = " Advance to the next match of the currently running query.\n\n If there is a match, write it to `*match` and return `true`.\n Otherwise, return `false`."] pub fn ts_query_cursor_next_match(self_: *mut TSQueryCursor, match_: *mut TSQueryMatch) diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index a02fa173..bf86cf74 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -3181,6 +3181,44 @@ impl QueryCursor { self } + /// Set the byte range within which all matches must be fully contained. + /// + /// Set the range of bytes in which matches will be searched for. In contrast to + /// `ts_query_cursor_set_byte_range`, this will restrict the query cursor to only return + /// matches where _all_ nodes are _fully_ contained within the given range. Both functions + /// can be used together, e.g. to search for any matches that intersect line 5000, as + /// long as they are fully contained within lines 4500-5500 + #[doc(alias = "ts_query_cursor_set_containing_byte_range")] + pub fn set_containing_byte_range(&mut self, range: ops::Range) -> &mut Self { + unsafe { + ffi::ts_query_cursor_set_containing_byte_range( + self.ptr.as_ptr(), + range.start as u32, + range.end as u32, + ); + } + self + } + + /// Set the point range within which all matches must be fully contained. + /// + /// Set the range of bytes in which matches will be searched for. In contrast to + /// `ts_query_cursor_set_point_range`, this will restrict the query cursor to only return + /// matches where _all_ nodes are _fully_ contained within the given range. Both functions + /// can be used together, e.g. to search for any matches that intersect line 5000, as + /// long as they are fully contained within lines 4500-5500 + #[doc(alias = "ts_query_cursor_set_containing_point_range")] + pub fn set_containing_point_range(&mut self, range: ops::Range) -> &mut Self { + unsafe { + ffi::ts_query_cursor_set_containing_point_range( + self.ptr.as_ptr(), + range.start.into(), + range.end.into(), + ); + } + self + } + /// Set the maximum start depth for a query cursor. /// /// This prevents cursors from exploring children nodes at a certain depth. diff --git a/lib/binding_web/lib/tree-sitter.c b/lib/binding_web/lib/tree-sitter.c index db6c108b..828132ac 100644 --- a/lib/binding_web/lib/tree-sitter.c +++ b/lib/binding_web/lib/tree-sitter.c @@ -874,6 +874,12 @@ void ts_query_matches_wasm( uint32_t end_column, uint32_t start_index, uint32_t end_index, + uint32_t start_containing_row, + uint32_t start_containing_column, + uint32_t end_containing_row, + uint32_t end_containing_column, + uint32_t start_containing_index, + uint32_t end_containing_index, uint32_t match_limit, uint32_t max_start_depth ) { @@ -889,8 +895,20 @@ void ts_query_matches_wasm( TSNode node = unmarshal_node(tree); TSPoint start_point = {start_row, code_unit_to_byte(start_column)}; TSPoint end_point = {end_row, code_unit_to_byte(end_column)}; + TSPoint start_containing_point = {start_containing_row, code_unit_to_byte(start_containing_column)}; + TSPoint end_containing_point = {end_containing_row, code_unit_to_byte(end_containing_column)}; ts_query_cursor_set_point_range(scratch_query_cursor, start_point, end_point); ts_query_cursor_set_byte_range(scratch_query_cursor, start_index, end_index); + ts_query_cursor_set_containing_point_range( + scratch_query_cursor, + start_containing_point, + end_containing_point + ); + ts_query_cursor_set_containing_byte_range( + scratch_query_cursor, + start_containing_index, + end_containing_index + ); ts_query_cursor_set_match_limit(scratch_query_cursor, match_limit); ts_query_cursor_set_max_start_depth(scratch_query_cursor, max_start_depth); @@ -932,6 +950,12 @@ void ts_query_captures_wasm( uint32_t end_column, uint32_t start_index, uint32_t end_index, + uint32_t start_containing_row, + uint32_t start_containing_column, + uint32_t end_containing_row, + uint32_t end_containing_column, + uint32_t start_containing_index, + uint32_t end_containing_index, uint32_t match_limit, uint32_t max_start_depth ) { @@ -944,8 +968,20 @@ void ts_query_captures_wasm( TSNode node = unmarshal_node(tree); TSPoint start_point = {start_row, code_unit_to_byte(start_column)}; TSPoint end_point = {end_row, code_unit_to_byte(end_column)}; + TSPoint start_containing_point = {start_containing_row, code_unit_to_byte(start_containing_column)}; + TSPoint end_containing_point = {end_containing_row, code_unit_to_byte(end_containing_column)}; ts_query_cursor_set_point_range(scratch_query_cursor, start_point, end_point); ts_query_cursor_set_byte_range(scratch_query_cursor, start_index, end_index); + ts_query_cursor_set_containing_point_range( + scratch_query_cursor, + start_containing_point, + end_containing_point + ); + ts_query_cursor_set_containing_byte_range( + scratch_query_cursor, + start_containing_index, + end_containing_index + ); ts_query_cursor_set_match_limit(scratch_query_cursor, match_limit); ts_query_cursor_set_max_start_depth(scratch_query_cursor, max_start_depth); ts_query_cursor_exec(scratch_query_cursor, self, node); diff --git a/lib/binding_web/lib/web-tree-sitter.d.ts b/lib/binding_web/lib/web-tree-sitter.d.ts index c19d7bf4..c1e0e0dd 100644 --- a/lib/binding_web/lib/web-tree-sitter.d.ts +++ b/lib/binding_web/lib/web-tree-sitter.d.ts @@ -175,8 +175,8 @@ interface WasmModule { _ts_node_is_extra_wasm(_0: number): number; _ts_node_parse_state_wasm(_0: number): number; _ts_node_next_parse_state_wasm(_0: number): number; - _ts_query_matches_wasm(_0: number, _1: number, _2: number, _3: number, _4: number, _5: number, _6: number, _7: number, _8: number, _9: number): void; - _ts_query_captures_wasm(_0: number, _1: number, _2: number, _3: number, _4: number, _5: number, _6: number, _7: number, _8: number, _9: number): void; + _ts_query_matches_wasm(_0: number, _1: number, _2: number, _3: number, _4: number, _5: number, _6: number, _7: number, _8: number, _9: number, _10: number, _11: number, _12: number, _13: number, _14: number, _15: number): void; + _ts_query_captures_wasm(_0: number, _1: number, _2: number, _3: number, _4: number, _5: number, _6: number, _7: number, _8: number, _9: number, _10: number, _11: number, _12: number, _13: number, _14: number, _15: number): void; _memset(_0: number, _1: number, _2: number): number; _memcpy(_0: number, _1: number, _2: number): number; _memmove(_0: number, _1: number, _2: number): number; diff --git a/lib/binding_web/src/query.ts b/lib/binding_web/src/query.ts index 6f3064a8..b9cd1971 100644 --- a/lib/binding_web/src/query.ts +++ b/lib/binding_web/src/query.ts @@ -20,12 +20,32 @@ export interface QueryOptions { /** The end position of the range to query */ endPosition?: Point; + /** The start position of the range to query Only the matches that are fully + * contained within provided range will be returned. + **/ + startContainingPosition?: Point; + + /** The end position of the range to query Only the matches that are fully + * contained within provided range will be returned. + **/ + endContainingPosition?: Point; + /** The start index of the range to query */ startIndex?: number; /** The end index of the range to query */ endIndex?: number; + /** The start index of the range to query Only the matches that are fully + * contained within provided range will be returned. + **/ + startContainingIndex?: number; + + /** The end index of the range to query Only the matches that are fully + * contained within provided range will be returned. + **/ + endContainingIndex?: number; + /** * The maximum number of in-progress matches for this query. * The limit must be > 0 and <= 65536. @@ -695,6 +715,10 @@ export class Query { const endPosition = options.endPosition ?? ZERO_POINT; const startIndex = options.startIndex ?? 0; const endIndex = options.endIndex ?? 0; + const startContainingPosition = options.startContainingPosition ?? ZERO_POINT; + const endContainingPosition = options.endContainingPosition ?? ZERO_POINT; + const startContainingIndex = options.startContainingIndex ?? 0; + const endContainingIndex = options.endContainingIndex ?? 0; const matchLimit = options.matchLimit ?? 0xFFFFFFFF; const maxStartDepth = options.maxStartDepth ?? 0xFFFFFFFF; const progressCallback = options.progressCallback; @@ -715,6 +739,18 @@ export class Query { throw new Error('`startPosition` cannot be greater than `endPosition`'); } + if (endContainingIndex !== 0 && startContainingIndex > endContainingIndex) { + throw new Error('`startContainingIndex` cannot be greater than `endContainingIndex`'); + } + + if (endContainingPosition !== ZERO_POINT && ( + startContainingPosition.row > endContainingPosition.row || + (startContainingPosition.row === endContainingPosition.row && + startContainingPosition.column > endContainingPosition.column) + )) { + throw new Error('`startContainingPosition` cannot be greater than `endContainingPosition`'); + } + if (progressCallback) { C.currentQueryProgressCallback = progressCallback; } @@ -730,6 +766,12 @@ export class Query { endPosition.column, startIndex, endIndex, + startContainingPosition.row, + startContainingPosition.column, + endContainingPosition.row, + endContainingPosition.column, + startContainingIndex, + endContainingIndex, matchLimit, maxStartDepth, ); @@ -788,6 +830,10 @@ export class Query { const endPosition = options.endPosition ?? ZERO_POINT; const startIndex = options.startIndex ?? 0; const endIndex = options.endIndex ?? 0; + const startContainingPosition = options.startContainingPosition ?? ZERO_POINT; + const endContainingPosition = options.endContainingPosition ?? ZERO_POINT; + const startContainingIndex = options.startContainingIndex ?? 0; + const endContainingIndex = options.endContainingIndex ?? 0; const matchLimit = options.matchLimit ?? 0xFFFFFFFF; const maxStartDepth = options.maxStartDepth ?? 0xFFFFFFFF; const progressCallback = options.progressCallback; @@ -808,6 +854,18 @@ export class Query { throw new Error('`startPosition` cannot be greater than `endPosition`'); } + if (endContainingIndex !== 0 && startContainingIndex > endContainingIndex) { + throw new Error('`startContainingIndex` cannot be greater than `endContainingIndex`'); + } + + if (endContainingPosition !== ZERO_POINT && ( + startContainingPosition.row > endContainingPosition.row || + (startContainingPosition.row === endContainingPosition.row && + startContainingPosition.column > endContainingPosition.column) + )) { + throw new Error('`startContainingPosition` cannot be greater than `endContainingPosition`'); + } + if (progressCallback) { C.currentQueryProgressCallback = progressCallback; } @@ -823,6 +881,12 @@ export class Query { endPosition.column, startIndex, endIndex, + startContainingPosition.row, + startContainingPosition.column, + endContainingPosition.row, + endContainingPosition.column, + startContainingIndex, + endContainingIndex, matchLimit, maxStartDepth, ); diff --git a/lib/binding_web/test/query.test.ts b/lib/binding_web/test/query.test.ts index ad6a6660..f90e9464 100644 --- a/lib/binding_web/test/query.test.ts +++ b/lib/binding_web/test/query.test.ts @@ -96,6 +96,64 @@ describe('Query', () => { ]); }); + it('can search in contained within point ranges', () => { + tree = parser.parse(`[ + {"key1": "value1"}, + {"key2": "value2"}, + {"key3": "value3"}, + {"key4": "value4"}, + {"key5": "value5"}, + {"key6": "value6"}, + {"key7": "value7"}, + {"key8": "value8"}, + {"key9": "value9"}, + {"key10": "value10"}, + {"key11": "value11"}, + {"key12": "value12"}, +]`)!; + query = new Query(JavaScript, '("[" @l_bracket "]" @r_bracket) ("{" @l_brace "}" @r_brace)'); + const matches = query.matches( + tree.rootNode, + { + startContainingPosition: { row: 5, column: 0 }, + endContainingPosition: { row: 7, column: 0 }, + } + ); + expect(formatMatches(matches)).toEqual([ + { patternIndex: 1, captures: [{ patternIndex: 1, name: 'l_brace', text: '{' }, { patternIndex: 1, name: 'r_brace', text: '}' },] }, + { patternIndex: 1, captures: [{ patternIndex: 1, name: 'l_brace', text: '{' }, { patternIndex: 1, name: 'r_brace', text: '}' },] }, + ]); + }); + + it('can search in contained within byte ranges', () => { + tree = parser.parse(`[ + {"key1": "value1"}, + {"key2": "value2"}, + {"key3": "value3"}, + {"key4": "value4"}, + {"key5": "value5"}, + {"key6": "value6"}, + {"key7": "value7"}, + {"key8": "value8"}, + {"key9": "value9"}, + {"key10": "value10"}, + {"key11": "value11"}, + {"key12": "value12"}, +]`)!; + query = new Query(JavaScript, '("[" @l_bracket "]" @r_bracket) ("{" @l_brace "}" @r_brace)'); + const matches = query.matches( + tree.rootNode, + { + startContainingIndex: 290, + endContainingIndex: 432, + } + ); + expect(formatMatches(matches)).toEqual([ + { patternIndex: 1, captures: [{ patternIndex: 1, name: 'l_brace', text: '{' }, { patternIndex: 1, name: 'r_brace', text: '}' },] }, + { patternIndex: 1, captures: [{ patternIndex: 1, name: 'l_brace', text: '{' }, { patternIndex: 1, name: 'r_brace', text: '}' },] }, + ]); + }); + it('handles predicates that compare the text of capture to literal strings', () => { tree = parser.parse(` giraffe(1, 2, []); diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 264d405d..22c85d48 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -1101,6 +1101,28 @@ bool ts_query_cursor_set_byte_range(TSQueryCursor *self, uint32_t start_byte, ui */ bool ts_query_cursor_set_point_range(TSQueryCursor *self, TSPoint start_point, TSPoint end_point); +/** + * Set the byte range within which all matches must be fully contained. + * + * Set the range of bytes in which matches will be searched for. In contrast to + * `ts_query_cursor_set_byte_range`, this will restrict the query cursor to only return + * matches where _all_ nodes are _fully_ contained within the given range. Both functions + * can be used together, e.g. to search for any matches that intersect line 5000, as + * long as they are fully contained within lines 4500-5500 + */ +bool ts_query_cursor_set_containing_byte_range(TSQueryCursor *self, uint32_t start_byte, uint32_t end_byte); + +/** + * Set the point range within which all matches must be fully contained. + * + * Set the range of bytes in which matches will be searched for. In contrast to + * `ts_query_cursor_set_point_range`, this will restrict the query cursor to only return + * matches where _all_ nodes are _fully_ contained within the given range. Both functions + * can be used together, e.g. to search for any matches that intersect line 5000, as + * long as they are fully contained within lines 4500-5500 + */ +bool ts_query_cursor_set_containing_point_range(TSQueryCursor *self, TSPoint start_point, TSPoint end_point); + /** * Advance to the next match of the currently running query. * diff --git a/lib/src/query.c b/lib/src/query.c index d1695549..7a8b855b 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -318,10 +318,8 @@ struct TSQueryCursor { CaptureListPool capture_list_pool; uint32_t depth; uint32_t max_start_depth; - uint32_t start_byte; - uint32_t end_byte; - TSPoint start_point; - TSPoint end_point; + TSRange included_range; + TSRange containing_range; uint32_t next_state_id; const TSQueryCursorOptions *query_options; TSQueryCursorState query_state; @@ -1336,7 +1334,7 @@ static void ts_query__perform_analysis( // of the query pattern. bool does_match = false; - // ERROR nodes can appear anywhere, so if the step is + // ERROR nodes can appear anywhere, so if the step is // looking for an ERROR node, consider it potentially matchable. if (step->symbol == ts_builtin_sym_error) { does_match = true; @@ -3155,10 +3153,18 @@ TSQueryCursor *ts_query_cursor_new(void) { .states = array_new(), .finished_states = array_new(), .capture_list_pool = capture_list_pool_new(), - .start_byte = 0, - .end_byte = UINT32_MAX, - .start_point = {0, 0}, - .end_point = POINT_MAX, + .included_range = { + .start_point = {0, 0}, + .end_point = POINT_MAX, + .start_byte = 0, + .end_byte = UINT32_MAX, + }, + .containing_range = { + .start_point = {0, 0}, + .end_point = POINT_MAX, + .start_byte = 0, + .end_byte = UINT32_MAX, + }, .max_start_depth = UINT32_MAX, .operation_count = 0, }; @@ -3266,8 +3272,8 @@ bool ts_query_cursor_set_byte_range( if (start_byte > end_byte) { return false; } - self->start_byte = start_byte; - self->end_byte = end_byte; + self->included_range.start_byte = start_byte; + self->included_range.end_byte = end_byte; return true; } @@ -3282,8 +3288,40 @@ bool ts_query_cursor_set_point_range( if (point_gt(start_point, end_point)) { return false; } - self->start_point = start_point; - self->end_point = end_point; + self->included_range.start_point = start_point; + self->included_range.end_point = end_point; + return true; +} + +bool ts_query_cursor_set_containing_byte_range( + TSQueryCursor *self, + uint32_t start_byte, + uint32_t end_byte +) { + if (end_byte == 0) { + end_byte = UINT32_MAX; + } + if (start_byte > end_byte) { + return false; + } + self->containing_range.start_byte = start_byte; + self->containing_range.end_byte = end_byte; + return true; +} + +bool ts_query_cursor_set_containing_point_range( + TSQueryCursor *self, + TSPoint start_point, + TSPoint end_point +) { + if (end_point.row == 0 && end_point.column == 0) { + end_point = POINT_MAX; + } + if (point_gt(start_point, end_point)) { + return false; + } + self->containing_range.start_point = start_point; + self->containing_range.end_point = end_point; return true; } @@ -3314,8 +3352,8 @@ static bool ts_query_cursor__first_in_progress_capture( TSNode node = array_get(captures, state->consumed_capture_count)->node; if ( - ts_node_end_byte(node) <= self->start_byte || - point_lte(ts_node_end_point(node), self->start_point) + ts_node_end_byte(node) <= self->included_range.start_byte || + point_lte(ts_node_end_point(node), self->included_range.start_point) ) { state->consumed_capture_count++; i--; @@ -3771,28 +3809,38 @@ static inline bool ts_query_cursor__advance( bool is_empty = start_byte == end_byte; bool parent_precedes_range = !ts_node_is_null(parent_node) && ( - ts_node_end_byte(parent_node) <= self->start_byte || - point_lte(ts_node_end_point(parent_node), self->start_point) + ts_node_end_byte(parent_node) <= self->included_range.start_byte || + point_lte(ts_node_end_point(parent_node), self->included_range.start_point) ); bool parent_follows_range = !ts_node_is_null(parent_node) && ( - ts_node_start_byte(parent_node) >= self->end_byte || - point_gte(ts_node_start_point(parent_node), self->end_point) + ts_node_start_byte(parent_node) >= self->included_range.end_byte || + point_gte(ts_node_start_point(parent_node), self->included_range.end_point) ); bool node_precedes_range = parent_precedes_range || - end_byte < self->start_byte || - point_lt(end_point, self->start_point) || - (!is_empty && end_byte == self->start_byte) || - (!is_empty && point_eq(end_point, self->start_point)); + end_byte < self->included_range.start_byte || + point_lt(end_point, self->included_range.start_point) || + (!is_empty && end_byte == self->included_range.start_byte) || + (!is_empty && point_eq(end_point, self->included_range.start_point)); bool node_follows_range = parent_follows_range || ( - start_byte >= self->end_byte || - point_gte(start_point, self->end_point) + start_byte >= self->included_range.end_byte || + point_gte(start_point, self->included_range.end_point) ); bool parent_intersects_range = !parent_precedes_range && !parent_follows_range; bool node_intersects_range = !node_precedes_range && !node_follows_range; + bool node_within_containing_range = + start_byte >= self->containing_range.start_byte && + point_gte(start_point, self->containing_range.start_point) && + end_byte <= self->containing_range.end_byte && + point_lte(end_point, self->containing_range.end_point); + bool node_intersects_containing_range = + end_byte > self->containing_range.start_byte && + point_gt(end_point, self->containing_range.start_point) && + start_byte < self->containing_range.end_byte && + point_lt(start_point, self->containing_range.end_point); - if (self->on_visible_node) { + if (node_within_containing_range && self->on_visible_node) { TSSymbol symbol = ts_node_symbol(node); bool is_named = ts_node_is_named(node); bool is_missing = ts_node_is_missing(node); @@ -4182,7 +4230,7 @@ static inline bool ts_query_cursor__advance( } } - if (ts_query_cursor__should_descend(self, node_intersects_range)) { + if (node_intersects_containing_range && ts_query_cursor__should_descend(self, node_intersects_range)) { switch (ts_tree_cursor_goto_first_child_internal(&self->cursor)) { case TreeCursorStepVisible: self->depth++; @@ -4304,12 +4352,12 @@ bool ts_query_cursor_next_capture( TSNode node = array_get(captures, state->consumed_capture_count)->node; bool node_precedes_range = ( - ts_node_end_byte(node) <= self->start_byte || - point_lte(ts_node_end_point(node), self->start_point) + ts_node_end_byte(node) <= self->included_range.start_byte || + point_lte(ts_node_end_point(node), self->included_range.start_point) ); bool node_follows_range = ( - ts_node_start_byte(node) >= self->end_byte || - point_gte(ts_node_start_point(node), self->end_point) + ts_node_start_byte(node) >= self->included_range.end_byte || + point_gte(ts_node_start_point(node), self->included_range.end_point) ); bool node_outside_of_range = node_precedes_range || node_follows_range;