query: Allow multiple captures on a single node
This commit is contained in:
parent
631710bada
commit
de8e3ee188
2 changed files with 110 additions and 33 deletions
|
|
@ -591,7 +591,60 @@ fn test_query_matches_different_queries_same_cursor() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_query_captures() {
|
fn test_query_matches_with_multiple_captures_on_a_node() {
|
||||||
|
allocations::record(|| {
|
||||||
|
let language = get_language("javascript");
|
||||||
|
let mut query = Query::new(
|
||||||
|
language,
|
||||||
|
"(function_declaration
|
||||||
|
(identifier) @name1 @name2 @name3
|
||||||
|
(statement_block) @body1 @body2)",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let source = "function foo() { return 1; }";
|
||||||
|
let mut parser = Parser::new();
|
||||||
|
let mut cursor = QueryCursor::new();
|
||||||
|
|
||||||
|
parser.set_language(language).unwrap();
|
||||||
|
let tree = parser.parse(&source, None).unwrap();
|
||||||
|
|
||||||
|
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
|
||||||
|
assert_eq!(
|
||||||
|
collect_matches(matches, &query, source),
|
||||||
|
&[(
|
||||||
|
0,
|
||||||
|
vec![
|
||||||
|
("name1", "foo"),
|
||||||
|
("name2", "foo"),
|
||||||
|
("name3", "foo"),
|
||||||
|
("body1", "{ return 1; }"),
|
||||||
|
("body2", "{ return 1; }"),
|
||||||
|
]
|
||||||
|
),]
|
||||||
|
);
|
||||||
|
|
||||||
|
// disabling captures still works when there are multiple captures on a
|
||||||
|
// single node.
|
||||||
|
query.disable_capture("name2");
|
||||||
|
let matches = cursor.matches(&query, tree.root_node(), to_callback(source));
|
||||||
|
assert_eq!(
|
||||||
|
collect_matches(matches, &query, source),
|
||||||
|
&[(
|
||||||
|
0,
|
||||||
|
vec![
|
||||||
|
("name1", "foo"),
|
||||||
|
("name3", "foo"),
|
||||||
|
("body1", "{ return 1; }"),
|
||||||
|
("body2", "{ return 1; }"),
|
||||||
|
]
|
||||||
|
),]
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_query_captures_basic() {
|
||||||
allocations::record(|| {
|
allocations::record(|| {
|
||||||
let language = get_language("javascript");
|
let language = get_language("javascript");
|
||||||
let query = Query::new(
|
let query = Query::new(
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ typedef struct {
|
||||||
uint8_t next_size;
|
uint8_t next_size;
|
||||||
} Stream;
|
} Stream;
|
||||||
|
|
||||||
|
#define MAX_STEP_CAPTURE_COUNT 4
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* QueryStep - A step in the process of matching a query. Each node within
|
* QueryStep - A step in the process of matching a query. Each node within
|
||||||
* a query S-expression maps to one of these steps. An entire pattern is
|
* a query S-expression maps to one of these steps. An entire pattern is
|
||||||
|
|
@ -37,7 +39,7 @@ typedef struct {
|
||||||
typedef struct {
|
typedef struct {
|
||||||
TSSymbol symbol;
|
TSSymbol symbol;
|
||||||
TSFieldId field;
|
TSFieldId field;
|
||||||
uint16_t capture_id;
|
uint16_t capture_ids[MAX_STEP_CAPTURE_COUNT];
|
||||||
uint16_t depth: 15;
|
uint16_t depth: 15;
|
||||||
bool contains_captures: 1;
|
bool contains_captures: 1;
|
||||||
} QueryStep;
|
} QueryStep;
|
||||||
|
|
@ -326,6 +328,44 @@ static uint16_t symbol_table_insert_name(
|
||||||
return self->slices.size - 1;
|
return self->slices.size - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/************
|
||||||
|
* QueryStep
|
||||||
|
************/
|
||||||
|
|
||||||
|
static QueryStep query_step__new(TSSymbol symbol, uint16_t depth) {
|
||||||
|
return (QueryStep) {
|
||||||
|
.symbol = symbol,
|
||||||
|
.depth = depth,
|
||||||
|
.field = 0,
|
||||||
|
.capture_ids = {NONE, NONE, NONE, NONE},
|
||||||
|
.contains_captures = false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static void query_step__add_capture(QueryStep *self, uint16_t capture_id) {
|
||||||
|
for (unsigned i = 0; i < MAX_STEP_CAPTURE_COUNT; i++) {
|
||||||
|
if (self->capture_ids[i] == NONE) {
|
||||||
|
self->capture_ids[i] = capture_id;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void query_step__remove_capture(QueryStep *self, uint16_t capture_id) {
|
||||||
|
for (unsigned i = 0; i < MAX_STEP_CAPTURE_COUNT; i++) {
|
||||||
|
if (self->capture_ids[i] == capture_id) {
|
||||||
|
self->capture_ids[i] = NONE;
|
||||||
|
while (i + 1 < MAX_STEP_CAPTURE_COUNT) {
|
||||||
|
if (self->capture_ids[i + 1] == NONE) break;
|
||||||
|
self->capture_ids[i] = self->capture_ids[i + 1];
|
||||||
|
self->capture_ids[i + 1] = NONE;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*********
|
/*********
|
||||||
* Query
|
* Query
|
||||||
*********/
|
*********/
|
||||||
|
|
@ -401,14 +441,14 @@ static void ts_query__finalize_steps(TSQuery *self) {
|
||||||
for (unsigned i = 0; i < self->steps.size; i++) {
|
for (unsigned i = 0; i < self->steps.size; i++) {
|
||||||
QueryStep *step = &self->steps.contents[i];
|
QueryStep *step = &self->steps.contents[i];
|
||||||
uint32_t depth = step->depth;
|
uint32_t depth = step->depth;
|
||||||
if (step->capture_id != NONE) {
|
if (step->capture_ids[0] != NONE) {
|
||||||
step->contains_captures = true;
|
step->contains_captures = true;
|
||||||
} else {
|
} else {
|
||||||
step->contains_captures = false;
|
step->contains_captures = false;
|
||||||
for (unsigned j = i + 1; j < self->steps.size; j++) {
|
for (unsigned j = i + 1; j < self->steps.size; j++) {
|
||||||
QueryStep *s = &self->steps.contents[j];
|
QueryStep *s = &self->steps.contents[j];
|
||||||
if (s->depth == PATTERN_DONE_MARKER || s->depth <= depth) break;
|
if (s->depth == PATTERN_DONE_MARKER || s->depth <= depth) break;
|
||||||
if (s->capture_id != NONE) step->contains_captures = true;
|
if (s->capture_ids[0] != NONE) step->contains_captures = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -599,13 +639,7 @@ static TSQueryError ts_query__parse_pattern(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a step for the node.
|
// Add a step for the node.
|
||||||
array_push(&self->steps, ((QueryStep) {
|
array_push(&self->steps, query_step__new(symbol, depth));
|
||||||
.depth = depth,
|
|
||||||
.symbol = symbol,
|
|
||||||
.field = 0,
|
|
||||||
.capture_id = NONE,
|
|
||||||
.contains_captures = false,
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Parse the child patterns
|
// Parse the child patterns
|
||||||
stream_skip_whitespace(stream);
|
stream_skip_whitespace(stream);
|
||||||
|
|
@ -645,13 +679,7 @@ static TSQueryError ts_query__parse_pattern(
|
||||||
stream_reset(stream, string_content);
|
stream_reset(stream, string_content);
|
||||||
return TSQueryErrorNodeType;
|
return TSQueryErrorNodeType;
|
||||||
}
|
}
|
||||||
array_push(&self->steps, ((QueryStep) {
|
array_push(&self->steps, query_step__new(symbol, depth));
|
||||||
.depth = depth,
|
|
||||||
.symbol = symbol,
|
|
||||||
.field = 0,
|
|
||||||
.capture_id = NONE,
|
|
||||||
.contains_captures = false,
|
|
||||||
}));
|
|
||||||
|
|
||||||
if (stream->next != '"') return TSQueryErrorSyntax;
|
if (stream->next != '"') return TSQueryErrorSyntax;
|
||||||
stream_advance(stream);
|
stream_advance(stream);
|
||||||
|
|
@ -697,12 +725,7 @@ static TSQueryError ts_query__parse_pattern(
|
||||||
stream_skip_whitespace(stream);
|
stream_skip_whitespace(stream);
|
||||||
|
|
||||||
// Add a step that matches any kind of node
|
// Add a step that matches any kind of node
|
||||||
array_push(&self->steps, ((QueryStep) {
|
array_push(&self->steps, query_step__new(WILDCARD_SYMBOL, depth));
|
||||||
.depth = depth,
|
|
||||||
.symbol = WILDCARD_SYMBOL,
|
|
||||||
.field = 0,
|
|
||||||
.contains_captures = false,
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
else {
|
else {
|
||||||
|
|
@ -712,7 +735,7 @@ static TSQueryError ts_query__parse_pattern(
|
||||||
stream_skip_whitespace(stream);
|
stream_skip_whitespace(stream);
|
||||||
|
|
||||||
// Parse an '@'-prefixed capture pattern
|
// Parse an '@'-prefixed capture pattern
|
||||||
if (stream->next == '@') {
|
while (stream->next == '@') {
|
||||||
stream_advance(stream);
|
stream_advance(stream);
|
||||||
|
|
||||||
// Parse the capture name
|
// Parse the capture name
|
||||||
|
|
@ -727,7 +750,8 @@ static TSQueryError ts_query__parse_pattern(
|
||||||
capture_name,
|
capture_name,
|
||||||
length
|
length
|
||||||
);
|
);
|
||||||
self->steps.contents[starting_step_index].capture_id = capture_id;
|
QueryStep *step = &self->steps.contents[starting_step_index];
|
||||||
|
query_step__add_capture(step, capture_id);
|
||||||
(*capture_count)++;
|
(*capture_count)++;
|
||||||
|
|
||||||
stream_skip_whitespace(stream);
|
stream_skip_whitespace(stream);
|
||||||
|
|
@ -797,7 +821,7 @@ TSQuery *ts_query_new(
|
||||||
.length = 0,
|
.length = 0,
|
||||||
}));
|
}));
|
||||||
*error_type = ts_query__parse_pattern(self, &stream, 0, &capture_count);
|
*error_type = ts_query__parse_pattern(self, &stream, 0, &capture_count);
|
||||||
array_push(&self->steps, ((QueryStep) { .depth = PATTERN_DONE_MARKER }));
|
array_push(&self->steps, query_step__new(0, PATTERN_DONE_MARKER));
|
||||||
|
|
||||||
// If any pattern could not be parsed, then report the error information
|
// If any pattern could not be parsed, then report the error information
|
||||||
// and terminate.
|
// and terminate.
|
||||||
|
|
@ -899,9 +923,7 @@ void ts_query_disable_capture(
|
||||||
if (id != -1) {
|
if (id != -1) {
|
||||||
for (unsigned i = 0; i < self->steps.size; i++) {
|
for (unsigned i = 0; i < self->steps.size; i++) {
|
||||||
QueryStep *step = &self->steps.contents[i];
|
QueryStep *step = &self->steps.contents[i];
|
||||||
if (step->capture_id == id) {
|
query_step__remove_capture(step, id);
|
||||||
step->capture_id = NONE;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ts_query__finalize_steps(self);
|
ts_query__finalize_steps(self);
|
||||||
}
|
}
|
||||||
|
|
@ -1280,11 +1302,13 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
|
||||||
|
|
||||||
// If the current node is captured in this pattern, add it to the
|
// If the current node is captured in this pattern, add it to the
|
||||||
// capture list.
|
// capture list.
|
||||||
if (step->capture_id != NONE) {
|
for (unsigned j = 0; j < MAX_STEP_CAPTURE_COUNT; j++) {
|
||||||
|
uint16_t capture_id = step->capture_ids[j];
|
||||||
|
if (step->capture_ids[j] == NONE) break;
|
||||||
LOG(
|
LOG(
|
||||||
" capture node. pattern:%u, capture_id:%u\n",
|
" capture node. pattern:%u, capture_id:%u\n",
|
||||||
next_state->pattern_index,
|
next_state->pattern_index,
|
||||||
step->capture_id
|
capture_id
|
||||||
);
|
);
|
||||||
TSQueryCapture *capture_list = capture_list_pool_get(
|
TSQueryCapture *capture_list = capture_list_pool_get(
|
||||||
&self->capture_list_pool,
|
&self->capture_list_pool,
|
||||||
|
|
@ -1292,7 +1316,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
|
||||||
);
|
);
|
||||||
capture_list[next_state->capture_count++] = (TSQueryCapture) {
|
capture_list[next_state->capture_count++] = (TSQueryCapture) {
|
||||||
node,
|
node,
|
||||||
step->capture_id
|
capture_id
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue