diff --git a/cli/src/generate/build_tables/minimize_parse_table.rs b/cli/src/generate/build_tables/minimize_parse_table.rs index aa4801c8..d159a2c4 100644 --- a/cli/src/generate/build_tables/minimize_parse_table.rs +++ b/cli/src/generate/build_tables/minimize_parse_table.rs @@ -68,6 +68,7 @@ impl<'a> Minimizer<'a> { .. } => { if !self.simple_aliases.contains_key(&symbol) + && !self.syntax_grammar.supertype_symbols.contains(&symbol) && !aliased_symbols.contains(&symbol) && self.syntax_grammar.variables[symbol.index].kind != VariableType::Named diff --git a/cli/src/generate/render.rs b/cli/src/generate/render.rs index 2758eb58..f33539d6 100644 --- a/cli/src/generate/render.rs +++ b/cli/src/generate/render.rs @@ -460,6 +460,9 @@ impl Generator { VariableType::Hidden => { add_line!(self, ".visible = false,"); add_line!(self, ".named = true,"); + if self.syntax_grammar.supertype_symbols.contains(symbol) { + add_line!(self, ".supertype = true,"); + } } VariableType::Auxiliary => { add_line!(self, ".visible = false,"); diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index f3521bb5..900b7be1 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -701,7 +701,6 @@ fn test_query_matches_with_immediate_siblings() { (2, vec![("last-stmt", "g()")]), ], ); - }); } @@ -1395,6 +1394,45 @@ fn test_query_matches_with_anonymous_tokens() { }); } +#[test] +fn test_query_matches_with_supertypes() { + allocations::record(|| { + let language = get_language("python"); + let query = Query::new( + language, + r#" + ((_simple_statement) @before . (_simple_statement) @after) + + (assignment + left: (left_hand_side (identifier) @def)) + + (_primary_expression/identifier) @ref + "#, + ) + .unwrap(); + + assert_query_matches( + language, + &query, + " + a = b + print c + if d: print e.f; print g.h.i + ", + &[ + (1, vec![("def", "a")]), + (2, vec![("ref", "b")]), + (0, vec![("before", "a = b"), ("after", "print c")]), + (2, vec![("ref", "c")]), + (2, vec![("ref", "d")]), + (2, vec![("ref", "e")]), + (0, vec![("before", "print e.f"), ("after", "print g.h.i")]), + (2, vec![("ref", "g")]), + ], + ); + }); +} + #[test] fn test_query_matches_within_byte_range() { allocations::record(|| { diff --git a/lib/include/tree_sitter/parser.h b/lib/include/tree_sitter/parser.h index 84096132..c5a788ff 100644 --- a/lib/include/tree_sitter/parser.h +++ b/lib/include/tree_sitter/parser.h @@ -35,6 +35,7 @@ typedef uint16_t TSStateId; typedef struct { bool visible : 1; bool named : 1; + bool supertype: 1; } TSSymbolMetadata; typedef struct TSLexer TSLexer; diff --git a/lib/src/language.c b/lib/src/language.c index c00c49e3..9ccf2bc3 100644 --- a/lib/src/language.c +++ b/lib/src/language.c @@ -89,7 +89,7 @@ TSSymbol ts_language_symbol_for_name( uint32_t count = ts_language_symbol_count(self); for (TSSymbol i = 0; i < count; i++) { TSSymbolMetadata metadata = ts_language_symbol_metadata(self, i); - if (!metadata.visible || metadata.named != is_named) continue; + if ((!metadata.visible && !metadata.supertype) || metadata.named != is_named) continue; const char *symbol_name = self->symbol_names[i]; if (!strncmp(symbol_name, string, length) && !symbol_name[length]) { if (self->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING) { diff --git a/lib/src/query.c b/lib/src/query.c index 45aa3877..0ca03782 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -47,6 +47,7 @@ typedef struct { */ typedef struct { TSSymbol symbol; + TSSymbol supertype_symbol; TSFieldId field; uint16_t capture_ids[MAX_STEP_CAPTURE_COUNT]; uint16_t alternative_index; @@ -1626,14 +1627,9 @@ static TSQueryError ts_query__parse_pattern( else { TSSymbol symbol; - // Parse the wildcard symbol - if ( - stream->next == '_' || - - // TODO - remove. - // For temporary backward compatibility, handle '*' as a wildcard. - stream->next == '*' - ) { + // TODO - remove. + // For temporary backward compatibility, handle '*' as a wildcard. + if (stream->next == '*') { symbol = depth > 0 ? NAMED_WILDCARD_SYMBOL : WILDCARD_SYMBOL; stream_advance(stream); } @@ -1651,15 +1647,22 @@ static TSQueryError ts_query__parse_pattern( return ts_query__parse_predicate(self, stream); } - symbol = ts_language_symbol_for_name( - self->language, - node_name, - length, - true - ); - if (!symbol) { - stream_reset(stream, node_name); - return TSQueryErrorNodeType; + // Parse the wildcard symbol + else if (length == 1 && node_name[0] == '_') { + symbol = depth > 0 ? NAMED_WILDCARD_SYMBOL : WILDCARD_SYMBOL; + } + + else { + symbol = ts_language_symbol_for_name( + self->language, + node_name, + length, + true + ); + if (!symbol) { + stream_reset(stream, node_name); + return TSQueryErrorNodeType; + } } } else { return TSQueryErrorSyntax; @@ -1667,9 +1670,38 @@ static TSQueryError ts_query__parse_pattern( // Add a step for the node. array_push(&self->steps, query_step__new(symbol, depth, is_immediate)); + if (ts_language_symbol_metadata(self->language, symbol).supertype) { + QueryStep *step = array_back(&self->steps); + step->supertype_symbol = step->symbol; + step->symbol = WILDCARD_SYMBOL; + } + + stream_skip_whitespace(stream); + + if (stream->next == '/') { + stream_advance(stream); + if (!stream_is_ident_start(stream)) { + return TSQueryErrorSyntax; + } + + const char *node_name = stream->input; + stream_scan_identifier(stream); + uint32_t length = stream->input - node_name; + + QueryStep *step = array_back(&self->steps); + step->symbol = ts_language_symbol_for_name( + self->language, + node_name, + length, + true + ); + if (!step->symbol) { + stream_reset(stream, node_name); + return TSQueryErrorNodeType; + } + } // Parse the child patterns - stream_skip_whitespace(stream); bool child_is_immediate = false; uint16_t child_start_step_index = self->steps.size; for (;;) { @@ -2622,6 +2654,21 @@ static inline bool ts_query_cursor__advance( if (step->is_last_child && has_later_named_siblings) { node_does_match = false; } + if (step->supertype_symbol) { + bool has_supertype = ts_tree_cursor_has_supertype(&self->cursor, step->supertype_symbol); + + if (symbol == 1) { + LOG( + " has supertype %s: %d", + ts_language_symbol_name(self->query->language, step->supertype_symbol), + has_supertype + ); + } + + if (!has_supertype) { + node_does_match = false; + } + } if (step->field) { if (step->field == field_id) { if (!can_have_later_siblings_with_this_field) { diff --git a/lib/src/tree_cursor.c b/lib/src/tree_cursor.c index b193a754..8ef17960 100644 --- a/lib/src/tree_cursor.c +++ b/lib/src/tree_cursor.c @@ -352,6 +352,50 @@ TSFieldId ts_tree_cursor_current_status( return result; } +bool ts_tree_cursor_has_supertype( + const TSTreeCursor *_self, + TSSymbol supertype_symbol +) { + const TreeCursor *self = (const TreeCursor *)_self; + + // Walk up the tree, visiting the current node and its invisible ancestors, + // because fields can refer to nodes through invisible *wrapper* nodes, + for (unsigned i = self->stack.size - 1; i > 0; i--) { + TreeCursorEntry *entry = &self->stack.contents[i]; + TreeCursorEntry *parent_entry = &self->stack.contents[i - 1]; + + const TSSymbol *alias_sequence = ts_language_alias_sequence( + self->tree->language, + parent_entry->subtree->ptr->production_id + ); + + // If the subtree is visible, return its public-facing symbol. + // Otherwise, return zero. + #define subtree_visible_symbol(subtree, structural_child_index) \ + (( \ + !ts_subtree_extra(subtree) && \ + alias_sequence && \ + alias_sequence[structural_child_index] \ + ) ? \ + alias_sequence[structural_child_index] : \ + ts_subtree_visible(subtree) ? \ + ts_subtree_symbol(subtree) : \ + 0) \ + + // Stop walking up when a visible ancestor is found. + if ( + i != self->stack.size - 1 && + subtree_visible_symbol(*entry->subtree, entry->structural_child_index) + ) break; + + if (ts_subtree_symbol(*entry->subtree) == supertype_symbol) { + return true; + } + } + + return false; +} + TSFieldId ts_tree_cursor_current_field_id(const TSTreeCursor *_self) { const TreeCursor *self = (const TreeCursor *)_self; diff --git a/lib/src/tree_cursor.h b/lib/src/tree_cursor.h index 0bb486d7..7829e8b9 100644 --- a/lib/src/tree_cursor.h +++ b/lib/src/tree_cursor.h @@ -17,5 +17,6 @@ typedef struct { void ts_tree_cursor_init(TreeCursor *, TSNode); TSFieldId ts_tree_cursor_current_status(const TSTreeCursor *, bool *, bool *, bool *); +bool ts_tree_cursor_has_supertype(const TSTreeCursor *, TSSymbol); #endif // TREE_SITTER_TREE_CURSOR_H_