query: Allow multiple captures on a single node
This commit is contained in:
parent
631710bada
commit
de8e3ee188
2 changed files with 110 additions and 33 deletions
|
|
@ -591,7 +591,60 @@ fn test_query_matches_different_queries_same_cursor() {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_captures() {
|
||||
fn test_query_matches_with_multiple_captures_on_a_node() {
|
||||
allocations::record(|| {
|
||||
let language = get_language("javascript");
|
||||
let mut query = Query::new(
|
||||
language,
|
||||
"(function_declaration
|
||||
(identifier) @name1 @name2 @name3
|
||||
(statement_block) @body1 @body2)",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let source = "function foo() { return 1; }";
|
||||
let mut parser = Parser::new();
|
||||
let mut cursor = QueryCursor::new();
|
||||
|
||||
parser.set_language(language).unwrap();
|
||||
let tree = parser.parse(&source, None).unwrap();
|
||||
|
||||
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
|
||||
assert_eq!(
|
||||
collect_matches(matches, &query, source),
|
||||
&[(
|
||||
0,
|
||||
vec![
|
||||
("name1", "foo"),
|
||||
("name2", "foo"),
|
||||
("name3", "foo"),
|
||||
("body1", "{ return 1; }"),
|
||||
("body2", "{ return 1; }"),
|
||||
]
|
||||
),]
|
||||
);
|
||||
|
||||
// disabling captures still works when there are multiple captures on a
|
||||
// single node.
|
||||
query.disable_capture("name2");
|
||||
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
|
||||
assert_eq!(
|
||||
collect_matches(matches, &query, source),
|
||||
&[(
|
||||
0,
|
||||
vec![
|
||||
("name1", "foo"),
|
||||
("name3", "foo"),
|
||||
("body1", "{ return 1; }"),
|
||||
("body2", "{ return 1; }"),
|
||||
]
|
||||
),]
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_captures_basic() {
|
||||
allocations::record(|| {
|
||||
let language = get_language("javascript");
|
||||
let query = Query::new(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ typedef struct {
|
|||
uint8_t next_size;
|
||||
} Stream;
|
||||
|
||||
#define MAX_STEP_CAPTURE_COUNT 4
|
||||
|
||||
/*
|
||||
* QueryStep - A step in the process of matching a query. Each node within
|
||||
* a query S-expression maps to one of these steps. An entire pattern is
|
||||
|
|
@ -37,7 +39,7 @@ typedef struct {
|
|||
typedef struct {
|
||||
TSSymbol symbol;
|
||||
TSFieldId field;
|
||||
uint16_t capture_id;
|
||||
uint16_t capture_ids[MAX_STEP_CAPTURE_COUNT];
|
||||
uint16_t depth: 15;
|
||||
bool contains_captures: 1;
|
||||
} QueryStep;
|
||||
|
|
@ -326,6 +328,44 @@ static uint16_t symbol_table_insert_name(
|
|||
return self->slices.size - 1;
|
||||
}
|
||||
|
||||
/************
|
||||
* QueryStep
|
||||
************/
|
||||
|
||||
static QueryStep query_step__new(TSSymbol symbol, uint16_t depth) {
|
||||
return (QueryStep) {
|
||||
.symbol = symbol,
|
||||
.depth = depth,
|
||||
.field = 0,
|
||||
.capture_ids = {NONE, NONE, NONE, NONE},
|
||||
.contains_captures = false,
|
||||
};
|
||||
}
|
||||
|
||||
static void query_step__add_capture(QueryStep *self, uint16_t capture_id) {
|
||||
for (unsigned i = 0; i < MAX_STEP_CAPTURE_COUNT; i++) {
|
||||
if (self->capture_ids[i] == NONE) {
|
||||
self->capture_ids[i] = capture_id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void query_step__remove_capture(QueryStep *self, uint16_t capture_id) {
|
||||
for (unsigned i = 0; i < MAX_STEP_CAPTURE_COUNT; i++) {
|
||||
if (self->capture_ids[i] == capture_id) {
|
||||
self->capture_ids[i] = NONE;
|
||||
while (i + 1 < MAX_STEP_CAPTURE_COUNT) {
|
||||
if (self->capture_ids[i + 1] == NONE) break;
|
||||
self->capture_ids[i] = self->capture_ids[i + 1];
|
||||
self->capture_ids[i + 1] = NONE;
|
||||
i++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*********
|
||||
* Query
|
||||
*********/
|
||||
|
|
@ -401,14 +441,14 @@ static void ts_query__finalize_steps(TSQuery *self) {
|
|||
for (unsigned i = 0; i < self->steps.size; i++) {
|
||||
QueryStep *step = &self->steps.contents[i];
|
||||
uint32_t depth = step->depth;
|
||||
if (step->capture_id != NONE) {
|
||||
if (step->capture_ids[0] != NONE) {
|
||||
step->contains_captures = true;
|
||||
} else {
|
||||
step->contains_captures = false;
|
||||
for (unsigned j = i + 1; j < self->steps.size; j++) {
|
||||
QueryStep *s = &self->steps.contents[j];
|
||||
if (s->depth == PATTERN_DONE_MARKER || s->depth <= depth) break;
|
||||
if (s->capture_id != NONE) step->contains_captures = true;
|
||||
if (s->capture_ids[0] != NONE) step->contains_captures = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -599,13 +639,7 @@ static TSQueryError ts_query__parse_pattern(
|
|||
}
|
||||
|
||||
// Add a step for the node.
|
||||
array_push(&self->steps, ((QueryStep) {
|
||||
.depth = depth,
|
||||
.symbol = symbol,
|
||||
.field = 0,
|
||||
.capture_id = NONE,
|
||||
.contains_captures = false,
|
||||
}));
|
||||
array_push(&self->steps, query_step__new(symbol, depth));
|
||||
|
||||
// Parse the child patterns
|
||||
stream_skip_whitespace(stream);
|
||||
|
|
@ -645,13 +679,7 @@ static TSQueryError ts_query__parse_pattern(
|
|||
stream_reset(stream, string_content);
|
||||
return TSQueryErrorNodeType;
|
||||
}
|
||||
array_push(&self->steps, ((QueryStep) {
|
||||
.depth = depth,
|
||||
.symbol = symbol,
|
||||
.field = 0,
|
||||
.capture_id = NONE,
|
||||
.contains_captures = false,
|
||||
}));
|
||||
array_push(&self->steps, query_step__new(symbol, depth));
|
||||
|
||||
if (stream->next != '"') return TSQueryErrorSyntax;
|
||||
stream_advance(stream);
|
||||
|
|
@ -697,12 +725,7 @@ static TSQueryError ts_query__parse_pattern(
|
|||
stream_skip_whitespace(stream);
|
||||
|
||||
// Add a step that matches any kind of node
|
||||
array_push(&self->steps, ((QueryStep) {
|
||||
.depth = depth,
|
||||
.symbol = WILDCARD_SYMBOL,
|
||||
.field = 0,
|
||||
.contains_captures = false,
|
||||
}));
|
||||
array_push(&self->steps, query_step__new(WILDCARD_SYMBOL, depth));
|
||||
}
|
||||
|
||||
else {
|
||||
|
|
@ -712,7 +735,7 @@ static TSQueryError ts_query__parse_pattern(
|
|||
stream_skip_whitespace(stream);
|
||||
|
||||
// Parse an '@'-prefixed capture pattern
|
||||
if (stream->next == '@') {
|
||||
while (stream->next == '@') {
|
||||
stream_advance(stream);
|
||||
|
||||
// Parse the capture name
|
||||
|
|
@ -727,7 +750,8 @@ static TSQueryError ts_query__parse_pattern(
|
|||
capture_name,
|
||||
length
|
||||
);
|
||||
self->steps.contents[starting_step_index].capture_id = capture_id;
|
||||
QueryStep *step = &self->steps.contents[starting_step_index];
|
||||
query_step__add_capture(step, capture_id);
|
||||
(*capture_count)++;
|
||||
|
||||
stream_skip_whitespace(stream);
|
||||
|
|
@ -797,7 +821,7 @@ TSQuery *ts_query_new(
|
|||
.length = 0,
|
||||
}));
|
||||
*error_type = ts_query__parse_pattern(self, &stream, 0, &capture_count);
|
||||
array_push(&self->steps, ((QueryStep) { .depth = PATTERN_DONE_MARKER }));
|
||||
array_push(&self->steps, query_step__new(0, PATTERN_DONE_MARKER));
|
||||
|
||||
// If any pattern could not be parsed, then report the error information
|
||||
// and terminate.
|
||||
|
|
@ -899,9 +923,7 @@ void ts_query_disable_capture(
|
|||
if (id != -1) {
|
||||
for (unsigned i = 0; i < self->steps.size; i++) {
|
||||
QueryStep *step = &self->steps.contents[i];
|
||||
if (step->capture_id == id) {
|
||||
step->capture_id = NONE;
|
||||
}
|
||||
query_step__remove_capture(step, id);
|
||||
}
|
||||
ts_query__finalize_steps(self);
|
||||
}
|
||||
|
|
@ -1280,11 +1302,13 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
|
|||
|
||||
// If the current node is captured in this pattern, add it to the
|
||||
// capture list.
|
||||
if (step->capture_id != NONE) {
|
||||
for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) {
|
||||
uint16_t capture_id = step->capture_ids[j];
|
||||
if (step->capture_ids[j] == NONE) break;
|
||||
LOG(
|
||||
" capture node. pattern:%u, capture_id:%u\n",
|
||||
next_state->pattern_index,
|
||||
step->capture_id
|
||||
capture_id
|
||||
);
|
||||
TSQueryCapture *capture_list = capture_list_pool_get(
|
||||
&self->capture_list_pool,
|
||||
|
|
@ -1292,7 +1316,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
|
|||
);
|
||||
capture_list[next_state->capture_count++] = (TSQueryCapture) {
|
||||
node,
|
||||
step->capture_id
|
||||
capture_id
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue