diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 87420501..c6980d45 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -464,6 +464,7 @@ fn test_query_matches_with_wildcard_at_the_root() { ); }); } + #[test] fn test_query_with_immediate_siblings() { allocations::record(|| { @@ -515,6 +516,73 @@ fn test_query_with_immediate_siblings() { }); } +#[test] +fn test_query_matches_with_repeated_nodes() { + allocations::record(|| { + let language = get_language("javascript"); + + let query = Query::new( + language, + " + (* + (comment)+ @doc + . + (class_declaration + name: (identifier) @name)) + + (* + (comment)+ @doc + . + (function_declaration + name: (identifier) @name)) + ", + ) + .unwrap(); + + let source = " + // one + // two + a(); + + // three + { + // four + // five + // six + class B {} + + // seven + c(); + + // eight + function d() {} + } + "; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + + assert_eq!( + collect_matches(matches, &query, source), + &[ + ( + 0, + vec![ + ("doc", "// four"), + ("doc", "// five"), + ("doc", "// six"), + ("name", "B") + ] + ), + (1, vec![("doc", "// eight"), ("name", "d")]), + ] + ); + }); +} + #[test] fn test_query_matches_in_language_with_simple_aliases() { allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index 65144395..b93c2ea4 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -40,10 +40,11 @@ typedef struct { TSSymbol symbol; TSFieldId field; uint16_t capture_ids[MAX_STEP_CAPTURE_COUNT]; - uint16_t depth: 13; + uint16_t depth: 12; bool contains_captures: 1; bool is_immediate: 1; bool is_last: 1; + bool is_repeated: 1; } QueryStep; /* @@ -88,12 +89,15 @@ typedef struct { uint16_t start_depth; uint16_t pattern_index; uint16_t step_index; - uint16_t capture_count; uint16_t capture_list_id; uint16_t consumed_capture_count; uint32_t id; + uint16_t current_step_match_count; + bool seeking_non_match; } QueryState; +typedef Array(TSQueryCapture) CaptureList; + /* * CaptureListPool - A collection of *lists* of captures. Each QueryState * needs to maintain its own list of captures. They are all represented as @@ -101,7 +105,7 @@ typedef struct { * parts of the shared array are currently in use by a QueryState. */ typedef struct { - Array(TSQueryCapture) list; + CaptureList list[32]; uint32_t usage_map; } CaptureListPool; @@ -233,24 +237,22 @@ static void stream_scan_identifier(Stream *stream) { static CaptureListPool capture_list_pool_new() { return (CaptureListPool) { - .list = array_new(), .usage_map = UINT32_MAX, }; } -static void capture_list_pool_reset(CaptureListPool *self, uint16_t list_size) { +static void capture_list_pool_reset(CaptureListPool *self) { self->usage_map = UINT32_MAX; - uint32_t total_size = MAX_STATE_COUNT * list_size; - array_reserve(&self->list, total_size); - self->list.size = total_size; } static void capture_list_pool_delete(CaptureListPool *self) { - array_delete(&self->list); + for (unsigned i = 0; i < 32; i++) { + array_delete(&self->list[i]); + } } -static TSQueryCapture *capture_list_pool_get(CaptureListPool *self, uint16_t id) { - return &self->list.contents[id * (self->list.size / MAX_STATE_COUNT)]; +static CaptureList *capture_list_pool_get(CaptureListPool *self, uint16_t id) { + return &self->list[id]; } static bool capture_list_pool_is_empty(const CaptureListPool *self) { @@ -269,6 +271,7 @@ static uint16_t capture_list_pool_acquire(CaptureListPool *self) { } static void capture_list_pool_release(CaptureListPool *self, uint16_t id) { + array_clear(&self->list[id]); self->usage_map |= bitmask_for_index(id); } @@ -408,6 +411,7 @@ static QueryStep query_step__new( .capture_ids = {NONE, NONE, NONE, NONE}, .contains_captures = false, .is_immediate = is_immediate, + .is_repeated = false, }; } @@ -842,27 +846,42 @@ static TSQueryError ts_query__parse_pattern( stream_skip_whitespace(stream); - // Parse an '@'-prefixed capture pattern - while (stream->next == '@') { - stream_advance(stream); - - // Parse the capture name - if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax; - const char *capture_name = stream->input; - stream_scan_identifier(stream); - uint32_t length = stream->input - capture_name; - - // Add the capture id to the first step of the pattern - uint16_t capture_id = symbol_table_insert_name( - &self->captures, - capture_name, - length - ); + // Parse suffixes modifiers for this pattern + for (;;) { QueryStep *step = &self->steps.contents[starting_step_index]; - query_step__add_capture(step, capture_id); - (*capture_count)++; - stream_skip_whitespace(stream); + if (stream->next == '+') { + stream_advance(stream); + step->is_repeated = true; + stream_skip_whitespace(stream); + } + + // Parse an '@'-prefixed capture pattern + else if (stream->next == '@') { + stream_advance(stream); + + // Parse the capture name + if (!stream_is_ident_start(stream)) return TSQueryErrorSyntax; + const char *capture_name = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - capture_name; + + // Add the capture id to the first step of the pattern + uint16_t capture_id = symbol_table_insert_name( + &self->captures, + capture_name, + length + ); + query_step__add_capture(step, capture_id); + (*capture_count)++; + + stream_skip_whitespace(stream); + } + + // No more suffix modifiers + else { + break; + } } return 0; @@ -1089,7 +1108,7 @@ void ts_query_cursor_exec( array_clear(&self->states); array_clear(&self->finished_states); ts_tree_cursor_reset(&self->cursor, node); - capture_list_pool_reset(&self->capture_list_pool, query->max_capture_count); + capture_list_pool_reset(&self->capture_list_pool); self->next_state_id = 0; self->depth = 0; self->ascending = false; @@ -1133,12 +1152,12 @@ static bool ts_query_cursor__first_in_progress_capture( bool result = false; for (unsigned i = 0; i < self->states.size; i++) { const QueryState *state = &self->states.contents[i]; - if (state->capture_count > 0) { - const TSQueryCapture *captures = capture_list_pool_get( - &self->capture_list_pool, - state->capture_list_id - ); - uint32_t capture_byte = ts_node_start_byte(captures[0].node); + const CaptureList *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + if (captures->size > 0) { + uint32_t capture_byte = ts_node_start_byte(captures->contents[0].node); if ( !result || capture_byte < *byte_offset || @@ -1192,8 +1211,9 @@ static bool ts_query__cursor_add_state( .step_index = pattern->step_index, .pattern_index = pattern->pattern_index, .start_depth = self->depth, - .capture_count = 0, .consumed_capture_count = 0, + .current_step_match_count = 0, + .seeking_non_match = false, })); return true; } @@ -1207,15 +1227,15 @@ static QueryState *ts_query__cursor_copy_state( array_push(&self->states, *state); QueryState *new_state = array_back(&self->states); new_state->capture_list_id = new_list_id; - TSQueryCapture *old_captures = capture_list_pool_get( + CaptureList *old_captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); - TSQueryCapture *new_captures = capture_list_pool_get( + CaptureList *new_captures = capture_list_pool_get( &self->capture_list_pool, new_list_id ); - memcpy(new_captures, old_captures, state->capture_count * sizeof(TSQueryCapture)); + array_push_all(new_captures, old_captures); return new_state; } @@ -1371,7 +1391,27 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { } } - if (!node_does_match) { + if (node_does_match) { + // The `seeking_non_match` flag indicates that a previous QueryState + // has already begun processing this repeating sequence, so that *this* + // QueryState should not begin matching until a separate repeating sequence + // is found. + if (state->seeking_non_match) continue; + } else { + // If this QueryState has processed a repeating sequence, and that repeating + // sequence has ended, move on to the *next* step of this state's pattern. + if (state->current_step_match_count > 0) { + LOG( + " finish repetition state. pattern:%u, step:%u\n", + state->pattern_index, + state->step_index + ); + state->step_index++; + state->current_step_match_count = 0; + i--; + continue; + } + if (!later_sibling_can_match) { LOG( " discard state. pattern:%u, step:%u\n", @@ -1386,6 +1426,8 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { i--; n--; } + + state->seeking_non_match = false; continue; } @@ -1400,9 +1442,18 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { if ( step->depth > 0 && step->contains_captures && - later_sibling_can_match + later_sibling_can_match && + state->current_step_match_count == 0 ) { QueryState *copy = ts_query__cursor_copy_state(self, state); + + // The QueryState that matched this node has begun matching a repeating + // sequence. The QueryState that *skipped* this node should not start + // matching later elements of the same repeating sequence. + if (step->is_repeated) { + state->seeking_non_match = true; + } + if (copy) { LOG( " split state. pattern:%u, step:%u\n", @@ -1411,7 +1462,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { ); next_state = copy; } else { - LOG(" canot split state.\n"); + LOG(" cannot split state.\n"); } } @@ -1431,35 +1482,44 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { next_state->pattern_index, capture_id ); - TSQueryCapture *capture_list = capture_list_pool_get( + CaptureList *capture_list = capture_list_pool_get( &self->capture_list_pool, next_state->capture_list_id ); - capture_list[next_state->capture_count++] = (TSQueryCapture) { + array_push(capture_list, ((TSQueryCapture) { node, capture_id - }; + })); } - // If the pattern is now done, then remove it from the list of - // in-progress states, and add it to the list of finished states. - next_state->step_index++; - QueryStep *next_step = step + 1; - if (next_step->depth == PATTERN_DONE_MARKER) { - LOG(" finish pattern %u\n", next_state->pattern_index); + // If this step repeats, then don't move to the next step until + // this step no longer matches. + if (step->is_repeated) { + next_state->current_step_match_count++; + } else { + next_state->step_index++; + next_state->current_step_match_count = 0; + QueryStep *next_step = step + 1; - next_state->id = self->next_state_id++; - array_push(&self->finished_states, *next_state); - if (next_state == state) { - array_erase(&self->states, i); - i--; - n--; - } else { - self->states.size--; + // If the pattern is now done, then remove it from the list of + // in-progress states, and add it to the list of finished states. + if (next_step->depth == PATTERN_DONE_MARKER) { + LOG(" finish pattern %u\n", next_state->pattern_index); + + next_state->id = self->next_state_id++; + array_push(&self->finished_states, *next_state); + if (next_state == state) { + array_erase(&self->states, i); + i--; + n--; + } else { + self->states.size--; + } } } } + // Continue descending if possible. if (ts_tree_cursor_goto_first_child(&self->cursor)) { self->depth++; @@ -1485,11 +1545,12 @@ bool ts_query_cursor_next_match( QueryState *state = &self->finished_states.contents[0]; match->id = state->id; match->pattern_index = state->pattern_index; - match->capture_count = state->capture_count; - match->captures = capture_list_pool_get( + CaptureList *captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); + match->captures = captures->contents; + match->capture_count = captures->size; capture_list_pool_release(&self->capture_list_pool, state->capture_list_id); array_erase(&self->finished_states, 0); return true; @@ -1542,13 +1603,13 @@ bool ts_query_cursor_next_capture( uint32_t first_finished_pattern_index = first_unfinished_pattern_index; for (unsigned i = 0; i < self->finished_states.size; i++) { const QueryState *state = &self->finished_states.contents[i]; - if (state->capture_count > state->consumed_capture_count) { - const TSQueryCapture *captures = capture_list_pool_get( - &self->capture_list_pool, - state->capture_list_id - ); + CaptureList *captures = capture_list_pool_get( + &self->capture_list_pool, + state->capture_list_id + ); + if (captures->size > state->consumed_capture_count) { uint32_t capture_byte = ts_node_start_byte( - captures[state->consumed_capture_count].node + captures->contents[state->consumed_capture_count].node ); if ( capture_byte < first_finished_capture_byte || @@ -1580,11 +1641,12 @@ bool ts_query_cursor_next_capture( ]; match->id = state->id; match->pattern_index = state->pattern_index; - match->capture_count = state->capture_count; - match->captures = capture_list_pool_get( + CaptureList *captures = capture_list_pool_get( &self->capture_list_pool, state->capture_list_id ); + match->captures = captures->contents; + match->capture_count = captures->size; *capture_index = state->consumed_capture_count; state->consumed_capture_count++; return true;