From fda35894d4c9c1cd078e6275ea31884177f456ba Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Tue, 25 May 2021 13:11:22 -0700 Subject: [PATCH] Stop matching new patterns past the end of QueryCursor's range This restores the original signatures of the `set_byte_range` and `set_point_range` functions. Now, the QueryCursor will properly report matches that intersect, but are not fully contained by its range. Co-Authored-By: Nathan Sobo --- cli/src/tests/query_test.rs | 217 ++++++++++++++-------------------- lib/binding_rust/lib.rs | 4 +- lib/include/tree_sitter/api.h | 2 - lib/src/point.h | 4 + lib/src/query.c | 47 +++++--- 5 files changed, 122 insertions(+), 152 deletions(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 4fde4f2a..daa3b04b 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1918,6 +1918,92 @@ fn test_query_captures_within_byte_range() { }); } +#[test] +fn test_query_captures_within_byte_range_assigned_after_iterating() { + allocations::record(|| { + let language = get_language("rust"); + let query = Query::new( + language, + r#" + (function_item + name: (identifier) @fn_name) + + (mod_item + name: (identifier) @mod_name + body: (declaration_list + "{" @lbrace + "}" @rbrace)) + + ; functions that return Result<()> + ((function_item + return_type: (generic_type + type: (type_identifier) @result + type_arguments: (type_arguments + (unit_type))) + body: _ @fallible_fn_body) + (#eq? @result "Result")) + "#, + ) + .unwrap(); + let source = " + mod m1 { + mod m2 { + fn f1() -> Option<()> { Some(()) } + } + fn f2() -> Result<()> { Ok(()) } + fn f3() {} + } + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); + + // Retrieve some captures + let mut results = Vec::new(); + for (mat, capture_ix) in captures.by_ref().take(5) { + let capture = mat.captures[capture_ix as usize]; + results.push(( + query.capture_names()[capture.index as usize].as_str(), + &source[capture.node.byte_range()], + )); + } + assert_eq!( + results, + vec![ + ("mod_name", "m1"), + ("lbrace", "{"), + ("mod_name", "m2"), + ("lbrace", "{"), + ("fn_name", "f1"), + ] + ); + + // Advance to a range that only partially intersects some matches. + // Captures from these matches are reported, but only those that + // intersect the range. + results.clear(); + captures.set_byte_range(source.find("Ok").unwrap(), source.len()); + for (mat, capture_ix) in captures { + let capture = mat.captures[capture_ix as usize]; + results.push(( + query.capture_names()[capture.index as usize].as_str(), + &source[capture.node.byte_range()], + )); + } + assert_eq!( + results, + vec![ + ("fallible_fn_body", "{ Ok(()) }"), + ("fn_name", "f3"), + ("rbrace", "}") + ] + ); + }); +} + #[test] fn test_query_matches_different_queries_same_cursor() { allocations::record(|| { @@ -3036,137 +3122,6 @@ fn test_query_text_callback_returns_chunks() { }); } -#[test] -fn test_query_captures_advance_to_byte() { - allocations::record(|| { - let language = get_language("rust"); - let query = Query::new( - language, - r#" - (function_item - name: (identifier) @fn_name) - - (mod_item - name: (identifier) @mod_name - body: (declaration_list - "{" @lbrace - "}" @rbrace)) - - ; functions that return Result<()> - ((function_item - return_type: (generic_type - type: (type_identifier) @result - type_arguments: (type_arguments - (unit_type))) - body: _ @fallible_fn_body) - (#eq? @result "Result")) - "#, - ) - .unwrap(); - let source = " - mod m1 { - mod m2 { - fn f1() -> Option<()> { Some(()) } - } - fn f2() -> Result<()> { Ok(()) } - fn f3() {} - } - "; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(&source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); - - // Retrieve some captures - let mut results = Vec::new(); - for (mat, capture_ix) in captures.by_ref().take(5) { - let capture = mat.captures[capture_ix as usize]; - results.push(( - query.capture_names()[capture.index as usize].as_str(), - &source[capture.node.byte_range()], - )); - } - assert_eq!( - results, - vec![ - ("mod_name", "m1"), - ("lbrace", "{"), - ("mod_name", "m2"), - ("lbrace", "{"), - ("fn_name", "f1"), - ] - ); - - results.clear(); - captures.advance_to_byte(source.find("Ok").unwrap()); - - // Advance further ahead in the source, retrieve the remaining captures. - for (mat, capture_ix) in captures { - let capture = mat.captures[capture_ix as usize]; - results.push(( - query.capture_names()[capture.index as usize].as_str(), - &source[capture.node.byte_range()], - )); - } - assert_eq!( - results, - vec![ - ("fallible_fn_body", "{ Ok(()) }"), - ("fn_name", "f3"), - ("rbrace", "}") - ] - ); - - // Advance past the last capture. There are no more captures. - let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); - captures.advance_to_byte(source.len()); - assert!(captures.next().is_none()); - assert!(captures.next().is_none()); - }); -} - -#[test] -fn test_query_advance_to_byte_within_node() { - allocations::record(|| { - let language = get_language("rust"); - let query = Query::new( - language, - r#" - (fn_item - name: (identifier) @name - return_type: _? @ret) - - (mod_item - name: (identifier) @name - body: _ @body) - "#, - ) - .unwrap(); - let source = " - fn foo() -> i32 {} - - ... - - mod foo {} - "; - - let mut parser = Parser::new(); - parser.set_language(language).unwrap(); - let tree = parser.parse(&source, None).unwrap(); - let mut cursor = QueryCursor::new(); - let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); - - captures.advance_to_byte(source.find("{").unwrap()); - - assert_eq!( - collect_captures(captures, &query, source), - &[("body", "{}"),] - ); - }) -} - #[test] fn test_query_start_byte_for_pattern() { let language = get_language("javascript"); diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index da2b3252..88124f08 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1810,9 +1810,9 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> { } impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> { - pub fn advance_to_byte(&mut self, offset: usize) { + pub fn set_byte_range(&mut self, start: usize, end: usize) { unsafe { - ffi::ts_query_cursor_advance_to_byte(self.ptr, offset as u32); + ffi::ts_query_cursor_set_byte_range(self.ptr, start as u32, end as u32); } } } diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 6889a121..43315415 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -824,8 +824,6 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match); void ts_query_cursor_remove_match(TSQueryCursor *, uint32_t id); -void ts_query_cursor_advance_to_byte(TSQueryCursor *, uint32_t offset); - /** * Advance to the next capture of the currently running query. * diff --git a/lib/src/point.h b/lib/src/point.h index a50d2021..c3bf3c26 100644 --- a/lib/src/point.h +++ b/lib/src/point.h @@ -33,6 +33,10 @@ static inline bool point_lt(TSPoint a, TSPoint b) { return (a.row < b.row) || (a.row == b.row && a.column < b.column); } +static inline bool point_gt(TSPoint a, TSPoint b) { + return (a.row > b.row) || (a.row == b.row && a.column > b.column); +} + static inline bool point_eq(TSPoint a, TSPoint b) { return a.row == b.row && a.column == b.column; } diff --git a/lib/src/query.c b/lib/src/query.c index d70e5afd..00f66ec0 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -256,6 +256,9 @@ struct TSQueryCursor { CaptureListPool capture_list_pool; uint32_t depth; uint32_t start_byte; + uint32_t end_byte; + TSPoint start_point; + TSPoint end_point; uint32_t next_state_id; bool ascending; bool halted; @@ -2261,6 +2264,9 @@ TSQueryCursor *ts_query_cursor_new(void) { .finished_states = array_new(), .capture_list_pool = capture_list_pool_new(), .start_byte = 0, + .end_byte = UINT32_MAX, + .start_point = {0, 0}, + .end_point = POINT_MAX, }; array_reserve(&self->states, 8); array_reserve(&self->finished_states, 8); @@ -2290,7 +2296,6 @@ void ts_query_cursor_exec( capture_list_pool_reset(&self->capture_list_pool); self->next_state_id = 0; self->depth = 0; - self->start_byte = 0; self->ascending = false; self->halted = false; self->query = query; @@ -2302,6 +2307,11 @@ void ts_query_cursor_set_byte_range( uint32_t start_byte, uint32_t end_byte ) { + if (end_byte == 0) { + end_byte = UINT32_MAX; + } + self->start_byte = start_byte; + self->end_byte = end_byte; } void ts_query_cursor_set_point_range( @@ -2309,6 +2319,11 @@ void ts_query_cursor_set_point_range( TSPoint start_point, TSPoint end_point ) { + if (end_point.row == 0 && end_point.column == 0) { + end_point = POINT_MAX; + } + self->start_point = start_point; + self->end_point = end_point; } // Search through all of the in-progress states, and find the captured @@ -2337,7 +2352,10 @@ static bool ts_query_cursor__first_in_progress_capture( } TSNode node = captures->contents[state->consumed_capture_count].node; - if (ts_node_end_byte(node) <= self->start_byte) { + if ( + ts_node_end_byte(node) <= self->start_byte || + point_lte(ts_node_end_point(node), self->start_point) + ) { state->consumed_capture_count++; i--; continue; @@ -2682,12 +2700,8 @@ static inline bool ts_query_cursor__advance( // Enter a new node. else { - // If this node is before the selected range, then avoid descending into it. - TSNode node = ts_tree_cursor_current_node(&self->cursor); - - bool node_exceeds_start_byte = ts_node_end_byte(node) > self->start_byte; - // Get the properties of the current node. + TSNode node = ts_tree_cursor_current_node(&self->cursor); TSSymbol symbol = ts_node_symbol(node); bool is_named = ts_node_is_named(node); bool has_later_siblings; @@ -2714,7 +2728,14 @@ static inline bool ts_query_cursor__advance( self->finished_states.size ); - if (node_exceeds_start_byte) { + bool node_intersects_range = ( + ts_node_end_byte(node) > self->start_byte && + ts_node_start_byte(node) < self->end_byte && + point_gt(ts_node_end_point(node), self->start_point) && + point_lt(ts_node_start_point(node), self->end_point) + ); + + if (node_intersects_range) { // Add new states for any patterns whose root node is a wildcard. for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { PatternEntry *pattern = &self->query->pattern_map.contents[i]; @@ -3039,7 +3060,7 @@ static inline bool ts_query_cursor__advance( // When the current node ends prior to the desired start offset, // only descend for the purpose of continuing in-progress matches. - bool should_descend = node_exceeds_start_byte; + bool should_descend = node_intersects_range; if (!should_descend) { for (unsigned i = 0; i < self->states.size; i++) { QueryState *state = &self->states.contents[i];; @@ -3071,14 +3092,6 @@ static inline bool ts_query_cursor__advance( } } -void ts_query_cursor_advance_to_byte( - TSQueryCursor *self, - uint32_t offset -) { - LOG("advance_to_byte %u\n", offset); - self->start_byte = offset; -} - bool ts_query_cursor_next_match( TSQueryCursor *self, TSQueryMatch *match