diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index d4f18c7d..06ecc42e 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -1189,6 +1189,45 @@ fn test_query_matches_within_byte_range() { }); } +#[test] +fn test_query_captures_within_byte_range() { + allocations::record(|| { + let language = get_language("c"); + let query = Query::new( + language, + " + (call_expression + function: (identifier) @function + arguments: (argument_list (string_literal) @string.arg)) + + (string_literal) @string + ", + ) + .unwrap(); + + let source = r#"DEFUN ("safe-length", Fsafe_length, Ssafe_length, 1, 1, 0)"#; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + + let mut cursor = QueryCursor::new(); + let captures = + cursor + .set_byte_range(3, 27) + .captures(&query, tree.root_node(), to_callback(source)); + + assert_eq!( + collect_captures(captures, &query, source), + &[ + ("function", "DEFUN"), + ("string.arg", "\"safe-length\""), + ("string", "\"safe-length\""), + ] + ); + }); +} + #[test] fn test_query_matches_different_queries_same_cursor() { allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index ff243494..b95ba057 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -172,6 +172,7 @@ struct TSQueryCursor { TSPoint start_point; TSPoint end_point; bool ascending; + bool halted; }; static const TSQueryError PARENT_DONE = -1; @@ -1286,6 +1287,7 @@ TSQueryCursor *ts_query_cursor_new(void) { TSQueryCursor *self = ts_malloc(sizeof(TSQueryCursor)); *self = (TSQueryCursor) { .ascending = false, + .halted = false, .states = array_new(), .finished_states = array_new(), .capture_list_pool = capture_list_pool_new(), @@ -1319,6 +1321,7 @@ void ts_query_cursor_exec( self->next_state_id = 0; self->depth = 0; self->ascending = false; + self->halted = false; self->query = query; } @@ -1522,18 +1525,30 @@ static QueryState *ts_query__cursor_copy_state( // `finished_states` array. Multiple patterns can finish on the same node. If // there are no more matches, return `false`. static inline bool ts_query_cursor__advance(TSQueryCursor *self) { - do { + bool did_match = false; + for (;;) { + if (self->halted) { + while (self->states.size > 0) { + QueryState state = array_pop(&self->states); + capture_list_pool_release( + &self->capture_list_pool, + state.capture_list_id + ); + } + } + + if (did_match || self->halted) return did_match; + if (self->ascending) { LOG("leave node. type:%s\n", ts_node_type(ts_tree_cursor_current_node(&self->cursor))); // Leave this node by stepping to its next sibling or to its parent. - bool did_move = true; if (ts_tree_cursor_goto_next_sibling(&self->cursor)) { self->ascending = false; } else if (ts_tree_cursor_goto_parent(&self->cursor)) { self->depth--; } else { - did_move = false; + self->halted = true; } // After leaving a node, remove any states that cannot make further progress. @@ -1545,10 +1560,11 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { // If a state completed its pattern inside of this node, but was deferred from finishing // in order to search for longer matches, mark it as finished. if (step->depth == PATTERN_DONE_MARKER) { - if (state->start_depth > self->depth || !did_move) { + if (state->start_depth > self->depth || self->halted) { LOG(" finish pattern %u\n", state->pattern_index); state->id = self->next_state_id++; array_push(&self->finished_states, *state); + did_match = true; deleted_count++; continue; } @@ -1575,10 +1591,6 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { } } self->states.size -= deleted_count; - - if (!did_move) { - return self->finished_states.size > 0; - } } else { // If this node is before the selected range, then avoid descending into it. TSNode node = ts_tree_cursor_current_node(&self->cursor); @@ -1596,7 +1608,10 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { if ( self->end_byte <= ts_node_start_byte(node) || point_lte(self->end_point, ts_node_start_point(node)) - ) return false; + ) { + self->halted = true; + continue; + } // Get the properties of the current node. TSSymbol symbol = ts_node_symbol(node); @@ -1888,6 +1903,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { state->id = self->next_state_id++; array_push(&self->finished_states, *state); array_erase(&self->states, state - self->states.contents); + did_match = true; i--; } } @@ -1901,9 +1917,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { self->ascending = true; } } - } while (self->finished_states.size == 0); - - return true; + } } bool ts_query_cursor_next_match( @@ -2043,7 +2057,10 @@ bool ts_query_cursor_next_capture( // If there are no finished matches that are ready to be returned, then // continue finding more matches. - if (!ts_query_cursor__advance(self)) return false; + if ( + !ts_query_cursor__advance(self) && + self->finished_states.size == 0 + ) return false; } }