feat(wasm): add Supertype API

This commit is contained in:
Amaan Qureshi 2025-01-04 22:13:39 -05:00
parent 86b507a842
commit ef39298342
5 changed files with 127 additions and 1 deletions

View file

@ -212,6 +212,20 @@ int ts_language_type_is_visible_wasm(const TSLanguage *self, TSSymbol typeId) {
return symbolType <= TSSymbolTypeAnonymous;
}
void ts_language_supertypes_wasm(const TSLanguage *self) {
uint32_t length;
const TSSymbol *supertypes = ts_language_supertypes(self, &length);
TRANSFER_BUFFER[0] = (const void *)length;
TRANSFER_BUFFER[1] = supertypes;
}
void ts_language_subtypes_wasm(const TSLanguage *self, TSSymbol supertype) {
uint32_t length;
const TSSymbol *subtypes = ts_language_subtypes(self, supertype, &length);
TRANSFER_BUFFER[0] = (const void *)length;
TRANSFER_BUFFER[1] = subtypes;
}
/******************/
/* Section - Tree */
/******************/

View file

@ -5,6 +5,7 @@
const C = Module;
const INTERNAL = {};
const SIZE_OF_SHORT = 2;
const SIZE_OF_INT = 4;
const SIZE_OF_CURSOR = 4 * SIZE_OF_INT;
const SIZE_OF_NODE = 5 * SIZE_OF_INT;
@ -858,6 +859,40 @@ class Language {
return C._ts_language_type_is_visible_wasm(this[0], typeId) ? true : false;
}
get supertypes() {
C._ts_language_supertypes_wasm(this[0]);
const count = getValue(TRANSFER_BUFFER, 'i32');
const buffer = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');
const result = new Array(count);
if (count > 0) {
let address = buffer;
for (let i = 0; i < count; i++) {
result[i] = getValue(address, 'i16');
address += SIZE_OF_SHORT;
}
}
return result;
}
subtypes(supertype) {
C._ts_language_subtypes_wasm(this[0], supertype);
const count = getValue(TRANSFER_BUFFER, 'i32');
const buffer = getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32');
const result = new Array(count);
if (count > 0) {
let address = buffer;
for (let i = 0; i < count; i++) {
result[i] = getValue(address, 'i16');
address += SIZE_OF_SHORT;
}
}
return result;
}
nextState(stateId, typeId) {
return C._ts_language_next_state(this[0], stateId, typeId);
}

View file

@ -5,6 +5,8 @@
"ts_language_type_is_visible_wasm",
"ts_language_symbol_count",
"ts_language_state_count",
"ts_language_supertypes_wasm",
"ts_language_subtypes_wasm",
"ts_language_symbol_for_name",
"ts_language_symbol_name",
"ts_language_symbol_type",

View file

@ -12,4 +12,5 @@ module.exports = Parser.init().then(async () => ({
JavaScript: await Parser.Language.load(languageURL('javascript')),
JSON: await Parser.Language.load(languageURL('json')),
Python: await Parser.Language.load(languageURL('python')),
Rust: await Parser.Language.load(languageURL('rust')),
}));

View file

@ -2,7 +2,7 @@ const {assert} = require('chai');
let JavaScript;
describe('Language', () => {
before(async () => ({JavaScript} = await require('./helper')));
before(async () => ({JavaScript, Rust} = await require('./helper')));
describe('.fieldIdForName, .fieldNameForId', () => {
it('converts between the string and integer representations of fields', () => {
@ -41,6 +41,80 @@ describe('Language', () => {
assert.equal(null, JavaScript.idForNodeType('export_statement', false));
});
});
describe('Supertypes', () => {
it('gets the supertypes and subtypes of a parser', () => {
const supertypes = Rust.supertypes;
const names = supertypes.map((id) => Rust.nodeTypeForId(id));
assert.deepStrictEqual(
names,
['_expression', '_literal', '_literal_pattern', '_pattern', '_type'],
);
for (const id of supertypes) {
const name = Rust.nodeTypeForId(id);
const subtypes = Rust.subtypes(id);
let subtypeNames = subtypes.map((id) => Rust.nodeTypeForId(id));
subtypeNames = [...new Set(subtypeNames)].sort(); // Remove duplicates & sort
switch (name) {
case '_literal':
assert.deepStrictEqual(subtypeNames, [
'boolean_literal',
'char_literal',
'float_literal',
'integer_literal',
'raw_string_literal',
'string_literal',
]);
break;
case '_pattern':
assert.deepStrictEqual(subtypeNames, [
'_',
'_literal_pattern',
'captured_pattern',
'const_block',
'identifier',
'macro_invocation',
'mut_pattern',
'or_pattern',
'range_pattern',
'ref_pattern',
'reference_pattern',
'remaining_field_pattern',
'scoped_identifier',
'slice_pattern',
'struct_pattern',
'tuple_pattern',
'tuple_struct_pattern',
]);
break;
case '_type':
assert.deepStrictEqual(subtypeNames, [
'abstract_type',
'array_type',
'bounded_type',
'dynamic_type',
'function_type',
'generic_type',
'macro_invocation',
'metavariable',
'never_type',
'pointer_type',
'primitive_type',
'reference_type',
'removed_trait_bound',
'scoped_type_identifier',
'tuple_type',
'type_identifier',
'unit_type',
]);
break;
default:
break;
}
}
});
});
});
describe('Lookahead iterator', () => {