diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 7caf5dcb..fc1d4b3b 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -591,7 +591,60 @@ fn test_query_matches_different_queries_same_cursor() { } #[test] -fn test_query_captures() { +fn test_query_matches_with_multiple_captures_on_a_node() { + allocations::record(|| { + let language = get_language("javascript"); + let mut query = Query::new( + language, + "(function_declaration + (identifier) @name1 @name2 @name3 + (statement_block) @body1 @body2)", + ) + .unwrap(); + + let source = "function foo() { return 1; }"; + let mut parser = Parser::new(); + let mut cursor = QueryCursor::new(); + + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + assert_eq!( + collect_matches(matches, &query, source), + &[( + 0, + vec![ + ("name1", "foo"), + ("name2", "foo"), + ("name3", "foo"), + ("body1", "{ return 1; }"), + ("body2", "{ return 1; }"), + ] + ),] + ); + + // disabling captures still works when there are multiple captures on a + // single node. + query.disable_capture("name2"); + let matches = cursor.matches(&query, tree.root_node(), to_callback(source)); + assert_eq!( + collect_matches(matches, &query, source), + &[( + 0, + vec![ + ("name1", "foo"), + ("name3", "foo"), + ("body1", "{ return 1; }"), + ("body2", "{ return 1; }"), + ] + ),] + ); + }); +} + +#[test] +fn test_query_captures_basic() { allocations::record(|| { let language = get_language("javascript"); let query = Query::new( diff --git a/lib/src/query.c b/lib/src/query.c index 21e8bd9a..7e13978d 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -19,6 +19,8 @@ typedef struct { uint8_t next_size; } Stream; +#define MAX_STEP_CAPTURE_COUNT 4 + /* * QueryStep - A step in the process of matching a query. Each node within * a query S-expression maps to one of these steps. An entire pattern is @@ -37,7 +39,7 @@ typedef struct { typedef struct { TSSymbol symbol; TSFieldId field; - uint16_t capture_id; + uint16_t capture_ids[MAX_STEP_CAPTURE_COUNT]; uint16_t depth: 15; bool contains_captures: 1; } QueryStep; @@ -326,6 +328,44 @@ static uint16_t symbol_table_insert_name( return self->slices.size - 1; } +/************ + * QueryStep + ************/ + +static QueryStep query_step__new(TSSymbol symbol, uint16_t depth) { + return (QueryStep) { + .symbol = symbol, + .depth = depth, + .field = 0, + .capture_ids = {NONE, NONE, NONE, NONE}, + .contains_captures = false, + }; +} + +static void query_step__add_capture(QueryStep *self, uint16_t capture_id) { + for (unsigned i = 0; i < MAX_STEP_CAPTURE_COUNT; i++) { + if (self->capture_ids[i] == NONE) { + self->capture_ids[i] = capture_id; + break; + } + } +} + +static void query_step__remove_capture(QueryStep *self, uint16_t capture_id) { + for (unsigned i = 0; i < MAX_STEP_CAPTURE_COUNT; i++) { + if (self->capture_ids[i] == capture_id) { + self->capture_ids[i] = NONE; + while (i + 1 < MAX_STEP_CAPTURE_COUNT) { + if (self->capture_ids[i + 1] == NONE) break; + self->capture_ids[i] = self->capture_ids[i + 1]; + self->capture_ids[i + 1] = NONE; + i++; + } + break; + } + } +} + /********* * Query *********/ @@ -401,14 +441,14 @@ static void ts_query__finalize_steps(TSQuery *self) { for (unsigned i = 0; i < self->steps.size; i++) { QueryStep *step = &self->steps.contents[i]; uint32_t depth = step->depth; - if (step->capture_id != NONE) { + if (step->capture_ids[0] != NONE) { step->contains_captures = true; } else { step->contains_captures = false; for (unsigned j = i + 1; j < self->steps.size; j++) { QueryStep *s = &self->steps.contents[j]; if (s->depth == PATTERN_DONE_MARKER || s->depth <= depth) break; - if (s->capture_id != NONE) step->contains_captures = true; + if (s->capture_ids[0] != NONE) step->contains_captures = true; } } } @@ -599,13 +639,7 @@ static TSQueryError ts_query__parse_pattern( } // Add a step for the node. - array_push(&self->steps, ((QueryStep) { - .depth = depth, - .symbol = symbol, - .field = 0, - .capture_id = NONE, - .contains_captures = false, - })); + array_push(&self->steps, query_step__new(symbol, depth)); // Parse the child patterns stream_skip_whitespace(stream); @@ -645,13 +679,7 @@ static TSQueryError ts_query__parse_pattern( stream_reset(stream, string_content); return TSQueryErrorNodeType; } - array_push(&self->steps, ((QueryStep) { - .depth = depth, - .symbol = symbol, - .field = 0, - .capture_id = NONE, - .contains_captures = false, - })); + array_push(&self->steps, query_step__new(symbol, depth)); if (stream->next != '"') return TSQueryErrorSyntax; stream_advance(stream); @@ -697,12 +725,7 @@ static TSQueryError ts_query__parse_pattern( stream_skip_whitespace(stream); // Add a step that matches any kind of node - array_push(&self->steps, ((QueryStep) { - .depth = depth, - .symbol = WILDCARD_SYMBOL, - .field = 0, - .contains_captures = false, - })); + array_push(&self->steps, query_step__new(WILDCARD_SYMBOL, depth)); } else { @@ -712,7 +735,7 @@ static TSQueryError ts_query__parse_pattern( stream_skip_whitespace(stream); // Parse an '@'-prefixed capture pattern - if (stream->next == '@') { + while (stream->next == '@') { stream_advance(stream); // Parse the capture name @@ -727,7 +750,8 @@ static TSQueryError ts_query__parse_pattern( capture_name, length ); - self->steps.contents[starting_step_index].capture_id = capture_id; + QueryStep *step = &self->steps.contents[starting_step_index]; + query_step__add_capture(step, capture_id); (*capture_count)++; stream_skip_whitespace(stream); @@ -797,7 +821,7 @@ TSQuery *ts_query_new( .length = 0, })); *error_type = ts_query__parse_pattern(self, &stream, 0, &capture_count); - array_push(&self->steps, ((QueryStep) { .depth = PATTERN_DONE_MARKER })); + array_push(&self->steps, query_step__new(0, PATTERN_DONE_MARKER)); // If any pattern could not be parsed, then report the error information // and terminate. @@ -899,9 +923,7 @@ void ts_query_disable_capture( if (id != -1) { for (unsigned i = 0; i < self->steps.size; i++) { QueryStep *step = &self->steps.contents[i]; - if (step->capture_id == id) { - step->capture_id = NONE; - } + query_step__remove_capture(step, id); } ts_query__finalize_steps(self); } @@ -1280,11 +1302,13 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { // If the current node is captured in this pattern, add it to the // capture list. - if (step->capture_id != NONE) { + for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) { + uint16_t capture_id = step->capture_ids[j]; + if (step->capture_ids[j] == NONE) break; LOG( " capture node. pattern:%u, capture_id:%u\n", next_state->pattern_index, - step->capture_id + capture_id ); TSQueryCapture *capture_list = capture_list_pool_get( &self->capture_list_pool, @@ -1292,7 +1316,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { ); capture_list[next_state->capture_count++] = (TSQueryCapture) { node, - step->capture_id + capture_id }; }