From 13dd76e44403e1022271fb290d52d0c7177c7811 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 27 Nov 2023 17:46:37 -0800 Subject: [PATCH] Return an informative error on failing to construct a WasmStore --- cli/Cargo.toml | 1 + cli/loader/src/lib.rs | 2 +- cli/src/main.rs | 4 +- cli/src/tests/wasm_language_test.rs | 6 +- lib/binding_rust/bindings.rs | 22 ++- lib/binding_rust/wasm_language.rs | 50 ++++--- lib/include/tree_sitter/api.h | 21 ++- lib/src/wasm.c | 202 ++++++++++++++++++++-------- 8 files changed, 211 insertions(+), 97 deletions(-) diff --git a/cli/Cargo.toml b/cli/Cargo.toml index d5066668..0261fdcc 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -21,6 +21,7 @@ name = "benchmark" harness = false [features] +# default = ["wasm"] wasm = ["tree-sitter/wasm", "tree-sitter-loader/wasm"] [dependencies] diff --git a/cli/loader/src/lib.rs b/cli/loader/src/lib.rs index 5a40e524..9586c984 100644 --- a/cli/loader/src/lib.rs +++ b/cli/loader/src/lib.rs @@ -868,7 +868,7 @@ impl Loader { #[cfg(feature = "wasm")] pub fn use_wasm(&mut self, engine: tree_sitter::wasmtime::Engine) { - *self.wasm_store.lock().unwrap() = Some(tree_sitter::WasmStore::new(engine)) + *self.wasm_store.lock().unwrap() = Some(tree_sitter::WasmStore::new(engine).unwrap()) } pub fn get_scanner_path(&self, src_path: &Path) -> Option { diff --git a/cli/src/main.rs b/cli/src/main.rs index b6c4932e..d6b143d5 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -405,7 +405,7 @@ fn run() -> Result<()> { if matches.is_present("wasm") { let engine = tree_sitter::wasmtime::Engine::default(); parser - .set_wasm_store(tree_sitter::WasmStore::new(engine.clone())) + .set_wasm_store(tree_sitter::WasmStore::new(engine.clone()).unwrap()) .unwrap(); loader.use_wasm(engine); } @@ -502,7 +502,7 @@ fn run() -> Result<()> { if matches.is_present("wasm") { let engine = tree_sitter::wasmtime::Engine::default(); parser - .set_wasm_store(tree_sitter::WasmStore::new(engine.clone())) + .set_wasm_store(tree_sitter::WasmStore::new(engine.clone()).unwrap()) .unwrap(); loader.use_wasm(engine); } diff --git a/cli/src/tests/wasm_language_test.rs b/cli/src/tests/wasm_language_test.rs index 7cd3a7e8..161c7d3a 100644 --- a/cli/src/tests/wasm_language_test.rs +++ b/cli/src/tests/wasm_language_test.rs @@ -9,7 +9,7 @@ lazy_static! { #[test] fn test_load_wasm_language() { - let mut store = WasmStore::new(ENGINE.clone()); + let mut store = WasmStore::new(ENGINE.clone()).unwrap(); let mut parser = Parser::new(); let wasm_cpp = fs::read(&WASM_DIR.join(format!("tree-sitter-cpp.wasm"))).unwrap(); @@ -25,7 +25,7 @@ fn test_load_wasm_language() { let mut parser2 = Parser::new(); parser2 - .set_wasm_store(WasmStore::new(ENGINE.clone())) + .set_wasm_store(WasmStore::new(ENGINE.clone()).unwrap()) .unwrap(); for mut parser in [parser, parser2] { @@ -63,7 +63,7 @@ fn test_load_wasm_language() { #[test] fn test_load_wasm_errors() { - let mut store = WasmStore::new(ENGINE.clone()); + let mut store = WasmStore::new(ENGINE.clone()).unwrap(); let wasm = fs::read(&WASM_DIR.join(format!("tree-sitter-rust.wasm"))).unwrap(); let bad_wasm = &wasm[1..]; diff --git a/lib/binding_rust/bindings.rs b/lib/binding_rust/bindings.rs index ef3ba30b..b7c0f2ed 100644 --- a/lib/binding_rust/bindings.rs +++ b/lib/binding_rust/bindings.rs @@ -778,13 +778,24 @@ pub type TSWasmEngine = wasm_engine_t; pub struct TSWasmStore { _unused: [u8; 0], } -pub const TSWasmErrorParse: TSWasmError = 0; -pub const TSWasmErrorCompile: TSWasmError = 1; -pub const TSWasmErrorInstantiate: TSWasmError = 2; -pub type TSWasmError = ::std::os::raw::c_uint; +pub const TSWasmErrorKindNone: TSWasmErrorKind = 0; +pub const TSWasmErrorKindParse: TSWasmErrorKind = 1; +pub const TSWasmErrorKindCompile: TSWasmErrorKind = 2; +pub const TSWasmErrorKindInstantiate: TSWasmErrorKind = 3; +pub const TSWasmErrorKindAllocate: TSWasmErrorKind = 4; +pub type TSWasmErrorKind = ::std::os::raw::c_uint; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TSWasmError { + pub kind: TSWasmErrorKind, + pub message: *mut ::std::os::raw::c_char, +} extern "C" { #[doc = " Create a Wasm store."] - pub fn ts_wasm_store_new(engine: *mut TSWasmEngine) -> *mut TSWasmStore; + pub fn ts_wasm_store_new( + engine: *mut TSWasmEngine, + error: *mut TSWasmError, + ) -> *mut TSWasmStore; } extern "C" { #[doc = " Free the memory associated with the given Wasm store."] @@ -798,7 +809,6 @@ extern "C" { wasm: *const ::std::os::raw::c_char, wasm_len: u32, error: *mut TSWasmError, - message: *mut *mut ::std::os::raw::c_char, ) -> *const TSLanguage; } extern "C" { diff --git a/lib/binding_rust/wasm_language.rs b/lib/binding_rust/wasm_language.rs index f0cc4f81..db777f07 100644 --- a/lib/binding_rust/wasm_language.rs +++ b/lib/binding_rust/wasm_language.rs @@ -39,40 +39,35 @@ pub enum WasmErrorKind { } impl WasmStore { - pub fn new(engine: wasmtime::Engine) -> Self { - let engine = Box::new(wasm_engine_t { engine }); - WasmStore(unsafe { - ffi::ts_wasm_store_new(Box::leak(engine) as *mut wasm_engine_t as *mut _) - }) + pub fn new(engine: wasmtime::Engine) -> Result { + unsafe { + let mut error = MaybeUninit::::uninit(); + let engine = Box::new(wasm_engine_t { engine }); + let store = ffi::ts_wasm_store_new( + Box::leak(engine) as *mut wasm_engine_t as *mut _, + error.as_mut_ptr(), + ); + if store.is_null() { + Err(WasmError::new(error.assume_init())) + } else { + Ok(WasmStore(store)) + } + } } pub fn load_language(&mut self, name: &str, bytes: &[u8]) -> Result { let name = CString::new(name).unwrap(); unsafe { let mut error = MaybeUninit::::uninit(); - let mut message = MaybeUninit::<*mut c_char>::uninit(); let language = ffi::ts_wasm_store_load_language( self.0, name.as_ptr(), bytes.as_ptr() as *const c_char, bytes.len() as u32, error.as_mut_ptr(), - message.as_mut_ptr(), ); - if language.is_null() { - let error = error.assume_init(); - let message = message.assume_init(); - let message = CString::from_raw(message); - Err(WasmError { - kind: match error { - ffi::TSWasmErrorParse => WasmErrorKind::Parse, - ffi::TSWasmErrorCompile => WasmErrorKind::Compile, - ffi::TSWasmErrorInstantiate => WasmErrorKind::Instantiate, - _ => WasmErrorKind::Other, - }, - message: message.into_string().unwrap(), - }) + Err(WasmError::new(error.assume_init())) } else { Ok(Language(language)) } @@ -80,6 +75,21 @@ impl WasmStore { } } +impl WasmError { + unsafe fn new(error: ffi::TSWasmError) -> Self { + let message = CString::from_raw(error.message); + Self { + kind: match error.kind { + ffi::TSWasmErrorKindParse => WasmErrorKind::Parse, + ffi::TSWasmErrorKindCompile => WasmErrorKind::Compile, + ffi::TSWasmErrorKindInstantiate => WasmErrorKind::Instantiate, + _ => WasmErrorKind::Other, + }, + message: message.into_string().unwrap(), + } + } +} + impl Language { pub fn is_wasm(&self) -> bool { unsafe { ffi::ts_language_is_wasm(self.0) } diff --git a/lib/include/tree_sitter/api.h b/lib/include/tree_sitter/api.h index 3129d5c3..eeecf317 100644 --- a/lib/include/tree_sitter/api.h +++ b/lib/include/tree_sitter/api.h @@ -1150,15 +1150,25 @@ typedef struct wasm_engine_t TSWasmEngine; typedef struct TSWasmStore TSWasmStore; typedef enum { - TSWasmErrorParse, - TSWasmErrorCompile, - TSWasmErrorInstantiate, + TSWasmErrorKindNone = 0, + TSWasmErrorKindParse, + TSWasmErrorKindCompile, + TSWasmErrorKindInstantiate, + TSWasmErrorKindAllocate, +} TSWasmErrorKind; + +typedef struct { + TSWasmErrorKind kind; + char *message; } TSWasmError; /** * Create a Wasm store. */ -TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine); +TSWasmStore *ts_wasm_store_new( + TSWasmEngine *engine, + TSWasmError *error +); /** * Free the memory associated with the given Wasm store. @@ -1177,8 +1187,7 @@ const TSLanguage *ts_wasm_store_load_language( const char *name, const char *wasm, uint32_t wasm_len, - TSWasmError *error, - char **message + TSWasmError *error ); /** diff --git a/lib/src/wasm.c b/lib/src/wasm.c index 81581ec6..04e14dc9 100644 --- a/lib/src/wasm.c +++ b/lib/src/wasm.c @@ -270,13 +270,13 @@ static bool wasm_dylink_info__parse( * Native callbacks exposed to wasm modules *******************************************/ -static wasm_trap_t *callback__exit( + static wasm_trap_t *callback__exit( void *env, wasmtime_caller_t* caller, wasmtime_val_raw_t *args_and_results, size_t args_and_results_len ) { - printf("exit called"); + fprintf(stderr, "wasm module called exit"); abort(); } @@ -286,18 +286,7 @@ static wasm_trap_t *callback__at_exit( wasmtime_val_raw_t *args_and_results, size_t args_and_results_len ) { - printf("atexit called"); - abort(); -} - -static wasm_trap_t *callback__assert_fail( - void *env, - wasmtime_caller_t* caller, - wasmtime_val_raw_t *args_and_results, - size_t args_and_results_len -) { - printf("assert failed called"); - abort(); + return NULL; } static wasm_trap_t *callback__lexer_advance( @@ -462,13 +451,18 @@ static wasmtime_extern_t get_builtin_func_extern( return (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_FUNC, .of.func = val.of.funcref}; } +#define format(output, ...) \ + do { \ + size_t message_length = snprintf((char *)NULL, 0, __VA_ARGS__); \ + *output = ts_malloc(message_length + 1); \ + snprintf(*output, message_length + 1, __VA_ARGS__); \ + } while (0) + static bool ts_wasm_store__provide_builtin_import( TSWasmStore *self, const wasm_name_t *import_name, wasmtime_extern_t *import ) { - if (import_name->size == 0) return false; - wasmtime_error_t *error = NULL; wasmtime_context_t *context = wasmtime_store_context(self->store); @@ -519,19 +513,49 @@ static bool ts_wasm_store__provide_builtin_import( return true; } -TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { +static bool ts_wasm_store__call_module_initializer( + TSWasmStore *self, + const wasm_name_t *export_name, + wasmtime_extern_t *export, + wasm_trap_t **trap +) { + if ( + name_eq(export_name, "_initialize") || + name_eq(export_name, "__wasm_apply_data_relocs") + ) { + wasmtime_context_t *context = wasmtime_store_context(self->store); + wasmtime_func_t initialization_func = export->of.func; + wasmtime_error_t *error = wasmtime_func_call(context, &initialization_func, NULL, 0, NULL, 0, trap); + assert(!error); + return true; + } else { + return false; + } +} + +TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) { TSWasmStore *self = ts_malloc(sizeof(TSWasmStore)); wasmtime_store_t *store = wasmtime_store_new(engine, self, NULL); wasmtime_context_t *context = wasmtime_store_context(store); wasmtime_error_t *error = NULL; wasm_trap_t *trap = NULL; + wasm_message_t message = WASM_EMPTY_VEC; // Initialize store's memory wasm_limits_t memory_limits = {.min = 256, .max = 256}; wasm_memorytype_t *memory_type = wasm_memorytype_new(&memory_limits); wasmtime_memory_t memory; error = wasmtime_memory_new(context, memory_type, &memory); - assert(!error); + if (error) { + wasmtime_error_message(error, &message); + wasm_error->kind = TSWasmErrorKindAllocate; + format( + &wasm_error->message, + "failed to allocate wasm memory: %.*s", + (int)message.size, message.data + ); + goto error; + } wasm_memorytype_delete(memory_type); // Initialize lexer struct with function pointers in wasm memory. @@ -552,8 +576,8 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { [NULL_IX] = {NULL, NULL}, [PROC_EXIT_IX] = {callback__exit, wasm_functype_new_1_0(wasm_valtype_new_i32())}, [ABORT_IX] = {callback__exit, wasm_functype_new_0_0()}, + [ASSERT_FAIL_IX] = {callback__exit, wasm_functype_new_4_0(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())}, [AT_EXIT_IX] = {callback__at_exit, wasm_functype_new_3_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())}, - [ASSERT_FAIL_IX] = {callback__assert_fail, wasm_functype_new_4_0(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())}, [LEXER_ADVANCE_IX] = {callback__lexer_advance, wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, [LEXER_MARK_END_IX] = {callback__lexer_mark_end, wasm_functype_new_1_0(wasm_valtype_new_i32())}, [LEXER_GET_COLUMN_IX] = {callback__lexer_get_column, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, @@ -568,12 +592,31 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { wasm_tabletype_t *table_type = wasm_tabletype_new(wasm_valtype_new(WASM_FUNCREF), &table_limits); wasmtime_val_t initializer = {.kind = WASMTIME_FUNCREF}; error = wasmtime_table_new(context, table_type, &initializer, &function_table); - assert(!error); + if (error) { + wasmtime_error_message(error, &message); + wasm_error->kind = TSWasmErrorKindAllocate; + format( + &wasm_error->message, + "failed to allocate wasm table: %.*s", + (int)message.size, message.data + ); + goto error; + } wasm_tabletype_delete(table_type); uint32_t prev_size; error = wasmtime_table_grow(context, &function_table, definitions_len, &initializer, &prev_size); - assert(!error); + if (error) { + wasmtime_error_message(error, &message); + wasm_error->kind = TSWasmErrorKindAllocate; + format( + &wasm_error->message, + "failed to grow wasm table to initial size: %.*s", + (int)message.size, message.data + ); + goto error; + } + for (unsigned i = 1; i < definitions_len; i++) { FunctionDefinition *definition = &definitions[i]; wasmtime_func_t func; @@ -598,13 +641,23 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { WasmDylinkInfo dylink_info; if (!wasm_dylink_info__parse(STDLIB_WASM, STDLIB_WASM_LEN, &dylink_info)) { - printf("failed to parse wasm dylink info\n"); - abort(); + wasm_error->kind = TSWasmErrorKindParse; + format(&wasm_error->message, "failed to parse wasm stdlib"); + goto error; } wasmtime_module_t *stdlib_module; error = wasmtime_module_new(engine, STDLIB_WASM, STDLIB_WASM_LEN, &stdlib_module); - assert(!error); + if (error) { + wasmtime_error_message(error, &message); + wasm_error->kind = TSWasmErrorKindCompile; + format( + &wasm_error->message, + "failed to compile wasm stdlib: %.*s", + (int)message.size, message.data + ); + goto error; + } wasmtime_instance_t instance; wasm_importtype_vec_t import_types = WASM_EMPTY_VEC; @@ -616,18 +669,36 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { wasm_importtype_t *type = import_types.data[i]; const wasm_name_t *import_name = wasm_importtype_name(type); if (!ts_wasm_store__provide_builtin_import(self, import_name, &imports[i])) { - printf("unexpected import name: %.*s\n", (int)import_name->size, import_name->data); - abort(); + wasm_error->kind = TSWasmErrorKindInstantiate; + format( + &wasm_error->message, + "unexpected import in wasm stdlib: %.*s\n", + (int)import_name->size, import_name->data + ); + goto error; } } error = wasmtime_instance_new(context, stdlib_module, imports, import_types.size, &instance, &trap); - assert(!error); + if (error) { + wasmtime_error_message(error, &message); + wasm_error->kind = TSWasmErrorKindInstantiate; + format( + &wasm_error->message, + "failed to instantiate wasm stdlib module: %.*s", + (int)message.size, message.data + ); + goto error; + } if (trap) { - wasm_message_t message; wasm_trap_message(trap, &message); - printf("error compiling standard library: %.*s\n", (int)message.size, message.data); - abort(); + wasm_error->kind = TSWasmErrorKindInstantiate; + format( + &wasm_error->message, + "trapped when instantiating wasm stdlib module: %.*s", + (int)message.size, message.data + ); + goto error; } wasm_importtype_vec_delete(&import_types); @@ -651,25 +722,39 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { bool exists = wasmtime_instance_export_nth(context, &instance, i, &export_name, &name_len, &export); assert(exists); - bool store_index = false; if (export.kind == WASMTIME_EXTERN_FUNC) { + if (ts_wasm_store__call_module_initializer(self, name, &export, &trap)) { + if (trap) { + wasm_trap_message(trap, &message); + wasm_error->kind = TSWasmErrorKindInstantiate; + format( + &wasm_error->message, + "trap when calling stdlib relocation function: %.*s\n", + (int)message.size, message.data + ); + goto error; + } + continue; + } + for (unsigned j = 0; j < array_len(STDLIB_SYMBOLS); j++) { if (name_eq(name, STDLIB_SYMBOLS[j])) { self->fn_indices[j] = export.of.func.index; - store_index = true; break; } } - if (!store_index) { - printf(" unused export name: %.*s\n", (int)name->size, name->data); - } } } for (unsigned i = 0; i < STDLIB_SYMBOL_COUNT; i++) { if (self->fn_indices[i] == UINT16_MAX) { - printf("undefined stdlib import: %s\n", STDLIB_SYMBOLS[i]); - abort(); + wasm_error->kind = TSWasmErrorKindInstantiate; + format( + &wasm_error->message, + "missing exported symbol in wasm stdlib: %s", + STDLIB_SYMBOLS[i] + ); + goto error; } } @@ -678,7 +763,11 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { error: ts_free(self); - wasm_exporttype_vec_delete(&export_types); + if (store) wasmtime_store_delete(store); + if (trap) wasm_trap_delete(trap); + if (error) wasmtime_error_delete(error); + if (message.size) wasm_byte_vec_delete(&message); + if (export_types.size) wasm_exporttype_vec_delete(&export_types); return NULL; } @@ -692,13 +781,6 @@ void ts_wasm_store_delete(TSWasmStore *self) { ts_free(self); } -#define format(output, ...) \ - do { \ - size_t message_length = snprintf((char *)NULL, 0, __VA_ARGS__); \ - *output = ts_malloc(message_length + 1); \ - snprintf(*output, message_length + 1, __VA_ARGS__); \ - } while (0) - static bool ts_wasm_store__instantiate( TSWasmStore *self, wasmtime_module_t *module, @@ -806,11 +888,8 @@ static bool ts_wasm_store__instantiate( bool exists = wasmtime_instance_export_nth(context, &instance, i, &export_name, &name_len, &export); assert(exists); - // Update pointers to reflect memory and function table offsets. - if (name_eq(name, "__wasm_apply_data_relocs")) { - wasmtime_func_t apply_relocation_func = export.of.func; - error = wasmtime_func_call(context, &apply_relocation_func, NULL, 0, NULL, 0, &trap); - assert(!error); + // If the module exports an initialization or data-relocation function, call it. + if (ts_wasm_store__call_module_initializer(self, name, &export, &trap)) { if (trap) { wasm_trap_message(trap, &message); format( @@ -885,16 +964,16 @@ const TSLanguage *ts_wasm_store_load_language( const char *language_name, const char *wasm, uint32_t wasm_len, - TSWasmError *wasm_error, - char **error_message + TSWasmError *wasm_error ) { WasmDylinkInfo dylink_info; wasmtime_module_t *module = NULL; wasmtime_error_t *error = NULL; + wasm_error->kind = TSWasmErrorKindNone; if (!wasm_dylink_info__parse((const unsigned char *)wasm, wasm_len, &dylink_info)) { - *wasm_error = TSWasmErrorParse; - format(error_message, "failed to parse dylink section of wasm module"); + wasm_error->kind = TSWasmErrorKindParse; + format(&wasm_error->message, "failed to parse dylink section of wasm module"); goto error; } @@ -903,8 +982,8 @@ const TSLanguage *ts_wasm_store_load_language( if (error) { wasm_message_t message; wasmtime_error_message(error, &message); - *wasm_error = TSWasmErrorCompile; - format(error_message, "error compiling wasm module: %.*s", (int)message.size, message.data); + wasm_error->kind = TSWasmErrorKindCompile; + format(&wasm_error->message, "error compiling wasm module: %.*s", (int)message.size, message.data); wasm_byte_vec_delete(&message); goto error; } @@ -919,9 +998,9 @@ const TSLanguage *ts_wasm_store_load_language( &dylink_info, &instance, &language_address, - error_message + &wasm_error->message )) { - *wasm_error = TSWasmErrorInstantiate; + wasm_error->kind = TSWasmErrorKindInstantiate; goto error; } @@ -1210,7 +1289,12 @@ static void ts_wasm_store__call( if (trap) { wasm_message_t message; wasm_trap_message(trap, &message); - printf("error calling function index %u: %s\n", function_index, message.data); + fprintf( + stderr, + "trap when calling wasm lexing function %u: %.*s\n", + function_index, + (int)message.size, message.data + ); abort(); } }