Merge pull request #492 from tree-sitter/query-fixes

Fix handling of tricky patterns in tree queries
This commit is contained in:
Max Brunsfeld 2019-11-22 13:05:20 -08:00 committed by GitHub
commit a23e8b3dcb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 222 additions and 72 deletions

View file

@ -913,6 +913,97 @@ fn test_query_captures_with_many_nested_results_with_fields() {
});
}
#[test]
fn test_query_captures_with_too_many_nested_results() {
allocations::record(|| {
let language = get_language("javascript");
// Search for method calls in general, and also method calls with a template string
// in place of an argument list (aka "tagged template strings") in particular.
//
// This second pattern, which looks for the tagged template strings, is expensive to
// use with the `captures()` method, because:
// 1. When calling `captures`, all of the captures must be returned in order of their
// appearance.
// 2. This pattern captures the root `call_expression`.
// 3. This pattern's result also depends on the final child (the template string).
// 4. In between the `call_expression` and the possible `template_string`, there can
// be an arbitrarily deep subtree.
//
// This means that, if any patterns match *after* the initial `call_expression` is
// captured, but before the final `template_string` is found, those matches must
// be buffered, in order to prevent captures from being returned out-of-order.
let query = Query::new(
language,
r#"
;; easy 👇
(call_expression
function: (member_expression
property: (property_identifier) @method-name))
;; hard 👇
(call_expression
function: (member_expression
property: (property_identifier) @template-tag)
arguments: (template_string)) @template-call
"#,
)
.unwrap();
// There are a *lot* of matches in between the beginning of the outer `call_expression`
// (the call to `a(...).f`), which starts at the beginning of the file, and the final
// template string, which occurs at the end of the file. The query algorithm imposes a
// limit on the total number of matches which can be buffered at a time. But we don't
// want to neglect the inner matches just because of the expensive outer match, so we
// abandon the outer match (which would have captured `f` as a `template-tag`).
let source = "
a(b => {
b.c0().d0 `😄`;
b.c1().d1 `😄`;
b.c2().d2 `😄`;
b.c3().d3 `😄`;
b.c4().d4 `😄`;
b.c5().d5 `😄`;
b.c6().d6 `😄`;
b.c7().d7 `😄`;
b.c8().d8 `😄`;
b.c9().d9 `😄`;
}).e().f ``;
"
.trim();
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let captures = cursor.captures(&query, tree.root_node(), to_callback(&source));
let captures = collect_captures(captures, &query, &source);
assert_eq!(
&captures[0..4],
&[
("template-call", "b.c0().d0 `😄`"),
("method-name", "c0"),
("method-name", "d0"),
("template-tag", "d0"),
]
);
assert_eq!(
&captures[36..40],
&[
("template-call", "b.c9().d9 `😄`"),
("method-name", "c9"),
("method-name", "d9"),
("template-tag", "d9"),
]
);
assert_eq!(
&captures[40..],
&[("method-name", "e"), ("method-name", "f"),]
);
});
}
#[test]
fn test_query_captures_ordered_by_both_start_and_end_positions() {
allocations::record(|| {

View file

@ -140,7 +140,7 @@ static const uint16_t NONE = UINT16_MAX;
static const TSSymbol WILDCARD_SYMBOL = 0;
static const uint16_t MAX_STATE_COUNT = 32;
// #define LOG printf
// #define LOG(...) fprintf(stderr, __VA_ARGS__)
#define LOG(...)
/**********
@ -244,6 +244,10 @@ static TSQueryCapture *capture_list_pool_get(CaptureListPool *self, uint16_t id)
return &self->list.contents[id * (self->list.size / MAX_STATE_COUNT)];
}
static bool capture_list_pool_is_empty(const CaptureListPool *self) {
return self->usage_map == 0;
}
static uint16_t capture_list_pool_acquire(CaptureListPool *self) {
// In the usage_map bitmask, ones represent free lists, and zeros represent
// lists that are in use. A free list id can quickly be found by counting
@ -412,7 +416,7 @@ static void ts_query__finalize_steps(TSQuery *self) {
// a higher level of abstraction, such as the Rust/JavaScript bindings. They
// can contain '@'-prefixed capture names, double-quoted strings, and bare
// symbols, which also represent strings.
static TSQueryError ts_query_parse_predicate(
static TSQueryError ts_query__parse_predicate(
TSQuery *self,
Stream *stream
) {
@ -523,7 +527,7 @@ static TSQueryError ts_query_parse_predicate(
// Read one S-expression pattern from the stream, and incorporate it into
// the query's internal state machine representation. For nested patterns,
// this function calls itself recursively.
static TSQueryError ts_query_parse_pattern(
static TSQueryError ts_query__parse_pattern(
TSQuery *self,
Stream *stream,
uint32_t depth,
@ -546,13 +550,13 @@ static TSQueryError ts_query_parse_pattern(
// Parse a nested list, which represents a pattern followed by
// zero-or-more predicates.
if (stream->next == '(' && depth == 0) {
TSQueryError e = ts_query_parse_pattern(self, stream, 0, capture_count);
TSQueryError e = ts_query__parse_pattern(self, stream, 0, capture_count);
if (e) return e;
// Parse the predicates.
stream_skip_whitespace(stream);
for (;;) {
TSQueryError e = ts_query_parse_predicate(self, stream);
TSQueryError e = ts_query__parse_predicate(self, stream);
if (e == PARENT_DONE) {
stream_advance(stream);
stream_skip_whitespace(stream);
@ -602,7 +606,7 @@ static TSQueryError ts_query_parse_pattern(
// Parse the child patterns
stream_skip_whitespace(stream);
for (;;) {
TSQueryError e = ts_query_parse_pattern(self, stream, depth + 1, capture_count);
TSQueryError e = ts_query__parse_pattern(self, stream, depth + 1, capture_count);
if (e == PARENT_DONE) {
stream_advance(stream);
break;
@ -666,7 +670,7 @@ static TSQueryError ts_query_parse_pattern(
// Parse the pattern
uint32_t step_index = self->steps.size;
TSQueryError e = ts_query_parse_pattern(self, stream, depth, capture_count);
TSQueryError e = ts_query__parse_pattern(self, stream, depth, capture_count);
if (e == PARENT_DONE) return TSQueryErrorSyntax;
if (e) return e;
@ -782,7 +786,7 @@ TSQuery *ts_query_new(
.offset = self->predicate_steps.size,
.length = 0,
}));
*error_type = ts_query_parse_pattern(self, &stream, 0, &capture_count);
*error_type = ts_query__parse_pattern(self, &stream, 0, &capture_count);
array_push(&self->steps, ((QueryStep) { .depth = PATTERN_DONE_MARKER }));
// If any pattern could not be parsed, then report the error information
@ -961,7 +965,83 @@ void ts_query_cursor_set_point_range(
self->end_point = end_point;
}
static QueryState *ts_query_cursor_copy_state(
// Search through all of the in-progress states, and find the captured
// node that occurs earliest in the document.
static bool ts_query_cursor__first_in_progress_capture(
TSQueryCursor *self,
uint32_t *state_index,
uint32_t *byte_offset,
uint32_t *pattern_index
) {
bool result = false;
for (unsigned i = 0; i < self->states.size; i++) {
const QueryState *state = &self->states.contents[i];
if (state->capture_count > 0) {
const TSQueryCapture *captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
uint32_t capture_byte = ts_node_start_byte(captures[0].node);
if (
!result ||
capture_byte < *byte_offset ||
(
capture_byte == *byte_offset &&
state->pattern_index < *pattern_index
)
) {
result = true;
*state_index = i;
*byte_offset = capture_byte;
*pattern_index = state->pattern_index;
}
}
}
return result;
}
static bool ts_query__cursor_add_state(
TSQueryCursor *self,
const PatternEntry *slice
) {
uint32_t list_id = capture_list_pool_acquire(&self->capture_list_pool);
// If there are no capture lists left in the pool, then terminate whichever
// state has captured the earliest node in the document, and steal its
// capture list.
if (list_id == NONE) {
uint32_t state_index, byte_offset, pattern_index;
if (ts_query_cursor__first_in_progress_capture(
self,
&state_index,
&byte_offset,
&pattern_index
)) {
LOG(
" abandon state. index:%u, pattern:%u, offset:%u.\n",
state_index, pattern_index, byte_offset
);
list_id = self->states.contents[state_index].capture_list_id;
array_erase(&self->states, state_index);
} else {
LOG(" too many finished states.\n");
return false;
}
}
LOG(" start state. pattern:%u\n", slice->pattern_index);
array_push(&self->states, ((QueryState) {
.capture_list_id = list_id,
.step_index = slice->step_index,
.pattern_index = slice->pattern_index,
.start_depth = self->depth,
.capture_count = 0,
.consumed_capture_count = 0,
}));
return true;
}
static QueryState *ts_query__cursor_copy_state(
TSQueryCursor *self,
const QueryState *state
) {
@ -989,7 +1069,7 @@ static QueryState *ts_query_cursor_copy_state(
static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
do {
if (self->ascending) {
LOG("leave node %s\n", ts_node_type(ts_tree_cursor_current_node(&self->cursor)));
LOG("leave node. type:%s\n", ts_node_type(ts_tree_cursor_current_node(&self->cursor)));
// When leaving a node, remove any unfinished states whose next step
// needed to match something within that node.
@ -1057,11 +1137,14 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
) return false;
LOG(
"enter node %s. row:%u state_count:%u, finished_state_count: %u\n",
"enter node. type:%s, field:%s, row:%u state_count:%u, finished_state_count:%u, can_have_later_siblings:%d, can_have_later_siblings_with_this_field:%d\n",
ts_node_type(node),
ts_language_field_name_for_id(self->query->language, field_id),
ts_node_start_point(node).row,
self->states.size,
self->finished_states.size
self->finished_states.size,
can_have_later_siblings,
can_have_later_siblings_with_this_field
);
// Add new states for any patterns whose root node is a wildcard.
@ -1072,17 +1155,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
// If this node matches the first step of the pattern, then add a new
// state at the start of this pattern.
if (step->field && field_id != step->field) continue;
uint32_t capture_list_id = capture_list_pool_acquire(
&self->capture_list_pool
);
if (capture_list_id == NONE) break;
array_push(&self->states, ((QueryState) {
.step_index = slice->step_index,
.pattern_index = slice->pattern_index,
.capture_list_id = capture_list_id,
.capture_count = 0,
.consumed_capture_count = 0,
}));
if (!ts_query__cursor_add_state(self, slice)) break;
}
// Add new states for any patterns whose root node matches this node.
@ -1091,29 +1164,10 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
PatternEntry *slice = &self->query->pattern_map.contents[i];
QueryStep *step = &self->query->steps.contents[slice->step_index];
do {
// If this node matches the first step of the pattern, then add a new
// state at the start of this pattern.
if (step->field && field_id != step->field) continue;
LOG(" start state. pattern:%u\n", slice->pattern_index);
// If this node matches the first step of the pattern, then add a
// new in-progress state. First, acquire a list to hold the pattern's
// captures.
uint32_t capture_list_id = capture_list_pool_acquire(
&self->capture_list_pool
);
if (capture_list_id == NONE) {
LOG(" too many states.");
break;
}
array_push(&self->states, ((QueryState) {
.pattern_index = slice->pattern_index,
.step_index = slice->step_index,
.start_depth = self->depth,
.capture_list_id = capture_list_id,
.capture_count = 0,
.consumed_capture_count = 0,
}));
if (!ts_query__cursor_add_state(self, slice)) break;
// Advance to the next pattern whose root node matches this node.
i++;
@ -1178,13 +1232,17 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
step->contains_captures &&
later_sibling_can_match
) {
LOG(
" split state. pattern:%u, step:%u\n",
state->pattern_index,
state->step_index
);
QueryState *copy = ts_query_cursor_copy_state(self, state);
if (copy) next_state = copy;
QueryState *copy = ts_query__cursor_copy_state(self, state);
if (copy) {
LOG(
" split state. pattern:%u, step:%u\n",
copy->pattern_index,
copy->step_index
);
next_state = copy;
} else {
LOG(" canot split state.\n");
}
}
LOG(
@ -1298,26 +1356,13 @@ bool ts_query_cursor_next_capture(
// this position.
uint32_t first_unfinished_capture_byte = UINT32_MAX;
uint32_t first_unfinished_pattern_index = UINT32_MAX;
for (unsigned i = 0; i < self->states.size; i++) {
const QueryState *state = &self->states.contents[i];
if (state->capture_count > 0) {
const TSQueryCapture *captures = capture_list_pool_get(
&self->capture_list_pool,
state->capture_list_id
);
uint32_t capture_byte = ts_node_start_byte(captures[0].node);
if (
capture_byte < first_unfinished_capture_byte ||
(
capture_byte == first_unfinished_capture_byte &&
state->pattern_index < first_unfinished_pattern_index
)
) {
first_unfinished_capture_byte = capture_byte;
first_unfinished_pattern_index = state->pattern_index;
}
}
}
uint32_t first_unfinished_state_index;
ts_query_cursor__first_in_progress_capture(
self,
&first_unfinished_state_index,
&first_unfinished_capture_byte,
&first_unfinished_pattern_index
);
// Find the earliest capture in a finished match.
int first_finished_state_index = -1;
@ -1372,6 +1417,20 @@ bool ts_query_cursor_next_capture(
state->consumed_capture_count++;
return true;
}
if (capture_list_pool_is_empty(&self->capture_list_pool)) {
LOG(
" abandon state. index:%u, pattern:%u, offset:%u.\n",
first_unfinished_state_index,
first_unfinished_pattern_index,
first_unfinished_capture_byte
);
capture_list_pool_release(
&self->capture_list_pool,
self->states.contents[first_unfinished_state_index].capture_list_id
);
array_erase(&self->states, first_unfinished_state_index);
}
}
// If there are no finished matches that are ready to be returned, then