From 4a8e4b19639f01a4faa677810ddecc114dbf91ba Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Wed, 27 Dec 2023 14:54:38 -0800 Subject: [PATCH] Allow wasm languages to be deleted --- cli/src/tests/language_test.rs | 2 +- cli/src/tests/query_test.rs | 8 +-- cli/src/tests/wasm_language_test.rs | 89 +++++++++++++++-------------- lib/binding_rust/lib.rs | 32 ++++++++--- lib/src/language.c | 12 +++- lib/src/parser.c | 3 +- lib/src/query.c | 1 + lib/src/tree.c | 3 +- lib/src/wasm.c | 51 ++++++++++++++++- lib/src/wasm.h | 3 + 10 files changed, 142 insertions(+), 62 deletions(-) diff --git a/cli/src/tests/language_test.rs b/cli/src/tests/language_test.rs index 1a69e491..726bcd5d 100644 --- a/cli/src/tests/language_test.rs +++ b/cli/src/tests/language_test.rs @@ -28,7 +28,7 @@ fn test_lookahead_iterator() { let expected_symbols = ["identifier", "block_comment", "line_comment"]; let mut lookahead = language.lookahead_iterator(next_state).unwrap(); - assert_eq!(lookahead.language(), language); + assert_eq!(*lookahead.language(), language); assert!(lookahead.iter_names().eq(expected_symbols)); lookahead.reset_state(next_state); diff --git a/cli/src/tests/query_test.rs b/cli/src/tests/query_test.rs index 53dd5d26..4b26a5d9 100644 --- a/cli/src/tests/query_test.rs +++ b/cli/src/tests/query_test.rs @@ -2115,7 +2115,7 @@ fn test_query_cursor_next_capture_with_byte_range() { allocations::record(|| { let language = get_language("python"); let query = Query::new( - language, + &language, "(function_definition name: (identifier) @function) (attribute attribute: (identifier) @property) ((identifier) @variable)", @@ -2128,7 +2128,7 @@ fn test_query_cursor_next_capture_with_byte_range() { // point_pos (0,0) (1,0) (1,5) (1,15) let mut parser = Parser::new(); - parser.set_language(language).unwrap(); + parser.set_language(&language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); @@ -2149,7 +2149,7 @@ fn test_query_cursor_next_capture_with_point_range() { allocations::record(|| { let language = get_language("python"); let query = Query::new( - language, + &language, "(function_definition name: (identifier) @function) (attribute attribute: (identifier) @property) ((identifier) @variable)", @@ -2162,7 +2162,7 @@ fn test_query_cursor_next_capture_with_point_range() { // point_pos (0,0) (1,0) (1,5) (1,15) let mut parser = Parser::new(); - parser.set_language(language).unwrap(); + parser.set_language(&language).unwrap(); let tree = parser.parse(source, None).unwrap(); let mut cursor = QueryCursor::new(); diff --git a/cli/src/tests/wasm_language_test.rs b/cli/src/tests/wasm_language_test.rs index 161c7d3a..d6930bb0 100644 --- a/cli/src/tests/wasm_language_test.rs +++ b/cli/src/tests/wasm_language_test.rs @@ -1,4 +1,4 @@ -use crate::tests::helpers::fixtures::WASM_DIR; +use crate::tests::helpers::{allocations, fixtures::WASM_DIR}; use lazy_static::lazy_static; use std::fs; use tree_sitter::{wasmtime::Engine, Parser, WasmError, WasmErrorKind, WasmStore}; @@ -9,56 +9,59 @@ lazy_static! { #[test] fn test_load_wasm_language() { - let mut store = WasmStore::new(ENGINE.clone()).unwrap(); - let mut parser = Parser::new(); + allocations::record(|| { + let mut store = WasmStore::new(ENGINE.clone()).unwrap(); + let mut parser = Parser::new(); - let wasm_cpp = fs::read(&WASM_DIR.join(format!("tree-sitter-cpp.wasm"))).unwrap(); - let wasm_rs = fs::read(&WASM_DIR.join(format!("tree-sitter-rust.wasm"))).unwrap(); - let wasm_rb = fs::read(&WASM_DIR.join(format!("tree-sitter-ruby.wasm"))).unwrap(); - let wasm_typescript = fs::read(&WASM_DIR.join(format!("tree-sitter-typescript.wasm"))).unwrap(); + let wasm_cpp = fs::read(&WASM_DIR.join(format!("tree-sitter-cpp.wasm"))).unwrap(); + let wasm_rs = fs::read(&WASM_DIR.join(format!("tree-sitter-rust.wasm"))).unwrap(); + let wasm_rb = fs::read(&WASM_DIR.join(format!("tree-sitter-ruby.wasm"))).unwrap(); + let wasm_typescript = + fs::read(&WASM_DIR.join(format!("tree-sitter-typescript.wasm"))).unwrap(); - let language_rust = store.load_language("rust", &wasm_rs).unwrap(); - let language_cpp = store.load_language("cpp", &wasm_cpp).unwrap(); - let language_ruby = store.load_language("ruby", &wasm_rb).unwrap(); - let language_typescript = store.load_language("typescript", &wasm_typescript).unwrap(); - parser.set_wasm_store(store).unwrap(); + let language_rust = store.load_language("rust", &wasm_rs).unwrap(); + let language_cpp = store.load_language("cpp", &wasm_cpp).unwrap(); + let language_ruby = store.load_language("ruby", &wasm_rb).unwrap(); + let language_typescript = store.load_language("typescript", &wasm_typescript).unwrap(); + parser.set_wasm_store(store).unwrap(); - let mut parser2 = Parser::new(); - parser2 - .set_wasm_store(WasmStore::new(ENGINE.clone()).unwrap()) - .unwrap(); + let mut parser2 = Parser::new(); + parser2 + .set_wasm_store(WasmStore::new(ENGINE.clone())) + .unwrap(); - for mut parser in [parser, parser2] { - for _ in 0..2 { - parser.set_language(language_cpp).unwrap(); - let tree = parser.parse("A c = d();", None).unwrap(); - assert_eq!( - tree.root_node().to_sexp(), - "(translation_unit (declaration type: (template_type name: (type_identifier) arguments: (template_argument_list (type_descriptor type: (type_identifier)))) declarator: (init_declarator declarator: (identifier) value: (call_expression function: (identifier) arguments: (argument_list)))))" - ); + for mut parser in [parser, parser2] { + for _ in 0..2 { + parser.set_language(&language_cpp).unwrap(); + let tree = parser.parse("A c = d();", None).unwrap(); + assert_eq!( + tree.root_node().to_sexp(), + "(translation_unit (declaration type: (template_type name: (type_identifier) arguments: (template_argument_list (type_descriptor type: (type_identifier)))) declarator: (init_declarator declarator: (identifier) value: (call_expression function: (identifier) arguments: (argument_list)))))" + ); - parser.set_language(language_rust).unwrap(); - let tree = parser.parse("const A: B = c();", None).unwrap(); - assert_eq!( - tree.root_node().to_sexp(), - "(source_file (const_item name: (identifier) type: (type_identifier) value: (call_expression function: (identifier) arguments: (arguments))))" - ); + parser.set_language(&language_rust).unwrap(); + let tree = parser.parse("const A: B = c();", None).unwrap(); + assert_eq!( + tree.root_node().to_sexp(), + "(source_file (const_item name: (identifier) type: (type_identifier) value: (call_expression function: (identifier) arguments: (arguments))))" + ); - parser.set_language(language_ruby).unwrap(); - let tree = parser.parse("class A; end", None).unwrap(); - assert_eq!( - tree.root_node().to_sexp(), - "(program (class name: (constant)))" - ); + parser.set_language(&language_ruby).unwrap(); + let tree = parser.parse("class A; end", None).unwrap(); + assert_eq!( + tree.root_node().to_sexp(), + "(program (class name: (constant)))" + ); - parser.set_language(language_typescript).unwrap(); - let tree = parser.parse("class A {}", None).unwrap(); - assert_eq!( - tree.root_node().to_sexp(), - "(program (class_declaration name: (type_identifier) body: (class_body)))" - ); + parser.set_language(&language_typescript).unwrap(); + let tree = parser.parse("class A {}", None).unwrap(); + assert_eq!( + tree.root_node().to_sexp(), + "(program (class_declaration name: (type_identifier) body: (class_body)))" + ); + } } - } + }); } #[test] diff --git a/lib/binding_rust/lib.rs b/lib/binding_rust/lib.rs index 359502b4..7c80d4fb 100644 --- a/lib/binding_rust/lib.rs +++ b/lib/binding_rust/lib.rs @@ -11,9 +11,9 @@ use std::{ ffi::CStr, fmt, hash, iter, marker::PhantomData, - mem::MaybeUninit, + mem::{self, MaybeUninit}, num::NonZeroU16, - ops, + ops::{self, Deref}, os::raw::{c_char, c_void}, ptr::{self, NonNull}, slice, str, @@ -51,6 +51,8 @@ pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/par #[repr(transparent)] pub struct Language(*const ffi::TSLanguage); +pub struct LanguageRef<'a>(*const ffi::TSLanguage, PhantomData<&'a ()>); + /// A tree that represents the syntactic structure of a source code file. #[doc(alias = "TSTree")] pub struct Tree(NonNull); @@ -397,6 +399,14 @@ impl Drop for Language { } } +impl<'a> Deref for LanguageRef<'a> { + type Target = Language; + + fn deref(&self) -> &Self::Target { + unsafe { mem::transmute(&self.0) } + } +} + impl Parser { /// Create a new parser. pub fn new() -> Parser { @@ -778,8 +788,11 @@ impl Tree { /// Get the language that was used to parse the syntax tree. #[doc(alias = "ts_tree_language")] - pub fn language(&self) -> Language { - Language(unsafe { ffi::ts_tree_language(self.0.as_ptr()) }) + pub fn language(&self) -> LanguageRef { + LanguageRef( + unsafe { ffi::ts_tree_language(self.0.as_ptr()) }, + PhantomData, + ) } /// Edit the syntax tree to keep it in sync with source code that has been @@ -906,8 +919,8 @@ impl<'tree> Node<'tree> { /// Get the [`Language`] that was used to parse this node's syntax tree. #[doc(alias = "ts_node_language")] - pub fn language(&self) -> Language { - Language(unsafe { ffi::ts_node_language(self.0) }) + pub fn language(&self) -> LanguageRef { + LanguageRef(unsafe { ffi::ts_node_language(self.0) }, PhantomData) } /// Check if this node is *named*. @@ -1485,8 +1498,11 @@ impl Drop for TreeCursor<'_> { impl LookaheadIterator { /// Get the current language of the lookahead iterator. #[doc(alias = "ts_lookahead_iterator_language")] - pub fn language(&self) -> Language { - Language(unsafe { ffi::ts_lookahead_iterator_language(self.0.as_ptr()) }) + pub fn language(&self) -> LanguageRef<'_> { + LanguageRef( + unsafe { ffi::ts_lookahead_iterator_language(self.0.as_ptr()) }, + PhantomData, + ) } /// Get the current symbol of the lookahead iterator. diff --git a/lib/src/language.c b/lib/src/language.c index 4842c037..f5ec6083 100644 --- a/lib/src/language.c +++ b/lib/src/language.c @@ -1,13 +1,19 @@ #include "./language.h" +#include "./wasm.h" +#include "tree_sitter/api.h" #include const TSLanguage *ts_language_copy(const TSLanguage *self) { - // TODO - increment reference count for wasm languages + if (self && ts_language_is_wasm(self)) { + ts_wasm_language_retain(self); + } return self; } -void ts_language_delete(const TSLanguage *_self) { - // TODO - decrement reference count for wasm languages +void ts_language_delete(const TSLanguage *self) { + if (self && ts_language_is_wasm(self)) { + ts_wasm_language_release(self); + } } uint32_t ts_language_symbol_count(const TSLanguage *self) { diff --git a/lib/src/parser.c b/lib/src/parser.c index 39ad71b8..a94b17d0 100644 --- a/lib/src/parser.c +++ b/lib/src/parser.c @@ -1869,6 +1869,7 @@ const TSLanguage *ts_parser_language(const TSParser *self) { bool ts_parser_set_language(TSParser *self, const TSLanguage *language) { ts_parser__external_scanner_destroy(self); + ts_language_delete(self->language); self->language = NULL; if (language) { @@ -1885,7 +1886,7 @@ bool ts_parser_set_language(TSParser *self, const TSLanguage *language) { } } - self->language = language; + self->language = ts_language_copy(language); ts_parser__external_scanner_create(self); ts_parser_reset(self); return true; diff --git a/lib/src/query.c b/lib/src/query.c index 6da2be27..5f869c8d 100644 --- a/lib/src/query.c +++ b/lib/src/query.c @@ -2812,6 +2812,7 @@ void ts_query_delete(TSQuery *self) { array_delete(&self->string_buffer); array_delete(&self->negated_fields); array_delete(&self->repeat_symbols_with_rootless_patterns); + ts_language_delete(self->language); symbol_table_delete(&self->captures); symbol_table_delete(&self->predicate_values); for (uint32_t index = 0; index < self->capture_quantifiers.size; index++) { diff --git a/lib/src/tree.c b/lib/src/tree.c index 784c51fd..135e3923 100644 --- a/lib/src/tree.c +++ b/lib/src/tree.c @@ -12,7 +12,7 @@ TSTree *ts_tree_new( ) { TSTree *result = ts_malloc(sizeof(TSTree)); result->root = root; - result->language = language; + result->language = ts_language_copy(language); result->included_ranges = ts_calloc(included_range_count, sizeof(TSRange)); memcpy(result->included_ranges, included_ranges, included_range_count * sizeof(TSRange)); result->included_range_count = included_range_count; @@ -30,6 +30,7 @@ void ts_tree_delete(TSTree *self) { SubtreePool pool = ts_subtree_pool_new(0); ts_subtree_release(&pool, self->root); ts_subtree_pool_delete(&pool); + ts_language_delete(self->language); ts_free(self->included_ranges); ts_free(self); } diff --git a/lib/src/wasm.c b/lib/src/wasm.c index e30318e9..cc1b1f25 100644 --- a/lib/src/wasm.c +++ b/lib/src/wasm.c @@ -72,6 +72,7 @@ typedef struct { // wasm store, so it can be shared by all users of a `TSLanguage`. A pointer to // this is stored on the language itself. typedef struct { + volatile uint32_t ref_count; wasmtime_module_t *module; uint32_t language_id; const char *name; @@ -1074,7 +1075,7 @@ const TSLanguage *ts_wasm_store_load_language( }; uint32_t address_count = array_len(addresses); - TSLanguage *language = ts_malloc(sizeof(TSLanguage)); + TSLanguage *language = ts_calloc(1, sizeof(TSLanguage)); StringData symbol_name_buffer = array_new(); StringData field_name_buffer = array_new(); @@ -1202,6 +1203,7 @@ const TSLanguage *ts_wasm_store_load_language( .symbol_name_buffer = symbol_name_buffer.contents, .field_name_buffer = field_name_buffer.contents, .dylink_info = dylink_info, + .ref_count = 1, }; // The lex functions are not used for wasm languages. Use those two fields @@ -1468,6 +1470,45 @@ bool ts_language_is_wasm(const TSLanguage *self) { return self->lex_fn == ts_wasm_store__sentinel_lex_fn; } +static inline LanguageWasmModule *ts_language__wasm_module(const TSLanguage *self) { + return (LanguageWasmModule *)self->keyword_lex_fn; +} + +void ts_wasm_language_retain(const TSLanguage *self) { + LanguageWasmModule *module = ts_language__wasm_module(self); + assert(module->ref_count > 0); + atomic_inc(&module->ref_count); +} + +void ts_wasm_language_release(const TSLanguage *self) { + LanguageWasmModule *module = ts_language__wasm_module(self); + assert(module->ref_count > 0); + if (atomic_dec(&module->ref_count) == 0) { + ts_free((void *)module->field_name_buffer); + ts_free((void *)module->symbol_name_buffer); + ts_free((void *)module->name); + wasmtime_module_delete(module->module); + ts_free(module); + + ts_free((void *)self->alias_map); + ts_free((void *)self->alias_sequences); + ts_free((void *)self->external_scanner.symbol_map); + ts_free((void *)self->field_map_entries); + ts_free((void *)self->field_map_slices); + ts_free((void *)self->field_names); + ts_free((void *)self->lex_modes); + ts_free((void *)self->parse_actions); + ts_free((void *)self->parse_table); + ts_free((void *)self->primary_state_ids); + ts_free((void *)self->public_symbol_map); + ts_free((void *)self->small_parse_table); + ts_free((void *)self->small_parse_table_map); + ts_free((void *)self->symbol_metadata); + ts_free((void *)self->symbol_names); + ts_free((void *)self); + } +} + #else // If the WASM feature is not enabled, define dummy versions of all of the @@ -1556,4 +1597,12 @@ bool ts_language_is_wasm(const TSLanguage *self) { return false; } +void ts_wasm_language_retain(const TSLanguage *self) { + (void)self; +} + +void ts_wasm_language_release(const TSLanguage *self) { + (void)self; +} + #endif diff --git a/lib/src/wasm.h b/lib/src/wasm.h index 0e734e82..849a3c1c 100644 --- a/lib/src/wasm.h +++ b/lib/src/wasm.h @@ -20,6 +20,9 @@ bool ts_wasm_store_call_scanner_scan(TSWasmStore *, uint32_t, uint32_t); uint32_t ts_wasm_store_call_scanner_serialize(TSWasmStore *, uint32_t, char *); void ts_wasm_store_call_scanner_deserialize(TSWasmStore *, uint32_t, const char *, unsigned); +void ts_wasm_language_retain(const TSLanguage *); +void ts_wasm_language_release(const TSLanguage *); + #ifdef __cplusplus } #endif