diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 2245f4f9..4fde4f2a 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -3039,32 +3039,49 @@ fn test_query_text_callback_returns_chunks() { #[test] fn test_query_captures_advance_to_byte() { allocations::record(|| { - let language = get_language("javascript"); + let language = get_language("rust"); let query = Query::new( language, r#" - (identifier) @id - (array - "[" @lbracket - "]" @rbracket) + (function_item + name: (identifier) @fn_name) + + (mod_item + name: (identifier) @mod_name + body: (declaration_list + "{" @lbrace + "}" @rbrace)) + + ; functions that return Result<()> + ((function_item + return_type: (generic_type + type: (type_identifier) @result + type_arguments: (type_arguments + (unit_type))) + body: _ @fallible_fn_body) + (#eq? @result "Result")) "#, ) .unwrap(); - let source = "[one, two, [three, four, five, six, seven, eight, nine, ten], eleven, twelve, thirteen]"; + let source = " + mod m1 { + mod m2 { + fn f1() -> Option<()> { Some(()) } + } + fn f2() -> Result<()> { Ok(()) } + fn f3() {} + } + "; let mut parser = Parser::new(); parser.set_language(language).unwrap(); let tree = parser.parse(&source, None).unwrap(); let mut cursor = QueryCursor::new(); - cursor.set_byte_range( - source.find("two").unwrap() + 1, - source.find(", twelve").unwrap(), - ); let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); - // Retrieve four captures. + // Retrieve some captures let mut results = Vec::new(); - for (mat, capture_ix) in captures.by_ref().take(4) { + for (mat, capture_ix) in captures.by_ref().take(5) { let capture = mat.captures[capture_ix as usize]; results.push(( query.capture_names()[capture.index as usize].as_str(), @@ -3074,16 +3091,18 @@ fn test_query_captures_advance_to_byte() { assert_eq!( results, vec![ - ("id", "two"), - ("lbracket", "["), - ("id", "three"), - ("id", "four") + ("mod_name", "m1"), + ("lbrace", "{"), + ("mod_name", "m2"), + ("lbrace", "{"), + ("fn_name", "f1"), ] ); - // Advance further ahead in the source, retrieve the remaining captures. results.clear(); - captures.advance_to_byte(source.find("ten").unwrap() + 1); + captures.advance_to_byte(source.find("Ok").unwrap()); + + // Advance further ahead in the source, retrieve the remaining captures. for (mat, capture_ix) in captures { let capture = mat.captures[capture_ix as usize]; results.push(( @@ -3093,7 +3112,11 @@ fn test_query_captures_advance_to_byte() { } assert_eq!( results, - vec![("id", "ten"), ("rbracket", "]"), ("id", "eleven"),] + vec![ + ("fallible_fn_body", "{ Ok(()) }"), + ("fn_name", "f3"), + ("rbrace", "}") + ] ); // Advance past the last capture. There are no more captures. @@ -3104,6 +3127,46 @@ fn test_query_captures_advance_to_byte() { }); } +#[test] +fn test_query_advance_to_byte_within_node() { + allocations::record(|| { + let language = get_language("rust"); + let query = Query::new( + language, + r#" + (fn_item + name: (identifier) @name + return_type: _? @ret) + + (mod_item + name: (identifier) @name + body: _ @body) + "#, + ) + .unwrap(); + let source = " + fn foo() -> i32 {} + + ... + + mod foo {} + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let mut captures = cursor.captures(&query, tree.root_node(), source.as_bytes()); + + captures.advance_to_byte(source.find("{").unwrap()); + + assert_eq!( + collect_captures(captures, &query, source), + &[("body", "{}"),] + ); + }) +} + #[test] fn test_query_start_byte_for_pattern() { let language = get_language("javascript"); diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index a729c12c..dccc9aca 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -737,18 +737,8 @@ extern "C" { pub fn ts_query_cursor_did_exceed_match_limit(arg1: *const TSQueryCursor) -> bool; } extern "C" { - #[doc = " Get or set the range of bytes or (row, column) positions in which the query"] + #[doc = " Set the range of bytes or (row, column) positions in which the query"] #[doc = " will be executed."] - pub fn ts_query_cursor_byte_range(arg1: *const TSQueryCursor, arg2: *mut u32, arg3: *mut u32); -} -extern "C" { - pub fn ts_query_cursor_point_range( - arg1: *const TSQueryCursor, - arg2: *mut TSPoint, - arg3: *mut TSPoint, - ); -} -extern "C" { pub fn ts_query_cursor_set_byte_range(arg1: *mut TSQueryCursor, arg2: u32, arg3: u32); } extern "C" { @@ -764,6 +754,9 @@ extern "C" { extern "C" { pub fn ts_query_cursor_remove_match(arg1: *mut TSQueryCursor, id: u32); } +extern "C" { + pub fn ts_query_cursor_advance_to_byte(arg1: *mut TSQueryCursor, offset: u32); +} extern "C" { #[doc = " Advance to the next capture of the currently running query."] #[doc = ""] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index ea99c067..da2b3252 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -1812,27 +1812,7 @@ impl<'a, 'tree, T: TextProvider<'a>> Iterator for QueryMatches<'a, 'tree, T> { impl<'a, 'tree, T: TextProvider<'a>> QueryCaptures<'a, 'tree, T> { pub fn advance_to_byte(&mut self, offset: usize) { unsafe { - let mut current_start = 0u32; - let mut current_end = 0u32; - ffi::ts_query_cursor_byte_range( - self.ptr, - &mut current_start as *mut u32, - &mut current_end as *mut u32, - ); - ffi::ts_query_cursor_set_byte_range(self.ptr, offset as u32, current_end); - } - } - - pub fn advance_to_point(&mut self, point: Point) { - unsafe { - let mut current_start = ffi::TSPoint { row: 0, column: 0 }; - let mut current_end = current_start; - ffi::ts_query_cursor_point_range( - self.ptr, - &mut current_start as *mut _, - &mut current_end as *mut _, - ); - ffi::ts_query_cursor_set_point_range(self.ptr, point.into(), current_end); + ffi::ts_query_cursor_advance_to_byte(self.ptr, offset as u32); } } } diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 01d84a6e..6889a121 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -809,11 +809,9 @@ void ts_query_cursor_exec(TSQueryCursor *, const TSQuery *, TSNode); bool ts_query_cursor_did_exceed_match_limit(const TSQueryCursor *); /** - * Get or set the range of bytes or (row, column) positions in which the query + * Set the range of bytes or (row, column) positions in which the query * will be executed. */ -void ts_query_cursor_byte_range(const TSQueryCursor *, uint32_t *, uint32_t *); -void ts_query_cursor_point_range(const TSQueryCursor *, TSPoint *, TSPoint *); void ts_query_cursor_set_byte_range(TSQueryCursor *, uint32_t, uint32_t); void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); @@ -826,6 +824,8 @@ void ts_query_cursor_set_point_range(TSQueryCursor *, TSPoint, TSPoint); bool ts_query_cursor_next_match(TSQueryCursor *, TSQueryMatch *match); void ts_query_cursor_remove_match(TSQueryCursor *, uint32_t id); +void ts_query_cursor_advance_to_byte(TSQueryCursor *, uint32_t offset); + /** * Advance to the next capture of the currently running query. * diff --git a/lib/src/query.c b/lib/src/query.c index 278b3a3c..d70e5afd 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -256,10 +256,7 @@ struct TSQueryCursor { CaptureListPool capture_list_pool; uint32_t depth; uint32_t start_byte; - uint32_t end_byte; uint32_t next_state_id; - TSPoint start_point; - TSPoint end_point; bool ascending; bool halted; bool did_exceed_match_limit; @@ -2264,9 +2261,6 @@ TSQueryCursor *ts_query_cursor_new(void) { .finished_states = array_new(), .capture_list_pool = capture_list_pool_new(), .start_byte = 0, - .end_byte = UINT32_MAX, - .start_point = {0, 0}, - .end_point = POINT_MAX, }; array_reserve(&self->states, 8); array_reserve(&self->finished_states, 8); @@ -2296,40 +2290,18 @@ void ts_query_cursor_exec( capture_list_pool_reset(&self->capture_list_pool); self->next_state_id = 0; self->depth = 0; + self->start_byte = 0; self->ascending = false; self->halted = false; self->query = query; self->did_exceed_match_limit = false; } -void ts_query_cursor_byte_range( - const TSQueryCursor *self, - uint32_t *start_byte, - uint32_t *end_byte -) { - *start_byte = self->start_byte; - *end_byte = self->end_byte; -} - -void ts_query_cursor_point_range( - const TSQueryCursor *self, - TSPoint *start_point, - TSPoint *end_point -) { - *start_point = self->start_point; - *end_point = self->end_point; -} - void ts_query_cursor_set_byte_range( TSQueryCursor *self, uint32_t start_byte, uint32_t end_byte ) { - if (end_byte == 0) { - end_byte = UINT32_MAX; - } - self->start_byte = start_byte; - self->end_byte = end_byte; } void ts_query_cursor_set_point_range( @@ -2337,11 +2309,6 @@ void ts_query_cursor_set_point_range( TSPoint start_point, TSPoint end_point ) { - if (end_point.row == 0 && end_point.column == 0) { - end_point = POINT_MAX; - } - self->start_point = start_point; - self->end_point = end_point; } // Search through all of the in-progress states, and find the captured @@ -2358,31 +2325,41 @@ static bool ts_query_cursor__first_in_progress_capture( *byte_offset = UINT32_MAX; *pattern_index = UINT32_MAX; for (unsigned i = 0; i < self->states.size; i++) { - const QueryState *state = &self->states.contents[i]; + QueryState *state = &self->states.contents[i]; if (state->dead) continue; + const CaptureList *captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); - if (captures->size > state->consumed_capture_count) { - uint32_t capture_byte = ts_node_start_byte(captures->contents[state->consumed_capture_count].node); - if ( - !result || - capture_byte < *byte_offset || - (capture_byte == *byte_offset && state->pattern_index < *pattern_index) - ) { - QueryStep *step = &self->query->steps.contents[state->step_index]; - if (is_definite) { - *is_definite = step->is_definite; - } else if (step->is_definite) { - continue; - } + if (state->consumed_capture_count >= captures->size) { + continue; + } - result = true; - *state_index = i; - *byte_offset = capture_byte; - *pattern_index = state->pattern_index; + TSNode node = captures->contents[state->consumed_capture_count].node; + if (ts_node_end_byte(node) <= self->start_byte) { + state->consumed_capture_count++; + i--; + continue; + } + + uint32_t node_start_byte = ts_node_start_byte(node); + if ( + !result || + node_start_byte < *byte_offset || + (node_start_byte == *byte_offset && state->pattern_index < *pattern_index) + ) { + QueryStep *step = &self->query->steps.contents[state->step_index]; + if (is_definite) { + *is_definite = step->is_definite; + } else if (step->is_definite) { + continue; } + + result = true; + *state_index = i; + *byte_offset = node_start_byte; + *pattern_index = state->pattern_index; } } return result; @@ -2707,26 +2684,8 @@ static inline bool ts_query_cursor__advance( else { // If this node is before the selected range, then avoid descending into it. TSNode node = ts_tree_cursor_current_node(&self->cursor); - if ( - ts_node_end_byte(node) <= self->start_byte || - point_lte(ts_node_end_point(node), self->start_point) - ) { - if (!ts_tree_cursor_goto_next_sibling(&self->cursor)) { - self->ascending = true; - } - LOG("skip until start of range\n"); - continue; - } - // If this node is after the selected range, then stop walking. - if ( - self->end_byte <= ts_node_start_byte(node) || - point_lte(self->end_point, ts_node_start_point(node)) - ) { - LOG("halt at end of range\n"); - self->halted = true; - continue; - } + bool node_exceeds_start_byte = ts_node_end_byte(node) > self->start_byte; // Get the properties of the current node. TSSymbol symbol = ts_node_symbol(node); @@ -2755,36 +2714,44 @@ static inline bool ts_query_cursor__advance( self->finished_states.size ); - // Add new states for any patterns whose root node is a wildcard. - for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { - PatternEntry *pattern = &self->query->pattern_map.contents[i]; - QueryStep *step = &self->query->steps.contents[pattern->step_index]; + if (node_exceeds_start_byte) { + // Add new states for any patterns whose root node is a wildcard. + for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) { + PatternEntry *pattern = &self->query->pattern_map.contents[i]; + QueryStep *step = &self->query->steps.contents[pattern->step_index]; - // If this node matches the first step of the pattern, then add a new - // state at the start of this pattern. - if (step->field && field_id != step->field) continue; - if (step->supertype_symbol && !supertype_count) continue; - ts_query_cursor__add_state(self, pattern); - } - - // Add new states for any patterns whose root node matches this node. - unsigned i; - if (ts_query__pattern_map_search(self->query, symbol, &i)) { - PatternEntry *pattern = &self->query->pattern_map.contents[i]; - QueryStep *step = &self->query->steps.contents[pattern->step_index]; - do { // If this node matches the first step of the pattern, then add a new // state at the start of this pattern. - if (!step->field || field_id == step->field) { - ts_query_cursor__add_state(self, pattern); - } + if (step->field && field_id != step->field) continue; + if (step->supertype_symbol && !supertype_count) continue; + ts_query_cursor__add_state(self, pattern); + } - // Advance to the next pattern whose root node matches this node. - i++; - if (i == self->query->pattern_map.size) break; - pattern = &self->query->pattern_map.contents[i]; - step = &self->query->steps.contents[pattern->step_index]; - } while (step->symbol == symbol); + // Add new states for any patterns whose root node matches this node. + unsigned i; + if (ts_query__pattern_map_search(self->query, symbol, &i)) { + PatternEntry *pattern = &self->query->pattern_map.contents[i]; + QueryStep *step = &self->query->steps.contents[pattern->step_index]; + do { + // If this node matches the first step of the pattern, then add a new + // state at the start of this pattern. + if (!step->field || field_id == step->field) { + ts_query_cursor__add_state(self, pattern); + } + + // Advance to the next pattern whose root node matches this node. + i++; + if (i == self->query->pattern_map.size) break; + pattern = &self->query->pattern_map.contents[i]; + step = &self->query->steps.contents[pattern->step_index]; + } while (step->symbol == symbol); + } + } else { + LOG( + " not starting new patterns. node end byte: %u, start_byte: %u\n", + ts_node_end_byte(node), + self->start_byte + ); } // Update all of the in-progress states with current node. @@ -3070,8 +3037,32 @@ static inline bool ts_query_cursor__advance( } } - // Continue descending if possible. - if (ts_tree_cursor_goto_first_child(&self->cursor)) { + // When the current node ends prior to the desired start offset, + // only descend for the purpose of continuing in-progress matches. + bool should_descend = node_exceeds_start_byte; + if (!should_descend) { + for (unsigned i = 0; i < self->states.size; i++) { + QueryState *state = &self->states.contents[i];; + QueryStep *next_step = &self->query->steps.contents[state->step_index]; + if ( + next_step->depth != PATTERN_DONE_MARKER && + state->start_depth + next_step->depth > self->depth + ) { + should_descend = true; + break; + } + } + } + + if (!should_descend) { + LOG( + " not descending. node end byte: %u, start byte: %u\n", + ts_node_end_byte(node), + self->start_byte + ); + } + + if (should_descend && ts_tree_cursor_goto_first_child(&self->cursor)) { self->depth++; } else { self->ascending = true; @@ -3080,6 +3071,14 @@ static inline bool ts_query_cursor__advance( } } +void ts_query_cursor_advance_to_byte( + TSQueryCursor *self, + uint32_t offset +) { + LOG("advance_to_byte %u\n", offset); + self->start_byte = offset; +} + bool ts_query_cursor_next_match( TSQueryCursor *self, TSQueryMatch *match @@ -3148,35 +3147,43 @@ bool ts_query_cursor_next_capture( QueryState *first_finished_state = NULL; uint32_t first_finished_capture_byte = first_unfinished_capture_byte; uint32_t first_finished_pattern_index = first_unfinished_pattern_index; - for (unsigned i = 0; i < self->finished_states.size; i++) { + for (unsigned i = 0; i < self->finished_states.size;) { QueryState *state = &self->finished_states.contents[i]; const CaptureList *captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); - if (captures->size > state->consumed_capture_count) { - uint32_t capture_byte = ts_node_start_byte( - captures->contents[state->consumed_capture_count].node - ); - if ( - capture_byte < first_finished_capture_byte || - ( - capture_byte == first_finished_capture_byte && - state->pattern_index < first_finished_pattern_index - ) - ) { - first_finished_state = state; - first_finished_capture_byte = capture_byte; - first_finished_pattern_index = state->pattern_index; - } - } else { + + // Remove states whose captures are all consumed. + if (state->consumed_capture_count >= captures->size) { capture_list_pool_release( &self->capture_list_pool, state->capture_list_id ); array_erase(&self->finished_states, i); - i--; + continue; } + + // Skip captures that precede the cursor's start byte. + TSNode node = captures->contents[state->consumed_capture_count].node; + if (ts_node_end_byte(node) <= self->start_byte) { + state->consumed_capture_count++; + continue; + } + + uint32_t node_start_byte = ts_node_start_byte(node); + if ( + node_start_byte < first_finished_capture_byte || + ( + node_start_byte == first_finished_capture_byte && + state->pattern_index < first_finished_pattern_index + ) + ) { + first_finished_state = state; + first_finished_capture_byte = node_start_byte; + first_finished_pattern_index = state->pattern_index; + } + i++; } // If there is finished capture that is clearly before any unfinished