From 7dc81303f6126d53d3b4fd276b740297145d0fc5 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 7 Sep 2022 14:40:11 -0700 Subject: [PATCH] Get wasm store working with parser with stateless external scanners --- cli/loader/src/lib.rs | 14 +- lib/src/parser.c | 138 +++++++++++++----- lib/src/wasm.c | 323 ++++++++++++++++++++++++++++++++---------- lib/src/wasm.h | 15 +- 4 files changed, 366 insertions(+), 124 deletions(-) diff --git a/cli/loader/src/lib.rs b/cli/loader/src/lib.rs index f89e437c..a0fb3249 100644 --- a/cli/loader/src/lib.rs +++ b/cli/loader/src/lib.rs @@ -359,14 +359,16 @@ impl Loader { } } + if self.wasm_store.lock().unwrap().is_some() { + library_path.set_extension("wasm"); + } else { + library_path.set_extension(DYLIB_EXTENSION); + } + let recompile = needs_recompile(&library_path, &parser_path, &scanner_path) .with_context(|| "Failed to compare source and binary timestamps")?; if let Some(wasm_store) = self.wasm_store.lock().unwrap().as_mut() { - library_path.set_extension("wasm"); - - eprintln!("library_path: {:?}", &library_path); - if recompile { self.compile_parser_to_wasm( name, @@ -380,11 +382,8 @@ impl Loader { } let wasm_bytes = fs::read(&library_path)?; - Ok(wasm_store.load_language(name, &wasm_bytes)) } else { - library_path.set_extension(DYLIB_EXTENSION); - if recompile { self.compile_parser_to_dylib( header_path, @@ -548,6 +547,7 @@ impl Loader { "-s", &format!("EXPORTED_FUNCTIONS=[\"_tree_sitter_{}\"]", language_name), "-fno-exceptions", + "-fvisibility=hidden", "-I", ".", ]); diff --git a/lib/src/parser.c b/lib/src/parser.c index d6b0a654..12605b37 100644 --- a/lib/src/parser.c +++ b/lib/src/parser.c @@ -334,18 +334,103 @@ static bool ts_parser__better_version_exists( return false; } -static void ts_parser__restore_external_scanner( +static void ts_parser__external_scanner_create( + TSParser *self +) { + if (self->language && self->language->external_scanner.states) { + if (ts_language_is_wasm(self->language)) { + self->external_scanner_payload = (void *)(uintptr_t)ts_wasm_store_call_scanner_create( + self->wasm_store + ); + } else if (self->language->external_scanner.create) { + self->external_scanner_payload = self->language->external_scanner.create(); + } + } +} + +static void ts_parser__external_scanner_destroy( + TSParser *self +) { + if (self->language && self->external_scanner_payload) { + if (ts_language_is_wasm(self->language)) { + ts_wasm_store_call_scanner_destroy( + self->wasm_store, + (uintptr_t)self->external_scanner_payload + ); + } else if (self->language->external_scanner.destroy) { + self->language->external_scanner.destroy( + self->external_scanner_payload + ); + } + self->external_scanner_payload = NULL; + } +} + +static unsigned ts_parser__external_scanner_serialize( + TSParser *self +) { + if (ts_language_is_wasm(self->language)) { + return ts_wasm_store_call_scanner_serialize( + self->wasm_store, + (uintptr_t)self->external_scanner_payload, + self->lexer.debug_buffer + ); + } else { + return self->language->external_scanner.serialize( + self->external_scanner_payload, + self->lexer.debug_buffer + ); + } +} + +static void ts_parser__external_scanner_deserialize( TSParser *self, Subtree external_token ) { + const char *data = NULL; + uint32_t length = 0; if (external_token.ptr) { - self->language->external_scanner.deserialize( - self->external_scanner_payload, - ts_external_scanner_state_data(&external_token.ptr->external_scanner_state), - external_token.ptr->external_scanner_state.length + data = ts_external_scanner_state_data(&external_token.ptr->external_scanner_state); + length = external_token.ptr->external_scanner_state.length; + } + + if (ts_language_is_wasm(self->language)) { + ts_wasm_store_call_scanner_deserialize( + self->wasm_store, + (uintptr_t)self->external_scanner_payload, + data, + length ); } else { - self->language->external_scanner.deserialize(self->external_scanner_payload, NULL, 0); + self->language->external_scanner.deserialize( + self->external_scanner_payload, + data, + length + ); + } +} + +static bool ts_parser__external_scanner_scan( + TSParser *self, + TSStateId external_lex_state +) { + + if (ts_language_is_wasm(self->language)) { + return ts_wasm_store_call_scanner_scan( + self->wasm_store, + (uintptr_t)self->external_scanner_payload, + external_lex_state * self->language->external_token_count + ); + } else { + const bool *valid_external_tokens = ts_language_enabled_external_tokens( + self->language, + external_lex_state + ); + return self->language->external_scanner.scan( + self->external_scanner_payload, + &self->lexer.data, + valid_external_tokens + ); } } @@ -397,10 +482,6 @@ static Subtree ts_parser__lex( const Length start_position = ts_stack_position(self->stack, version); const Subtree external_token = ts_stack_last_external_token(self->stack, version); - const bool *valid_external_tokens = ts_language_enabled_external_tokens( - self->language, - lex_mode.external_lex_state - ); bool found_external_token = false; bool error_mode = parse_state == ERROR_STATE; @@ -418,7 +499,7 @@ static Subtree ts_parser__lex( bool found_token = false; Length current_position = self->lexer.current_position; - if (valid_external_tokens) { + if (lex_mode.external_lex_state != 0) { LOG( "lex_external state:%d, row:%u, column:%u", lex_mode.external_lex_state, @@ -426,19 +507,12 @@ static Subtree ts_parser__lex( current_position.extent.column ); ts_lexer_start(&self->lexer); - ts_parser__restore_external_scanner(self, external_token); - found_token = self->language->external_scanner.scan( - self->external_scanner_payload, - &self->lexer.data, - valid_external_tokens - ); + ts_parser__external_scanner_deserialize(self, external_token); + found_token = ts_parser__external_scanner_scan(self, lex_mode.external_lex_state); ts_lexer_finish(&self->lexer, &lookahead_end_byte); if (found_token) { - external_scanner_state_len = self->language->external_scanner.serialize( - self->external_scanner_payload, - self->lexer.debug_buffer - ); + external_scanner_state_len = ts_parser__external_scanner_serialize(self); external_scanner_state_changed = !ts_external_scanner_state_eq( ts_subtree_external_scanner_state(external_token), self->lexer.debug_buffer, @@ -487,7 +561,7 @@ static Subtree ts_parser__lex( ts_lexer_start(&self->lexer); found_token = false; if (ts_language_is_wasm(self->language)) { - found_token = ts_wasm_store_run_main_lex_function(self->wasm_store, lex_mode.lex_state); + found_token = ts_wasm_store_call_lex_main(self->wasm_store, lex_mode.lex_state); } else { found_token = self->language->lex_fn(&self->lexer.data, lex_mode.lex_state); } @@ -497,10 +571,6 @@ static Subtree ts_parser__lex( if (!error_mode) { error_mode = true; lex_mode = self->language->lex_modes[ERROR_STATE]; - valid_external_tokens = ts_language_enabled_external_tokens( - self->language, - lex_mode.external_lex_state - ); ts_lexer_reset(&self->lexer, start_position); continue; } @@ -553,7 +623,7 @@ static Subtree ts_parser__lex( ts_lexer_start(&self->lexer); if (ts_language_is_wasm(self->language)) { - is_keyword = ts_wasm_store_run_keyword_lex_function(self->wasm_store, 0); + is_keyword = ts_wasm_store_call_lex_keyword(self->wasm_store, 0); } else { is_keyword = self->language->keyword_lex_fn(&self->lexer.data, 0); } @@ -1799,18 +1869,10 @@ bool ts_parser_set_language(TSParser *self, const TSLanguage *language) { if (language->version > TREE_SITTER_LANGUAGE_VERSION) return false; if (language->version < TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION) return false; } - - if (self->external_scanner_payload && self->language->external_scanner.destroy) { - self->language->external_scanner.destroy(self->external_scanner_payload); - } - - if (language && language->external_scanner.create) { - self->external_scanner_payload = language->external_scanner.create(); - } else { - self->external_scanner_payload = NULL; - } - + ts_parser__external_scanner_destroy(self); self->language = language; + ts_wasm_store_start(self->wasm_store, &self->lexer.data, language); + ts_parser__external_scanner_create(self); ts_parser_reset(self); return true; } diff --git a/lib/src/wasm.c b/lib/src/wasm.c index 284f1272..f7019a15 100644 --- a/lib/src/wasm.c +++ b/lib/src/wasm.c @@ -1,5 +1,6 @@ #include #include +#include #include "tree_sitter/api.h" #include "./alloc.h" #include "./language.h" @@ -20,9 +21,14 @@ typedef struct { typedef struct { uint32_t language_id; wasmtime_instance_t instance; - uint32_t main_lex_fn_index; - uint32_t keyword_lex_fn_index; - uint32_t external_scan_index; + int32_t external_states_address; + int32_t lex_main_fn_index; + int32_t lex_keyword_fn_index; + int32_t scanner_create_fn_index; + int32_t scanner_destroy_fn_index; + int32_t scanner_serialize_fn_index; + int32_t scanner_deserialize_fn_index; + int32_t scanner_scan_fn_index; } LanguageWasmInstance; struct TSWasmStore { @@ -92,14 +98,26 @@ static volatile uint32_t NEXT_LANGUAGE_ID; static const uint32_t LEXER_ADDRESS = 32; static const uint32_t LEXER_END_ADDRESS = LEXER_ADDRESS + sizeof(LexerInWasmMemory); -static wasm_trap_t *advance_callback( +enum FunctionIx { + LEXER_ADVANCE_IX, + LEXER_MARK_END_IX, + LEXER_GET_COLUMN_IX, + LEXER_IS_AT_INCLUDED_RANGE_START_IX, + LEXER_EOF_IX, + ISWSPACE_IX, + ISWDIGIT_IX, + ISWALPHA_IX, + ISWALNUM_IX, +}; + +static wasm_trap_t *callback__lexer_advance( void *env, wasmtime_caller_t* caller, wasmtime_val_raw_t *args_and_results, - size_t num_args_and_results + size_t args_and_results_len ) { wasmtime_context_t *context = wasmtime_caller_context(caller); - assert(num_args_and_results == 2); + assert(args_and_results_len == 2); TSWasmStore *store = env; TSLexer *lexer = store->current_lexer; @@ -111,11 +129,11 @@ static wasm_trap_t *advance_callback( return NULL; } -static wasm_trap_t *mark_end_callback( +static wasm_trap_t *callback__lexer_mark_end( void *env, wasmtime_caller_t* caller, wasmtime_val_raw_t *args_and_results, - size_t num_args_and_results + size_t args_and_results_len ) { TSWasmStore *store = env; TSLexer *lexer = store->current_lexer; @@ -123,11 +141,11 @@ static wasm_trap_t *mark_end_callback( return NULL; } -static wasm_trap_t *get_column_callback( +static wasm_trap_t *callback__lexer_get_column( void *env, wasmtime_caller_t* caller, wasmtime_val_raw_t *args_and_results, - size_t num_args_and_results + size_t args_and_results_len ) { TSWasmStore *store = env; TSLexer *lexer = store->current_lexer; @@ -136,11 +154,11 @@ static wasm_trap_t *get_column_callback( return NULL; } -static wasm_trap_t *is_at_included_range_start_callback( +static wasm_trap_t *callback__lexer_is_at_included_range_start( void *env, wasmtime_caller_t* caller, wasmtime_val_raw_t *args_and_results, - size_t num_args_and_results + size_t args_and_results_len ) { TSWasmStore *store = env; TSLexer *lexer = store->current_lexer; @@ -149,11 +167,11 @@ static wasm_trap_t *is_at_included_range_start_callback( return NULL; } -static wasm_trap_t *eof_callback( +static wasm_trap_t *callback__lexer_eof( void *env, wasmtime_caller_t* caller, wasmtime_val_raw_t *args_and_results, - size_t num_args_and_results + size_t args_and_results_len ) { TSWasmStore *store = env; TSLexer *lexer = store->current_lexer; @@ -162,6 +180,23 @@ static wasm_trap_t *eof_callback( return NULL; } +#define DEFINE_CTYPE_CALLBACK(fn_name) \ + static wasm_trap_t *callback__##fn_name( \ + void *env, \ + wasmtime_caller_t* caller, \ + wasmtime_val_raw_t *args_and_results, \ + size_t args_and_results_len \ + ) { \ + int32_t result = fn_name(args_and_results[0].i32); \ + args_and_results[0].i32 = result; \ + return NULL; \ + } \ + +DEFINE_CTYPE_CALLBACK(iswspace); +DEFINE_CTYPE_CALLBACK(iswdigit); +DEFINE_CTYPE_CALLBACK(iswalpha); +DEFINE_CTYPE_CALLBACK(iswalnum); + typedef struct { wasmtime_func_unchecked_callback_t callback; wasm_functype_t *type; @@ -200,7 +235,6 @@ static void *copy_strings( } else { result[i] = string_data->contents + (uintptr_t)result[i]; } - printf(" string %u: %s\n", i, result[i]); } return result; } @@ -209,6 +243,17 @@ static bool name_eq(const wasm_name_t *name, const char *string) { return strncmp(string, name->data, name->size) == 0; } +static wasmtime_extern_t get_builtin_func_extern( + wasmtime_context_t *context, + wasmtime_table_t *table, + unsigned index +) { + wasmtime_val_t val; + bool exists = wasmtime_table_get(context, table, index, &val); + assert(exists); + return (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_FUNC, .of.func = val.of.funcref}; +} + TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { TSWasmStore *self = ts_malloc(sizeof(TSWasmStore)); wasmtime_store_t *store = wasmtime_store_new(engine, self, NULL); @@ -228,25 +273,29 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine) { LexerInWasmMemory lexer = { .lookahead = 0, .result_symbol = 0, - .advance = 0, - .mark_end = 1, - .get_column = 2, - .is_at_included_range_start = 3, - .eof = 4, + .advance = LEXER_ADVANCE_IX, + .mark_end = LEXER_MARK_END_IX, + .get_column = LEXER_GET_COLUMN_IX, + .is_at_included_range_start = LEXER_IS_AT_INCLUDED_RANGE_START_IX, + .eof = LEXER_EOF_IX, }; memcpy(&memory_data[LEXER_ADDRESS], &lexer, sizeof(lexer)); - // Define lexer functions. + // Define builtin functions. FunctionDefinition definitions[] = { - {advance_callback, wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, - {mark_end_callback, wasm_functype_new_1_0(wasm_valtype_new_i32())}, - {get_column_callback, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, - {is_at_included_range_start_callback, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, - {eof_callback, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, + [LEXER_ADVANCE_IX] = {callback__lexer_advance, wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, + [LEXER_MARK_END_IX] = {callback__lexer_mark_end, wasm_functype_new_1_0(wasm_valtype_new_i32())}, + [LEXER_GET_COLUMN_IX] = {callback__lexer_get_column, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, + [LEXER_IS_AT_INCLUDED_RANGE_START_IX] = {callback__lexer_is_at_included_range_start, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, + [LEXER_EOF_IX] = {callback__lexer_eof, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, + [ISWSPACE_IX] = {callback__iswspace, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, + [ISWDIGIT_IX] = {callback__iswdigit, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, + [ISWALPHA_IX] = {callback__iswalpha, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, + [ISWALNUM_IX] = {callback__iswalnum, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())}, }; unsigned definitions_len = array_len(definitions); - // Add lexer functions to the store's function table. + // Add builtin functions to the store's function table. wasmtime_table_t function_table; wasm_limits_t table_limits = {.min = definitions_len, .max = wasm_limits_max_default}; wasm_tabletype_t *table_type = wasm_tabletype_new(wasm_valtype_new(WASM_FUNCREF), &table_limits); @@ -326,24 +375,52 @@ static bool ts_wasm_store__instantiate( for (unsigned i = 0; i < import_types.size; i++) { const wasm_importtype_t *import_type = import_types.data[i]; const wasm_name_t *import_name = wasm_importtype_name(import_type); - - if (name_eq(import_name, "__memory_base")) { - imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = memory_base_global}; - } else if (name_eq(import_name, "__table_base")) { - imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = table_base_global}; - } else if (name_eq(import_name, "memory")) { - imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_MEMORY, .of.memory = self->memory}; - } else if (name_eq(import_name, "__indirect_function_table")) { - imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_TABLE, .of.table = self->function_table}; - } else { - printf("unexpected import '%.*s'\n", (int)import_name->size, import_name->data); + if (import_name->size == 0) { + printf("import with blank name\n"); return false; } + + switch (import_name->data[0]) { + case '_': + if (name_eq(import_name, "__memory_base")) { + imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = memory_base_global}; + } else if (name_eq(import_name, "__table_base")) { + imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = table_base_global}; + } else if (name_eq(import_name, "__indirect_function_table")) { + imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_TABLE, .of.table = self->function_table}; + } else { + break; + } + continue; + case 'm': + if (name_eq(import_name, "memory")) { + imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_MEMORY, .of.memory = self->memory}; + } else { + break; + } + continue; + case 'i': + if (name_eq(import_name, "iswspace")) { + imports[i] = get_builtin_func_extern(context, &self->function_table, ISWSPACE_IX); + } else if (name_eq(import_name, "iswdigit")) { + imports[i] = get_builtin_func_extern(context, &self->function_table, ISWDIGIT_IX); + } else if (name_eq(import_name, "iswalpha")) { + imports[i] = get_builtin_func_extern(context, &self->function_table, ISWALPHA_IX); + } else if (name_eq(import_name, "iswalnum")) { + imports[i] = get_builtin_func_extern(context, &self->function_table, ISWALNUM_IX); + } else { + break; + } + continue; + } + + printf("unexpected import '%.*s'\n", (int)import_name->size, import_name->data); + return false; } wasm_importtype_vec_delete(&import_types); wasmtime_instance_t instance; - error = wasmtime_instance_new(context, module, imports, 4, &instance, &trap); + error = wasmtime_instance_new(context, module, imports, array_len(imports), &instance, &trap); assert(!error); if (trap) { wasm_message_t message; @@ -469,7 +546,7 @@ const TSLanguage *ts_wasm_store_load_language( ), .parse_actions = copy( &memory[wasm_language.parse_actions], - 2800 * sizeof(TSParseActionEntry) // TODO - determine number of parse actions + 5655 * sizeof(TSParseActionEntry) // TODO - determine number of parse actions ), .symbol_names = copy_strings( memory, @@ -544,6 +621,14 @@ const TSLanguage *ts_wasm_store_load_language( ); } + if (language->external_token_count > 0) { + language->external_scanner.symbol_map = copy( + &memory[wasm_language.external_scanner.symbol_map], + wasm_language.external_token_count * sizeof(TSSymbol) + ); + language->external_scanner.states = (void *)(uintptr_t)wasm_language.external_scanner.states; + } + unsigned name_len = strlen(language_name); char *name = ts_malloc(name_len + 1); memcpy(name, language_name, name_len); @@ -569,26 +654,34 @@ const TSLanguage *ts_wasm_store_load_language( array_push(&self->language_instances, ((LanguageWasmInstance) { .language_id = language_module->language_id, .instance = instance, - .main_lex_fn_index = wasm_language.lex_fn, - .keyword_lex_fn_index = wasm_language.keyword_lex_fn, + .external_states_address = wasm_language.external_scanner.states, + .lex_main_fn_index = wasm_language.lex_fn, + .lex_keyword_fn_index = wasm_language.keyword_lex_fn, + .scanner_create_fn_index = wasm_language.external_scanner.create, + .scanner_destroy_fn_index = wasm_language.external_scanner.destroy, + .scanner_serialize_fn_index = wasm_language.external_scanner.serialize, + .scanner_deserialize_fn_index = wasm_language.external_scanner.deserialize, + .scanner_scan_fn_index = wasm_language.external_scanner.scan, })); return language; } -bool ts_wasm_store_start(TSWasmStore *self, TSLexer *lexer, const TSLanguage *language) { - if (!ts_language_is_wasm(language)) return false; +bool ts_wasm_store_add_language( + TSWasmStore *self, + const TSLanguage *language, + uint32_t *index +) { wasmtime_context_t *context = wasmtime_store_context(self->store); const LanguageWasmModule *language_module = (void *)language->keyword_lex_fn; // Search for the information about this store's instance of the language module. - uint32_t instance_index = 0; bool exists = false; array_search_sorted_by( &self->language_instances, .language_id, language_module->language_id, - &instance_index, + index, &exists ); @@ -610,14 +703,28 @@ bool ts_wasm_store_start(TSWasmStore *self, TSLexer *lexer, const TSLanguage *la LanguageInWasmMemory wasm_language; const uint8_t *memory = wasmtime_memory_data(context, &self->memory); memcpy(&wasm_language, &memory[language_address], sizeof(LanguageInWasmMemory)); - array_insert(&self->language_instances, instance_index, ((LanguageWasmInstance) { + array_insert(&self->language_instances, *index, ((LanguageWasmInstance) { .language_id = language_module->language_id, .instance = instance, - .main_lex_fn_index = wasm_language.lex_fn, - .keyword_lex_fn_index = wasm_language.keyword_lex_fn, + .external_states_address = wasm_language.external_scanner.states, + .lex_main_fn_index = wasm_language.lex_fn, + .lex_keyword_fn_index = wasm_language.keyword_lex_fn, + .scanner_create_fn_index = wasm_language.external_scanner.create, + .scanner_destroy_fn_index = wasm_language.external_scanner.destroy, + .scanner_serialize_fn_index = wasm_language.external_scanner.serialize, + .scanner_deserialize_fn_index = wasm_language.external_scanner.deserialize, + .scanner_scan_fn_index = wasm_language.external_scanner.scan, })); } + return true; +} + +bool ts_wasm_store_start(TSWasmStore *self, TSLexer *lexer, const TSLanguage *language) { + uint32_t instance_index; + if (!language) return false; + if (!ts_language_is_wasm(language)) return false; + if (!ts_wasm_store_add_language(self, language, &instance_index)) return false; self->current_lexer = lexer; self->current_instance = &self->language_instances.contents[instance_index]; return true; @@ -628,9 +735,25 @@ void ts_wasm_store_stop(TSWasmStore *self) { self->current_instance = NULL; } -bool ts_wasm_store_run_lex_function(TSWasmStore *self, TSStateId state, unsigned function_index) { +static void ts_wasm_store__call(TSWasmStore *self, int32_t function_index, wasmtime_val_raw_t *args_and_results) { wasmtime_context_t *context = wasmtime_store_context(self->store); + wasmtime_val_t value; + bool succeeded = wasmtime_table_get(context, &self->function_table, function_index, &value); + assert(succeeded); + assert(value.kind == WASMTIME_FUNCREF); + wasmtime_func_t func = value.of.funcref; + wasm_trap_t *trap = wasmtime_func_call_unchecked(context, &func, args_and_results); + if (trap) { + wasm_message_t message; + wasm_trap_message(trap, &message); + printf("error calling function index %u: %s\n", function_index, message.data); + abort(); + } +} + +static bool ts_wasm_store__call_lex_function(TSWasmStore *self, unsigned function_index, TSStateId state) { + wasmtime_context_t *context = wasmtime_store_context(self->store); uint8_t *memory_data = wasmtime_memory_data(context, &self->memory); memcpy( &memory_data[LEXER_ADDRESS], @@ -638,23 +761,71 @@ bool ts_wasm_store_run_lex_function(TSWasmStore *self, TSStateId state, unsigned sizeof(self->current_lexer->lookahead) ); - wasmtime_val_t lex_val; - bool succeeded = wasmtime_table_get(context, &self->function_table, function_index, &lex_val); - assert(succeeded); - assert(lex_val.kind == WASMTIME_FUNCREF); - wasmtime_func_t lex_func = lex_val.of.funcref; - wasmtime_val_raw_t args[2] = { {.i32 = LEXER_ADDRESS}, {.i32 = state}, }; - wasm_trap_t *trap = wasmtime_func_call_unchecked(context, &lex_func, args); - if (trap) { - wasm_message_t message; - wasm_trap_message(trap, &message); - printf("error calling lex function index %u: %s\n", function_index, message.data); - abort(); - } + ts_wasm_store__call(self, function_index, args); + bool result = args[0].i32; + + memcpy( + &self->current_lexer->lookahead, + &memory_data[LEXER_ADDRESS], + sizeof(self->current_lexer->lookahead) + sizeof(self->current_lexer->result_symbol) + ); + return result; +} + +bool ts_wasm_store_call_lex_main(TSWasmStore *self, TSStateId state) { + return ts_wasm_store__call_lex_function( + self, + self->current_instance->lex_main_fn_index, + state + ); +} + +bool ts_wasm_store_call_lex_keyword(TSWasmStore *self, TSStateId state) { + return ts_wasm_store__call_lex_function( + self, + self->current_instance->lex_keyword_fn_index, + state + ); +} + +uint32_t ts_wasm_store_call_scanner_create(TSWasmStore *self) { + wasmtime_val_raw_t args[1] = {{.i32 = 0}}; + ts_wasm_store__call(self, self->current_instance->scanner_create_fn_index, args); + return args[0].i32; +} + +void ts_wasm_store_call_scanner_destroy(TSWasmStore *self, uint32_t scanner_address) { + wasmtime_val_raw_t args[1] = {{.i32 = scanner_address}}; + ts_wasm_store__call(self, self->current_instance->scanner_destroy_fn_index, args); +} + +bool ts_wasm_store_call_scanner_scan( + TSWasmStore *self, + uint32_t scanner_address, + uint32_t valid_tokens_ix +) { + wasmtime_context_t *context = wasmtime_store_context(self->store); + uint8_t *memory_data = wasmtime_memory_data(context, &self->memory); + + memcpy( + &memory_data[LEXER_ADDRESS], + &self->current_lexer->lookahead, + sizeof(self->current_lexer->lookahead) + ); + + uint32_t valid_tokens_address = + self->current_instance->external_states_address + + (valid_tokens_ix * sizeof(bool)); + wasmtime_val_raw_t args[3] = { + {.i32 = scanner_address}, + {.i32 = LEXER_ADDRESS}, + {.i32 = valid_tokens_address} + }; + ts_wasm_store__call(self, self->current_instance->scanner_scan_fn_index, args); memcpy( &self->current_lexer->lookahead, @@ -664,20 +835,22 @@ bool ts_wasm_store_run_lex_function(TSWasmStore *self, TSStateId state, unsigned return args[0].i32; } -bool ts_wasm_store_run_main_lex_function(TSWasmStore *self, TSStateId state) { - return ts_wasm_store_run_lex_function( - self, - state, - self->current_instance->main_lex_fn_index - ); +uint32_t ts_wasm_store_call_scanner_serialize( + TSWasmStore *self, + uint32_t scanner_address, + char *buffer +) { + // TODO + return 0; } -bool ts_wasm_store_run_keyword_lex_function(TSWasmStore *self, TSStateId state) { - return ts_wasm_store_run_lex_function( - self, - state, - self->current_instance->keyword_lex_fn_index - ); +void ts_wasm_store_call_scanner_deserialize( + TSWasmStore *self, + uint32_t scanner_address, + const char *buffer, + unsigned length +) { + // TODO } bool ts_language_is_wasm(const TSLanguage *self) { diff --git a/lib/src/wasm.h b/lib/src/wasm.h index 47153c14..0e734e82 100644 --- a/lib/src/wasm.h +++ b/lib/src/wasm.h @@ -8,10 +8,17 @@ extern "C" { #include "tree_sitter/api.h" #include "tree_sitter/parser.h" -bool ts_wasm_store_start(TSWasmStore *self, TSLexer *lexer, const TSLanguage *language); -void ts_wasm_store_stop(TSWasmStore *self); -bool ts_wasm_store_run_main_lex_function(TSWasmStore *self, TSStateId state); -bool ts_wasm_store_run_keyword_lex_function(TSWasmStore *self, TSStateId state); +bool ts_wasm_store_start(TSWasmStore *, TSLexer *, const TSLanguage *); +void ts_wasm_store_stop(TSWasmStore *); + +bool ts_wasm_store_call_lex_main(TSWasmStore *, TSStateId); +bool ts_wasm_store_call_lex_keyword(TSWasmStore *, TSStateId); + +uint32_t ts_wasm_store_call_scanner_create(TSWasmStore *); +void ts_wasm_store_call_scanner_destroy(TSWasmStore *, uint32_t); +bool ts_wasm_store_call_scanner_scan(TSWasmStore *, uint32_t, uint32_t); +uint32_t ts_wasm_store_call_scanner_serialize(TSWasmStore *, uint32_t, char *); +void ts_wasm_store_call_scanner_deserialize(TSWasmStore *, uint32_t, const char *, unsigned); #ifdef __cplusplus }