Allow wasm languages to be deleted

This commit is contained in:
Max Brunsfeld 2023-12-27 14:54:38 -08:00
parent da16cb1459
commit 4a8e4b1963
10 changed files with 142 additions and 62 deletions

View file

@ -28,7 +28,7 @@ fn test_lookahead_iterator() {
let expected_symbols = ["identifier", "block_comment", "line_comment"];
let mut lookahead = language.lookahead_iterator(next_state).unwrap();
assert_eq!(lookahead.language(), language);
assert_eq!(*lookahead.language(), language);
assert!(lookahead.iter_names().eq(expected_symbols));
lookahead.reset_state(next_state);

View file

@ -2115,7 +2115,7 @@ fn test_query_cursor_next_capture_with_byte_range() {
allocations::record(|| {
let language = get_language("python");
let query = Query::new(
language,
&language,
"(function_definition name: (identifier) @function)
(attribute attribute: (identifier) @property)
((identifier) @variable)",
@ -2128,7 +2128,7 @@ fn test_query_cursor_next_capture_with_byte_range() {
// point_pos (0,0) (1,0) (1,5) (1,15)
let mut parser = Parser::new();
parser.set_language(language).unwrap();
parser.set_language(&language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();
@ -2149,7 +2149,7 @@ fn test_query_cursor_next_capture_with_point_range() {
allocations::record(|| {
let language = get_language("python");
let query = Query::new(
language,
&language,
"(function_definition name: (identifier) @function)
(attribute attribute: (identifier) @property)
((identifier) @variable)",
@ -2162,7 +2162,7 @@ fn test_query_cursor_next_capture_with_point_range() {
// point_pos (0,0) (1,0) (1,5) (1,15)
let mut parser = Parser::new();
parser.set_language(language).unwrap();
parser.set_language(&language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut cursor = QueryCursor::new();

View file

@ -1,4 +1,4 @@
use crate::tests::helpers::fixtures::WASM_DIR;
use crate::tests::helpers::{allocations, fixtures::WASM_DIR};
use lazy_static::lazy_static;
use std::fs;
use tree_sitter::{wasmtime::Engine, Parser, WasmError, WasmErrorKind, WasmStore};
@ -9,56 +9,59 @@ lazy_static! {
#[test]
fn test_load_wasm_language() {
let mut store = WasmStore::new(ENGINE.clone()).unwrap();
let mut parser = Parser::new();
allocations::record(|| {
let mut store = WasmStore::new(ENGINE.clone()).unwrap();
let mut parser = Parser::new();
let wasm_cpp = fs::read(&WASM_DIR.join(format!("tree-sitter-cpp.wasm"))).unwrap();
let wasm_rs = fs::read(&WASM_DIR.join(format!("tree-sitter-rust.wasm"))).unwrap();
let wasm_rb = fs::read(&WASM_DIR.join(format!("tree-sitter-ruby.wasm"))).unwrap();
let wasm_typescript = fs::read(&WASM_DIR.join(format!("tree-sitter-typescript.wasm"))).unwrap();
let wasm_cpp = fs::read(&WASM_DIR.join(format!("tree-sitter-cpp.wasm"))).unwrap();
let wasm_rs = fs::read(&WASM_DIR.join(format!("tree-sitter-rust.wasm"))).unwrap();
let wasm_rb = fs::read(&WASM_DIR.join(format!("tree-sitter-ruby.wasm"))).unwrap();
let wasm_typescript =
fs::read(&WASM_DIR.join(format!("tree-sitter-typescript.wasm"))).unwrap();
let language_rust = store.load_language("rust", &wasm_rs).unwrap();
let language_cpp = store.load_language("cpp", &wasm_cpp).unwrap();
let language_ruby = store.load_language("ruby", &wasm_rb).unwrap();
let language_typescript = store.load_language("typescript", &wasm_typescript).unwrap();
parser.set_wasm_store(store).unwrap();
let language_rust = store.load_language("rust", &wasm_rs).unwrap();
let language_cpp = store.load_language("cpp", &wasm_cpp).unwrap();
let language_ruby = store.load_language("ruby", &wasm_rb).unwrap();
let language_typescript = store.load_language("typescript", &wasm_typescript).unwrap();
parser.set_wasm_store(store).unwrap();
let mut parser2 = Parser::new();
parser2
.set_wasm_store(WasmStore::new(ENGINE.clone()).unwrap())
.unwrap();
let mut parser2 = Parser::new();
parser2
.set_wasm_store(WasmStore::new(ENGINE.clone()))
.unwrap();
for mut parser in [parser, parser2] {
for _ in 0..2 {
parser.set_language(language_cpp).unwrap();
let tree = parser.parse("A<B> c = d();", None).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(translation_unit (declaration type: (template_type name: (type_identifier) arguments: (template_argument_list (type_descriptor type: (type_identifier)))) declarator: (init_declarator declarator: (identifier) value: (call_expression function: (identifier) arguments: (argument_list)))))"
);
for mut parser in [parser, parser2] {
for _ in 0..2 {
parser.set_language(&language_cpp).unwrap();
let tree = parser.parse("A<B> c = d();", None).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(translation_unit (declaration type: (template_type name: (type_identifier) arguments: (template_argument_list (type_descriptor type: (type_identifier)))) declarator: (init_declarator declarator: (identifier) value: (call_expression function: (identifier) arguments: (argument_list)))))"
);
parser.set_language(language_rust).unwrap();
let tree = parser.parse("const A: B = c();", None).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(source_file (const_item name: (identifier) type: (type_identifier) value: (call_expression function: (identifier) arguments: (arguments))))"
);
parser.set_language(&language_rust).unwrap();
let tree = parser.parse("const A: B = c();", None).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(source_file (const_item name: (identifier) type: (type_identifier) value: (call_expression function: (identifier) arguments: (arguments))))"
);
parser.set_language(language_ruby).unwrap();
let tree = parser.parse("class A; end", None).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(program (class name: (constant)))"
);
parser.set_language(&language_ruby).unwrap();
let tree = parser.parse("class A; end", None).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(program (class name: (constant)))"
);
parser.set_language(language_typescript).unwrap();
let tree = parser.parse("class A {}", None).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(program (class_declaration name: (type_identifier) body: (class_body)))"
);
parser.set_language(&language_typescript).unwrap();
let tree = parser.parse("class A {}", None).unwrap();
assert_eq!(
tree.root_node().to_sexp(),
"(program (class_declaration name: (type_identifier) body: (class_body)))"
);
}
}
}
});
}
#[test]

View file

@ -11,9 +11,9 @@ use std::{
ffi::CStr,
fmt, hash, iter,
marker::PhantomData,
mem::MaybeUninit,
mem::{self, MaybeUninit},
num::NonZeroU16,
ops,
ops::{self, Deref},
os::raw::{c_char, c_void},
ptr::{self, NonNull},
slice, str,
@ -51,6 +51,8 @@ pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/par
#[repr(transparent)]
pub struct Language(*const ffi::TSLanguage);
pub struct LanguageRef<'a>(*const ffi::TSLanguage, PhantomData<&'a ()>);
/// A tree that represents the syntactic structure of a source code file.
#[doc(alias = "TSTree")]
pub struct Tree(NonNull<ffi::TSTree>);
@ -397,6 +399,14 @@ impl Drop for Language {
}
}
impl<'a> Deref for LanguageRef<'a> {
type Target = Language;
fn deref(&self) -> &Self::Target {
unsafe { mem::transmute(&self.0) }
}
}
impl Parser {
/// Create a new parser.
pub fn new() -> Parser {
@ -778,8 +788,11 @@ impl Tree {
/// Get the language that was used to parse the syntax tree.
#[doc(alias = "ts_tree_language")]
pub fn language(&self) -> Language {
Language(unsafe { ffi::ts_tree_language(self.0.as_ptr()) })
pub fn language(&self) -> LanguageRef {
LanguageRef(
unsafe { ffi::ts_tree_language(self.0.as_ptr()) },
PhantomData,
)
}
/// Edit the syntax tree to keep it in sync with source code that has been
@ -906,8 +919,8 @@ impl<'tree> Node<'tree> {
/// Get the [`Language`] that was used to parse this node's syntax tree.
#[doc(alias = "ts_node_language")]
pub fn language(&self) -> Language {
Language(unsafe { ffi::ts_node_language(self.0) })
pub fn language(&self) -> LanguageRef {
LanguageRef(unsafe { ffi::ts_node_language(self.0) }, PhantomData)
}
/// Check if this node is *named*.
@ -1485,8 +1498,11 @@ impl Drop for TreeCursor<'_> {
impl LookaheadIterator {
/// Get the current language of the lookahead iterator.
#[doc(alias = "ts_lookahead_iterator_language")]
pub fn language(&self) -> Language {
Language(unsafe { ffi::ts_lookahead_iterator_language(self.0.as_ptr()) })
pub fn language(&self) -> LanguageRef<'_> {
LanguageRef(
unsafe { ffi::ts_lookahead_iterator_language(self.0.as_ptr()) },
PhantomData,
)
}
/// Get the current symbol of the lookahead iterator.

View file

@ -1,13 +1,19 @@
#include "./language.h"
#include "./wasm.h"
#include "tree_sitter/api.h"
#include <string.h>
const TSLanguage *ts_language_copy(const TSLanguage *self) {
// TODO - increment reference count for wasm languages
if (self && ts_language_is_wasm(self)) {
ts_wasm_language_retain(self);
}
return self;
}
void ts_language_delete(const TSLanguage *_self) {
// TODO - decrement reference count for wasm languages
void ts_language_delete(const TSLanguage *self) {
if (self && ts_language_is_wasm(self)) {
ts_wasm_language_release(self);
}
}
uint32_t ts_language_symbol_count(const TSLanguage *self) {

View file

@ -1869,6 +1869,7 @@ const TSLanguage *ts_parser_language(const TSParser *self) {
bool ts_parser_set_language(TSParser *self, const TSLanguage *language) {
ts_parser__external_scanner_destroy(self);
ts_language_delete(self->language);
self->language = NULL;
if (language) {
@ -1885,7 +1886,7 @@ bool ts_parser_set_language(TSParser *self, const TSLanguage *language) {
}
}
self->language = language;
self->language = ts_language_copy(language);
ts_parser__external_scanner_create(self);
ts_parser_reset(self);
return true;

View file

@ -2812,6 +2812,7 @@ void ts_query_delete(TSQuery *self) {
array_delete(&self->string_buffer);
array_delete(&self->negated_fields);
array_delete(&self->repeat_symbols_with_rootless_patterns);
ts_language_delete(self->language);
symbol_table_delete(&self->captures);
symbol_table_delete(&self->predicate_values);
for (uint32_t index = 0; index < self->capture_quantifiers.size; index++) {

View file

@ -12,7 +12,7 @@ TSTree *ts_tree_new(
) {
TSTree *result = ts_malloc(sizeof(TSTree));
result->root = root;
result->language = language;
result->language = ts_language_copy(language);
result->included_ranges = ts_calloc(included_range_count, sizeof(TSRange));
memcpy(result->included_ranges, included_ranges, included_range_count * sizeof(TSRange));
result->included_range_count = included_range_count;
@ -30,6 +30,7 @@ void ts_tree_delete(TSTree *self) {
SubtreePool pool = ts_subtree_pool_new(0);
ts_subtree_release(&pool, self->root);
ts_subtree_pool_delete(&pool);
ts_language_delete(self->language);
ts_free(self->included_ranges);
ts_free(self);
}

View file

@ -72,6 +72,7 @@ typedef struct {
// wasm store, so it can be shared by all users of a `TSLanguage`. A pointer to
// this is stored on the language itself.
typedef struct {
volatile uint32_t ref_count;
wasmtime_module_t *module;
uint32_t language_id;
const char *name;
@ -1074,7 +1075,7 @@ const TSLanguage *ts_wasm_store_load_language(
};
uint32_t address_count = array_len(addresses);
TSLanguage *language = ts_malloc(sizeof(TSLanguage));
TSLanguage *language = ts_calloc(1, sizeof(TSLanguage));
StringData symbol_name_buffer = array_new();
StringData field_name_buffer = array_new();
@ -1202,6 +1203,7 @@ const TSLanguage *ts_wasm_store_load_language(
.symbol_name_buffer = symbol_name_buffer.contents,
.field_name_buffer = field_name_buffer.contents,
.dylink_info = dylink_info,
.ref_count = 1,
};
// The lex functions are not used for wasm languages. Use those two fields
@ -1468,6 +1470,45 @@ bool ts_language_is_wasm(const TSLanguage *self) {
return self->lex_fn == ts_wasm_store__sentinel_lex_fn;
}
static inline LanguageWasmModule *ts_language__wasm_module(const TSLanguage *self) {
return (LanguageWasmModule *)self->keyword_lex_fn;
}
void ts_wasm_language_retain(const TSLanguage *self) {
LanguageWasmModule *module = ts_language__wasm_module(self);
assert(module->ref_count > 0);
atomic_inc(&module->ref_count);
}
void ts_wasm_language_release(const TSLanguage *self) {
LanguageWasmModule *module = ts_language__wasm_module(self);
assert(module->ref_count > 0);
if (atomic_dec(&module->ref_count) == 0) {
ts_free((void *)module->field_name_buffer);
ts_free((void *)module->symbol_name_buffer);
ts_free((void *)module->name);
wasmtime_module_delete(module->module);
ts_free(module);
ts_free((void *)self->alias_map);
ts_free((void *)self->alias_sequences);
ts_free((void *)self->external_scanner.symbol_map);
ts_free((void *)self->field_map_entries);
ts_free((void *)self->field_map_slices);
ts_free((void *)self->field_names);
ts_free((void *)self->lex_modes);
ts_free((void *)self->parse_actions);
ts_free((void *)self->parse_table);
ts_free((void *)self->primary_state_ids);
ts_free((void *)self->public_symbol_map);
ts_free((void *)self->small_parse_table);
ts_free((void *)self->small_parse_table_map);
ts_free((void *)self->symbol_metadata);
ts_free((void *)self->symbol_names);
ts_free((void *)self);
}
}
#else
// If the WASM feature is not enabled, define dummy versions of all of the
@ -1556,4 +1597,12 @@ bool ts_language_is_wasm(const TSLanguage *self) {
return false;
}
void ts_wasm_language_retain(const TSLanguage *self) {
(void)self;
}
void ts_wasm_language_release(const TSLanguage *self) {
(void)self;
}
#endif

View file

@ -20,6 +20,9 @@ bool ts_wasm_store_call_scanner_scan(TSWasmStore *, uint32_t, uint32_t);
uint32_t ts_wasm_store_call_scanner_serialize(TSWasmStore *, uint32_t, char *);
void ts_wasm_store_call_scanner_deserialize(TSWasmStore *, uint32_t, const char *, unsigned);
void ts_wasm_language_retain(const TSLanguage *);
void ts_wasm_language_release(const TSLanguage *);
#ifdef __cplusplus
}
#endif