Fix management of capture lists in query execution

This commit is contained in:
Max Brunsfeld 2019-09-11 12:06:38 -07:00
parent 60467ae701
commit 4fa0b02d67
3 changed files with 322 additions and 210 deletions

View file

@ -107,6 +107,12 @@ static const uint16_t NONE = UINT16_MAX;
static const TSSymbol WILDCARD_SYMBOL = 0;
static const uint16_t MAX_STATE_COUNT = 32;
#ifdef DEBUG_QUERY
#define LOG printf
#else
#define LOG(...)
#endif
/**********
* Stream
**********/
@ -183,15 +189,23 @@ static TSQueryCapture *capture_list_pool_get(CaptureListPool *self, uint16_t id)
return &self->contents[id * self->list_size];
}
static inline uint32_t capture_list_bitmask_for_id(uint16_t id) {
// An id of zero corresponds to the highest-order bit in the bitmask.
return (1u << (31 - id));
}
static uint16_t capture_list_pool_acquire(CaptureListPool *self) {
// In the usage_map bitmask, ones represent free lists, and zeros represent
// lists that are in use. A free list can quickly be found by counting
// the leading zeros in the usage map.
uint16_t id = count_leading_zeros(self->usage_map);
if (id == 32) return NONE;
self->usage_map &= ~(1 << id);
self->usage_map &= ~capture_list_bitmask_for_id(id);
return id;
}
static void capture_list_pool_release(CaptureListPool *self, uint16_t id) {
self->usage_map |= (1 << id);
self->usage_map |= capture_list_bitmask_for_id(id);
}
/*********
@ -586,9 +600,31 @@ void ts_query_context_exec(TSQueryContext *self, TSNode node) {
self->ascending = false;
}
static QueryState *ts_query_context_copy_state(
TSQueryContext *self,
QueryState *state
) {
uint32_t capture_list_id = capture_list_pool_acquire(&self->capture_list_pool);
if (capture_list_id == NONE) return NULL;
array_push(&self->states, *state);
QueryState *new_state = array_back(&self->states);
new_state->capture_list_id = capture_list_id;
TSQueryCapture *old_captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
TSQueryCapture *new_captures = capture_list_pool_get(
&self->capture_list_pool,
capture_list_id
);
memcpy(new_captures, old_captures, state->capture_count * sizeof(TSQueryCapture));
return new_state;
}
bool ts_query_context_next(TSQueryContext *self) {
if (self->finished_states.size > 0) {
array_pop(&self->finished_states);
QueryState state = array_pop(&self->finished_states);
capture_list_pool_release(&self->capture_list_pool, state.capture_list_id);
}
while (self->finished_states.size == 0) {
@ -598,9 +634,14 @@ bool ts_query_context_next(TSQueryContext *self) {
uint32_t deleted_count = 0;
for (unsigned i = 0, n = self->states.size; i < n; i++) {
QueryState *state = &self->states.contents[i];
if (state->start_depth == self->depth) {
QueryStep *step = &self->query->steps.contents[state->step_index];
// printf("FAIL STATE pattern: %u, step: %u\n", state->pattern_index, state->step_index);
if (state->start_depth + step->depth > self->depth) {
LOG(
"fail state with pattern: %u, step: %u\n",
state->pattern_index,
state->step_index
);
capture_list_pool_release(
&self->capture_list_pool,
@ -612,9 +653,9 @@ bool ts_query_context_next(TSQueryContext *self) {
}
}
// if (deleted_count) {
// printf("FAILED %u of %u states\n", deleted_count, self->states.size);
// }
if (deleted_count) {
LOG("failed %u of %u states\n", deleted_count, self->states.size);
}
self->states.size -= deleted_count;
@ -631,7 +672,7 @@ bool ts_query_context_next(TSQueryContext *self) {
TSNode node = ts_tree_cursor_current_node(&self->cursor);
TSSymbol symbol = ts_node_symbol(node);
// printf("DESCEND INTO NODE: %s\n", ts_node_type(node));
LOG("enter node %s\n", ts_node_type(node));
// Add new states for any patterns whose root node is a wildcard.
for (unsigned i = 0; i < self->query->wildcard_root_pattern_count; i++) {
@ -678,7 +719,7 @@ bool ts_query_context_next(TSQueryContext *self) {
if (field_id != step->field) continue;
}
// printf("START NEW STATE: %u\n", slice->pattern_index);
LOG("start pattern %u\n", slice->pattern_index);
// If the node matches the first step of the pattern, then add
// a new in-progress state. First, acquire a list to hold the
@ -733,19 +774,15 @@ bool ts_query_context_next(TSQueryContext *self) {
// siblings.
QueryState *next_state = state;
if (step->depth > 0 && (!step->field || field_occurs_in_later_sibling)) {
uint32_t capture_list_id = capture_list_pool_acquire(
&self->capture_list_pool
);
if (capture_list_id != NONE) {
array_push(&self->states, *state);
next_state = array_back(&self->states);
next_state->capture_list_id = capture_list_id;
}
QueryState *copy = ts_query_context_copy_state(self, state);
if (copy) next_state = copy;
}
LOG("advance state for pattern %u\n", next_state->pattern_index);
// Record captures
if (step->capture_id != NONE) {
// printf("CAPTURE id: %u\n", step->capture_id);
LOG("capture id %u\n", step->capture_id);
TSQueryCapture *capture_list = capture_list_pool_get(
&self->capture_list_pool,
@ -762,7 +799,7 @@ bool ts_query_context_next(TSQueryContext *self) {
next_state->step_index++;
QueryStep *next_step = step + 1;
if (next_step->depth == PATTERN_DONE_MARKER) {
// printf("FINISHED MATCH pattern: %u\n", next_state->pattern_index);
LOG("finish pattern %u\n", next_state->pattern_index);
array_push(&self->finished_states, *next_state);
if (next_state == state) {
@ -808,3 +845,5 @@ const TSQueryCapture *ts_query_context_matched_captures(
}
return NULL;
}
#undef LOG