Merge pull request #3181 from tree-sitter/handle-wasm-oom

When loading languages via WASM, gracefully handle memory errors and leaks in external scanners
This commit is contained in:
Max Brunsfeld 2024-03-18 13:15:06 -07:00 committed by GitHub
commit 09b18fad5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1895 additions and 2463 deletions

2
.dockerignore Normal file
View file

@ -0,0 +1,2 @@
target
.git

10
Dockerfile Normal file
View file

@ -0,0 +1,10 @@
FROM rust:1.76-buster
WORKDIR /app
RUN apt-get update
RUN apt-get install -y nodejs
COPY . .
CMD cargo test --all-features

View file

@ -81,7 +81,6 @@ fn record_alloc(ptr: *mut c_void) {
}
fn record_dealloc(ptr: *mut c_void) {
assert!(!ptr.is_null(), "Zero pointer deallocation!");
RECORDER.with(|recorder| {
if recorder.enabled.load(SeqCst) {
recorder

View file

@ -29,7 +29,71 @@ fn test_wasm_stdlib_symbols() {
}
#[test]
fn test_load_wasm_language() {
fn test_load_wasm_ruby_language() {
allocations::record(|| {
let mut store = WasmStore::new(ENGINE.clone()).unwrap();
let mut parser = Parser::new();
let wasm = fs::read(WASM_DIR.join("tree-sitter-ruby.wasm")).unwrap();
let language = store.load_language("ruby", &wasm).unwrap();
parser.set_wasm_store(store).unwrap();
parser.set_language(&language).unwrap();
let tree = parser.parse("class A; end", None).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(program (class name: (constant)))"
);
});
}
#[test]
fn test_load_wasm_html_language() {
allocations::record(|| {
let mut store = WasmStore::new(ENGINE.clone()).unwrap();
let mut parser = Parser::new();
let wasm = fs::read(WASM_DIR.join("tree-sitter-html.wasm")).unwrap();
let language = store.load_language("html", &wasm).unwrap();
parser.set_wasm_store(store).unwrap();
parser.set_language(&language).unwrap();
let tree = parser
.parse("<div><span></span><p></p></div>", None)
.unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(document (element (start_tag (tag_name)) (element (start_tag (tag_name)) (end_tag (tag_name))) (element (start_tag (tag_name)) (end_tag (tag_name))) (end_tag (tag_name))))"
);
});
}
#[test]
fn test_load_wasm_rust_language() {
allocations::record(|| {
let mut store = WasmStore::new(ENGINE.clone()).unwrap();
let mut parser = Parser::new();
let wasm = fs::read(WASM_DIR.join("tree-sitter-rust.wasm")).unwrap();
let language = store.load_language("rust", &wasm).unwrap();
parser.set_wasm_store(store).unwrap();
parser.set_language(&language).unwrap();
let tree = parser.parse("fn main() {}", None).unwrap();
assert_eq!(tree.root_node().to_sexp(), "(source_file (function_item name: (identifier) parameters: (parameters) body: (block)))");
});
}
#[test]
fn test_load_wasm_javascript_language() {
allocations::record(|| {
let mut store = WasmStore::new(ENGINE.clone()).unwrap();
let mut parser = Parser::new();
let wasm = fs::read(WASM_DIR.join("tree-sitter-javascript.wasm")).unwrap();
let language = store.load_language("javascript", &wasm).unwrap();
parser.set_wasm_store(store).unwrap();
parser.set_language(&language).unwrap();
let tree = parser.parse("const a = b\nconst c = d", None).unwrap();
assert_eq!(tree.root_node().to_sexp(), "(program (lexical_declaration (variable_declarator name: (identifier) value: (identifier))) (lexical_declaration (variable_declarator name: (identifier) value: (identifier))))");
});
}
#[test]
fn test_load_multiple_wasm_languages() {
allocations::record(|| {
let mut store = WasmStore::new(ENGINE.clone()).unwrap();
let mut parser = Parser::new();
@ -51,6 +115,9 @@ fn test_load_wasm_language() {
.unwrap();
let mut query_cursor = QueryCursor::new();
// First, parse with the store that originally loaded the languages.
// Then parse with a new parser and wasm store, so that the languages
// are added one-by-one, in between parses.
for mut parser in [parser, parser2] {
for _ in 0..2 {
let query_rust = Query::new(&language_rust, "(const_item) @foo").unwrap();
@ -158,3 +225,28 @@ fn test_load_wasm_errors() {
);
});
}
#[test]
fn test_wasm_oom() {
allocations::record(|| {
let mut store = WasmStore::new(ENGINE.clone()).unwrap();
let mut parser = Parser::new();
let wasm = fs::read(WASM_DIR.join("tree-sitter-html.wasm")).unwrap();
let language = store.load_language("html", &wasm).unwrap();
parser.set_wasm_store(store).unwrap();
parser.set_language(&language).unwrap();
let tag_name = "a-b".repeat(2 * 1024 * 1024);
let code = format!("<{tag_name}>hello world</{tag_name}>");
assert!(parser.parse(&code, None).is_none());
let tag_name = "a-b".repeat(20);
let code = format!("<{tag_name}>hello world</{tag_name}>");
parser.set_language(&language).unwrap();
let tree = parser.parse(&code, None).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(document (element (start_tag (tag_name)) (text) (end_tag (tag_name))))"
);
});
}

View file

@ -67,6 +67,7 @@ pub fn compile_language_to_wasm(
"__cxa_atexit",
"abort",
"emscripten_notify_memory_growth",
"tree_sitter_debug_message",
"proc_exit",
];

View file

@ -4,18 +4,6 @@ use std::{env, fs};
fn main() {
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
println!("cargo:rerun-if-env-changed=TREE_SITTER_STATIC_ANALYSIS");
if env::var("TREE_SITTER_STATIC_ANALYSIS").is_ok() {
if let (Some(clang_path), Some(scan_build_path)) = (which("clang"), which("scan-build")) {
let clang_path = clang_path.to_str().unwrap();
let scan_build_path = scan_build_path.to_str().unwrap();
env::set_var(
"CC",
format!("{scan_build_path} -analyze-headers --use-analyzer={clang_path} cc",),
);
}
}
#[cfg(feature = "bindgen")]
generate_bindings(&out_dir);
@ -96,16 +84,3 @@ fn generate_bindings(out_dir: &Path) {
.write_to_file(&bindings_rs)
.unwrap_or_else(|_| panic!("Failed to write bindings into path: {bindings_rs:?}"));
}
fn which(exe_name: impl AsRef<Path>) -> Option<PathBuf> {
env::var_os("PATH").and_then(|paths| {
env::split_paths(&paths).find_map(|dir| {
let full_path = dir.join(&exe_name);
if full_path.is_file() {
Some(full_path)
} else {
None
}
})
})
}

View file

@ -110,6 +110,7 @@ struct TSParser {
Subtree old_tree;
TSRangeArray included_range_differences;
unsigned included_range_difference_index;
bool has_scanner_error;
};
typedef struct {
@ -337,6 +338,22 @@ static bool ts_parser__better_version_exists(
return false;
}
static bool ts_parser__call_main_lex_fn(TSParser *self, TSLexMode lex_mode) {
if (ts_language_is_wasm(self->language)) {
return ts_wasm_store_call_lex_main(self->wasm_store, lex_mode.lex_state);
} else {
return self->language->lex_fn(&self->lexer.data, lex_mode.lex_state);
}
}
static bool ts_parser__call_keyword_lex_fn(TSParser *self, TSLexMode lex_mode) {
if (ts_language_is_wasm(self->language)) {
return ts_wasm_store_call_lex_keyword(self->wasm_store, 0);
} else {
return self->language->keyword_lex_fn(&self->lexer.data, 0);
}
}
static void ts_parser__external_scanner_create(
TSParser *self
) {
@ -345,6 +362,9 @@ static void ts_parser__external_scanner_create(
self->external_scanner_payload = (void *)(uintptr_t)ts_wasm_store_call_scanner_create(
self->wasm_store
);
if (ts_wasm_store_has_error(self->wasm_store)) {
self->has_scanner_error = true;
}
} else if (self->language->external_scanner.create) {
self->external_scanner_payload = self->language->external_scanner.create();
}
@ -354,21 +374,17 @@ static void ts_parser__external_scanner_create(
static void ts_parser__external_scanner_destroy(
TSParser *self
) {
if (self->language && self->external_scanner_payload) {
if (ts_language_is_wasm(self->language)) {
if (self->wasm_store) {
ts_wasm_store_call_scanner_destroy(
self->wasm_store,
(uintptr_t)self->external_scanner_payload
);
}
} else if (self->language->external_scanner.destroy) {
self->language->external_scanner.destroy(
self->external_scanner_payload
);
}
self->external_scanner_payload = NULL;
if (
self->language &&
self->external_scanner_payload &&
self->language->external_scanner.destroy &&
!ts_language_is_wasm(self->language)
) {
self->language->external_scanner.destroy(
self->external_scanner_payload
);
}
self->external_scanner_payload = NULL;
}
static unsigned ts_parser__external_scanner_serialize(
@ -406,6 +422,9 @@ static void ts_parser__external_scanner_deserialize(
data,
length
);
if (ts_wasm_store_has_error(self->wasm_store)) {
self->has_scanner_error = true;
}
} else {
self->language->external_scanner.deserialize(
self->external_scanner_payload,
@ -419,13 +438,16 @@ static bool ts_parser__external_scanner_scan(
TSParser *self,
TSStateId external_lex_state
) {
if (ts_language_is_wasm(self->language)) {
return ts_wasm_store_call_scanner_scan(
bool result = ts_wasm_store_call_scanner_scan(
self->wasm_store,
(uintptr_t)self->external_scanner_payload,
external_lex_state * self->language->external_token_count
);
if (ts_wasm_store_has_error(self->wasm_store)) {
self->has_scanner_error = true;
}
return result;
} else {
const bool *valid_external_tokens = ts_language_enabled_external_tokens(
self->language,
@ -514,6 +536,7 @@ static Subtree ts_parser__lex(
ts_lexer_start(&self->lexer);
ts_parser__external_scanner_deserialize(self, external_token);
found_token = ts_parser__external_scanner_scan(self, lex_mode.external_lex_state);
if (self->has_scanner_error) return NULL_SUBTREE;
ts_lexer_finish(&self->lexer, &lookahead_end_byte);
if (found_token) {
@ -564,11 +587,7 @@ static Subtree ts_parser__lex(
current_position.extent.column
);
ts_lexer_start(&self->lexer);
if (ts_language_is_wasm(self->language)) {
found_token = ts_wasm_store_call_lex_main(self->wasm_store, lex_mode.lex_state);
} else {
found_token = self->language->lex_fn(&self->lexer.data, lex_mode.lex_state);
}
found_token = ts_parser__call_main_lex_fn(self, lex_mode);
ts_lexer_finish(&self->lexer, &lookahead_end_byte);
if (found_token) break;
@ -626,11 +645,7 @@ static Subtree ts_parser__lex(
ts_lexer_reset(&self->lexer, self->lexer.token_start_position);
ts_lexer_start(&self->lexer);
if (ts_language_is_wasm(self->language)) {
is_keyword = ts_wasm_store_call_lex_keyword(self->wasm_store, 0);
} else {
is_keyword = self->language->keyword_lex_fn(&self->lexer.data, 0);
}
is_keyword = ts_parser__call_keyword_lex_fn(self, lex_mode);
if (
is_keyword &&
@ -1527,6 +1542,7 @@ static bool ts_parser__advance(
if (needs_lex) {
needs_lex = false;
lookahead = ts_parser__lex(self, version, state);
if (self->has_scanner_error) return false;
if (lookahead.ptr) {
ts_parser__set_cached_token(self, position, last_external_token, lookahead);
@ -1811,6 +1827,7 @@ static unsigned ts_parser__condense_stack(TSParser *self) {
static bool ts_parser_has_outstanding_parse(TSParser *self) {
return (
self->external_scanner_payload ||
ts_stack_state(self->stack, 0) != 1 ||
ts_stack_node_count_since_error(self->stack, 0) != 0
);
@ -1830,6 +1847,9 @@ TSParser *ts_parser_new(void) {
self->dot_graph_file = NULL;
self->cancellation_flag = NULL;
self->timeout_duration = 0;
self->language = NULL;
self->has_scanner_error = false;
self->external_scanner_payload = NULL;
self->end_clock = clock_null();
self->operation_count = 0;
self->old_tree = NULL_SUBTREE;
@ -1870,7 +1890,7 @@ const TSLanguage *ts_parser_language(const TSParser *self) {
}
bool ts_parser_set_language(TSParser *self, const TSLanguage *language) {
ts_parser__external_scanner_destroy(self);
ts_parser_reset(self);
ts_language_delete(self->language);
self->language = NULL;
@ -1889,8 +1909,6 @@ bool ts_parser_set_language(TSParser *self, const TSLanguage *language) {
}
self->language = ts_language_copy(language);
ts_parser__external_scanner_create(self);
ts_parser_reset(self);
return true;
}
@ -1947,8 +1965,9 @@ const TSRange *ts_parser_included_ranges(const TSParser *self, uint32_t *count)
}
void ts_parser_reset(TSParser *self) {
if (self->language && self->language->external_scanner.deserialize) {
self->language->external_scanner.deserialize(self->external_scanner_payload, NULL, 0);
ts_parser__external_scanner_destroy(self);
if (self->wasm_store) {
ts_wasm_store_reset(self->wasm_store);
}
if (self->old_tree.ptr) {
@ -1965,6 +1984,7 @@ void ts_parser_reset(TSParser *self) {
self->finished_tree = NULL_SUBTREE;
}
self->accept_count = 0;
self->has_scanner_error = false;
}
TSTree *ts_parser_parse(
@ -1972,41 +1992,43 @@ TSTree *ts_parser_parse(
const TSTree *old_tree,
TSInput input
) {
TSTree *result = NULL;
if (!self->language || !input.read) return NULL;
if (ts_language_is_wasm(self->language)) {
if (self->wasm_store) {
ts_wasm_store_start(self->wasm_store, &self->lexer.data, self->language);
} else {
return NULL;
}
if (!self->wasm_store) return NULL;
ts_wasm_store_start(self->wasm_store, &self->lexer.data, self->language);
}
ts_lexer_set_input(&self->lexer, input);
array_clear(&self->included_range_differences);
self->included_range_difference_index = 0;
if (ts_parser_has_outstanding_parse(self)) {
LOG("resume_parsing");
} else if (old_tree) {
ts_subtree_retain(old_tree->root);
self->old_tree = old_tree->root;
ts_range_array_get_changed_ranges(
old_tree->included_ranges, old_tree->included_range_count,
self->lexer.included_ranges, self->lexer.included_range_count,
&self->included_range_differences
);
reusable_node_reset(&self->reusable_node, old_tree->root);
LOG("parse_after_edit");
LOG_TREE(self->old_tree);
for (unsigned i = 0; i < self->included_range_differences.size; i++) {
TSRange *range = &self->included_range_differences.contents[i];
LOG("different_included_range %u - %u", range->start_byte, range->end_byte);
}
} else {
reusable_node_clear(&self->reusable_node);
LOG("new_parse");
ts_parser__external_scanner_create(self);
if (self->has_scanner_error) goto exit;
if (old_tree) {
ts_subtree_retain(old_tree->root);
self->old_tree = old_tree->root;
ts_range_array_get_changed_ranges(
old_tree->included_ranges, old_tree->included_range_count,
self->lexer.included_ranges, self->lexer.included_range_count,
&self->included_range_differences
);
reusable_node_reset(&self->reusable_node, old_tree->root);
LOG("parse_after_edit");
LOG_TREE(self->old_tree);
for (unsigned i = 0; i < self->included_range_differences.size; i++) {
TSRange *range = &self->included_range_differences.contents[i];
LOG("different_included_range %u - %u", range->start_byte, range->end_byte);
}
} else {
reusable_node_clear(&self->reusable_node);
LOG("new_parse");
}
}
self->operation_count = 0;
@ -2035,7 +2057,11 @@ TSTree *ts_parser_parse(
ts_stack_position(self->stack, version).extent.column
);
if (!ts_parser__advance(self, version, allow_node_reuse)) return NULL;
if (!ts_parser__advance(self, version, allow_node_reuse)) {
if (self->has_scanner_error) goto exit;
return NULL;
}
LOG_STACK();
position = ts_stack_position(self->stack, version).bytes;
@ -2074,13 +2100,15 @@ TSTree *ts_parser_parse(
LOG("done");
LOG_TREE(self->finished_tree);
TSTree *result = ts_tree_new(
result = ts_tree_new(
self->finished_tree,
self->language,
self->lexer.included_ranges,
self->lexer.included_range_count
);
self->finished_tree = NULL_SUBTREE;
exit:
ts_parser_reset(self);
return result;
}

109
lib/src/wasm/stdlib.c Normal file
View file

@ -0,0 +1,109 @@
// This file implements a very simple allocator for external scanners running
// in WASM. Allocation is just bumping a static pointer and growing the heap
// as needed, and freeing is mostly a noop. But in the special case of freeing
// the last-allocated pointer, we'll reuse that pointer again.
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
extern void tree_sitter_debug_message(const char *, size_t);
#define PAGESIZE 0x10000
#define MAX_HEAP_SIZE (4 * 1024 * 1024)
typedef struct {
size_t size;
char data[0];
} Region;
static Region *heap_end = NULL;
static Region *heap_start = NULL;
static Region *next = NULL;
// Get the region metadata for the given heap pointer.
static inline Region *region_for_ptr(void *ptr) {
return ((Region *)ptr) - 1;
}
// Get the location of the next region after the given region,
// if the given region had the given size.
static inline Region *region_after(Region *self, size_t len) {
char *address = self->data + len;
char *aligned = (char *)((uintptr_t)(address + 3) & ~0x3);
return (Region *)aligned;
}
static void *get_heap_end() {
return (void *)(__builtin_wasm_memory_size(0) * PAGESIZE);
}
static int grow_heap(size_t size) {
size_t new_page_count = ((size - 1) / PAGESIZE) + 1;
return __builtin_wasm_memory_grow(0, new_page_count) != SIZE_MAX;
}
// Clear out the heap, and move it to the given address.
void reset_heap(void *new_heap_start) {
heap_start = new_heap_start;
next = new_heap_start;
heap_end = get_heap_end();
}
void *malloc(size_t size) {
Region *region_end = region_after(next, size);
if (region_end > heap_end) {
if ((char *)region_end - (char *)heap_start > MAX_HEAP_SIZE) {
return NULL;
}
if (!grow_heap(size)) return NULL;
heap_end = get_heap_end();
}
void *result = &next->data;
next->size = size;
next = region_end;
return result;
}
void free(void *ptr) {
if (ptr == NULL) return;
Region *region = region_for_ptr(ptr);
Region *region_end = region_after(region, region->size);
// When freeing the last allocated pointer, re-use that
// pointer for the next allocation.
if (region_end == next) {
next = region;
}
}
void *calloc(size_t count, size_t size) {
void *result = malloc(count * size);
memset(result, 0, count * size);
return result;
}
void *realloc(void *ptr, size_t new_size) {
if (ptr == NULL) {
return malloc(new_size);
}
Region *region = region_for_ptr(ptr);
Region *region_end = region_after(region, region->size);
// When reallocating the last allocated region, return
// the same pointer, and skip copying the data.
if (region_end == next) {
next = region;
return malloc(new_size);
}
void *result = malloc(new_size);
memcpy(result, &region->data, region->size);
return result;
}

File diff suppressed because it is too large Load diff

View file

@ -15,14 +15,14 @@
#include "./wasm_store.h"
#include "./wasm/wasm-stdlib.h"
#define array_len(a) (sizeof(a) / sizeof(a[0]))
// The following symbols from the C and C++ standard libraries are available
// for external scanners to use.
const char *STDLIB_SYMBOLS[] = {
#include "./stdlib-symbols.txt"
};
#define STDLIB_SYMBOL_COUNT (sizeof(STDLIB_SYMBOLS) / sizeof(STDLIB_SYMBOLS[0]))
// The contents of the `dylink.0` custom section of a wasm module,
// as specified by the current WebAssembly dynamic linking ABI proposal.
typedef struct {
@ -70,6 +70,18 @@ typedef struct {
int32_t scanner_scan_fn_index;
} LanguageWasmInstance;
typedef struct {
uint32_t reset_heap;
uint32_t proc_exit;
uint32_t abort;
uint32_t assert_fail;
uint32_t notify_memory_growth;
uint32_t debug_message;
uint32_t at_exit;
uint32_t args_get;
uint32_t args_sizes_get;
} BuiltinFunctionIndices;
// TSWasmStore - A struct that allows a given `Parser` to use wasm-backed
// languages. This struct is mutable, and can only be used by one parser at a
// time.
@ -82,11 +94,14 @@ struct TSWasmStore {
LanguageWasmInstance *current_instance;
Array(LanguageWasmInstance) language_instances;
uint32_t current_memory_offset;
uint32_t current_memory_size;
uint32_t current_function_table_offset;
uint16_t *fn_indices;
uint32_t *stdlib_fn_indices;
BuiltinFunctionIndices builtin_fn_indices;
wasmtime_global_t stack_pointer_global;
wasm_globaltype_t *const_i32_type;
wasm_globaltype_t *var_i32_type;
bool has_error;
uint32_t lexer_address;
uint32_t serialization_buffer_address;
};
typedef Array(char) StringData;
@ -147,29 +162,8 @@ typedef struct {
static volatile uint32_t NEXT_LANGUAGE_ID;
// Linear memory layout:
// [ <-- stack | built-in data | heap --> | static data ]
#define STACK_SIZE (64 * 1024)
#define HEAP_SIZE (1024 * 1024)
#define INITIAL_MEMORY_SIZE (4 * 1024 * 1024 / MEMORY_PAGE_SIZE)
#define MAX_MEMORY_SIZE 32768
#define SERIALIZATION_BUFFER_ADDRESS (STACK_SIZE)
#define LEXER_ADDRESS (SERIALIZATION_BUFFER_ADDRESS + TREE_SITTER_SERIALIZATION_BUFFER_SIZE)
#define HEAP_START_ADDRESS (LEXER_ADDRESS + sizeof(LexerInWasmMemory))
#define DATA_START_ADDRESS (HEAP_START_ADDRESS + HEAP_SIZE)
enum FunctionIx {
NULL_IX = 0,
PROC_EXIT_IX,
ABORT_IX,
ASSERT_FAIL_IX,
NOTIFY_MEMORY_GROWTH_IX,
AT_EXIT_IX,
LEXER_ADVANCE_IX,
LEXER_MARK_END_IX,
LEXER_GET_COLUMN_IX,
LEXER_IS_AT_INCLUDED_RANGE_START_IX,
LEXER_EOF_IX,
};
// [ <-- stack | stdlib statics | lexer | serialization_buffer | language statics --> | heap --> ]
#define MAX_MEMORY_SIZE (128 * 1024 * 1024 / MEMORY_PAGE_SIZE)
/************************
* WasmDylinkMemoryInfo
@ -247,27 +241,32 @@ static bool wasm_dylink_info__parse(
* Native callbacks exposed to wasm modules
*******************************************/
static wasm_trap_t *callback__exit(
static wasm_trap_t *callback__abort(
void *env,
wasmtime_caller_t* caller,
wasmtime_val_raw_t *args_and_results,
size_t args_and_results_len
) {
fprintf(stderr, "wasm module called exit");
abort();
return wasmtime_trap_new("wasm module called abort", 24);
}
static wasm_trap_t *callback__notify_memory_growth(
static wasm_trap_t *callback__debug_message(
void *env,
wasmtime_caller_t* caller,
wasmtime_val_raw_t *args_and_results,
size_t args_and_results_len
) {
fprintf(stderr, "wasm module called exit");
abort();
wasmtime_context_t *context = wasmtime_caller_context(caller);
TSWasmStore *store = env;
assert(args_and_results_len == 2);
uint32_t string_address = args_and_results[0].i32;
uint32_t value = args_and_results[1].i32;
uint8_t *memory = wasmtime_memory_data(context, &store->memory);
printf("DEBUG: %s %u\n", &memory[string_address], value);
return NULL;
}
static wasm_trap_t *callback__at_exit(
static wasm_trap_t *callback__noop(
void *env,
wasmtime_caller_t* caller,
wasmtime_val_raw_t *args_and_results,
@ -291,7 +290,7 @@ static wasm_trap_t *callback__lexer_advance(
lexer->advance(lexer, skip);
uint8_t *memory = wasmtime_memory_data(context, &store->memory);
memcpy(&memory[LEXER_ADDRESS], &lexer->lookahead, sizeof(lexer->lookahead));
memcpy(&memory[store->lexer_address], &lexer->lookahead, sizeof(lexer->lookahead));
return NULL;
}
@ -347,12 +346,11 @@ static wasm_trap_t *callback__lexer_eof(
}
typedef struct {
uint32_t *storage_location;
wasmtime_func_unchecked_callback_t callback;
wasm_functype_t *type;
} FunctionDefinition;
#define array_len(a) (sizeof(a) / sizeof(a[0]))
static void *copy(const void *data, size_t size) {
void *result = ts_malloc(size);
memcpy(result, data, size);
@ -427,17 +425,6 @@ static inline wasm_functype_t* wasm_functype_new_4_0(
return wasm_functype_new(&params, &results);
}
static wasmtime_extern_t get_builtin_func_extern(
wasmtime_context_t *context,
wasmtime_table_t *table,
unsigned index
) {
wasmtime_val_t val;
bool exists = wasmtime_table_get(context, table, index, &val);
assert(exists);
return (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_FUNC, .of.func = val.of.funcref};
}
#define format(output, ...) \
do { \
size_t message_length = snprintf((char *)NULL, 0, __VA_ARGS__); \
@ -463,6 +450,19 @@ void language_id_delete(WasmLanguageId *self) {
}
}
static wasmtime_extern_t get_builtin_extern(
wasmtime_table_t *table,
unsigned index
) {
return (wasmtime_extern_t) {
.kind = WASMTIME_EXTERN_FUNC,
.of.func = (wasmtime_func_t) {
.store_id = table->store_id,
.index = index
}
};
}
static bool ts_wasm_store__provide_builtin_import(
TSWasmStore *self,
const wasm_name_t *import_name,
@ -484,18 +484,8 @@ static bool ts_wasm_store__provide_builtin_import(
error = wasmtime_global_new(context, self->const_i32_type, &value, &global);
assert(!error);
*import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = global};
} else if (name_eq(import_name, "__heap_base")) {
wasmtime_val_t value = WASM_I32_VAL(HEAP_START_ADDRESS);
wasmtime_global_t global;
error = wasmtime_global_new(context, self->var_i32_type, &value, &global);
assert(!error);
*import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = global};
} else if (name_eq(import_name, "__stack_pointer")) {
wasmtime_val_t value = WASM_I32_VAL(STACK_SIZE);
wasmtime_global_t global;
error = wasmtime_global_new(context, self->var_i32_type, &value, &global);
assert(!error);
*import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = global};
*import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_GLOBAL, .of.global = self->stack_pointer_global};
} else if (name_eq(import_name, "__indirect_function_table")) {
*import = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_TABLE, .of.table = self->function_table};
} else if (name_eq(import_name, "memory")) {
@ -504,15 +494,21 @@ static bool ts_wasm_store__provide_builtin_import(
// Builtin functions
else if (name_eq(import_name, "__assert_fail")) {
*import = get_builtin_func_extern(context, &self->function_table, ASSERT_FAIL_IX);
*import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.assert_fail);
} else if (name_eq(import_name, "__cxa_atexit")) {
*import = get_builtin_func_extern(context, &self->function_table, AT_EXIT_IX);
*import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.at_exit);
} else if (name_eq(import_name, "args_get")) {
*import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.args_get);
} else if (name_eq(import_name, "args_sizes_get")) {
*import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.args_sizes_get);
} else if (name_eq(import_name, "abort")) {
*import = get_builtin_func_extern(context, &self->function_table, ABORT_IX);
*import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.abort);
} else if (name_eq(import_name, "proc_exit")) {
*import = get_builtin_func_extern(context, &self->function_table, PROC_EXIT_IX);
*import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.proc_exit);
} else if (name_eq(import_name, "emscripten_notify_memory_growth")) {
*import = get_builtin_func_extern(context, &self->function_table, NOTIFY_MEMORY_GROWTH_IX);
*import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.notify_memory_growth);
} else if (name_eq(import_name, "tree_sitter_debug_message")) {
*import = get_builtin_extern(&self->function_table, self->builtin_fn_indices.debug_message);
} else {
return false;
}
@ -528,7 +524,8 @@ static bool ts_wasm_store__call_module_initializer(
) {
if (
name_eq(export_name, "_initialize") ||
name_eq(export_name, "__wasm_apply_data_relocs")
name_eq(export_name, "__wasm_apply_data_relocs") ||
name_eq(export_name, "__wasm_call_ctors")
) {
wasmtime_context_t *context = wasmtime_store_context(self->store);
wasmtime_func_t initialization_func = export->of.func;
@ -541,7 +538,7 @@ static bool ts_wasm_store__call_module_initializer(
}
TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
TSWasmStore *self = ts_malloc(sizeof(TSWasmStore));
TSWasmStore *self = ts_calloc(1, sizeof(TSWasmStore));
wasmtime_store_t *store = wasmtime_store_new(engine, self, NULL);
wasmtime_context_t *context = wasmtime_store_context(store);
wasmtime_error_t *error = NULL;
@ -549,10 +546,146 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
wasm_message_t message = WASM_EMPTY_VEC;
wasm_exporttype_vec_t export_types = WASM_EMPTY_VEC;
wasmtime_extern_t *imports = NULL;
wasmtime_module_t *stdlib_module = NULL;
wasm_memorytype_t *memory_type = NULL;
wasm_tabletype_t *table_type = NULL;
// Define functions called by scanners via function pointers on the lexer.
LexerInWasmMemory lexer = {
.lookahead = 0,
.result_symbol = 0,
};
FunctionDefinition lexer_definitions[] = {
{
(uint32_t *)&lexer.advance,
callback__lexer_advance,
wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())
},
{
(uint32_t *)&lexer.mark_end,
callback__lexer_mark_end,
wasm_functype_new_1_0(wasm_valtype_new_i32())
},
{
(uint32_t *)&lexer.get_column,
callback__lexer_get_column,
wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())
},
{
(uint32_t *)&lexer.is_at_included_range_start,
callback__lexer_is_at_included_range_start,
wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())
},
{
(uint32_t *)&lexer.eof,
callback__lexer_eof,
wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())
},
};
// Define builtin functions that can be imported by scanners.
BuiltinFunctionIndices builtin_fn_indices;
FunctionDefinition builtin_definitions[] = {
{
&builtin_fn_indices.proc_exit,
callback__abort,
wasm_functype_new_1_0(wasm_valtype_new_i32())
},
{
&builtin_fn_indices.abort,
callback__abort,
wasm_functype_new_0_0()
},
{
&builtin_fn_indices.assert_fail,
callback__abort,
wasm_functype_new_4_0(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
},
{
&builtin_fn_indices.notify_memory_growth,
callback__noop,
wasm_functype_new_1_0(wasm_valtype_new_i32())
},
{
&builtin_fn_indices.debug_message,
callback__debug_message,
wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())
},
{
&builtin_fn_indices.at_exit,
callback__noop,
wasm_functype_new_3_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
},
{
&builtin_fn_indices.args_get,
callback__noop,
wasm_functype_new_2_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
},
{
&builtin_fn_indices.args_sizes_get,
callback__noop,
wasm_functype_new_2_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())
},
};
// Create all of the wasm functions.
unsigned builtin_definitions_len = array_len(builtin_definitions);
unsigned lexer_definitions_len = array_len(lexer_definitions);
for (unsigned i = 0; i < builtin_definitions_len; i++) {
FunctionDefinition *definition = &builtin_definitions[i];
wasmtime_func_t func;
wasmtime_func_new_unchecked(context, definition->type, definition->callback, self, NULL, &func);
*definition->storage_location = func.index;
wasm_functype_delete(definition->type);
}
for (unsigned i = 0; i < lexer_definitions_len; i++) {
FunctionDefinition *definition = &lexer_definitions[i];
wasmtime_func_t func;
wasmtime_func_new_unchecked(context, definition->type, definition->callback, self, NULL, &func);
*definition->storage_location = func.index;
wasm_functype_delete(definition->type);
}
// Compile the stdlib module.
error = wasmtime_module_new(engine, STDLIB_WASM, STDLIB_WASM_LEN, &stdlib_module);
if (error) {
wasmtime_error_message(error, &message);
wasm_error->kind = TSWasmErrorKindCompile;
format(
&wasm_error->message,
"failed to compile wasm stdlib: %.*s",
(int)message.size, message.data
);
goto error;
}
// Retrieve the stdlib module's imports.
wasm_importtype_vec_t import_types = WASM_EMPTY_VEC;
wasmtime_module_imports(stdlib_module, &import_types);
// Find the initial number of memory pages needed by the stdlib.
const wasm_memorytype_t *stdlib_memory_type;
for (unsigned i = 0; i < import_types.size; i++) {
wasm_importtype_t *import_type = import_types.data[i];
const wasm_name_t *import_name = wasm_importtype_name(import_type);
if (name_eq(import_name, "memory")) {
const wasm_externtype_t *type = wasm_importtype_type(import_type);
stdlib_memory_type = wasm_externtype_as_memorytype_const(type);
}
}
if (!stdlib_memory_type) {
wasm_error->kind = TSWasmErrorKindCompile;
format(
&wasm_error->message,
"wasm stdlib is missing the 'memory' import"
);
goto error;
}
// Initialize store's memory
wasm_limits_t memory_limits = {.min = INITIAL_MEMORY_SIZE, .max = MAX_MEMORY_SIZE};
wasm_memorytype_t *memory_type = wasm_memorytype_new(&memory_limits);
uint64_t initial_memory_pages = wasmtime_memorytype_minimum(stdlib_memory_type);
wasm_limits_t memory_limits = {.min = initial_memory_pages, .max = MAX_MEMORY_SIZE};
memory_type = wasm_memorytype_new(&memory_limits);
wasmtime_memory_t memory;
error = wasmtime_memory_new(context, memory_type, &memory);
if (error) {
@ -566,41 +699,13 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
goto error;
}
wasm_memorytype_delete(memory_type);
memory_type = NULL;
// Initialize lexer struct with function pointers in wasm memory.
uint8_t *memory_data = wasmtime_memory_data(context, &memory);
LexerInWasmMemory lexer = {
.lookahead = 0,
.result_symbol = 0,
.advance = LEXER_ADVANCE_IX,
.mark_end = LEXER_MARK_END_IX,
.get_column = LEXER_GET_COLUMN_IX,
.is_at_included_range_start = LEXER_IS_AT_INCLUDED_RANGE_START_IX,
.eof = LEXER_EOF_IX,
};
memcpy(&memory_data[LEXER_ADDRESS], &lexer, sizeof(lexer));
// Define builtin functions.
FunctionDefinition definitions[] = {
[NULL_IX] = {NULL, NULL},
[PROC_EXIT_IX] = {callback__exit, wasm_functype_new_1_0(wasm_valtype_new_i32())},
[ABORT_IX] = {callback__exit, wasm_functype_new_0_0()},
[ASSERT_FAIL_IX] = {callback__exit, wasm_functype_new_4_0(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())},
[NOTIFY_MEMORY_GROWTH_IX] = {callback__notify_memory_growth, wasm_functype_new_1_0(wasm_valtype_new_i32())},
[AT_EXIT_IX] = {callback__at_exit, wasm_functype_new_3_1(wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32(), wasm_valtype_new_i32())},
[LEXER_ADVANCE_IX] = {callback__lexer_advance, wasm_functype_new_2_0(wasm_valtype_new_i32(), wasm_valtype_new_i32())},
[LEXER_MARK_END_IX] = {callback__lexer_mark_end, wasm_functype_new_1_0(wasm_valtype_new_i32())},
[LEXER_GET_COLUMN_IX] = {callback__lexer_get_column, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())},
[LEXER_IS_AT_INCLUDED_RANGE_START_IX] = {callback__lexer_is_at_included_range_start, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())},
[LEXER_EOF_IX] = {callback__lexer_eof, wasm_functype_new_1_1(wasm_valtype_new_i32(), wasm_valtype_new_i32())},
};
unsigned definitions_len = array_len(definitions);
// Add builtin functions to the store's function table.
wasmtime_table_t function_table;
wasm_limits_t table_limits = {.min = definitions_len, .max = wasm_limits_max_default};
wasm_tabletype_t *table_type = wasm_tabletype_new(wasm_valtype_new(WASM_FUNCREF), &table_limits);
// Initialize store's function table
wasm_limits_t table_limits = {.min = 1, .max = wasm_limits_max_default};
table_type = wasm_tabletype_new(wasm_valtype_new(WASM_FUNCREF), &table_limits);
wasmtime_val_t initializer = {.kind = WASMTIME_FUNCREF};
wasmtime_table_t function_table;
error = wasmtime_table_new(context, table_type, &initializer, &function_table);
if (error) {
wasmtime_error_message(error, &message);
@ -613,68 +718,34 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
goto error;
}
wasm_tabletype_delete(table_type);
table_type = NULL;
uint32_t prev_size;
error = wasmtime_table_grow(context, &function_table, definitions_len, &initializer, &prev_size);
if (error) {
wasmtime_error_message(error, &message);
wasm_error->kind = TSWasmErrorKindAllocate;
format(
&wasm_error->message,
"failed to grow wasm table to initial size: %.*s",
(int)message.size, message.data
);
goto error;
}
unsigned stdlib_symbols_len = array_len(STDLIB_SYMBOLS);
for (unsigned i = 1; i < definitions_len; i++) {
FunctionDefinition *definition = &definitions[i];
wasmtime_func_t func;
wasmtime_func_new_unchecked(context, definition->type, definition->callback, self, NULL, &func);
wasmtime_val_t func_val = {.kind = WASMTIME_FUNCREF, .of.funcref = func};
error = wasmtime_table_set(context, &function_table, i, &func_val);
assert(!error);
wasm_functype_delete(definition->type);
}
// Define globals for the stack and heap start addresses.
wasm_globaltype_t *const_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_CONST);
wasm_globaltype_t *var_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_VAR);
wasmtime_val_t stack_pointer_value = WASM_I32_VAL(0);
wasmtime_global_t stack_pointer_global;
error = wasmtime_global_new(context, var_i32_type, &stack_pointer_value, &stack_pointer_global);
assert(!error);
*self = (TSWasmStore) {
.store = store,
.engine = engine,
.store = store,
.memory = memory,
.language_instances = array_new(),
.function_table = function_table,
.language_instances = array_new(),
.stdlib_fn_indices = ts_calloc(stdlib_symbols_len, sizeof(uint32_t)),
.builtin_fn_indices = builtin_fn_indices,
.stack_pointer_global = stack_pointer_global,
.current_memory_offset = 0,
.fn_indices = ts_calloc(STDLIB_SYMBOL_COUNT, sizeof(uint16_t)),
.current_memory_size = 64 * MEMORY_PAGE_SIZE,
.current_function_table_offset = definitions_len,
.const_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_CONST),
.var_i32_type = wasm_globaltype_new(wasm_valtype_new_i32(), WASM_VAR),
.current_function_table_offset = 0,
.const_i32_type = const_i32_type,
};
WasmDylinkInfo dylink_info;
if (!wasm_dylink_info__parse(STDLIB_WASM, STDLIB_WASM_LEN, &dylink_info)) {
wasm_error->kind = TSWasmErrorKindParse;
format(&wasm_error->message, "failed to parse wasm stdlib");
goto error;
}
wasmtime_module_t *stdlib_module;
error = wasmtime_module_new(engine, STDLIB_WASM, STDLIB_WASM_LEN, &stdlib_module);
if (error) {
wasmtime_error_message(error, &message);
wasm_error->kind = TSWasmErrorKindCompile;
format(
&wasm_error->message,
"failed to compile wasm stdlib: %.*s",
(int)message.size, message.data
);
goto error;
}
wasmtime_instance_t instance;
wasm_importtype_vec_t import_types = WASM_EMPTY_VEC;
wasmtime_module_imports(stdlib_module, &import_types);
// Set up the imports for the stdlib module.
imports = ts_calloc(import_types.size, sizeof(wasmtime_extern_t));
for (unsigned i = 0; i < import_types.size; i++) {
wasm_importtype_t *type = import_types.data[i];
@ -690,6 +761,8 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
}
}
// Instantiate the stdlib module.
wasmtime_instance_t instance;
error = wasmtime_instance_new(context, stdlib_module, imports, import_types.size, &instance, &trap);
ts_free(imports);
imports = NULL;
@ -715,14 +788,10 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
}
wasm_importtype_vec_delete(&import_types);
self->current_memory_offset = DATA_START_ADDRESS + dylink_info.memory_size;
self->current_function_table_offset += dylink_info.table_size;
for (unsigned i = 0; i < STDLIB_SYMBOL_COUNT; i++) {
self->fn_indices[i] = UINT16_MAX;
}
// Process the stdlib module's exports.
for (unsigned i = 0; i < stdlib_symbols_len; i++) {
self->stdlib_fn_indices[i] = UINT32_MAX;
}
wasmtime_module_exports(stdlib_module, &export_types);
for (unsigned i = 0; i < export_types.size; i++) {
wasm_exporttype_t *export_type = export_types.data[i];
@ -734,6 +803,12 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
bool exists = wasmtime_instance_export_nth(context, &instance, i, &export_name, &name_len, &export);
assert(exists);
if (export.kind == WASMTIME_EXTERN_GLOBAL) {
if (name_eq(name, "__stack_pointer")) {
self->stack_pointer_global = export.of.global;
}
}
if (export.kind == WASMTIME_EXTERN_FUNC) {
if (ts_wasm_store__call_module_initializer(self, name, &export, &trap)) {
if (trap) {
@ -749,17 +824,31 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
continue;
}
for (unsigned j = 0; j < array_len(STDLIB_SYMBOLS); j++) {
if (name_eq(name, "reset_heap")) {
self->builtin_fn_indices.reset_heap = export.of.func.index;
continue;
}
for (unsigned j = 0; j < stdlib_symbols_len; j++) {
if (name_eq(name, STDLIB_SYMBOLS[j])) {
self->fn_indices[j] = export.of.func.index;
self->stdlib_fn_indices[j] = export.of.func.index;
break;
}
}
}
}
for (unsigned i = 0; i < STDLIB_SYMBOL_COUNT; i++) {
if (self->fn_indices[i] == UINT16_MAX) {
if (self->builtin_fn_indices.reset_heap == UINT32_MAX) {
wasm_error->kind = TSWasmErrorKindInstantiate;
format(
&wasm_error->message,
"missing malloc reset function in wasm stdlib"
);
goto error;
}
for (unsigned i = 0; i < stdlib_symbols_len; i++) {
if (self->stdlib_fn_indices[i] == UINT32_MAX) {
wasm_error->kind = TSWasmErrorKindInstantiate;
format(
&wasm_error->message,
@ -771,11 +860,53 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
}
wasm_exporttype_vec_delete(&export_types);
wasmtime_module_delete(stdlib_module);
// Add all of the lexer callback functions to the function table. Store their function table
// indices on the in-memory lexer.
uint32_t table_index;
error = wasmtime_table_grow(context, &function_table, lexer_definitions_len, &initializer, &table_index);
if (error) {
wasmtime_error_message(error, &message);
wasm_error->kind = TSWasmErrorKindAllocate;
format(
&wasm_error->message,
"failed to grow wasm table to initial size: %.*s",
(int)message.size, message.data
);
goto error;
}
for (unsigned i = 0; i < lexer_definitions_len; i++) {
FunctionDefinition *definition = &lexer_definitions[i];
wasmtime_func_t func = {function_table.store_id, *definition->storage_location};
wasmtime_val_t func_val = {.kind = WASMTIME_FUNCREF, .of.funcref = func};
error = wasmtime_table_set(context, &function_table, table_index, &func_val);
assert(!error);
*(int32_t *)(definition->storage_location) = table_index;
table_index++;
}
self->current_function_table_offset = table_index;
self->lexer_address = initial_memory_pages * MEMORY_PAGE_SIZE;
self->serialization_buffer_address = self->lexer_address + sizeof(LexerInWasmMemory);
self->current_memory_offset = self->serialization_buffer_address + TREE_SITTER_SERIALIZATION_BUFFER_SIZE;
// Grow the memory enough to hold the builtin lexer and serialization buffer.
uint32_t new_pages_needed = (self->current_memory_offset - self->lexer_address - 1) / MEMORY_PAGE_SIZE + 1;
uint64_t prev_memory_size;
wasmtime_memory_grow(context, &memory, new_pages_needed, &prev_memory_size);
uint8_t *memory_data = wasmtime_memory_data(context, &memory);
memcpy(&memory_data[self->lexer_address], &lexer, sizeof(lexer));
return self;
error:
ts_free(self);
if (stdlib_module) wasmtime_module_delete(stdlib_module);
if (store) wasmtime_store_delete(store);
if (import_types.size) wasm_importtype_vec_delete(&import_types);
if (memory_type) wasm_memorytype_delete(memory_type);
if (table_type) wasm_tabletype_delete(table_type);
if (trap) wasm_trap_delete(trap);
if (error) wasmtime_error_delete(error);
if (message.size) wasm_byte_vec_delete(&message);
@ -786,9 +917,8 @@ error:
void ts_wasm_store_delete(TSWasmStore *self) {
if (!self) return;
ts_free(self->fn_indices);
ts_free(self->stdlib_fn_indices);
wasm_globaltype_delete(self->const_i32_type);
wasm_globaltype_delete(self->var_i32_type);
wasmtime_store_delete(self->store);
wasm_engine_delete(self->engine);
for (unsigned i = 0; i < self->language_instances.size; i++) {
@ -837,9 +967,10 @@ static bool ts_wasm_store__instantiate(
// Grow the memory to make room for the new data.
uint32_t needed_memory_size = self->current_memory_offset + dylink_info->memory_size;
if (needed_memory_size > self->current_memory_size) {
uint32_t current_memory_size = wasmtime_memory_data_size(context, &self->memory);
if (needed_memory_size > current_memory_size) {
uint32_t pages_to_grow = (
needed_memory_size - self->current_memory_size + MEMORY_PAGE_SIZE - 1) /
needed_memory_size - current_memory_size + MEMORY_PAGE_SIZE - 1) /
MEMORY_PAGE_SIZE;
uint64_t prev_memory_size;
error = wasmtime_memory_grow(context, &self->memory, pages_to_grow, &prev_memory_size);
@ -847,7 +978,6 @@ static bool ts_wasm_store__instantiate(
format(error_message, "invalid memory size %u", dylink_info->memory_size);
goto error;
}
self->current_memory_size += pages_to_grow * MEMORY_PAGE_SIZE;
}
// Construct the language function name as string.
@ -875,7 +1005,7 @@ static bool ts_wasm_store__instantiate(
bool defined_in_stdlib = false;
for (unsigned j = 0; j < array_len(STDLIB_SYMBOLS); j++) {
if (name_eq(import_name, STDLIB_SYMBOLS[j])) {
uint16_t address = self->fn_indices[j];
uint16_t address = self->stdlib_fn_indices[j];
imports[i] = (wasmtime_extern_t) {.kind = WASMTIME_EXTERN_FUNC, .of.func = {store_id, address}};
defined_in_stdlib = true;
break;
@ -1326,17 +1456,37 @@ bool ts_wasm_store_add_language(
return true;
}
void ts_wasm_store_reset_heap(TSWasmStore *self) {
wasmtime_context_t *context = wasmtime_store_context(self->store);
wasmtime_func_t func = {
self->function_table.store_id,
self->builtin_fn_indices.reset_heap
};
wasm_trap_t *trap = NULL;
wasmtime_val_t args[1] = {
{.of.i32 = self->current_memory_offset, .kind = WASMTIME_I32},
};
wasmtime_error_t *error = wasmtime_func_call(context, &func, args, 1, NULL, 0, &trap);
assert(!error);
assert(!trap);
}
bool ts_wasm_store_start(TSWasmStore *self, TSLexer *lexer, const TSLanguage *language) {
uint32_t instance_index;
if (!ts_wasm_store_add_language(self, language, &instance_index)) return false;
self->current_lexer = lexer;
self->current_instance = &self->language_instances.contents[instance_index];
self->has_error = false;
ts_wasm_store_reset_heap(self);
return true;
}
void ts_wasm_store_stop(TSWasmStore *self) {
void ts_wasm_store_reset(TSWasmStore *self) {
self->current_lexer = NULL;
self->current_instance = NULL;
self->has_error = false;
ts_wasm_store_reset_heap(self);
}
static void ts_wasm_store__call(
@ -1354,17 +1504,26 @@ static void ts_wasm_store__call(
wasm_trap_t *trap = NULL;
wasmtime_error_t *error = wasmtime_func_call_unchecked(context, &func, args_and_results, args_and_results_len, &trap);
assert(!error);
if (trap) {
wasm_message_t message;
wasm_trap_message(trap, &message);
fprintf(
stderr,
"trap when calling wasm lexing function %u: %.*s\n",
function_index,
(int)message.size, message.data
);
abort();
if (error) {
// wasm_message_t message;
// wasmtime_error_message(error, &message);
// fprintf(
// stderr,
// "error in wasm module: %.*s\n",
// (int)message.size, message.data
// );
wasmtime_error_delete(error);
self->has_error = true;
} else if (trap) {
// wasm_message_t message;
// wasm_trap_message(trap, &message);
// fprintf(
// stderr,
// "trap in wasm module: %.*s\n",
// (int)message.size, message.data
// );
wasm_trap_delete(trap);
self->has_error = true;
}
}
@ -1372,21 +1531,22 @@ static bool ts_wasm_store__call_lex_function(TSWasmStore *self, unsigned functio
wasmtime_context_t *context = wasmtime_store_context(self->store);
uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
memcpy(
&memory_data[LEXER_ADDRESS],
&memory_data[self->lexer_address],
&self->current_lexer->lookahead,
sizeof(self->current_lexer->lookahead)
);
wasmtime_val_raw_t args[2] = {
{.i32 = LEXER_ADDRESS},
{.i32 = self->lexer_address},
{.i32 = state},
};
ts_wasm_store__call(self, function_index, args, 2);
if (self->has_error) return false;
bool result = args[0].i32;
memcpy(
&self->current_lexer->lookahead,
&memory_data[LEXER_ADDRESS],
&memory_data[self->lexer_address],
sizeof(self->current_lexer->lookahead) + sizeof(self->current_lexer->result_symbol)
);
return result;
@ -1411,12 +1571,15 @@ bool ts_wasm_store_call_lex_keyword(TSWasmStore *self, TSStateId state) {
uint32_t ts_wasm_store_call_scanner_create(TSWasmStore *self) {
wasmtime_val_raw_t args[1] = {{.i32 = 0}};
ts_wasm_store__call(self, self->current_instance->scanner_create_fn_index, args, 1);
if (self->has_error) return 0;
return args[0].i32;
}
void ts_wasm_store_call_scanner_destroy(TSWasmStore *self, uint32_t scanner_address) {
wasmtime_val_raw_t args[1] = {{.i32 = scanner_address}};
ts_wasm_store__call(self, self->current_instance->scanner_destroy_fn_index, args, 1);
if (self->current_instance) {
wasmtime_val_raw_t args[1] = {{.i32 = scanner_address}};
ts_wasm_store__call(self, self->current_instance->scanner_destroy_fn_index, args, 1);
}
}
bool ts_wasm_store_call_scanner_scan(
@ -1428,7 +1591,7 @@ bool ts_wasm_store_call_scanner_scan(
uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
memcpy(
&memory_data[LEXER_ADDRESS],
&memory_data[self->lexer_address],
&self->current_lexer->lookahead,
sizeof(self->current_lexer->lookahead)
);
@ -1438,14 +1601,15 @@ bool ts_wasm_store_call_scanner_scan(
(valid_tokens_ix * sizeof(bool));
wasmtime_val_raw_t args[3] = {
{.i32 = scanner_address},
{.i32 = LEXER_ADDRESS},
{.i32 = self->lexer_address},
{.i32 = valid_tokens_address}
};
ts_wasm_store__call(self, self->current_instance->scanner_scan_fn_index, args, 3);
if (self->has_error) return false;
memcpy(
&self->current_lexer->lookahead,
&memory_data[LEXER_ADDRESS],
&memory_data[self->lexer_address],
sizeof(self->current_lexer->lookahead) + sizeof(self->current_lexer->result_symbol)
);
return args[0].i32;
@ -1461,15 +1625,17 @@ uint32_t ts_wasm_store_call_scanner_serialize(
wasmtime_val_raw_t args[2] = {
{.i32 = scanner_address},
{.i32 = SERIALIZATION_BUFFER_ADDRESS},
{.i32 = self->serialization_buffer_address},
};
ts_wasm_store__call(self, self->current_instance->scanner_serialize_fn_index, args, 2);
if (self->has_error) return 0;
uint32_t length = args[0].i32;
if (length > 0) {
memcpy(
((Lexer *)self->current_lexer)->debug_buffer,
&memory_data[SERIALIZATION_BUFFER_ADDRESS],
&memory_data[self->serialization_buffer_address],
length
);
}
@ -1487,7 +1653,7 @@ void ts_wasm_store_call_scanner_deserialize(
if (length > 0) {
memcpy(
&memory_data[SERIALIZATION_BUFFER_ADDRESS],
&memory_data[self->serialization_buffer_address],
buffer,
length
);
@ -1495,12 +1661,16 @@ void ts_wasm_store_call_scanner_deserialize(
wasmtime_val_raw_t args[3] = {
{.i32 = scanner_address},
{.i32 = SERIALIZATION_BUFFER_ADDRESS},
{.i32 = self->serialization_buffer_address},
{.i32 = length},
};
ts_wasm_store__call(self, self->current_instance->scanner_deserialize_fn_index, args, 3);
}
bool ts_wasm_store_has_error(const TSWasmStore *self) {
return self->has_error;
}
bool ts_language_is_wasm(const TSLanguage *self) {
return self->lex_fn == ts_wasm_store__sentinel_lex_fn;
}
@ -1569,7 +1739,7 @@ bool ts_wasm_store_start(
return false;
}
void ts_wasm_store_stop(TSWasmStore *self) {
void ts_wasm_store_reset(TSWasmStore *self) {
(void)self;
}
@ -1632,6 +1802,11 @@ void ts_wasm_store_call_scanner_deserialize(
(void)length;
}
bool ts_wasm_store_has_error(const TSWasmStore *self) {
(void)self;
return false;
}
bool ts_language_is_wasm(const TSLanguage *self) {
(void)self;
return false;

View file

@ -9,7 +9,8 @@ extern "C" {
#include "./parser.h"
bool ts_wasm_store_start(TSWasmStore *, TSLexer *, const TSLanguage *);
void ts_wasm_store_stop(TSWasmStore *);
void ts_wasm_store_reset(TSWasmStore *);
bool ts_wasm_store_has_error(const TSWasmStore *);
bool ts_wasm_store_call_lex_main(TSWasmStore *, TSStateId);
bool ts_wasm_store_call_lex_keyword(TSWasmStore *, TSStateId);

View file

@ -2,27 +2,33 @@
set -e
# Remove quotes, add leading underscores, remove newlines, remove trailing comma.
# Remove quotes and commas
EXPORTED_FUNCTIONS=$( \
cat lib/src/wasm/stdlib-symbols.txt | \
sed -e 's/"//g' | \
sed -e 's/^/_/g' | \
tr -d '\n"' | \
sed -e 's/,$//' \
tr -d ',"' \
)
emcc \
-o stdlib.wasm \
-Os \
--no-entry \
-s MAIN_MODULE=2 \
-s "EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}" \
-s 'ALLOW_MEMORY_GROWTH' \
-s 'TOTAL_MEMORY=4MB' \
-fvisibility=hidden \
-fno-exceptions \
-xc \
/dev/null
EXPORT_FLAGS=""
for function in ${EXPORTED_FUNCTIONS}; do
EXPORT_FLAGS+=" -Wl,--export=${function}"
done
target/wasi-sdk-21.0/bin/clang-17 \
-o stdlib.wasm \
-Os \
-fPIC \
-Wl,--no-entry \
-Wl,--stack-first \
-Wl,-z -Wl,stack-size=65536 \
-Wl,--import-undefined \
-Wl,--import-memory \
-Wl,--import-table \
-Wl,--strip-debug \
-Wl,--export=reset_heap \
-Wl,--export=__wasm_call_ctors \
-Wl,--export=__stack_pointer \
${EXPORT_FLAGS} \
lib/src/wasm/stdlib.c
xxd -C -i stdlib.wasm > lib/src/wasm/wasm-stdlib.h
mv stdlib.wasm target/

View file

@ -12,7 +12,7 @@ OPTIONS
-h Print this message
-a Compile C code with the Clang static analyzer
-a Compile C code with the Clang address sanitizer
-e Run only the corpus tests whose name contain the given string
@ -41,9 +41,17 @@ while getopts "adDghl:e:s:i:" option; do
exit
;;
a)
export RUSTFLAGS="-Z sanitizer=address"
# Specify a `--target` explicitly. For some reason, this is required for
# address sanitizer support.
export CFLAGS="-fsanitize=undefined,address"
# When the Tree-sitter C library is compiled with the address sanitizer, the address sanitizer
# runtime library needs to be linked into the final test executable. When using Xcode clang,
# the Rust linker doesn't know where to find that library, so we need to specify linker flags directly.
runtime_dir=$(cc -print-runtime-dir)
if [[ $runtime_dir == */Xcode.app/* ]]; then
export RUSTFLAGS="-C link-arg=-L${runtime_dir} -C link-arg=-lclang_rt.asan_osx_dynamic -C link-arg=-Wl,-rpath,${runtime_dir}"
fi
# Specify a `--target` explicitly. This is required for address sanitizer support.
toolchain=$(rustup show active-toolchain)
toolchain_regex='(stable|beta|nightly)-([_a-z0-9-]+).*'
if [[ $toolchain =~ $toolchain_regex ]]; then
@ -52,7 +60,8 @@ while getopts "adDghl:e:s:i:" option; do
else
echo "Failed to parse toolchain '${toolchain}'"
fi
test_flags="${test_flags} --target ${current_target}"
test_flags+=" --target ${current_target}"
;;
e)
export TREE_SITTER_EXAMPLE=${OPTARG}