From ff9a2c1f5352dd241ea60de92a223ae7c72fb63a Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 18 Sep 2019 17:35:47 -0700 Subject: [PATCH] Make queries work in languages with simple aliases --- cli/src/tests/query_test.rs | 35 +++++++++++++++++++++++++++++++++++ lib/src/query.c | 26 +++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 11abc028..9344598d 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -321,6 +321,41 @@ fn test_query_matches_with_many() { }); } +#[test] +fn test_query_matches_in_language_with_simple_aliases() { + allocations::record(|| { + let language = get_language("html"); + + // HTML uses different tokens to track start tags names, end + // tag names, script tag names, and style tag names. All of + // these tokens are aliased to `tag_name`. + let query = Query::new(language, "(tag_name) @tag").unwrap(); + let source = " +
+ + +
"; + + let mut parser = Parser::new(); + parser.set_language(language).unwrap(); + let tree = parser.parse(&source, None).unwrap(); + let mut cursor = QueryCursor::new(); + let matches = cursor.matches(&query, tree.root_node(), to_callback(&source)); + + assert_eq!( + collect_matches(matches, &query, source), + &[ + (0, vec![("tag", "div")]), + (0, vec![("tag", "script")]), + (0, vec![("tag", "script")]), + (0, vec![("tag", "style")]), + (0, vec![("tag", "style")]), + (0, vec![("tag", "div")]), + ], + ); + }); +} + #[test] fn test_query_matches_with_too_many_permutations_to_track() { allocations::record(|| { diff --git a/lib/src/query.c b/lib/src/query.c index 19d6adbc..6b3c7ddf 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -113,6 +113,7 @@ struct TSQuery { const TSLanguage *language; uint16_t max_capture_count; uint16_t wildcard_root_pattern_count; + TSSymbol *symbol_map; }; /* @@ -722,6 +723,27 @@ TSQuery *ts_query_new( uint32_t *error_offset, TSQueryError *error_type ) { + // Work around the fact that multiple symbols can currently be + // associated with the same name, due to "simple aliases". + // In the next language ABI version, this map should be contained + // within the language itself. + uint32_t symbol_count = ts_language_symbol_count(language); + TSSymbol *symbol_map = ts_malloc(sizeof(TSSymbol) * symbol_count); + for (unsigned i = 0; i < symbol_count; i++) { + const char *name = ts_language_symbol_name(language, i); + const char symbol_type = ts_language_symbol_type(language, i); + + symbol_map[i] = i; + for (unsigned j = 0; j < i; j++) { + if (ts_language_symbol_type(language, j) == symbol_type) { + if (!strcmp(name, ts_language_symbol_name(language, j))) { + symbol_map[i] = j; + break; + } + } + } + } + TSQuery *self = ts_malloc(sizeof(TSQuery)); *self = (TSQuery) { .steps = array_new(), @@ -730,6 +752,7 @@ TSQuery *ts_query_new( .predicate_values = symbol_table_new(), .predicate_steps = array_new(), .predicates_by_pattern = array_new(), + .symbol_map = symbol_map, .wildcard_root_pattern_count = 0, .max_capture_count = 0, .language = language, @@ -790,6 +813,7 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->start_bytes_by_pattern); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); + ts_free(self->symbol_map); ts_free(self); } } @@ -981,7 +1005,7 @@ static inline bool ts_query_cursor__advance(TSQueryCursor *self) { &can_have_later_siblings_with_this_field ); TSNode node = ts_tree_cursor_current_node(&self->cursor); - TSSymbol symbol = ts_node_symbol(node); + TSSymbol symbol = self->query->symbol_map[ts_node_symbol(node)]; // If this node is before the selected range, then avoid descending // into it.