Make queries work in languages with simple aliases

This commit is contained in:
Max Brunsfeld 2019-09-18 17:35:47 -07:00
parent 9ce3a53427
commit ff9a2c1f53
2 changed files with 60 additions and 1 deletions

View file

@ -321,6 +321,41 @@ fn test_query_matches_with_many() {
});
}
#[test]
fn test_query_matches_in_language_with_simple_aliases() {
allocations::record(|| {
let language = get_language("html");
// HTML uses different tokens to track start tags names, end
// tag names, script tag names, and style tag names. All of
// these tokens are aliased to `tag_name`.
let query = Query::new(language, "(tag_name) @tag").unwrap();
let source = "
<div>
<script>hi</script>
<style>hi</style>
</div>";
let mut parser = Parser::new();
parser.set_language(language).unwrap();
let tree = parser.parse(&source, None).unwrap();
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, tree.root_node(), to_callback(&source));
assert_eq!(
collect_matches(matches, &query, source),
&[
(0, vec![("tag", "div")]),
(0, vec![("tag", "script")]),
(0, vec![("tag", "script")]),
(0, vec![("tag", "style")]),
(0, vec![("tag", "style")]),
(0, vec![("tag", "div")]),
],
);
});
}
#[test]
fn test_query_matches_with_too_many_permutations_to_track() {
allocations::record(|| {

View file

@ -113,6 +113,7 @@ struct TSQuery {
const TSLanguage *language;
uint16_t max_capture_count;
uint16_t wildcard_root_pattern_count;
TSSymbol *symbol_map;
};
/*
@ -722,6 +723,27 @@ TSQuery *ts_query_new(
uint32_t *error_offset,
TSQueryError *error_type
) {
// Work around the fact that multiple symbols can currently be
// associated with the same name, due to "simple aliases".
// In the next language ABI version, this map should be contained
// within the language itself.
uint32_t symbol_count = ts_language_symbol_count(language);
TSSymbol *symbol_map = ts_malloc(sizeof(TSSymbol) * symbol_count);
for (unsigned i = 0; i < symbol_count; i++) {
const char *name = ts_language_symbol_name(language, i);
const char symbol_type = ts_language_symbol_type(language, i);
symbol_map[i] = i;
for (unsigned j = 0; j < i; j++) {
if (ts_language_symbol_type(language, j) == symbol_type) {
if (!strcmp(name, ts_language_symbol_name(language, j))) {
symbol_map[i] = j;
break;
}
}
}
}
TSQuery *self = ts_malloc(sizeof(TSQuery));
*self = (TSQuery) {
.steps = array_new(),
@ -730,6 +752,7 @@ TSQuery *ts_query_new(
.predicate_values = symbol_table_new(),
.predicate_steps = array_new(),
.predicates_by_pattern = array_new(),
.symbol_map = symbol_map,
.wildcard_root_pattern_count = 0,
.max_capture_count = 0,
.language = language,
@ -790,6 +813,7 @@ void ts_query_delete(TSQuery *self) {
array_delete(&self->start_bytes_by_pattern);
symbol_table_delete(&self->captures);
symbol_table_delete(&self->predicate_values);
ts_free(self->symbol_map);
ts_free(self);
}
}
@ -981,7 +1005,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
&can_have_later_siblings_with_this_field
);
TSNode node = ts_tree_cursor_current_node(&self->cursor);
TSSymbol symbol = ts_node_symbol(node);
TSSymbol symbol = self->query->symbol_map[ts_node_symbol(node)];
// If this node is before the selected range, then avoid descending
// into it.