diff --git a/lib/src/wasm.c b/lib/src/wasm.c index 69b3db6b..fd0cccca 100644 --- a/lib/src/wasm.c +++ b/lib/src/wasm.c @@ -57,6 +57,9 @@ const char *STDLIB_SYMBOLS[STDLIB_SYMBOL_COUNT] = { "towupper" }; +#define BUILTIN_SYMBOL_COUNT 10 +#define MAX_IMPORT_COUNT (BUILTIN_SYMBOL_COUNT + STDLIB_SYMBOL_COUNT) + // LanguageWasmModule - Additional data associated with a wasm-backed // `TSLanguage`. This data is read-only and does not reference a particular // wasm store, so it can be shared by all users of a `TSLanguage`. A pointer to @@ -169,13 +172,15 @@ typedef struct { static volatile uint32_t NEXT_LANGUAGE_ID; -static const uint32_t STACK_SIZE = 64 * 1024; -static const uint32_t HEAP_SIZE = 1024 * 1024; -static const uint32_t SERIALIZATION_BUFFER_ADDRESS = STACK_SIZE - TREE_SITTER_SERIALIZATION_BUFFER_SIZE; -static const uint32_t LEXER_ADDRESS = SERIALIZATION_BUFFER_ADDRESS - sizeof(LexerInWasmMemory); -static const uint32_t INITIAL_STACK_POINTER_ADDRESS = LEXER_ADDRESS; -static const uint32_t HEAP_START_ADDRESS = STACK_SIZE; -static const uint32_t DATA_START_ADDRESS = STACK_SIZE + HEAP_SIZE; +// Linear memory layout: +// [ <-- stack grows down | fixed data | heap grows up --> | per-language static data ] +#define STACK_SIZE (64 * 1024) +#define HEAP_SIZE (1024 * 1024) +#define SERIALIZATION_BUFFER_ADDRESS (STACK_SIZE - TREE_SITTER_SERIALIZATION_BUFFER_SIZE) +#define LEXER_ADDRESS (SERIALIZATION_BUFFER_ADDRESS - sizeof(LexerInWasmMemory)) +#define INITIAL_STACK_POINTER_ADDRESS (LEXER_ADDRESS) +#define HEAP_START_ADDRESS (STACK_SIZE) +#define DATA_START_ADDRESS (STACK_SIZE + HEAP_SIZE) enum FunctionIx { NULL_IX = 0, @@ -583,8 +588,10 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { wasmtime_instance_t instance; wasm_importtype_vec_t import_types = WASM_EMPTY_VEC; wasmtime_module_imports(stdlib_module, &import_types); - wasmtime_extern_t imports[import_types.size]; - for (unsigned i = 0; i < import_types.size; i++) { + if (import_types.size > MAX_IMPORT_COUNT) goto error; + + wasmtime_extern_t imports[MAX_IMPORT_COUNT]; + for (unsigned i = 0; i < import_types.size && i < MAX_IMPORT_COUNT; i++) { wasm_importtype_t *type = import_types.data[i]; const wasm_name_t *import_name = wasm_importtype_name(type); if (!ts_wasm_store__provide_builtin_import(self, import_name, &imports[i])) { @@ -594,13 +601,13 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { } error = wasmtime_instance_new(context, stdlib_module, imports, import_types.size, &instance, &trap); - if (error) { + assert(!error); + if (trap) { wasm_message_t message; - wasmtime_error_message(error, &message); + wasm_trap_message(trap, &message); printf("error compiling standard library: %.*s\n", (int)message.size, message.data); abort(); } - assert(!error); wasm_importtype_vec_delete(&import_types); self->current_memory_offset += dylink_info.memory_size; @@ -647,6 +654,11 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { wasm_exporttype_vec_delete(&export_types); return self; + +error: + ts_free(self); + wasm_exporttype_vec_delete(&export_types); + return NULL; } void ts_wasm_store_delete(TSWasmStore *self) { @@ -673,7 +685,7 @@ static bool ts_wasm_store__instantiate( // Construct the language function name as string. unsigned prefix_len = strlen("tree_sitter_"); size_t name_len = strlen(language_name); - char language_function_name[prefix_len + name_len + 1]; + char *language_function_name = ts_malloc(prefix_len + name_len + 1); memcpy(&language_function_name[0], "tree_sitter_", prefix_len); memcpy(&language_function_name[prefix_len], language_name, name_len); language_function_name[prefix_len + name_len] = '\0'; @@ -683,13 +695,14 @@ static bool ts_wasm_store__instantiate( // Build the imports list for the module. wasm_importtype_vec_t import_types = WASM_EMPTY_VEC; wasmtime_module_imports(module, &import_types); - wasmtime_extern_t imports[import_types.size]; + if (import_types.size > MAX_IMPORT_COUNT) goto error; + wasmtime_extern_t imports[MAX_IMPORT_COUNT]; for (unsigned i = 0; i < import_types.size; i++) { const wasm_importtype_t *import_type = import_types.data[i]; const wasm_name_t *import_name = wasm_importtype_name(import_type); if (import_name->size == 0) { - return false; + goto error; } if (ts_wasm_store__provide_builtin_import(self, import_name, &imports[i])) { @@ -708,26 +721,25 @@ static bool ts_wasm_store__instantiate( if (!defined_in_stdlib) { printf("unexpected import '%.*s'\n", (int)import_name->size, import_name->data); - return false; + goto error; } } - wasm_importtype_vec_delete(&import_types); - wasmtime_instance_t instance; - error = wasmtime_instance_new(context, module, imports, array_len(imports), &instance, &trap); + error = wasmtime_instance_new(context, module, imports, import_types.size, &instance, &trap); + wasm_importtype_vec_delete(&import_types); if (error) { wasm_message_t message; wasmtime_error_message(error, &message); printf("error instantiating wasm module: %s\n", message.data); - return false; + goto error; } assert(!error); if (trap) { wasm_message_t message; wasm_trap_message(trap, &message); printf("error instantiating wasm module: %s\n", message.data); - return false; + goto error; } // Process the module's exports. @@ -752,7 +764,7 @@ static bool ts_wasm_store__instantiate( wasm_message_t message; wasm_trap_message(trap, &message); printf("error calling relocation function: %s\n", message.data); - abort(); + goto error; } } @@ -765,7 +777,7 @@ static bool ts_wasm_store__instantiate( if (language_extern.kind != WASMTIME_EXTERN_FUNC) { printf("failed to find function %s\n", language_function_name); - return false; + goto error; } // Invoke the language function to get the static address of the language object. @@ -777,13 +789,17 @@ static bool ts_wasm_store__instantiate( wasm_message_t message; wasm_trap_message(trap, &message); printf("error calling language function: %s\n", message.data); - return false; + goto error; } assert(language_address_val.kind == WASMTIME_I32); *result = instance; *language_address = language_address_val.of.i32; return true; + +error: + ts_free(language_function_name); + return false; } static bool ts_wasm_store__sentinel_lex_fn(TSLexer *_lexer, TSStateId state) {