diff --git a/cli/generate/src/lib.rs b/cli/generate/src/lib.rs index 14f20672..6f8f6ee7 100644 --- a/cli/generate/src/lib.rs +++ b/cli/generate/src/lib.rs @@ -124,6 +124,8 @@ fn generate_parser_for_grammar_with_opts( &simple_aliases, &variable_info, ); + let supertype_symbol_map = + node_types::get_supertype_symbol_map(&syntax_grammar, &simple_aliases, &variable_info); let tables = build_tables( &syntax_grammar, &lexical_grammar, @@ -139,6 +141,7 @@ fn generate_parser_for_grammar_with_opts( lexical_grammar, simple_aliases, abi_version, + supertype_symbol_map, ); Ok(GeneratedParser { c_code, diff --git a/cli/generate/src/node_types.rs b/cli/generate/src/node_types.rs index debd8ae1..03341661 100644 --- a/cli/generate/src/node_types.rs +++ b/cli/generate/src/node_types.rs @@ -369,6 +369,76 @@ pub fn get_variable_info( Ok(result) } +fn get_aliases_by_symbol( + syntax_grammar: &SyntaxGrammar, + default_aliases: &AliasMap, +) -> HashMap>> { + let mut aliases_by_symbol = HashMap::new(); + for (symbol, alias) in default_aliases { + aliases_by_symbol.insert(*symbol, { + let mut aliases = HashSet::new(); + aliases.insert(Some(alias.clone())); + aliases + }); + } + for extra_symbol in &syntax_grammar.extra_symbols { + if !default_aliases.contains_key(extra_symbol) { + aliases_by_symbol + .entry(*extra_symbol) + .or_insert_with(HashSet::new) + .insert(None); + } + } + for variable in &syntax_grammar.variables { + for production in &variable.productions { + for step in &production.steps { + aliases_by_symbol + .entry(step.symbol) + .or_insert_with(HashSet::new) + .insert( + step.alias + .as_ref() + .or_else(|| default_aliases.get(&step.symbol)) + .cloned(), + ); + } + } + } + aliases_by_symbol.insert( + Symbol::non_terminal(0), + std::iter::once(&None).cloned().collect(), + ); + aliases_by_symbol +} + +pub fn get_supertype_symbol_map( + syntax_grammar: &SyntaxGrammar, + default_aliases: &AliasMap, + variable_info: &[VariableInfo], +) -> BTreeMap> { + let aliases_by_symbol = get_aliases_by_symbol(syntax_grammar, default_aliases); + let mut supertype_symbol_map = BTreeMap::new(); + + let mut symbols_by_alias = HashMap::new(); + for (symbol, aliases) in &aliases_by_symbol { + for alias in aliases.iter().flatten() { + symbols_by_alias + .entry(alias) + .or_insert_with(Vec::new) + .push(*symbol); + } + } + + for (i, info) in variable_info.iter().enumerate() { + let symbol = Symbol::non_terminal(i); + if syntax_grammar.supertype_symbols.contains(&symbol) { + let subtypes = info.children.types.clone(); + supertype_symbol_map.insert(symbol, subtypes); + } + } + supertype_symbol_map +} + pub fn generate_node_types_json( syntax_grammar: &SyntaxGrammar, lexical_grammar: &LexicalGrammar, @@ -430,41 +500,7 @@ pub fn generate_node_types_json( } }; - let mut aliases_by_symbol = HashMap::new(); - for (symbol, alias) in default_aliases { - aliases_by_symbol.insert(*symbol, { - let mut aliases = HashSet::new(); - aliases.insert(Some(alias.clone())); - aliases - }); - } - for extra_symbol in &syntax_grammar.extra_symbols { - if !default_aliases.contains_key(extra_symbol) { - aliases_by_symbol - .entry(*extra_symbol) - .or_insert_with(HashSet::new) - .insert(None); - } - } - for variable in &syntax_grammar.variables { - for production in &variable.productions { - for step in &production.steps { - aliases_by_symbol - .entry(step.symbol) - .or_insert_with(HashSet::new) - .insert( - step.alias - .as_ref() - .or_else(|| default_aliases.get(&step.symbol)) - .cloned(), - ); - } - } - } - aliases_by_symbol.insert( - Symbol::non_terminal(0), - std::iter::once(&None).cloned().collect(), - ); + let aliases_by_symbol = get_aliases_by_symbol(syntax_grammar, default_aliases); let mut subtype_map = Vec::new(); for (i, info) in variable_info.iter().enumerate() { diff --git a/cli/generate/src/render.rs b/cli/generate/src/render.rs index 1712d9ff..32c467e0 100644 --- a/cli/generate/src/render.rs +++ b/cli/generate/src/render.rs @@ -1,6 +1,6 @@ use std::{ cmp, - collections::{HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, fmt::Write, mem::swap, }; @@ -9,6 +9,7 @@ use super::{ build_tables::Tables, grammars::{ExternalToken, LexicalGrammar, SyntaxGrammar, VariableType}, nfa::CharacterSet, + node_types::ChildType, rules::{Alias, AliasMap, Symbol, SymbolType, TokenSet}, tables::{ AdvanceAction, FieldLocation, GotoAction, LexState, LexTable, ParseAction, ParseTable, @@ -80,6 +81,8 @@ struct Generator { reserved_word_sets: Vec, reserved_word_set_ids_by_parse_state: Vec, field_names: Vec, + supertype_symbol_map: BTreeMap>, + supertype_map: BTreeMap>, #[allow(unused)] abi_version: usize, @@ -115,6 +118,10 @@ impl Generator { self.add_non_terminal_alias_map(); self.add_primary_state_id_list(); + if self.abi_version >= ABI_VERSION_WITH_RESERVED_WORDS && !self.supertype_map.is_empty() { + self.add_supertype_map(); + } + let buffer_offset_before_lex_functions = self.buffer.len(); let mut main_lex_table = LexTable::default(); @@ -224,33 +231,24 @@ impl Generator { for alias in &production_info.alias_sequence { // Generate a mapping from aliases to C identifiers. if let Some(alias) = &alias { - let existing_symbol = self.parse_table.symbols.iter().copied().find(|symbol| { - self.default_aliases.get(symbol).map_or_else( - || { - let (name, kind) = self.metadata_for_symbol(*symbol); - name == alias.value && kind == alias.kind() - }, - |default_alias| default_alias == alias, - ) - }); - // Some aliases match an existing symbol in the grammar. - let alias_id = if let Some(existing_symbol) = existing_symbol { - self.symbol_ids[&self.symbol_map[&existing_symbol]].clone() - } - // Other aliases don't match any existing symbol, and need their own - // identifiers. - else { - if let Err(i) = self.unique_aliases.binary_search(alias) { - self.unique_aliases.insert(i, alias.clone()); + let alias_id = + if let Some(existing_symbol) = self.symbols_for_alias(alias).first() { + self.symbol_ids[&self.symbol_map[existing_symbol]].clone() } + // Other aliases don't match any existing symbol, and need their own + // identifiers. + else { + if let Err(i) = self.unique_aliases.binary_search(alias) { + self.unique_aliases.insert(i, alias.clone()); + } - if alias.is_named { - format!("alias_sym_{}", self.sanitize_identifier(&alias.value)) - } else { - format!("anon_alias_sym_{}", self.sanitize_identifier(&alias.value)) - } - }; + if alias.is_named { + format!("alias_sym_{}", self.sanitize_identifier(&alias.value)) + } else { + format!("anon_alias_sym_{}", self.sanitize_identifier(&alias.value)) + } + }; self.alias_ids.entry(alias.clone()).or_insert(alias_id); } @@ -290,6 +288,18 @@ impl Generator { self.reserved_word_set_ids_by_parse_state.push(id); } + if self.abi_version >= ABI_VERSION_WITH_RESERVED_WORDS { + for (supertype, subtypes) in &self.supertype_symbol_map { + if let Some(supertype) = self.symbol_ids.get(supertype) { + self.supertype_map + .entry(supertype.clone()) + .or_insert_with(|| subtypes.clone()); + } + } + + self.supertype_symbol_map.clear(); + } + // Determine which states should use the "small state" representation, and which should // use the normal array representation. let threshold = cmp::min(SMALL_STATE_THRESHOLD, self.parse_table.symbols.len() / 2); @@ -404,6 +414,7 @@ impl Generator { "#define PRODUCTION_ID_COUNT {}", self.parse_table.production_infos.len() ); + add_line!(self, "#define SUPERTYPE_COUNT {}", self.supertype_map.len()); add_line!(self, ""); } @@ -689,7 +700,7 @@ impl Generator { add_line!( self, - "static const TSFieldMapSlice ts_field_map_slices[PRODUCTION_ID_COUNT] = {{", + "static const TSMapSlice ts_field_map_slices[PRODUCTION_ID_COUNT] = {{", ); indent!(self); for (production_id, (row_id, length)) in field_map_ids.into_iter().enumerate() { @@ -728,6 +739,83 @@ impl Generator { add_line!(self, ""); } + fn add_supertype_map(&mut self) { + add_line!( + self, + "static const TSSymbol ts_supertype_symbols[SUPERTYPE_COUNT] = {{" + ); + indent!(self); + for supertype in self.supertype_map.keys() { + add_line!(self, "{supertype},"); + } + dedent!(self); + add_line!(self, "}};\n"); + + add_line!( + self, + "static const TSMapSlice ts_supertype_map_slices[] = {{", + ); + indent!(self); + let mut row_id = 0; + let mut supertype_ids = vec![0]; + let mut supertype_string_map = BTreeMap::new(); + for (supertype, subtypes) in &self.supertype_map { + supertype_string_map.insert( + supertype, + subtypes + .iter() + .flat_map(|s| match s { + ChildType::Normal(symbol) => vec![self.symbol_ids.get(symbol).cloned()], + ChildType::Aliased(alias) => { + self.alias_ids.get(alias).cloned().map_or_else( + || { + self.symbols_for_alias(alias) + .into_iter() + .map(|s| self.symbol_ids.get(&s).cloned()) + .collect() + }, + |a| vec![Some(a)], + ) + } + }) + .flatten() + .collect::>(), + ); + } + for (supertype, subtypes) in &supertype_string_map { + let length = subtypes.len(); + add_line!( + self, + "[{supertype}] = {{.index = {row_id}, .length = {length}}},", + ); + row_id += length; + supertype_ids.push(row_id); + } + dedent!(self); + add_line!(self, "}};"); + add_line!(self, ""); + + add_line!( + self, + "static const TSSymbol ts_supertype_map_entries[] = {{", + ); + indent!(self); + for (i, (_, subtypes)) in supertype_string_map.iter().enumerate() { + let row_index = supertype_ids[i]; + add_line!(self, "[{row_index}] ="); + indent!(self); + for subtype in subtypes { + add_whitespace!(self); + add!(self, "{subtype},\n"); + } + dedent!(self); + } + + dedent!(self); + add_line!(self, "}};"); + add_line!(self, ""); + } + fn add_lex_function(&mut self, name: &str, lex_table: LexTable) { add_line!( self, @@ -1462,6 +1550,9 @@ impl Generator { add_line!(self, ".state_count = STATE_COUNT,"); add_line!(self, ".large_state_count = LARGE_STATE_COUNT,"); add_line!(self, ".production_id_count = PRODUCTION_ID_COUNT,"); + if self.abi_version >= ABI_VERSION_WITH_RESERVED_WORDS { + add_line!(self, ".supertype_count = SUPERTYPE_COUNT,"); + } add_line!(self, ".field_count = FIELD_COUNT,"); add_line!( self, @@ -1483,6 +1574,11 @@ impl Generator { add_line!(self, ".field_map_slices = ts_field_map_slices,"); add_line!(self, ".field_map_entries = ts_field_map_entries,"); } + if !self.supertype_map.is_empty() && self.abi_version >= ABI_VERSION_WITH_RESERVED_WORDS { + add_line!(self, ".supertype_map_slices = ts_supertype_map_slices,"); + add_line!(self, ".supertype_map_entries = ts_supertype_map_entries,"); + add_line!(self, ".supertype_symbols = ts_supertype_symbols,"); + } add_line!(self, ".symbol_metadata = ts_symbol_metadata,"); add_line!(self, ".public_symbol_map = ts_symbol_map,"); add_line!(self, ".alias_map = ts_non_terminal_alias_map,"); @@ -1635,6 +1731,23 @@ impl Generator { } } + fn symbols_for_alias(&self, alias: &Alias) -> Vec { + self.parse_table + .symbols + .iter() + .copied() + .filter(move |symbol| { + self.default_aliases.get(symbol).map_or_else( + || { + let (name, kind) = self.metadata_for_symbol(*symbol); + name == alias.value && kind == alias.kind() + }, + |default_alias| default_alias == alias, + ) + }) + .collect() + } + fn sanitize_identifier(&self, name: &str) -> String { let mut result = String::with_capacity(name.len()); for c in name.chars() { @@ -1802,6 +1915,7 @@ pub fn render_c_code( lexical_grammar: LexicalGrammar, default_aliases: AliasMap, abi_version: usize, + supertype_symbol_map: BTreeMap>, ) -> String { assert!( (ABI_VERSION_MIN..=ABI_VERSION_MAX).contains(&abi_version), @@ -1819,6 +1933,7 @@ pub fn render_c_code( lexical_grammar, default_aliases, abi_version, + supertype_symbol_map, ..Default::default() } .generate() diff --git a/cli/src/tests/language_test.rs b/cli/src/tests/language_test.rs index 3def3d60..c3b00437 100644 --- a/cli/src/tests/language_test.rs +++ b/cli/src/tests/language_test.rs @@ -95,3 +95,100 @@ fn test_symbol_metadata_checks() { } } } + +#[test] +fn test_supertypes() { + let language = get_language("rust"); + let supertypes = language.supertypes(); + + assert_eq!(supertypes.len(), 5); + assert_eq!( + supertypes + .iter() + .filter_map(|&s| language.node_kind_for_id(s)) + .map(|s| s.to_string()) + .collect::>(), + vec![ + "_expression", + "_literal", + "_literal_pattern", + "_pattern", + "_type" + ] + ); + + for &supertype in supertypes { + let mut subtypes = language + .subtypes_for_supertype(supertype) + .iter() + .filter_map(|symbol| language.node_kind_for_id(*symbol)) + .collect::>(); + subtypes.sort_unstable(); + subtypes.dedup(); + + match language.node_kind_for_id(supertype) { + Some("_literal") => { + assert_eq!( + subtypes, + &[ + "boolean_literal", + "char_literal", + "float_literal", + "integer_literal", + "raw_string_literal", + "string_literal" + ] + ); + } + Some("_pattern") => { + assert_eq!( + subtypes, + &[ + "_", + "_literal_pattern", + "captured_pattern", + "const_block", + "identifier", + "macro_invocation", + "mut_pattern", + "or_pattern", + "range_pattern", + "ref_pattern", + "reference_pattern", + "remaining_field_pattern", + "scoped_identifier", + "slice_pattern", + "struct_pattern", + "tuple_pattern", + "tuple_struct_pattern", + ] + ); + } + Some("_type") => { + assert_eq!( + subtypes, + &[ + "abstract_type", + "array_type", + "bounded_type", + "dynamic_type", + "function_type", + "generic_type", + "macro_invocation", + "metavariable", + "never_type", + "pointer_type", + "primitive_type", + "reference_type", + "removed_trait_bound", + "scoped_type_identifier", + "tuple_type", + "type_identifier", + "unit_type" + ] + ); + } + _ => {} + } + } +} diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index 82393ab3..9ce3ee21 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -759,13 +759,6 @@ extern "C" { #[doc = " Get the number of valid states in this language."] pub fn ts_language_state_count(self_: *const TSLanguage) -> u32; } -extern "C" { - #[doc = " Get a node type string for the given numerical id."] - pub fn ts_language_symbol_name( - self_: *const TSLanguage, - symbol: TSSymbol, - ) -> *const ::core::ffi::c_char; -} extern "C" { #[doc = " Get the numerical id for the given node type string."] pub fn ts_language_symbol_for_name( @@ -794,6 +787,25 @@ extern "C" { name_length: u32, ) -> TSFieldId; } +extern "C" { + #[doc = " Get a list of all supertype symbols for the language."] + pub fn ts_language_supertypes(self_: *const TSLanguage, length: *mut u32) -> *const TSSymbol; +} +extern "C" { + #[doc = " Get a list of all subtype symbol ids for a given supertype symbol.\n\n See [`ts_language_supertypes`] for fetching all supertype symbols."] + pub fn ts_language_subtypes( + self_: *const TSLanguage, + supertype: TSSymbol, + length: *mut u32, + ) -> *const TSSymbol; +} +extern "C" { + #[doc = " Get a node type string for the given numerical id."] + pub fn ts_language_symbol_name( + self_: *const TSLanguage, + symbol: TSSymbol, + ) -> *const ::core::ffi::c_char; +} extern "C" { #[doc = " Check whether the given node type id belongs to named nodes, anonymous nodes,\n or a hidden nodes.\n\n See also [`ts_node_is_named`]. Hidden nodes are never returned from the API."] pub fn ts_language_symbol_type(self_: *const TSLanguage, symbol: TSSymbol) -> TSSymbolType; diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 2d31bf49..f8be963c 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -420,6 +420,36 @@ impl Language { unsafe { ffi::ts_language_state_count(self.0) as usize } } + /// Get a list of all supertype symbols for the language. + #[doc(alias = "ts_language_supertypes")] + #[must_use] + pub fn supertypes(&self) -> &[u16] { + let mut length = 0u32; + unsafe { + let ptr = ffi::ts_language_supertypes(self.0, core::ptr::addr_of_mut!(length)); + if length == 0 { + &[] + } else { + slice::from_raw_parts(ptr.cast_mut(), length as usize) + } + } + } + + /// Get a list of all subtype symbol names for a given supertype symbol. + #[doc(alias = "ts_language_supertype_map")] + #[must_use] + pub fn subtypes_for_supertype(&self, supertype: u16) -> &[u16] { + unsafe { + let mut length = 0u32; + let ptr = ffi::ts_language_subtypes(self.0, supertype, core::ptr::addr_of_mut!(length)); + if length == 0 { + &[] + } else { + slice::from_raw_parts(ptr.cast_mut(), length as usize) + } + } + } + /// Get the name of the node kind for the given numerical id. #[doc(alias = "ts_language_symbol_name")] #[must_use] diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 9bc15bdc..d037d838 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -1166,11 +1166,6 @@ uint32_t ts_language_symbol_count(const TSLanguage *self); */ uint32_t ts_language_state_count(const TSLanguage *self); -/** - * Get a node type string for the given numerical id. - */ -const char *ts_language_symbol_name(const TSLanguage *self, TSSymbol symbol); - /** * Get the numerical id for the given node type string. */ @@ -1196,6 +1191,27 @@ const char *ts_language_field_name_for_id(const TSLanguage *self, TSFieldId id); */ TSFieldId ts_language_field_id_for_name(const TSLanguage *self, const char *name, uint32_t name_length); +/** + * Get a list of all supertype symbols for the language. +*/ +const TSSymbol *ts_language_supertypes(const TSLanguage *self, uint32_t *length); + +/** + * Get a list of all subtype symbol ids for a given supertype symbol. + * + * See [`ts_language_supertypes`] for fetching all supertype symbols. + */ +const TSSymbol *ts_language_subtypes( + const TSLanguage *self, + TSSymbol supertype, + uint32_t *length +); + +/** + * Get a node type string for the given numerical id. + */ +const char *ts_language_symbol_name(const TSLanguage *self, TSSymbol symbol); + /** * Check whether the given node type id belongs to named nodes, anonymous nodes, * or a hidden nodes. diff --git a/lib/src/language.c b/lib/src/language.c index cd1d4f08..93cc21b2 100644 --- a/lib/src/language.c +++ b/lib/src/language.c @@ -24,6 +24,31 @@ uint32_t ts_language_state_count(const TSLanguage *self) { return self->state_count; } +const TSSymbol *ts_language_supertypes(const TSLanguage *self, uint32_t *length) { + if (self->version >= LANGUAGE_VERSION_WITH_RESERVED_WORDS) { + *length = self->supertype_count; + return self->supertype_symbols; + } else { + *length = 0; + return NULL; + } +} + +const TSSymbol *ts_language_subtypes( + const TSLanguage *self, + TSSymbol supertype, + uint32_t *length +) { + if (self->version < LANGUAGE_VERSION_WITH_RESERVED_WORDS || !ts_language_symbol_metadata(self, supertype).supertype) { + *length = 0; + return NULL; + } + + TSMapSlice slice = self->supertype_map_slices[supertype]; + *length = slice.length; + return &self->supertype_map_entries[slice.index]; +} + uint32_t ts_language_version(const TSLanguage *self) { return self->version; } diff --git a/lib/src/language.h b/lib/src/language.h index d8358abe..6832f8fe 100644 --- a/lib/src/language.h +++ b/lib/src/language.h @@ -236,7 +236,7 @@ static inline void ts_language_field_map( return; } - TSFieldMapSlice slice = self->field_map_slices[production_id]; + TSMapSlice slice = self->field_map_slices[production_id]; *start = &self->field_map_entries[slice.index]; *end = &self->field_map_entries[slice.index] + slice.length; } diff --git a/lib/src/parser.h b/lib/src/parser.h index acffd031..a61358d1 100644 --- a/lib/src/parser.h +++ b/lib/src/parser.h @@ -26,10 +26,11 @@ typedef struct { bool inherited; } TSFieldMapEntry; +// Used to index the field and supertype maps. typedef struct { uint16_t index; uint16_t length; -} TSFieldMapSlice; +} TSMapSlice; typedef struct { bool visible; @@ -115,7 +116,7 @@ struct TSLanguage { const TSParseActionEntry *parse_actions; const char * const *symbol_names; const char * const *field_names; - const TSFieldMapSlice *field_map_slices; + const TSMapSlice *field_map_slices; const TSFieldMapEntry *field_map_entries; const TSSymbolMetadata *symbol_metadata; const TSSymbol *public_symbol_map; @@ -138,6 +139,10 @@ struct TSLanguage { const char *name; const TSSymbol *reserved_words; uint16_t max_reserved_word_set_size; + uint32_t supertype_count; + const TSSymbol *supertype_symbols; + const TSMapSlice *supertype_map_slices; + const TSSymbol *supertype_map_entries; }; static inline bool set_contains(const TSCharacterRange *ranges, uint32_t len, int32_t lookahead) { diff --git a/lib/src/wasm_store.c b/lib/src/wasm_store.c index 77e5a360..b5e0a5c7 100644 --- a/lib/src/wasm_store.c +++ b/lib/src/wasm_store.c @@ -156,6 +156,10 @@ typedef struct { int32_t name; int32_t reserved_words; uint16_t max_reserved_word_set_size; + uint32_t supertype_count; + int32_t supertype_symbols; + int32_t supertype_map_slices; + int32_t supertype_map_entries; } LanguageInWasmMemory; // LexerInWasmMemory - The memory layout of a `TSLexer` when compiled to wasm32. @@ -1234,6 +1238,9 @@ const TSLanguage *ts_wasm_store_load_language( wasm_language.primary_state_ids, wasm_language.name, wasm_language.reserved_words, + wasm_language.supertype_symbols, + wasm_language.supertype_map_entries, + wasm_language.supertype_map_slices, wasm_language.external_token_count > 0 ? wasm_language.external_scanner.states : 0, wasm_language.external_token_count > 0 ? wasm_language.external_scanner.symbol_map : 0, wasm_language.external_token_count > 0 ? wasm_language.external_scanner.create : 0, @@ -1260,6 +1267,7 @@ const TSLanguage *ts_wasm_store_load_language( .large_state_count = wasm_language.large_state_count, .production_id_count = wasm_language.production_id_count, .field_count = wasm_language.field_count, + .supertype_count = wasm_language.supertype_count, .max_alias_sequence_length = wasm_language.max_alias_sequence_length, .keyword_capture_token = wasm_language.keyword_capture_token, .parse_table = copy( @@ -1295,14 +1303,14 @@ const TSLanguage *ts_wasm_store_load_language( if (language->field_count > 0 && language->production_id_count > 0) { language->field_map_slices = copy( &memory[wasm_language.field_map_slices], - wasm_language.production_id_count * sizeof(TSFieldMapSlice) + wasm_language.production_id_count * sizeof(TSMapSlice) ); // Determine the number of field map entries by finding the greatest index // in any of the slices. uint32_t field_map_entry_count = 0; for (uint32_t i = 0; i < wasm_language.production_id_count; i++) { - TSFieldMapSlice slice = language->field_map_slices[i]; + TSMapSlice slice = language->field_map_slices[i]; uint32_t slice_end = slice.index + slice.length; if (slice_end > field_map_entry_count) { field_map_entry_count = slice_end; @@ -1321,6 +1329,37 @@ const TSLanguage *ts_wasm_store_load_language( ); } + if (language->supertype_count > 0) { + language->supertype_symbols = copy( + &memory[wasm_language.supertype_symbols], + wasm_language.supertype_count * sizeof(TSSymbol) + ); + + // Determine the number of supertype map slices by finding the greatest + // supertype ID. + int largest_supertype = 0; + for (unsigned i = 0; i < language->supertype_count; i++) { + TSSymbol supertype = language->supertype_symbols[i]; + if (supertype > largest_supertype) { + largest_supertype = supertype; + } + } + + language->supertype_map_slices = copy( + &memory[wasm_language.supertype_map_slices], + (largest_supertype + 1) * sizeof(TSMapSlice) + ); + + TSSymbol last_supertype = language->supertype_symbols[language->supertype_count - 1]; + TSMapSlice last_slice = language->supertype_map_slices[last_supertype]; + uint32_t supertype_map_entry_count = last_slice.index + last_slice.length; + + language->supertype_map_entries = copy( + &memory[wasm_language.supertype_map_entries], + supertype_map_entry_count * sizeof(char *) + ); + } + if (language->max_alias_sequence_length > 0 && language->production_id_count > 0) { // The alias map contains symbols, alias counts, and aliases, terminated by a null symbol. int32_t alias_map_size = 0; @@ -1752,6 +1791,9 @@ void ts_wasm_language_release(const TSLanguage *self) { ts_free((void *)self->external_scanner.symbol_map); ts_free((void *)self->field_map_entries); ts_free((void *)self->field_map_slices); + ts_free((void *)self->supertype_symbols); + ts_free((void *)self->supertype_map_entries); + ts_free((void *)self->supertype_map_slices); ts_free((void *)self->field_names); ts_free((void *)self->lex_modes); ts_free((void *)self->name);