From 56c620c0054e592caf48971a7a9573ae030b7a6c Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Thu, 5 Dec 2019 17:21:46 -0800 Subject: [PATCH] Store a mapping to ensure no two symbols map to the same metadata --- cli/src/generate/mod.rs | 3 +- cli/src/generate/render.rs | 54 ++++++++++++++++++++ lib/include/tree_sitter/parser.h | 1 + lib/src/language.c | 86 ++++++++++++++++++++++---------- lib/src/language.h | 3 ++ lib/src/node.c | 10 ++-- lib/src/query.c | 4 +- 7 files changed, 128 insertions(+), 33 deletions(-) diff --git a/cli/src/generate/mod.rs b/cli/src/generate/mod.rs index 866fa78b..12a59e1b 100644 --- a/cli/src/generate/mod.rs +++ b/cli/src/generate/mod.rs @@ -35,7 +35,8 @@ const NEW_HEADER_PARTS: [&'static str; 2] = [ " uint32_t large_state_count; const uint16_t *small_parse_table; - const uint32_t *small_parse_table_map;", + const uint32_t *small_parse_table_map; + const TSSymbol *public_symbol_map;", " #define SMALL_STATE(id) id - LARGE_STATE_COUNT ", diff --git a/cli/src/generate/render.rs b/cli/src/generate/render.rs index f937c245..f2699ece 100644 --- a/cli/src/generate/render.rs +++ b/cli/src/generate/render.rs @@ -80,6 +80,11 @@ impl Generator { self.add_stats(); self.add_symbol_enum(); self.add_symbol_names_list(); + + if self.next_abi { + self.add_unique_symbol_map(); + } + self.add_symbol_metadata_list(); if !self.field_names.is_empty() { @@ -320,6 +325,51 @@ impl Generator { add_line!(self, ""); } + fn add_unique_symbol_map(&mut self) { + add_line!(self, "static TSSymbol ts_symbol_map[] = {{"); + indent!(self); + for symbol in &self.parse_table.symbols { + let mut mapping = symbol; + + // If this symbol has a simple alias, then check if its alias has the same + // name and kind (e.g. named vs anonymous) as some other symbol in the grammar. + // If so, add an entry to the symbol map that deduplicates these two symbols, + // so that only one of them will ever be returned via the public API. + if let Some(alias) = self.simple_aliases.get(symbol) { + let kind = if alias.is_named { + VariableType::Named + } else { + VariableType::Anonymous + }; + + for other_symbol in &self.parse_table.symbols { + if other_symbol == symbol { + continue; + } + if let Some(other_alias) = self.simple_aliases.get(other_symbol) { + if other_symbol < symbol && other_alias == alias { + mapping = other_symbol; + break; + } + } else if self.metadata_for_symbol(*other_symbol) == (&alias.value, kind) { + mapping = other_symbol; + break; + } + } + } + + add_line!( + self, + "[{}] = {},", + self.symbol_ids[&symbol], + self.symbol_ids[&mapping], + ); + } + dedent!(self); + add_line!(self, "}};"); + add_line!(self, ""); + } + fn add_field_name_enum(&mut self) { add_line!(self, "enum {{"); indent!(self); @@ -1072,6 +1122,10 @@ impl Generator { add_line!(self, ".lex_modes = ts_lex_modes,"); add_line!(self, ".symbol_names = ts_symbol_names,"); + if self.next_abi { + add_line!(self, ".public_symbol_map = ts_symbol_map,"); + } + if !self.parse_table.production_infos.is_empty() { add_line!( self, diff --git a/lib/include/tree_sitter/parser.h b/lib/include/tree_sitter/parser.h index c28e9d5c..9df91f8c 100644 --- a/lib/include/tree_sitter/parser.h +++ b/lib/include/tree_sitter/parser.h @@ -118,6 +118,7 @@ struct TSLanguage { uint32_t large_state_count; const uint16_t *small_parse_table; const uint32_t *small_parse_table_map; + const TSSymbol *public_symbol_map; }; /* diff --git a/lib/src/language.c b/lib/src/language.c index 3c08bba8..e240ef2a 100644 --- a/lib/src/language.c +++ b/lib/src/language.c @@ -3,8 +3,28 @@ #include "./error_costs.h" #include -void ts_language_table_entry(const TSLanguage *self, TSStateId state, - TSSymbol symbol, TableEntry *result) { +uint32_t ts_language_symbol_count(const TSLanguage *self) { + return self->symbol_count + self->alias_count; +} + +uint32_t ts_language_version(const TSLanguage *self) { + return self->version; +} + +uint32_t ts_language_field_count(const TSLanguage *self) { + if (self->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_FIELDS) { + return self->field_count; + } else { + return 0; + } +} + +void ts_language_table_entry( + const TSLanguage *self, + TSStateId state, + TSSymbol symbol, + TableEntry *result +) { if (symbol == ts_builtin_sym_error || symbol == ts_builtin_sym_error_repeat) { result->action_count = 0; result->is_reusable = false; @@ -19,31 +39,41 @@ void ts_language_table_entry(const TSLanguage *self, TSStateId state, } } -uint32_t ts_language_symbol_count(const TSLanguage *language) { - return language->symbol_count + language->alias_count; -} - -uint32_t ts_language_version(const TSLanguage *language) { - return language->version; -} - -TSSymbolMetadata ts_language_symbol_metadata(const TSLanguage *language, TSSymbol symbol) { +TSSymbolMetadata ts_language_symbol_metadata( + const TSLanguage *self, + TSSymbol symbol +) { if (symbol == ts_builtin_sym_error) { return (TSSymbolMetadata){.visible = true, .named = true}; } else if (symbol == ts_builtin_sym_error_repeat) { return (TSSymbolMetadata){.visible = false, .named = false}; } else { - return language->symbol_metadata[symbol]; + return self->symbol_metadata[symbol]; } } -const char *ts_language_symbol_name(const TSLanguage *language, TSSymbol symbol) { +TSSymbol ts_language_public_symbol( + const TSLanguage *self, + TSSymbol symbol +) { + if (symbol == ts_builtin_sym_error) return symbol; + if (self->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING) { + return self->public_symbol_map[symbol]; + } else { + return symbol; + } +} + +const char *ts_language_symbol_name( + const TSLanguage *self, + TSSymbol symbol +) { if (symbol == ts_builtin_sym_error) { return "ERROR"; } else if (symbol == ts_builtin_sym_error_repeat) { return "_ERROR"; } else { - return language->symbol_names[symbol]; + return self->symbol_names[symbol]; } } @@ -59,13 +89,22 @@ TSSymbol ts_language_symbol_for_name( TSSymbolMetadata metadata = ts_language_symbol_metadata(self, i); if (!metadata.visible || metadata.named != is_named) continue; const char *symbol_name = self->symbol_names[i]; - if (!strncmp(symbol_name, string, length) && !symbol_name[length]) return i; + if (!strncmp(symbol_name, string, length) && !symbol_name[length]) { + if (self->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING) { + return self->public_symbol_map[i]; + } else { + return i; + } + } } return 0; } -TSSymbolType ts_language_symbol_type(const TSLanguage *language, TSSymbol symbol) { - TSSymbolMetadata metadata = ts_language_symbol_metadata(language, symbol); +TSSymbolType ts_language_symbol_type( + const TSLanguage *self, + TSSymbol symbol +) { + TSSymbolMetadata metadata = ts_language_symbol_metadata(self, symbol); if (metadata.named) { return TSSymbolTypeRegular; } else if (metadata.visible) { @@ -75,15 +114,10 @@ TSSymbolType ts_language_symbol_type(const TSLanguage *language, TSSymbol symbol } } -uint32_t ts_language_field_count(const TSLanguage *self) { - if (self->version >= TREE_SITTER_LANGUAGE_VERSION_WITH_FIELDS) { - return self->field_count; - } else { - return 0; - } -} - -const char *ts_language_field_name_for_id(const TSLanguage *self, TSFieldId id) { +const char *ts_language_field_name_for_id( + const TSLanguage *self, + TSFieldId id +) { uint32_t count = ts_language_field_count(self); if (count) { return self->field_names[id]; diff --git a/lib/src/language.h b/lib/src/language.h index 0741486a..d7e17c3d 100644 --- a/lib/src/language.h +++ b/lib/src/language.h @@ -10,6 +10,7 @@ extern "C" { #define ts_builtin_sym_error_repeat (ts_builtin_sym_error - 1) #define TREE_SITTER_LANGUAGE_VERSION_WITH_FIELDS 10 +#define TREE_SITTER_LANGUAGE_VERSION_WITH_SYMBOL_DEDUPING 11 #define TREE_SITTER_LANGUAGE_VERSION_WITH_SMALL_STATES 11 typedef struct { @@ -22,6 +23,8 @@ void ts_language_table_entry(const TSLanguage *, TSStateId, TSSymbol, TableEntry TSSymbolMetadata ts_language_symbol_metadata(const TSLanguage *, TSSymbol); +TSSymbol ts_language_public_symbol(const TSLanguage *, TSSymbol); + static inline bool ts_language_is_symbol_external(const TSLanguage *self, TSSymbol symbol) { return 0 < symbol && symbol < self->external_token_count + 1; } diff --git a/lib/src/node.c b/lib/src/node.c index 6b2be36e..b03e2fc9 100644 --- a/lib/src/node.c +++ b/lib/src/node.c @@ -415,13 +415,15 @@ TSPoint ts_node_end_point(TSNode self) { } TSSymbol ts_node_symbol(TSNode self) { - return ts_node__alias(&self) - ? ts_node__alias(&self) - : ts_subtree_symbol(ts_node__subtree(self)); + TSSymbol symbol = ts_node__alias(&self); + if (!symbol) symbol = ts_subtree_symbol(ts_node__subtree(self)); + return ts_language_public_symbol(self.tree->language, symbol); } const char *ts_node_type(TSNode self) { - return ts_language_symbol_name(self.tree->language, ts_node_symbol(self)); + TSSymbol symbol = ts_node__alias(&self); + if (!symbol) symbol = ts_subtree_symbol(ts_node__subtree(self)); + return ts_language_symbol_name(self.tree->language, symbol); } char *ts_node_string(TSNode self) { diff --git a/lib/src/query.c b/lib/src/query.c index 7fad2284..d08076c1 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -742,8 +742,8 @@ TSQuery *ts_query_new( ) { // Work around the fact that multiple symbols can currently be // associated with the same name, due to "simple aliases". - // In the next language ABI version, this map should be contained - // within the language itself. + // In the next language ABI version, this map will be contained + // in the language's `public_symbol_map` field. uint32_t symbol_count = ts_language_symbol_count(language); TSSymbol *symbol_map = ts_malloc(sizeof(TSSymbol) * symbol_count); for (unsigned i = 0; i < symbol_count; i++) {