import { C, INTERNAL, Internal, assertInternal, SIZE_OF_INT, SIZE_OF_SHORT } from './constants'; import { LookaheadIterator } from './lookahead_iterator'; import { Node } from './node'; import { TRANSFER_BUFFER } from './parser'; import { CaptureQuantifier, QueryPredicate, PredicateStep, QueryProperties, Query, TextPredicate } from './query'; const PREDICATE_STEP_TYPE_CAPTURE = 1; const PREDICATE_STEP_TYPE_STRING = 2; const QUERY_WORD_REGEX = /[\w-]+/g; const LANGUAGE_FUNCTION_REGEX = /^tree_sitter_\w+$/; /** * An opaque object that defines how to parse a particular language. * The code for each `Language` is generated by the Tree-sitter CLI. */ export class Language { /** @internal */ private [0] = 0; // Internal handle for WASM /** * A list of all node types in the language. The index of each type in this * array is its node type id. */ types: string[]; /** * A list of all field names in the language. The index of each field name in * this array is its field id. */ fields: (string | null)[]; /** @internal */ constructor(internal: Internal, address: number) { assertInternal(internal); this[0] = address; this.types = new Array(C._ts_language_symbol_count(this[0])); for (let i = 0, n = this.types.length; i < n; i++) { if (C._ts_language_symbol_type(this[0], i) < 2) { this.types[i] = C.UTF8ToString(C._ts_language_symbol_name(this[0], i)); } } this.fields = new Array(C._ts_language_field_count(this[0]) + 1); for (let i = 0, n = this.fields.length; i < n; i++) { const fieldName = C._ts_language_field_name_for_id(this[0], i); if (fieldName !== 0) { this.fields[i] = C.UTF8ToString(fieldName); } else { this.fields[i] = null; } } } /** * Gets the name of the language. */ get name(): string | null { const ptr = C._ts_language_name(this[0]); if (ptr === 0) return null; return C.UTF8ToString(ptr); } /** * Gets the version of the language. */ get version(): number { return C._ts_language_version(this[0]); } /** * Gets the number of fields in the language. */ get fieldCount(): number { return this.fields.length - 1; } /** * Gets the number of states in the language. */ get stateCount(): number { return C._ts_language_state_count(this[0]); } /** * Get the field id for a field name. */ fieldIdForName(fieldName: string): number | null { const result = this.fields.indexOf(fieldName); return result !== -1 ? result : null; } /** * Get the field name for a field id. */ fieldNameForId(fieldId: number): string | null { return this.fields[fieldId] ?? null; } /** * Get the node type id for a node type name. */ idForNodeType(type: string, named: boolean): number | null { const typeLength = C.lengthBytesUTF8(type); const typeAddress = C._malloc(typeLength + 1); C.stringToUTF8(type, typeAddress, typeLength + 1); const result = C._ts_language_symbol_for_name(this[0], typeAddress, typeLength, named ? 1 : 0); C._free(typeAddress); return result || null; } /** * Gets the number of node types in the language. */ get nodeTypeCount(): number { return C._ts_language_symbol_count(this[0]); } /** * Get the node type name for a node type id. */ nodeTypeForId(typeId: number): string | null { const name = C._ts_language_symbol_name(this[0], typeId); return name ? C.UTF8ToString(name) : null; } /** * Check if a node type is named. * * @see {@link https://tree-sitter.github.io/tree-sitter/using-parsers/2-basic-parsing.html#named-vs-anonymous-nodes} */ nodeTypeIsNamed(typeId: number): boolean { return C._ts_language_type_is_named_wasm(this[0], typeId) ? true : false; } /** * Check if a node type is visible. */ nodeTypeIsVisible(typeId: number): boolean { return C._ts_language_type_is_visible_wasm(this[0], typeId) ? true : false; } /** * Get the supertypes ids of this language. * * @see {@link https://tree-sitter.github.io/tree-sitter/using-parsers/6-static-node-types.html?highlight=supertype#supertype-nodes} */ get supertypes(): number[] { C._ts_language_supertypes_wasm(this[0]); const count = C.getValue(TRANSFER_BUFFER, 'i32'); const buffer = C.getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); const result = new Array(count); if (count > 0) { let address = buffer; for (let i = 0; i < count; i++) { result[i] = C.getValue(address, 'i16'); address += SIZE_OF_SHORT; } } return result; } /** * Get the subtype ids for a given supertype node id. */ subtypes(supertype: number): number[] { C._ts_language_subtypes_wasm(this[0], supertype); const count = C.getValue(TRANSFER_BUFFER, 'i32'); const buffer = C.getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); const result = new Array(count); if (count > 0) { let address = buffer; for (let i = 0; i < count; i++) { result[i] = C.getValue(address, 'i16'); address += SIZE_OF_SHORT; } } return result; } /** * Get the next state id for a given state id and node type id. */ nextState(stateId: number, typeId: number): number { return C._ts_language_next_state(this[0], stateId, typeId); } /** * Create a new lookahead iterator for this language and parse state. * * This returns `null` if state is invalid for this language. * * Iterating {@link LookaheadIterator} will yield valid symbols in the given * parse state. Newly created lookahead iterators will return the `ERROR` * symbol from {@link LookaheadIterator#currentType}. * * Lookahead iterators can be useful for generating suggestions and improving * syntax error diagnostics. To get symbols valid in an `ERROR` node, use the * lookahead iterator on its first leaf node state. For `MISSING` nodes, a * lookahead iterator created on the previous non-extra leaf node may be * appropriate. */ lookaheadIterator(stateId: number): LookaheadIterator | null { const address = C._ts_lookahead_iterator_new(this[0], stateId); if (address) return new LookaheadIterator(INTERNAL, address, this); return null; } /** * Create a new query from a string containing one or more S-expression * patterns. * * The query is associated with a particular language, and can only be run * on syntax nodes parsed with that language. References to Queries can be * shared between multiple threads. * * @link {@see https://tree-sitter.github.io/tree-sitter/using-parsers/queries} */ query(source: string): Query { const sourceLength = C.lengthBytesUTF8(source); const sourceAddress = C._malloc(sourceLength + 1); C.stringToUTF8(source, sourceAddress, sourceLength + 1); const address = C._ts_query_new( this[0], sourceAddress, sourceLength, TRANSFER_BUFFER, TRANSFER_BUFFER + SIZE_OF_INT ); if (!address) { const errorId = C.getValue(TRANSFER_BUFFER + SIZE_OF_INT, 'i32'); const errorByte = C.getValue(TRANSFER_BUFFER, 'i32'); const errorIndex = C.UTF8ToString(sourceAddress, errorByte).length; const suffix = source.slice(errorIndex, errorIndex + 100).split('\n')[0]; let word = suffix.match(QUERY_WORD_REGEX)?.[0] ?? ''; let error: Error; switch (errorId) { case 2: error = new RangeError(`Bad node name '${word}'`); break; case 3: error = new RangeError(`Bad field name '${word}'`); break; case 4: error = new RangeError(`Bad capture name @${word}`); break; case 5: error = new TypeError(`Bad pattern structure at offset ${errorIndex}: '${suffix}'...`); word = ''; break; default: error = new SyntaxError(`Bad syntax at offset ${errorIndex}: '${suffix}'...`); word = ''; break; } // eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-member-access (error as any).index = errorIndex; // eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-member-access (error as any).length = word.length; C._free(sourceAddress); throw error; } const stringCount = C._ts_query_string_count(address); const captureCount = C._ts_query_capture_count(address); const patternCount = C._ts_query_pattern_count(address); const captureNames = new Array(captureCount); const captureQuantifiers = new Array(patternCount); const stringValues = new Array(stringCount); for (let i = 0; i < captureCount; i++) { const nameAddress = C._ts_query_capture_name_for_id( address, i, TRANSFER_BUFFER ); const nameLength = C.getValue(TRANSFER_BUFFER, 'i32'); captureNames[i] = C.UTF8ToString(nameAddress, nameLength); } for (let i = 0; i < patternCount; i++) { const captureQuantifiersArray = new Array(captureCount); for (let j = 0; j < captureCount; j++) { const quantifier = C._ts_query_capture_quantifier_for_id(address, i, j); captureQuantifiersArray[j] = quantifier as CaptureQuantifier; } captureQuantifiers[i] = captureQuantifiersArray; } for (let i = 0; i < stringCount; i++) { const valueAddress = C._ts_query_string_value_for_id( address, i, TRANSFER_BUFFER ); const nameLength = C.getValue(TRANSFER_BUFFER, 'i32'); stringValues[i] = C.UTF8ToString(valueAddress, nameLength); } const setProperties = new Array(patternCount); const assertedProperties = new Array(patternCount); const refutedProperties = new Array(patternCount); const predicates = new Array(patternCount); const textPredicates = new Array(patternCount); for (let i = 0; i < patternCount; i++) { const predicatesAddress = C._ts_query_predicates_for_pattern( address, i, TRANSFER_BUFFER ); const stepCount = C.getValue(TRANSFER_BUFFER, 'i32'); predicates[i] = []; textPredicates[i] = []; const steps = new Array(); const isStringStep = (step: PredicateStep): step is { type: 'string', value: string } => { return step.type === 'string'; } let stepAddress = predicatesAddress; for (let j = 0; j < stepCount; j++) { const stepType = C.getValue(stepAddress, 'i32'); stepAddress += SIZE_OF_INT; const stepValueId: number = C.getValue(stepAddress, 'i32'); stepAddress += SIZE_OF_INT; if (stepType === PREDICATE_STEP_TYPE_CAPTURE) { const name = captureNames[stepValueId]; steps.push({ type: 'capture', name }); } else if (stepType === PREDICATE_STEP_TYPE_STRING) { steps.push({ type: 'string', value: stringValues[stepValueId] }); } else if (steps.length > 0) { if (steps[0].type !== 'string') { throw new Error('Predicates must begin with a literal value'); } const operator = steps[0].value; let isPositive = true; let matchAll = true; let captureName: string | undefined; switch (operator) { case 'any-not-eq?': case 'not-eq?': isPositive = false; case 'any-eq?': case 'eq?': { if (steps.length !== 3) { throw new Error( `Wrong number of arguments to \`#${operator}\` predicate. Expected 2, got ${steps.length - 1}` ); } if (steps[1].type !== 'capture') { throw new Error( `First argument of \`#${operator}\` predicate must be a capture. Got "${steps[1].value}"` ); } matchAll = !operator.startsWith('any-'); if (steps[2].type === 'capture') { const captureName1 = steps[1].name; const captureName2 = steps[2].name; textPredicates[i].push((captures) => { const nodes1: Node[] = []; const nodes2: Node[] = []; for (const c of captures) { if (c.name === captureName1) nodes1.push(c.node); if (c.name === captureName2) nodes2.push(c.node); } const compare = (n1: { text: string }, n2: { text: string }, positive: boolean) => { return positive ? n1.text === n2.text : n1.text !== n2.text; }; return matchAll ? nodes1.every((n1) => nodes2.some((n2) => compare(n1, n2, isPositive))) : nodes1.some((n1) => nodes2.some((n2) => compare(n1, n2, isPositive))); }); } else { captureName = steps[1].name; const stringValue = steps[2].value; const matches = (n: Node) => n.text === stringValue; const doesNotMatch = (n: Node) => n.text !== stringValue; textPredicates[i].push((captures) => { const nodes = []; for (const c of captures) { if (c.name === captureName) nodes.push(c.node); } const test = isPositive ? matches : doesNotMatch; return matchAll ? nodes.every(test) : nodes.some(test); }); } break; } case 'any-not-match?': case 'not-match?': isPositive = false; case 'any-match?': case 'match?': { if (steps.length !== 3) { throw new Error( `Wrong number of arguments to \`#${operator}\` predicate. Expected 2, got ${steps.length - 1}.`, ); } if (steps[1].type !== 'capture') { throw new Error( `First argument of \`#${operator}\` predicate must be a capture. Got "${steps[1].value}".`, ); } if (steps[2].type !== 'string') { throw new Error( `Second argument of \`#${operator}\` predicate must be a string. Got @${steps[2].name}.`, ); } captureName = steps[1].name; const regex = new RegExp(steps[2].value); matchAll = !operator.startsWith('any-'); textPredicates[i].push((captures) => { const nodes = []; for (const c of captures) { if (c.name === captureName) nodes.push(c.node.text); } const test = (text: string, positive: boolean) => { return positive ? regex.test(text) : !regex.test(text); }; if (nodes.length === 0) return !isPositive; return matchAll ? nodes.every((text) => test(text, isPositive)) : nodes.some((text) => test(text, isPositive)); }); break; } case 'set!': { if (steps.length < 2 || steps.length > 3) { throw new Error( `Wrong number of arguments to \`#set!\` predicate. Expected 1 or 2. Got ${steps.length - 1}.`, ); } if (!steps.every(isStringStep)) { throw new Error( `Arguments to \`#set!\` predicate must be strings.".`, ); } if (!setProperties[i]) setProperties[i] = {}; setProperties[i][steps[1].value] = steps[2]?.value ?? null; break; } case 'is?': case 'is-not?': { if (steps.length < 2 || steps.length > 3) { throw new Error( `Wrong number of arguments to \`#${operator}\` predicate. Expected 1 or 2. Got ${steps.length - 1}.`, ); } if (!steps.every(isStringStep)) { throw new Error( `Arguments to \`#${operator}\` predicate must be strings.".`, ); } const properties = operator === 'is?' ? assertedProperties : refutedProperties; if (!properties[i]) properties[i] = {}; properties[i][steps[1].value] = steps[2]?.value ?? null; break; } case 'not-any-of?': isPositive = false; case 'any-of?': { if (steps.length < 2) { throw new Error( `Wrong number of arguments to \`#${operator}\` predicate. Expected at least 1. Got ${steps.length - 1}.`, ); } if (steps[1].type !== 'capture') { throw new Error( `First argument of \`#${operator}\` predicate must be a capture. Got "${steps[1].value}".`, ); } captureName = steps[1].name; const stringSteps = steps.slice(2); if (!stringSteps.every(isStringStep)) { throw new Error( `Arguments to \`#${operator}\` predicate must be strings.".`, ); } const values = stringSteps.map((s) => s.value); textPredicates[i].push((captures) => { const nodes = []; for (const c of captures) { if (c.name === captureName) nodes.push(c.node.text); } if (nodes.length === 0) return !isPositive; return nodes.every((text) => values.includes(text)) === isPositive; }); break; } default: predicates[i].push({ operator, operands: steps.slice(1) }); } steps.length = 0; } } Object.freeze(setProperties[i]); Object.freeze(assertedProperties[i]); Object.freeze(refutedProperties[i]); } C._free(sourceAddress); return new Query( INTERNAL, address, captureNames, captureQuantifiers, textPredicates, predicates, setProperties, assertedProperties, refutedProperties, ); } /** * Load a language from a WebAssembly module. * The module can be provided as a path to a file or as a buffer. */ static async load(input: string | Uint8Array): Promise { let bytes: Promise; if (input instanceof Uint8Array) { bytes = Promise.resolve(input); } else { // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition if (globalThis.process?.versions.node) { // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment, @typescript-eslint/no-require-imports const fs: typeof import('fs/promises') = require('fs/promises'); bytes = fs.readFile(input); } else { bytes = fetch(input) .then((response) => response.arrayBuffer() .then((buffer) => { if (response.ok) { return new Uint8Array(buffer); } else { const body = new TextDecoder('utf-8').decode(buffer); throw new Error(`Language.load failed with status ${response.status}.\n\n${body}`); } })); } } const mod = await C.loadWebAssemblyModule(await bytes, { loadAsync: true }); const symbolNames = Object.keys(mod); const functionName = symbolNames.find((key) => LANGUAGE_FUNCTION_REGEX.test(key) && !key.includes('external_scanner_')); if (!functionName) { console.log(`Couldn't find language function in WASM file. Symbols:\n${JSON.stringify(symbolNames, null, 2)}`); throw new Error('Language.load failed: no language function found in WASM file'); } const languageAddress = mod[functionName](); return new Language(INTERNAL, languageAddress); } }