diff --git a/include/tree_sitter/runtime.h b/include/tree_sitter/runtime.h index b483ba7f..e21571e8 100644 --- a/include/tree_sitter/runtime.h +++ b/include/tree_sitter/runtime.h @@ -86,6 +86,8 @@ void ts_parser_print_dot_graphs(TSParser *, FILE *); void ts_parser_halt_on_error(TSParser *, bool); TSTree *ts_parser_parse(TSParser *, const TSTree *, TSInput); TSTree *ts_parser_parse_string(TSParser *, const TSTree *, const char *, uint32_t); +bool ts_parser_enabled(TSParser *); +void ts_parser_set_enabled(TSParser *, bool); TSTree *ts_tree_copy(const TSTree *); void ts_tree_delete(TSTree *); diff --git a/src/runtime/parser.c b/src/runtime/parser.c index 3ec5b969..26a377d7 100644 --- a/src/runtime/parser.c +++ b/src/runtime/parser.c @@ -60,6 +60,7 @@ struct TSParser { FILE *dot_graph_file; bool halt_on_error; unsigned accept_count; + volatile bool enabled; }; typedef struct { @@ -705,6 +706,15 @@ static void ts_parser__start(TSParser *self, TSInput input, const Subtree *previ self->accept_count = 0; } +static void ts_parser__stop(TSParser *self) { + ts_stack_clear(self->stack); + ts_parser__set_cached_token(self, 0, NULL, NULL); + if (self->finished_tree) { + ts_subtree_release(&self->tree_pool, self->finished_tree); + self->finished_tree = NULL; + } +} + static void ts_parser__accept(TSParser *self, StackVersion version, const Subtree *lookahead) { assert(lookahead->symbol == ts_builtin_sym_end); ts_stack_push(self->stack, version, lookahead, false, 1); @@ -1308,6 +1318,7 @@ TSParser *ts_parser_new() { self->reusable_node = reusable_node_new(); self->dot_graph_file = NULL; self->halt_on_error = false; + self->enabled = true; ts_parser__set_cached_token(self, 0, NULL, NULL); return self; } @@ -1362,6 +1373,14 @@ void ts_parser_halt_on_error(TSParser *self, bool should_halt_on_error) { self->halt_on_error = should_halt_on_error; } +bool ts_parser_enabled(TSParser *self) { + return self->enabled; +} + +void ts_parser_set_enabled(TSParser *self, bool enabled) { + self->enabled = enabled; +} + TSTree *ts_parser_parse(TSParser *self, const TSTree *old_tree, TSInput input) { if (!self->language) return NULL; ts_parser__start(self, input, old_tree ? old_tree->root : NULL); @@ -1372,6 +1391,11 @@ TSTree *ts_parser_parse(TSParser *self, const TSTree *old_tree, TSInput input) { for (StackVersion version = 0; version_count = ts_stack_version_count(self->stack), version < version_count; version++) { + if (!self->enabled) { + ts_parser__stop(self); + return NULL; + } + bool allow_node_reuse = version_count == 1; while (ts_stack_is_active(self->stack, version)) { LOG("process version:%d, version_count:%u, state:%d, row:%u, col:%u", @@ -1404,10 +1428,13 @@ TSTree *ts_parser_parse(TSParser *self, const TSTree *old_tree, TSInput input) { ts_parser__set_cached_token(self, 0, NULL, NULL); ts_subtree_balance(self->finished_tree, &self->tree_pool, self->language); + TSTree *result = ts_tree_new(self->finished_tree, self->language); LOG("done"); LOG_TREE(); + self->finished_tree = NULL; - return ts_tree_new(self->finished_tree, self->language); + ts_parser__stop(self); + return result; } TSTree *ts_parser_parse_string(TSParser *self, const TSTree *old_tree, diff --git a/test/runtime/parser_test.cc b/test/runtime/parser_test.cc index 7d0b2d1d..4e787c76 100644 --- a/test/runtime/parser_test.cc +++ b/test/runtime/parser_test.cc @@ -1,4 +1,5 @@ #include "test_helper.h" +#include #include "runtime/alloc.h" #include "runtime/language.h" #include "helpers/record_alloc.h" @@ -610,6 +611,51 @@ describe("Parser", [&]() { }); }); }); + + describe("set_enabled(enabled)", [&]() { + it("stops the in-progress parse if false is passed", [&]() { + ts_parser_set_language(parser, load_real_language("json")); + AssertThat(ts_parser_enabled(parser), IsTrue()); + + auto tree_future = std::async([parser]() { + size_t read_count = 0; + TSInput infinite_input = { + &read_count, + [](void *payload, uint32_t *bytes_read) { + size_t *read_count = static_cast(payload); + assert((*read_count)++ < 100000); + *bytes_read = 1; + return "["; + }, + [](void *payload, unsigned byte, TSPoint position) -> int { + return true; + }, + TSInputEncodingUTF8 + }; + + return ts_parser_parse(parser, nullptr, infinite_input); + }); + + auto cancel_future = std::async([parser]() { + ts_parser_set_enabled(parser, false); + }); + + cancel_future.wait(); + tree_future.wait(); + AssertThat(ts_parser_enabled(parser), IsFalse()); + AssertThat(tree_future.get(), Equals(nullptr)); + + TSTree *tree = ts_parser_parse_string(parser, nullptr, "[]", 2); + AssertThat(ts_parser_enabled(parser), IsFalse()); + AssertThat(tree, Equals(nullptr)); + + ts_parser_set_enabled(parser, true); + AssertThat(ts_parser_enabled(parser), IsTrue()); + tree = ts_parser_parse_string(parser, nullptr, "[]", 2); + AssertThat(tree, !Equals(nullptr)); + ts_tree_delete(tree); + }); + }); }); END_TEST