diff --git a/cli/src/tests/wasm_language_test.rs b/cli/src/tests/wasm_language_test.rs index 210635df..61c468cd 100644 --- a/cli/src/tests/wasm_language_test.rs +++ b/cli/src/tests/wasm_language_test.rs @@ -115,6 +115,9 @@ fn test_load_multiple_wasm_languages() { .unwrap(); let mut query_cursor = QueryCursor::new(); + // First, parse with the store that originally loaded the languages. + // Then parse with a new parser and wasm store, so that the languages + // are added one-by-one, in between parses. for mut parser in [parser, parser2] { for _ in 0..2 { let query_rust = Query::new(&language_rust, "(const_item) @foo").unwrap(); diff --git a/lib/src/wasm_store.c b/lib/src/wasm_store.c index 2ba1a8e9..7137a7fb 100644 --- a/lib/src/wasm_store.c +++ b/lib/src/wasm_store.c @@ -99,7 +99,6 @@ struct TSWasmStore { BuiltinFunctionIndices builtin_fn_indices; wasmtime_global_t stack_pointer_global; wasm_globaltype_t *const_i32_type; - wasm_globaltype_t *var_i32_type; bool has_error; uint32_t lexer_address; uint32_t serialization_buffer_address; @@ -251,15 +250,6 @@ static bool wasm_dylink_info__parse( return wasmtime_trap_new("wasm module called abort", 24); } -static wasm_trap_t *callback__notify_memory_growth( - void *env, - wasmtime_caller_t* caller, - wasmtime_val_raw_t *args_and_results, - size_t args_and_results_len -) { - return NULL; -} - static wasm_trap_t *callback__debug_message( void *env, wasmtime_caller_t* caller, @@ -548,7 +538,7 @@ static bool ts_wasm_store__call_module_initializer( } TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { - TSWasmStore *self = ts_malloc(sizeof(TSWasmStore)); + TSWasmStore *self = ts_calloc(1, sizeof(TSWasmStore)); wasmtime_store_t *store = wasmtime_store_new(engine, self, NULL); wasmtime_context_t *context = wasmtime_store_context(store); wasmtime_error_t *error = NULL; @@ -556,68 +546,11 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { wasm_message_t message = WASM_EMPTY_VEC; wasm_exporttype_vec_t export_types = WASM_EMPTY_VEC; wasmtime_extern_t *imports = NULL; + wasmtime_module_t *stdlib_module = NULL; + wasm_memorytype_t *memory_type = NULL; + wasm_tabletype_t *table_type = NULL; - // Initialize store's memory - wasm_limits_t memory_limits = {.min = 4, .max = MAX_MEMORY_SIZE}; - wasm_memorytype_t *memory_type = wasm_memorytype_new(&memory_limits); - wasmtime_memory_t memory; - error = wasmtime_memory_new(context, memory_type, &memory); - if (error) { - wasmtime_error_message(error, &message); - wasm_error->kind = TSWasmErrorKindAllocate; - format( - &wasm_error->message, - "failed to allocate wasm memory: %.*s", - (int)message.size, message.data - ); - goto error; - } - wasm_memorytype_delete(memory_type); - - // Initialize store's function table - wasm_limits_t table_limits = {.min = 1, .max = wasm_limits_max_default}; - wasm_tabletype_t *table_type = wasm_tabletype_new(wasm_valtype_new(WASM_FUNCREF), &table_limits); - wasmtime_val_t initializer = {.kind = WASMTIME_FUNCREF}; - wasmtime_table_t function_table; - error = wasmtime_table_new(context, table_type, &initializer, &function_table); - if (error) { - wasmtime_error_message(error, &message); - wasm_error->kind = TSWasmErrorKindAllocate; - format( - &wasm_error->message, - "failed to allocate wasm table: %.*s", - (int)message.size, message.data - ); - goto error; - } - wasm_tabletype_delete(table_type); - - unsigned stdlib_symbols_len = array_len(STDLIB_SYMBOLS); - - // Define globals for the stack and heap start addresses. - wasm_globaltype_t *const_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_CONST); - wasm_globaltype_t *var_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_VAR); - - wasmtime_val_t stack_pointer_value = WASM_I32_VAL(0); - wasmtime_global_t stack_pointer_global; - error = wasmtime_global_new(context, var_i32_type, &stack_pointer_value, &stack_pointer_global); - assert(!error); - - *self = (TSWasmStore) { - .engine = engine, - .store = store, - .memory = memory, - .function_table = function_table, - .language_instances = array_new(), - .stdlib_fn_indices = ts_calloc(stdlib_symbols_len, sizeof(uint32_t)), - .stack_pointer_global = stack_pointer_global, - .current_memory_offset = 0, - .current_function_table_offset = 0, - .const_i32_type = const_i32_type, - .var_i32_type = var_i32_type, - }; - - // Define lexer callback functions. + // Define functions called by scanners via function pointers on the lexer. LexerInWasmMemory lexer = { .lookahead = 0, .result_symbol = 0, @@ -650,51 +583,52 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { }, }; - // Define builtin functions used by scanners. + // Define builtin functions that can be imported by scanners. + BuiltinFunctionIndices builtin_fn_indices; FunctionDefinition builtin_definitions[] = { { - &self->builtin_fn_indices.proc_exit, + &builtin_fn_indices.proc_exit, callback__abort, wasm_functype_new_1_0(wasm_valtype_new_i32()) }, { - &self->builtin_fn_indices.abort, + &builtin_fn_indices.abort, callback__abort, wasm_functype_new_0_0() }, { - &self->builtin_fn_indices.assert_fail, + &builtin_fn_indices.assert_fail, callback__abort, wasm_functype_new_4_0(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32()) }, { - &self->builtin_fn_indices.notify_memory_growth, - callback__notify_memory_growth, + &builtin_fn_indices.notify_memory_growth, + callback__noop, wasm_functype_new_1_0(wasm_valtype_new_i32()) }, { - &self->builtin_fn_indices.debug_message, + &builtin_fn_indices.debug_message, callback__debug_message, wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32()) }, { - &self->builtin_fn_indices.at_exit, + &builtin_fn_indices.at_exit, callback__noop, wasm_functype_new_3_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32()) }, { - &self->builtin_fn_indices.args_get, + &builtin_fn_indices.args_get, callback__noop, wasm_functype_new_2_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32()) }, { - &self->builtin_fn_indices.args_sizes_get, + &builtin_fn_indices.args_sizes_get, callback__noop, wasm_functype_new_2_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32()) }, }; - // Create wasm functions. + // Create all of the wasm functions. unsigned builtin_definitions_len = array_len(builtin_definitions); unsigned lexer_definitions_len = array_len(lexer_definitions); for (unsigned i = 0; i < builtin_definitions_len; i++) { @@ -712,7 +646,7 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { wasm_functype_delete(definition->type); } - wasmtime_module_t *stdlib_module; + // Compile the stdlib module. error = wasmtime_module_new(engine, STDLIB_WASM, STDLIB_WASM_LEN, &stdlib_module); if (error) { wasmtime_error_message(error, &message); @@ -725,10 +659,93 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { goto error; } - wasmtime_instance_t instance; + // Retrieve the stdlib module's imports. wasm_importtype_vec_t import_types = WASM_EMPTY_VEC; wasmtime_module_imports(stdlib_module, &import_types); + // Find the initial number of memory pages needed by the stdlib. + const wasm_memorytype_t *stdlib_memory_type; + for (unsigned i = 0; i < import_types.size; i++) { + wasm_importtype_t *import_type = import_types.data[i]; + const wasm_name_t *import_name = wasm_importtype_name(import_type); + if (name_eq(import_name, "memory")) { + const wasm_externtype_t *type = wasm_importtype_type(import_type); + stdlib_memory_type = wasm_externtype_as_memorytype_const(type); + } + } + if (!stdlib_memory_type) { + wasm_error->kind = TSWasmErrorKindCompile; + format( + &wasm_error->message, + "wasm stdlib is missing the 'memory' import" + ); + goto error; + } + + // Initialize store's memory + uint64_t initial_memory_pages = wasmtime_memorytype_minimum(stdlib_memory_type); + wasm_limits_t memory_limits = {.min = initial_memory_pages, .max = MAX_MEMORY_SIZE}; + memory_type = wasm_memorytype_new(&memory_limits); + wasmtime_memory_t memory; + error = wasmtime_memory_new(context, memory_type, &memory); + if (error) { + wasmtime_error_message(error, &message); + wasm_error->kind = TSWasmErrorKindAllocate; + format( + &wasm_error->message, + "failed to allocate wasm memory: %.*s", + (int)message.size, message.data + ); + goto error; + } + wasm_memorytype_delete(memory_type); + memory_type = NULL; + + // Initialize store's function table + wasm_limits_t table_limits = {.min = 1, .max = wasm_limits_max_default}; + table_type = wasm_tabletype_new(wasm_valtype_new(WASM_FUNCREF), &table_limits); + wasmtime_val_t initializer = {.kind = WASMTIME_FUNCREF}; + wasmtime_table_t function_table; + error = wasmtime_table_new(context, table_type, &initializer, &function_table); + if (error) { + wasmtime_error_message(error, &message); + wasm_error->kind = TSWasmErrorKindAllocate; + format( + &wasm_error->message, + "failed to allocate wasm table: %.*s", + (int)message.size, message.data + ); + goto error; + } + wasm_tabletype_delete(table_type); + table_type = NULL; + + unsigned stdlib_symbols_len = array_len(STDLIB_SYMBOLS); + + // Define globals for the stack and heap start addresses. + wasm_globaltype_t *const_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_CONST); + wasm_globaltype_t *var_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_VAR); + + wasmtime_val_t stack_pointer_value = WASM_I32_VAL(0); + wasmtime_global_t stack_pointer_global; + error = wasmtime_global_new(context, var_i32_type, &stack_pointer_value, &stack_pointer_global); + assert(!error); + + *self = (TSWasmStore) { + .engine = engine, + .store = store, + .memory = memory, + .function_table = function_table, + .language_instances = array_new(), + .stdlib_fn_indices = ts_calloc(stdlib_symbols_len, sizeof(uint32_t)), + .builtin_fn_indices = builtin_fn_indices, + .stack_pointer_global = stack_pointer_global, + .current_memory_offset = 0, + .current_function_table_offset = 0, + .const_i32_type = const_i32_type, + }; + + // Set up the imports for the stdlib module. imports = ts_calloc(import_types.size, sizeof(wasmtime_extern_t)); for (unsigned i = 0; i < import_types.size; i++) { wasm_importtype_t *type = import_types.data[i]; @@ -744,6 +761,8 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { } } + // Instantiate the stdlib module. + wasmtime_instance_t instance; error = wasmtime_instance_new(context, stdlib_module, imports, import_types.size, &instance, &trap); ts_free(imports); imports = NULL; @@ -769,11 +788,10 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { } wasm_importtype_vec_delete(&import_types); + // Process the stdlib module's exports. for (unsigned i = 0; i < stdlib_symbols_len; i++) { self->stdlib_fn_indices[i] = UINT32_MAX; } - - // Process the stdlib module's exports. wasmtime_module_exports(stdlib_module, &export_types); for (unsigned i = 0; i < export_types.size; i++) { wasm_exporttype_t *export_type = export_types.data[i]; @@ -842,11 +860,12 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { } wasm_exporttype_vec_delete(&export_types); + wasmtime_module_delete(stdlib_module); - // Add lexer callback functions to the function table. Replace the func indices in the lexer - // struct with the function table indices which serve as function pointer addresses. - uint32_t prev_size; - error = wasmtime_table_grow(context, &function_table, lexer_definitions_len, &initializer, &prev_size); + // Add all of the lexer callback functions to the function table. Store their function table + // indices on the in-memory lexer. + uint32_t table_index; + error = wasmtime_table_grow(context, &function_table, lexer_definitions_len, &initializer, &table_index); if (error) { wasmtime_error_message(error, &message); wasm_error->kind = TSWasmErrorKindAllocate; @@ -857,8 +876,6 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { ); goto error; } - - uint32_t table_index = prev_size; for (unsigned i = 0; i < lexer_definitions_len; i++) { FunctionDefinition *definition = &lexer_definitions[i]; wasmtime_func_t func = {function_table.store_id, *definition->storage_location}; @@ -870,17 +887,14 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { } self->current_function_table_offset = table_index; - self->lexer_address = 2 * MEMORY_PAGE_SIZE; + self->lexer_address = initial_memory_pages * MEMORY_PAGE_SIZE; self->serialization_buffer_address = self->lexer_address + sizeof(LexerInWasmMemory); self->current_memory_offset = self->serialization_buffer_address + TREE_SITTER_SERIALIZATION_BUFFER_SIZE; + // Grow the memory enough to hold the builtin lexer and serialization buffer. + uint32_t new_pages_needed = (self->current_memory_offset - self->lexer_address - 1) / MEMORY_PAGE_SIZE + 1; uint64_t prev_memory_size; - wasmtime_memory_grow( - context, - &memory, - (self->current_memory_offset - (2 * MEMORY_PAGE_SIZE) - 1) / MEMORY_PAGE_SIZE + 1, - &prev_memory_size - ); + wasmtime_memory_grow(context, &memory, new_pages_needed, &prev_memory_size); uint8_t *memory_data = wasmtime_memory_data(context, &memory); memcpy(&memory_data[self->lexer_address], &lexer, sizeof(lexer)); @@ -888,7 +902,11 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { error: ts_free(self); + if (stdlib_module) wasmtime_module_delete(stdlib_module); if (store) wasmtime_store_delete(store); + if (import_types.size) wasm_importtype_vec_delete(&import_types); + if (memory_type) wasm_memorytype_delete(memory_type); + if (table_type) wasm_tabletype_delete(table_type); if (trap) wasm_trap_delete(trap); if (error) wasmtime_error_delete(error); if (message.size) wasm_byte_vec_delete(&message); @@ -901,7 +919,6 @@ void ts_wasm_store_delete(TSWasmStore *self) { if (!self) return; ts_free(self->stdlib_fn_indices); wasm_globaltype_delete(self->const_i32_type); - wasm_globaltype_delete(self->var_i32_type); wasmtime_store_delete(self->store); wasm_engine_delete(self->engine); for (unsigned i = 0; i < self->language_instances.size; i++) {