Merge pull request #2840 from tree-sitter/language-reference-count

Introduce APIs for managing the lifetimes of languages, allow WASM languages to be deleted
This commit is contained in:
Max Brunsfeld 2024-01-30 10:24:37 -08:00 committed by GitHub
commit 1d8975319c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 724 additions and 483 deletions

View file

@ -30,14 +30,14 @@ regex = "1.9.1"
[dependencies.wasmtime]
git = "https://github.com/bytecodealliance/wasmtime"
rev = "fa6fcd946b8f6d60c2d191a1b14b9399e261a76d"
rev = "v16.0.0"
optional = true
default-features = false
features = ["cranelift"]
[dependencies.wasmtime-c-api]
git = "https://github.com/bytecodealliance/wasmtime"
rev = "fa6fcd946b8f6d60c2d191a1b14b9399e261a76d"
rev = "v16.0.0"
optional = true
package = "wasmtime-c-api-impl"
default-features = false

View file

@ -664,6 +664,14 @@ extern "C" {
#[doc = " Set the maximum start depth for a query cursor.\n\n This prevents cursors from exploring children nodes at a certain depth.\n Note if a pattern includes many children, then they will still be checked.\n\n The zero max start depth value can be used as a special behavior and\n it helps to destructure a subtree by staying on a node and using captures\n for interested parts. Note that the zero max start depth only limit a search\n depth for a pattern's root node but other nodes that are parts of the pattern\n may be searched at any depth what defined by the pattern structure.\n\n Set to `UINT32_MAX` to remove the maximum start depth."]
pub fn ts_query_cursor_set_max_start_depth(self_: *mut TSQueryCursor, max_start_depth: u32);
}
extern "C" {
#[doc = " Get another reference to the given language."]
pub fn ts_language_copy(self_: *const TSLanguage) -> *const TSLanguage;
}
extern "C" {
#[doc = " Free any dynamically-allocated resources for this language, if\n this is the last reference."]
pub fn ts_language_delete(self_: *const TSLanguage);
}
extern "C" {
#[doc = " Get the number of distinct node types in the language."]
pub fn ts_language_symbol_count(self_: *const TSLanguage) -> u32;
@ -811,6 +819,10 @@ extern "C" {
error: *mut TSWasmError,
) -> *const TSLanguage;
}
extern "C" {
#[doc = " Get the number of languages instantiated in the given wasm store."]
pub fn ts_wasm_store_language_count(arg1: *const TSWasmStore) -> usize;
}
extern "C" {
#[doc = " Check if the language came from a Wasm module. If so, then in order to use\n this langauge with a Parser, that parser must have a Wasm store assigned."]
pub fn ts_language_is_wasm(arg1: *const TSLanguage) -> bool;

View file

@ -11,9 +11,9 @@ use std::{
ffi::CStr,
fmt, hash, iter,
marker::PhantomData,
mem::MaybeUninit,
mem::{self, MaybeUninit},
num::NonZeroU16,
ops,
ops::{self, Deref},
os::raw::{c_char, c_void},
ptr::{self, NonNull},
slice, str,
@ -47,10 +47,12 @@ pub const PARSER_HEADER: &'static str = include_str!("../include/tree_sitter/par
/// An opaque object that defines how to parse a particular language. The code for each
/// `Language` is generated by the Tree-sitter CLI.
#[doc(alias = "TSLanguage")]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[derive(Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct Language(*const ffi::TSLanguage);
pub struct LanguageRef<'a>(*const ffi::TSLanguage, PhantomData<&'a ()>);
/// A tree that represents the syntactic structure of a source code file.
#[doc(alias = "TSTree")]
pub struct Tree(NonNull<ffi::TSTree>);
@ -385,6 +387,26 @@ impl Language {
}
}
impl Clone for Language {
fn clone(&self) -> Self {
unsafe { Self(ffi::ts_language_copy(self.0)) }
}
}
impl Drop for Language {
fn drop(&mut self) {
unsafe { ffi::ts_language_delete(self.0) }
}
}
impl<'a> Deref for LanguageRef<'a> {
type Target = Language;
fn deref(&self) -> &Self::Target {
unsafe { mem::transmute(&self.0) }
}
}
impl Parser {
/// Create a new parser.
pub fn new() -> Parser {
@ -403,7 +425,7 @@ impl Parser {
/// and compare it to this library's [`LANGUAGE_VERSION`](LANGUAGE_VERSION) and
/// [`MIN_COMPATIBLE_LANGUAGE_VERSION`](MIN_COMPATIBLE_LANGUAGE_VERSION) constants.
#[doc(alias = "ts_parser_set_language")]
pub fn set_language(&mut self, language: Language) -> Result<(), LanguageError> {
pub fn set_language(&mut self, language: &Language) -> Result<(), LanguageError> {
let version = language.version();
if version < MIN_COMPATIBLE_LANGUAGE_VERSION || version > LANGUAGE_VERSION {
Err(LanguageError { version })
@ -766,8 +788,11 @@ impl Tree {
/// Get the language that was used to parse the syntax tree.
#[doc(alias = "ts_tree_language")]
pub fn language(&self) -> Language {
Language(unsafe { ffi::ts_tree_language(self.0.as_ptr()) })
pub fn language(&self) -> LanguageRef {
LanguageRef(
unsafe { ffi::ts_tree_language(self.0.as_ptr()) },
PhantomData,
)
}
/// Edit the syntax tree to keep it in sync with source code that has been
@ -894,8 +919,8 @@ impl<'tree> Node<'tree> {
/// Get the [`Language`] that was used to parse this node's syntax tree.
#[doc(alias = "ts_node_language")]
pub fn language(&self) -> Language {
Language(unsafe { ffi::ts_node_language(self.0) })
pub fn language(&self) -> LanguageRef {
LanguageRef(unsafe { ffi::ts_node_language(self.0) }, PhantomData)
}
/// Check if this node is *named*.
@ -1473,8 +1498,11 @@ impl Drop for TreeCursor<'_> {
impl LookaheadIterator {
/// Get the current language of the lookahead iterator.
#[doc(alias = "ts_lookahead_iterator_language")]
pub fn language(&self) -> Language {
Language(unsafe { ffi::ts_lookahead_iterator_language(self.0.as_ptr()) })
pub fn language(&self) -> LanguageRef<'_> {
LanguageRef(
unsafe { ffi::ts_lookahead_iterator_language(self.0.as_ptr()) },
PhantomData,
)
}
/// Get the current symbol of the lookahead iterator.
@ -1553,7 +1581,7 @@ impl Query {
/// The query is associated with a particular language, and can only be run
/// on syntax nodes parsed with that language. References to Queries can be
/// shared between multiple threads.
pub fn new(language: Language, source: &str) -> Result<Self, QueryError> {
pub fn new(language: &Language, source: &str) -> Result<Self, QueryError> {
let mut error_offset = 0u32;
let mut error_type: ffi::TSQueryError = 0;
let bytes = source.as_bytes();

View file

@ -1,7 +1,7 @@
use crate::{ffi, Language, LanguageError, Parser};
use crate::{ffi, Language, LanguageError, Parser, FREE_FN};
use std::{
error,
ffi::CString,
ffi::{CStr, CString},
fmt,
mem::{self, MaybeUninit},
os::raw::c_char,
@ -73,11 +73,16 @@ impl WasmStore {
}
}
}
pub fn language_count(&self) -> usize {
unsafe { ffi::ts_wasm_store_language_count(self.0) as usize }
}
}
impl WasmError {
unsafe fn new(error: ffi::TSWasmError) -> Self {
let message = CString::from_raw(error.message);
let message = CStr::from_ptr(error.message).to_str().unwrap().to_string();
(FREE_FN)(error.message as *mut _);
Self {
kind: match error.kind {
ffi::TSWasmErrorKindParse => WasmErrorKind::Parse,
@ -85,7 +90,7 @@ impl WasmError {
ffi::TSWasmErrorKindInstantiate => WasmErrorKind::Instantiate,
_ => WasmErrorKind::Other,
},
message: message.into_string().unwrap(),
message,
}
}
}

View file

@ -1013,6 +1013,17 @@ void ts_query_cursor_set_max_start_depth(TSQueryCursor *self, uint32_t max_start
/* Section - Language */
/**********************/
/**
* Get another reference to the given language.
*/
const TSLanguage *ts_language_copy(const TSLanguage *self);
/**
* Free any dynamically-allocated resources for this language, if
* this is the last reference.
*/
void ts_language_delete(const TSLanguage *self);
/**
* Get the number of distinct node types in the language.
*/
@ -1190,6 +1201,11 @@ const TSLanguage *ts_wasm_store_load_language(
TSWasmError *error
);
/**
* Get the number of languages instantiated in the given wasm store.
*/
size_t ts_wasm_store_language_count(const TSWasmStore *);
/**
* Check if the language came from a Wasm module. If so, then in order to use
* this langauge with a Parser, that parser must have a Wasm store assigned.

View file

@ -1,6 +1,21 @@
#include "./language.h"
#include "./wasm.h"
#include "tree_sitter/api.h"
#include <string.h>
const TSLanguage *ts_language_copy(const TSLanguage *self) {
if (self && ts_language_is_wasm(self)) {
ts_wasm_language_retain(self);
}
return self;
}
void ts_language_delete(const TSLanguage *self) {
if (self && ts_language_is_wasm(self)) {
ts_wasm_language_release(self);
}
}
uint32_t ts_language_symbol_count(const TSLanguage *self) {
return self->symbol_count + self->alias_count;
}

View file

@ -1869,6 +1869,7 @@ const TSLanguage *ts_parser_language(const TSParser *self) {
bool ts_parser_set_language(TSParser *self, const TSLanguage *language) {
ts_parser__external_scanner_destroy(self);
ts_language_delete(self->language);
self->language = NULL;
if (language) {
@ -1885,7 +1886,7 @@ bool ts_parser_set_language(TSParser *self, const TSLanguage *language) {
}
}
self->language = language;
self->language = ts_language_copy(language);
ts_parser__external_scanner_create(self);
ts_parser_reset(self);
return true;

View file

@ -2698,7 +2698,7 @@ TSQuery *ts_query_new(
.negated_fields = array_new(),
.repeat_symbols_with_rootless_patterns = array_new(),
.wildcard_root_pattern_count = 0,
.language = language,
.language = ts_language_copy(language),
};
array_push(&self->negated_fields, 0);
@ -2812,6 +2812,7 @@ void ts_query_delete(TSQuery *self) {
array_delete(&self->string_buffer);
array_delete(&self->negated_fields);
array_delete(&self->repeat_symbols_with_rootless_patterns);
ts_language_delete(self->language);
symbol_table_delete(&self->captures);
symbol_table_delete(&self->predicate_values);
for (uint32_t index = 0; index < self->capture_quantifiers.size; index++) {

View file

@ -12,7 +12,7 @@ TSTree *ts_tree_new(
) {
TSTree *result = ts_malloc(sizeof(TSTree));
result->root = root;
result->language = language;
result->language = ts_language_copy(language);
result->included_ranges = ts_calloc(included_range_count, sizeof(TSRange));
memcpy(result->included_ranges, included_ranges, included_range_count * sizeof(TSRange));
result->included_range_count = included_range_count;
@ -30,6 +30,7 @@ void ts_tree_delete(TSTree *self) {
SubtreePool pool = ts_subtree_pool_new(0);
ts_subtree_release(&pool, self->root);
ts_subtree_pool_delete(&pool);
ts_language_delete(self->language);
ts_free(self->included_ranges);
ts_free(self);
}

View file

@ -67,13 +67,22 @@ typedef struct {
uint32_t table_align;
} WasmDylinkInfo;
// WasmLanguageId - A pointer used to identify a language. This language id is
// reference-counted, so that its ownership can be shared between the language
// itself and the instances of the language that are held in wasm stores.
typedef struct {
volatile uint32_t ref_count;
volatile uint32_t is_language_deleted;
} WasmLanguageId;
// LanguageWasmModule - Additional data associated with a wasm-backed
// `TSLanguage`. This data is read-only and does not reference a particular
// wasm store, so it can be shared by all users of a `TSLanguage`. A pointer to
// this is stored on the language itself.
typedef struct {
volatile uint32_t ref_count;
WasmLanguageId *language_id;
wasmtime_module_t *module;
uint32_t language_id;
const char *name;
char *symbol_name_buffer;
char *field_name_buffer;
@ -84,7 +93,7 @@ typedef struct {
// a `TSLanguage` in a particular wasm store. The wasm store holds one of
// these structs for each language that it has instantiated.
typedef struct {
uint32_t language_id;
WasmLanguageId *language_id;
wasmtime_instance_t instance;
int32_t external_states_address;
int32_t lex_main_fn_index;
@ -471,6 +480,24 @@ static wasmtime_extern_t get_builtin_func_extern(
snprintf(*output, message_length + 1, __VA_ARGS__); \
} while (0)
WasmLanguageId *language_id_new() {
WasmLanguageId *self = ts_malloc(sizeof(WasmLanguageId));
self->is_language_deleted = false;
self->ref_count = 1;
return self;
}
WasmLanguageId *language_id_clone(WasmLanguageId *self) {
atomic_inc(&self->ref_count);
return self;
}
void language_id_delete(WasmLanguageId *self) {
if (atomic_dec(&self->ref_count) == 0) {
ts_free(self);
}
}
static bool ts_wasm_store__provide_builtin_import(
TSWasmStore *self,
const wasm_name_t *import_name,
@ -794,10 +821,25 @@ void ts_wasm_store_delete(TSWasmStore *self) {
wasm_globaltype_delete(self->var_i32_type);
wasmtime_store_delete(self->store);
wasm_engine_delete(self->engine);
for (unsigned i = 0; i < self->language_instances.size; i++) {
LanguageWasmInstance *instance = &self->language_instances.contents[i];
language_id_delete(instance->language_id);
}
array_delete(&self->language_instances);
ts_free(self);
}
size_t ts_wasm_store_language_count(const TSWasmStore *self) {
size_t result = 0;
for (unsigned i = 0; i < self->language_instances.size; i++) {
const WasmLanguageId *id = self->language_instances.contents[i].language_id;
if (!id->is_language_deleted) {
result++;
}
}
return result;
}
static bool ts_wasm_store__instantiate(
TSWasmStore *self,
wasmtime_module_t *module,
@ -1074,7 +1116,7 @@ const TSLanguage *ts_wasm_store_load_language(
};
uint32_t address_count = array_len(addresses);
TSLanguage *language = ts_malloc(sizeof(TSLanguage));
TSLanguage *language = ts_calloc(1, sizeof(TSLanguage));
StringData symbol_name_buffer = array_new();
StringData field_name_buffer = array_new();
@ -1196,12 +1238,13 @@ const TSLanguage *ts_wasm_store_load_language(
LanguageWasmModule *language_module = ts_malloc(sizeof(LanguageWasmModule));
*language_module = (LanguageWasmModule) {
.language_id = atomic_inc(&NEXT_LANGUAGE_ID),
.language_id = language_id_new(),
.module = module,
.name = name,
.symbol_name_buffer = symbol_name_buffer.contents,
.field_name_buffer = field_name_buffer.contents,
.dylink_info = dylink_info,
.ref_count = 1,
};
// The lex functions are not used for wasm languages. Use those two fields
@ -1210,10 +1253,19 @@ const TSLanguage *ts_wasm_store_load_language(
language->lex_fn = ts_wasm_store__sentinel_lex_fn;
language->keyword_lex_fn = (void *)language_module;
// Store some information about this store's specific instance of this
// language module, keyed by the language's id.
// Clear out any instances of languages that have been deleted.
for (unsigned i = 0; i < self->language_instances.size; i++) {
WasmLanguageId *id = self->language_instances.contents[i].language_id;
if (id->is_language_deleted) {
language_id_delete(id);
array_erase(&self->language_instances, i);
i--;
}
}
// Store this store's instance of this language module.
array_push(&self->language_instances, ((LanguageWasmInstance) {
.language_id = language_module->language_id,
.language_id = language_id_clone(language_module->language_id),
.instance = instance,
.external_states_address = wasm_language.external_scanner.states,
.lex_main_fn_index = wasm_language.lex_fn,
@ -1240,19 +1292,25 @@ bool ts_wasm_store_add_language(
wasmtime_context_t *context = wasmtime_store_context(self->store);
const LanguageWasmModule *language_module = (void *)language->keyword_lex_fn;
// Search for the information about this store's instance of the language module.
// Search for this store's instance of the language module. Also clear out any
// instances of languages that have been deleted.
bool exists = false;
array_search_sorted_by(
&self->language_instances,
.language_id,
language_module->language_id,
index,
&exists
);
for (unsigned i = 0; i < self->language_instances.size; i++) {
WasmLanguageId *id = self->language_instances.contents[i].language_id;
if (id->is_language_deleted) {
language_id_delete(id);
array_erase(&self->language_instances, i);
i--;
} else if (id == language_module->language_id) {
exists = true;
*index = i;
}
}
// If the language module has not been instantiated in this store, then add
// it to this store.
if (!exists) {
*index = self->language_instances.size;
char *message;
wasmtime_instance_t instance;
int32_t language_address;
@ -1272,8 +1330,8 @@ bool ts_wasm_store_add_language(
LanguageInWasmMemory wasm_language;
const uint8_t *memory = wasmtime_memory_data(context, &self->memory);
memcpy(&wasm_language, &memory[language_address], sizeof(LanguageInWasmMemory));
array_insert(&self->language_instances, *index, ((LanguageWasmInstance) {
.language_id = language_module->language_id,
array_push(&self->language_instances, ((LanguageWasmInstance) {
.language_id = language_id_clone(language_module->language_id),
.instance = instance,
.external_states_address = wasm_language.external_scanner.states,
.lex_main_fn_index = wasm_language.lex_fn,
@ -1468,6 +1526,50 @@ bool ts_language_is_wasm(const TSLanguage *self) {
return self->lex_fn == ts_wasm_store__sentinel_lex_fn;
}
static inline LanguageWasmModule *ts_language__wasm_module(const TSLanguage *self) {
return (LanguageWasmModule *)self->keyword_lex_fn;
}
void ts_wasm_language_retain(const TSLanguage *self) {
LanguageWasmModule *module = ts_language__wasm_module(self);
assert(module->ref_count > 0);
atomic_inc(&module->ref_count);
}
void ts_wasm_language_release(const TSLanguage *self) {
LanguageWasmModule *module = ts_language__wasm_module(self);
assert(module->ref_count > 0);
if (atomic_dec(&module->ref_count) == 0) {
// Update the language id to reflect that the language is deleted. This allows any wasm stores
// that hold wasm instances for this language to delete those instances.
atomic_inc(&module->language_id->is_language_deleted);
language_id_delete(module->language_id);
ts_free((void *)module->field_name_buffer);
ts_free((void *)module->symbol_name_buffer);
ts_free((void *)module->name);
wasmtime_module_delete(module->module);
ts_free(module);
ts_free((void *)self->alias_map);
ts_free((void *)self->alias_sequences);
ts_free((void *)self->external_scanner.symbol_map);
ts_free((void *)self->field_map_entries);
ts_free((void *)self->field_map_slices);
ts_free((void *)self->field_names);
ts_free((void *)self->lex_modes);
ts_free((void *)self->parse_actions);
ts_free((void *)self->parse_table);
ts_free((void *)self->primary_state_ids);
ts_free((void *)self->public_symbol_map);
ts_free((void *)self->small_parse_table);
ts_free((void *)self->small_parse_table_map);
ts_free((void *)self->symbol_metadata);
ts_free((void *)self->symbol_names);
ts_free((void *)self);
}
}
#else
// If the WASM feature is not enabled, define dummy versions of all of the
@ -1556,4 +1658,12 @@ bool ts_language_is_wasm(const TSLanguage *self) {
return false;
}
void ts_wasm_language_retain(const TSLanguage *self) {
(void)self;
}
void ts_wasm_language_release(const TSLanguage *self) {
(void)self;
}
#endif

View file

@ -20,6 +20,9 @@ bool ts_wasm_store_call_scanner_scan(TSWasmStore *, uint32_t, uint32_t);
uint32_t ts_wasm_store_call_scanner_serialize(TSWasmStore *, uint32_t, char *);
void ts_wasm_store_call_scanner_deserialize(TSWasmStore *, uint32_t, const char *, unsigned);
void ts_wasm_language_retain(const TSLanguage *);
void ts_wasm_language_release(const TSLanguage *);
#ifdef __cplusplus
}
#endif