diff --git a/include/tree_sitter/parser.h b/include/tree_sitter/parser.h index 18df7722..f16a03da 100644 --- a/include/tree_sitter/parser.h +++ b/include/tree_sitter/parser.h @@ -71,6 +71,7 @@ typedef union { typedef struct TSLanguage { uint32_t version; uint32_t symbol_count; + uint32_t rename_symbol_count; uint32_t token_count; uint32_t external_token_count; const char **symbol_names; @@ -173,6 +174,7 @@ typedef struct TSLanguage { static TSLanguage language = { \ .version = LANGUAGE_VERSION, \ .symbol_count = SYMBOL_COUNT, \ + .rename_symbol_count = RENAME_SYMBOL_COUNT, \ .token_count = TOKEN_COUNT, \ .symbol_metadata = ts_symbol_metadata, \ .parse_table = (const unsigned short *)ts_parse_table, \ diff --git a/src/compiler/generate_code/c_code.cc b/src/compiler/generate_code/c_code.cc index 4b1b30f7..25f8934f 100644 --- a/src/compiler/generate_code/c_code.cc +++ b/src/compiler/generate_code/c_code.cc @@ -147,13 +147,16 @@ class CCodeGenerator { max_rename_sequence_length = rename_sequence.size(); } for (const string &name_replacement : rename_sequence) { - unique_replacement_names.insert(name_replacement); + if (!name_replacement.empty()) { + unique_replacement_names.insert(name_replacement); + } } } line("#define LANGUAGE_VERSION " + to_string(TREE_SITTER_LANGUAGE_VERSION)); line("#define STATE_COUNT " + to_string(parse_table.states.size())); line("#define SYMBOL_COUNT " + to_string(parse_table.symbols.size())); + line("#define RENAME_SYMBOL_COUNT " + to_string(unique_replacement_names.size())); line("#define TOKEN_COUNT " + to_string(token_count)); line("#define EXTERNAL_TOKEN_COUNT " + to_string(syntax_grammar.external_tokens.size())); line("#define MAX_RENAME_SEQUENCE_LENGTH " + to_string(max_rename_sequence_length)); @@ -229,7 +232,7 @@ class CCodeGenerator { } void add_symbol_metadata_list() { - line("static const TSSymbolMetadata ts_symbol_metadata[SYMBOL_COUNT] = {"); + line("static const TSSymbolMetadata ts_symbol_metadata[] = {"); indent([&]() { for (const auto &entry : parse_table.symbols) { const Symbol &symbol = entry.first; @@ -260,6 +263,17 @@ class CCodeGenerator { line("},"); } + + for (const string &replacement_name : unique_replacement_names) { + line("[" + rename_id(replacement_name) + "] = {"); + indent([&]() { + line(".visible = true,"); + line(".named = true,"); + line(".structural = true,"); + line(".extra = true,"); + }); + line("},"); + } }); line("};"); line(); diff --git a/src/runtime/language.c b/src/runtime/language.c index 41a41b2b..9706e7ff 100644 --- a/src/runtime/language.c +++ b/src/runtime/language.c @@ -32,7 +32,7 @@ void ts_language_table_entry(const TSLanguage *self, TSStateId state, } uint32_t ts_language_symbol_count(const TSLanguage *language) { - return language->symbol_count; + return language->symbol_count + language->rename_symbol_count; } uint32_t ts_language_version(const TSLanguage *language) { @@ -41,19 +41,21 @@ uint32_t ts_language_version(const TSLanguage *language) { TSSymbolMetadata ts_language_symbol_metadata(const TSLanguage *language, TSSymbol symbol) { - if (symbol == ts_builtin_sym_error) + if (symbol == ts_builtin_sym_error) { return (TSSymbolMetadata){ .visible = true, .named = true, .extra = false, .structural = true, }; - else + } else { return language->symbol_metadata[symbol]; + } } const char *ts_language_symbol_name(const TSLanguage *language, TSSymbol symbol) { - if (symbol == ts_builtin_sym_error) + if (symbol == ts_builtin_sym_error) { return "ERROR"; - else + } else { return language->symbol_names[symbol]; + } } TSSymbolType ts_language_symbol_type(const TSLanguage *language, TSSymbol symbol) { diff --git a/src/runtime/node.c b/src/runtime/node.c index 8179d529..35267964 100644 --- a/src/runtime/node.c +++ b/src/runtime/node.c @@ -264,7 +264,8 @@ TSPoint ts_node_end_point(TSNode self) { } TSSymbol ts_node_symbol(TSNode self) { - return ts_node__tree(self)->symbol; + const Tree *tree = ts_node__tree(self); + return tree->context.rename_symbol ? tree->context.rename_symbol : tree->symbol; } TSSymbolIterator ts_node_symbols(TSNode self) { @@ -288,9 +289,7 @@ void ts_symbol_iterator_next(TSSymbolIterator *self) { } const char *ts_node_type(TSNode self, const TSDocument *document) { - const Tree *tree = ts_node__tree(self); - TSSymbol symbol = tree->context.rename_symbol ? tree->context.rename_symbol : tree->symbol; - return ts_language_symbol_name(document->parser.language, symbol); + return ts_language_symbol_name(document->parser.language, ts_node_symbol(self)); } char *ts_node_string(TSNode self, const TSDocument *document) { diff --git a/test/runtime/document_test.cc b/test/runtime/document_test.cc index 99d04a51..7bd85ebf 100644 --- a/test/runtime/document_test.cc +++ b/test/runtime/document_test.cc @@ -15,6 +15,8 @@ TSPoint point(size_t row, size_t column) { START_TEST + + describe("Document", [&]() { TSDocument *document; TSNode root; diff --git a/test/runtime/language_test.cc b/test/runtime/language_test.cc new file mode 100644 index 00000000..8ca0d64a --- /dev/null +++ b/test/runtime/language_test.cc @@ -0,0 +1,54 @@ +#include "test_helper.h" +#include "runtime/alloc.h" +#include "helpers/load_language.h" + +START_TEST + +describe("Language", []() { + describe("symbol_name(TSSymbol symbol)", [&]() { + it("returns the correct name for renamed nodes", [&]() { + TSCompileResult compile_result = ts_compile_grammar(R"JSON({ + "name": "renamed_rules", + + "rules": { + "a": { + "type": "RENAME", + "value": "c", + "content": { + "type": "SYMBOL", + "name": "b" + } + }, + + "b": { + "type": "STRING", + "value": "b" + } + } + })JSON"); + + TSDocument *document = ts_document_new(); + const TSLanguage *language = load_test_language("renamed_rules", compile_result); + ts_document_set_language(document, language); + ts_document_set_input_string(document, "b"); + ts_document_parse(document); + + TSNode root_node = ts_document_root_node(document); + char *string = ts_node_string(root_node, document); + AssertThat(string, Equals("(a (c))")); + + TSNode renamed_node = ts_node_child(root_node, 0); + AssertThat(ts_node_type(renamed_node, document), Equals("c")); + + TSSymbol renamed_symbol = ts_node_symbol(renamed_node); + AssertThat(ts_language_symbol_count(language), IsGreaterThan(renamed_symbol)); + AssertThat(ts_language_symbol_name(language, renamed_symbol), Equals("c")); + AssertThat(ts_language_symbol_type(language, renamed_symbol), Equals(TSSymbolTypeRegular)); + + ts_free(string); + ts_document_free(document); + }); + }); +}); + +END_TEST