Merge pull request #3318 from tree-sitter/serialization-buffer-overflows

Improve handling of serialization buffer overflows
This commit is contained in:
Max Brunsfeld 2024-04-25 14:49:09 -07:00 committed by GitHub
commit 8040baed18
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 11 deletions

View file

@ -397,10 +397,12 @@ static unsigned ts_parser__external_scanner_serialize(
self->lexer.debug_buffer
);
} else {
return self->language->external_scanner.serialize(
uint32_t length = self->language->external_scanner.serialize(
self->external_scanner_payload,
self->lexer.debug_buffer
);
assert(length <= TREE_SITTER_SERIALIZATION_BUFFER_SIZE);
return length;
}
}

View file

@ -101,7 +101,6 @@ struct TSWasmStore {
wasm_globaltype_t *const_i32_type;
bool has_error;
uint32_t lexer_address;
uint32_t serialization_buffer_address;
};
typedef Array(char) StringData;
@ -162,7 +161,7 @@ typedef struct {
static volatile uint32_t NEXT_LANGUAGE_ID;
// Linear memory layout:
// [ <-- stack | stdlib statics | lexer | serialization_buffer | language statics --> | heap --> ]
// [ <-- stack | stdlib statics | lexer | language statics --> | serialization_buffer | heap --> ]
#define MAX_MEMORY_SIZE (128 * 1024 * 1024 / MEMORY_PAGE_SIZE)
/************************
@ -888,8 +887,7 @@ TSWasmStore *ts_wasm_store_new(TSWasmEngine *engine, TSWasmError *wasm_error) {
self->current_function_table_offset = table_index;
self->lexer_address = initial_memory_pages * MEMORY_PAGE_SIZE;
self->serialization_buffer_address = self->lexer_address + sizeof(LexerInWasmMemory);
self->current_memory_offset = self->serialization_buffer_address + TREE_SITTER_SERIALIZATION_BUFFER_SIZE;
self->current_memory_offset = self->lexer_address + sizeof(LexerInWasmMemory);
// Grow the memory enough to hold the builtin lexer and serialization buffer.
uint32_t new_pages_needed = (self->current_memory_offset - self->lexer_address - 1) / MEMORY_PAGE_SIZE + 1;
@ -940,6 +938,14 @@ size_t ts_wasm_store_language_count(const TSWasmStore *self) {
return result;
}
static uint32_t ts_wasm_store__heap_address(TSWasmStore *self) {
return self->current_memory_offset + TREE_SITTER_SERIALIZATION_BUFFER_SIZE;
}
static uint32_t ts_wasm_store__serialization_buffer_address(TSWasmStore *self) {
return self->current_memory_offset;
}
static bool ts_wasm_store__instantiate(
TSWasmStore *self,
wasmtime_module_t *module,
@ -966,7 +972,7 @@ static bool ts_wasm_store__instantiate(
}
// Grow the memory to make room for the new data.
uint32_t needed_memory_size = self->current_memory_offset + dylink_info->memory_size;
uint32_t needed_memory_size = ts_wasm_store__heap_address(self) + dylink_info->memory_size;
uint32_t current_memory_size = wasmtime_memory_data_size(context, &self->memory);
if (needed_memory_size > current_memory_size) {
uint32_t pages_to_grow = (
@ -1475,7 +1481,7 @@ void ts_wasm_store_reset_heap(TSWasmStore *self) {
};
wasm_trap_t *trap = NULL;
wasmtime_val_t args[1] = {
{.of.i32 = self->current_memory_offset, .kind = WASMTIME_I32},
{.of.i32 = ts_wasm_store__heap_address(self), .kind = WASMTIME_I32},
};
wasmtime_error_t *error = wasmtime_func_call(context, &func, args, 1, NULL, 0, &trap);
@ -1633,20 +1639,25 @@ uint32_t ts_wasm_store_call_scanner_serialize(
) {
wasmtime_context_t *context = wasmtime_store_context(self->store);
uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
uint32_t serialization_buffer_address = ts_wasm_store__serialization_buffer_address(self);
wasmtime_val_raw_t args[2] = {
{.i32 = scanner_address},
{.i32 = self->serialization_buffer_address},
{.i32 = serialization_buffer_address},
};
ts_wasm_store__call(self, self->current_instance->scanner_serialize_fn_index, args, 2);
if (self->has_error) return 0;
uint32_t length = args[0].i32;
if (length > TREE_SITTER_SERIALIZATION_BUFFER_SIZE) {
self->has_error = true;
return 0;
}
if (length > 0) {
memcpy(
((Lexer *)self->current_lexer)->debug_buffer,
&memory_data[self->serialization_buffer_address],
&memory_data[serialization_buffer_address],
length
);
}
@ -1661,10 +1672,11 @@ void ts_wasm_store_call_scanner_deserialize(
) {
wasmtime_context_t *context = wasmtime_store_context(self->store);
uint8_t *memory_data = wasmtime_memory_data(context, &self->memory);
uint32_t serialization_buffer_address = ts_wasm_store__serialization_buffer_address(self);
if (length > 0) {
memcpy(
&memory_data[self->serialization_buffer_address],
&memory_data[serialization_buffer_address],
buffer,
length
);
@ -1672,7 +1684,7 @@ void ts_wasm_store_call_scanner_deserialize(
wasmtime_val_raw_t args[3] = {
{.i32 = scanner_address},
{.i32 = self->serialization_buffer_address},
{.i32 = serialization_buffer_address},
{.i32 = length},
};
ts_wasm_store__call(self, self->current_instance->scanner_deserialize_fn_index, args, 3);