diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs
index 11abc028..9344598d 100644
--- a/cli/src/tests/query_test.rs
+++ b/cli/src/tests/query_test.rs
@@ -321,6 +321,41 @@ fn test_query_matches_with_many() {
});
}
+#[test]
+fn test_query_matches_in_language_with_simple_aliases() {
+ allocations::record(|| {
+ let language = get_language("html");
+
+ // HTML uses different tokens to track start tags names, end
+ // tag names, script tag names, and style tag names. All of
+ // these tokens are aliased to `tag_name`.
+ let query = Query::new(language, "(tag_name) @tag").unwrap();
+ let source = "
+
+
+
+
";
+
+ let mut parser = Parser::new();
+ parser.set_language(language).unwrap();
+ let tree = parser.parse(&source, None).unwrap();
+ let mut cursor = QueryCursor::new();
+ let matches = cursor.matches(&query, tree.root_node(), to_callback(&source));
+
+ assert_eq!(
+ collect_matches(matches, &query, source),
+ &[
+ (0, vec![("tag", "div")]),
+ (0, vec![("tag", "script")]),
+ (0, vec![("tag", "script")]),
+ (0, vec![("tag", "style")]),
+ (0, vec![("tag", "style")]),
+ (0, vec![("tag", "div")]),
+ ],
+ );
+ });
+}
+
#[test]
fn test_query_matches_with_too_many_permutations_to_track() {
allocations::record(|| {
diff --git a/lib/src/query.c b/lib/src/query.c
index 19d6adbc..6b3c7ddf 100644
--- a/lib/src/query.c
+++ b/lib/src/query.c
@@ -113,6 +113,7 @@ struct TSQuery {
const TSLanguage *language;
uint16_t max_capture_count;
uint16_t wildcard_root_pattern_count;
+ TSSymbol *symbol_map;
};
/*
@@ -722,6 +723,27 @@ TSQuery *ts_query_new(
uint32_t *error_offset,
TSQueryError *error_type
) {
+ // Work around the fact that multiple symbols can currently be
+ // associated with the same name, due to "simple aliases".
+ // In the next language ABI version, this map should be contained
+ // within the language itself.
+ uint32_t symbol_count = ts_language_symbol_count(language);
+ TSSymbol *symbol_map = ts_malloc(sizeof(TSSymbol) * symbol_count);
+ for (unsigned i = 0; i < symbol_count; i++) {
+ const char *name = ts_language_symbol_name(language, i);
+ const char symbol_type = ts_language_symbol_type(language, i);
+
+ symbol_map[i] = i;
+ for (unsigned j = 0; j < i; j++) {
+ if (ts_language_symbol_type(language, j) == symbol_type) {
+ if (!strcmp(name, ts_language_symbol_name(language, j))) {
+ symbol_map[i] = j;
+ break;
+ }
+ }
+ }
+ }
+
TSQuery *self = ts_malloc(sizeof(TSQuery));
*self = (TSQuery) {
.steps = array_new(),
@@ -730,6 +752,7 @@ TSQuery *ts_query_new(
.predicate_values = symbol_table_new(),
.predicate_steps = array_new(),
.predicates_by_pattern = array_new(),
+ .symbol_map = symbol_map,
.wildcard_root_pattern_count = 0,
.max_capture_count = 0,
.language = language,
@@ -790,6 +813,7 @@ void ts_query_delete(TSQuery *self) {
array_delete(&self->start_bytes_by_pattern);
symbol_table_delete(&self->captures);
symbol_table_delete(&self->predicate_values);
+ ts_free(self->symbol_map);
ts_free(self);
}
}
@@ -981,7 +1005,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) {
&can_have_later_siblings_with_this_field
);
TSNode node = ts_tree_cursor_current_node(&self->cursor);
- TSSymbol symbol = ts_node_symbol(node);
+ TSSymbol symbol = self->query->symbol_map[ts_node_symbol(node)];
// If this node is before the selected range, then avoid descending
// into it.